Skip to content

Commit

Permalink
Remove superfluous _gen_access_expr and simplify _gen_local_init
Browse files Browse the repository at this point in the history
  • Loading branch information
AntonReinhard committed Nov 27, 2024
1 parent e76461f commit 68a7945
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 66 deletions.
7 changes: 1 addition & 6 deletions ext/devices/cuda/function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,6 @@ function ComputableDAGs.kernel(
code = Expr(:block, ComputableDAGs.expr_from_fc.(tape.schedule)...)

function_id = ComputableDAGs.to_var_name(UUIDs.uuid1(ComputableDAGs.rng[1]))
res_sym = eval(
ComputableDAGs._gen_access_expr(
ComputableDAGs.entry_device(tape.machine), tape.output_symbol
),
)
expr = Meta.parse(
"function compute_$(function_id)(input_vector, output_vector, n::Int64)
id = (blockIdx().x - 1) * blockDim().x + threadIdx().x
Expand All @@ -23,7 +18,7 @@ function ComputableDAGs.kernel(
@inline input = input_vector[id]
$(assign_inputs)
$code
@inline output_vector[id] = $res_sym
@inline output_vector[id] = $(tape.output_symbol)
return nothing
end"
)
Expand Down
7 changes: 1 addition & 6 deletions ext/devices/rocm/function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,6 @@ function ComputableDAGs.kernel(
code = Expr(:block, ComputableDAGs.expr_from_fc.(tape.schedule)...)

function_id = ComputableDAGs.to_var_name(UUIDs.uuid1(ComputableDAGs.rng[1]))
res_sym = eval(
ComputableDAGs._gen_access_expr(
ComputableDAGs.entry_device(tape.machine), tape.output_symbol
),
)
expr = Meta.parse(
"function compute_$(function_id)(input_vector, output_vector, n::Int64)
id = (workgroupIdx().x - 1) * workgroupDim().x + workgroupIdx().x
Expand All @@ -24,7 +19,7 @@ function ComputableDAGs.kernel(
@inline input = input_vector[id]
$(assign_inputs)
$code
@inline output_vector[id] = $res_sym
@inline output_vector[id] = $(tape.output_symbol)
return nothing
end"
)
Expand Down
7 changes: 5 additions & 2 deletions src/code_gen/function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ in your top level.
## Keyword Arguments
`closures_size` (default=0 (off)): The size of closures to use in the main generated code. This specifies the size of code blocks across which the compiler cannot optimize. For sufficiently large functions, a larger value means longer compile times but potentially faster execution time.
`closures_size` (default=0 (off)): The size of closures to use in the main generated code. This specifies the size of code blocks across which the
compiler cannot optimize. For sufficiently large functions, a larger value means longer compile times but potentially faster execution time.
**Note** that the actually used closure size might be different than the one passed here, since the function automatically chooses a size that
is close to a n-th root of the total number of loc, based off the given size.
"""
function get_compute_function(
graph::DAG, instance, machine::Machine, context_module::Module; closures_size=0
Expand All @@ -28,7 +31,7 @@ function get_compute_function(
code = gen_function_body(tape, context_module; closures_size=closures_size)

function_id = to_var_name(UUIDs.uuid1(rng[1]))
res_sym = _gen_access_expr(entry_device(tape.machine), tape.output_symbol)
res_sym = tape.output_symbol
expr = #
Expr(
:function, # function definition
Expand Down
7 changes: 1 addition & 6 deletions src/code_gen/tape_machine.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
function expr_from_fc(fc::FunctionCall{VAL_T,F_T}) where {VAL_T,F_T<:Function}
if length(fc) == 1
func_call = Expr(
:call,
fc.func,
fc.value_arguments[1]...,
_gen_access_expr.(Ref(fc.device), fc.arguments[1])...,
)
func_call = Expr(:call, fc.func, fc.value_arguments[1]..., fc.arguments[1]...)
else
# TBW; dispatch to device specific vectorization
throw("unimplemented")
Expand Down
27 changes: 21 additions & 6 deletions src/devices/impl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@ end
"""
gen_access_expr(fc::FunctionCall)
Dispatch from the given [`FunctionCall`](@ref) to the interface function [`_gen_access_expr`](@ref).
Return the
"""
function gen_access_expr(fc::FunctionCall{VAL_T}) where {VAL_T}
if length(fc.return_types) != 1
# general case
vec = Expr[]
for ret_symbols in fc.return_symbols
push!(vec, unroll_symbol_vector(_gen_access_expr.(Ref(fc.device), ret_symbols)))
push!(vec, unroll_symbol_vector(ret_symbols))
end
if length(vec) > 1
return unroll_symbol_vector(vec)
Expand All @@ -47,10 +47,10 @@ function gen_access_expr(fc::FunctionCall{VAL_T}) where {VAL_T}
end
end

# no vectorization case
# single return value per function
vec = Symbol[]
for ret_symbols in fc.return_symbols
push!(vec, _gen_access_expr.(Ref(fc.device), ret_symbols[1]))
push!(vec, ret_symbols[1])
end
if length(vec) > 1
return unroll_symbol_vector(vec)
Expand All @@ -62,15 +62,30 @@ end
"""
gen_local_init(fc::FunctionCall)
Dispatch from the given [`FunctionCall`](@ref) to the interface function [`_gen_local_init`](@ref).
Dispatch from the given [`FunctionCall`](@ref) to the lower-level function [`_gen_local_init`](@ref).
!!! note
This is currently unused, but may become useful in the future again.
"""
function gen_local_init(fc::FunctionCall)
return Expr(
:block,
_gen_local_init.(
Ref(fc.device),
Iterators.flatten(fc.return_symbols),
Iterators.cycle(fc.return_types, length(fc.return_symbols)),
)...,
)
end

"""
_gen_local_init(symbol::Symbol, type::Type)
Return an `Expr` that initializes the symbol in the local scope.
The result looks like `local <symbol>::<type>`.
!!! note
This is currently unused, but may become useful in the future again.
"""
function _gen_local_init(symbol::Symbol, type::Type)
return Expr(:local, symbol, :(::), Symbol(type))
end
17 changes: 0 additions & 17 deletions src/devices/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,23 +46,6 @@ Interface function that must be implemented for every subtype of [`AbstractDevic
"""
function measure_device! end

"""
_gen_access_expr(device::AbstractDevice, symbol::Symbol)
Interface function that must be implemented for every subtype of [`AbstractDevice`](@ref).
Return an `Expr` or `QuoteNode` accessing the variable identified by [`symbol`].
"""
function _gen_access_expr end

"""
_gen_local_init(device::AbstractDevice, symbol::Symbol, type::Type)
Interface function that must be implemented for every subtype of [`AbstractDevice`](@ref).
Return an `Expr` or `QuoteNode` that initializes the access expression returned by [`_gen_access_expr`](@ref) in the local scope.
This expression may be empty. For local variables it should be `local <variable_name>::<Type>`.
"""
function _gen_local_init end

"""
kernel(gpu_type::Type{<:AbstractGPU}, graph::DAG, instance)
Expand Down
23 changes: 0 additions & 23 deletions src/devices/numa/impl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,26 +38,3 @@ function get_devices(deviceType::Type{T}; verbose::Bool=false) where {T<:NumaNod

return devices
end

"""
_gen_access_expr(device::NumaNode, symbol::Symbol)
Interface implementation, dispatched to from [`gen_access_expr`](@ref).
"""
function _gen_access_expr(::NumaNode, symbol::Symbol)
# TODO rewrite these with Expr instead of quote node
#=s = Symbol("data_$symbol")
quote_node = Meta.parse(":($s)")=#
return symbol
end

"""
_gen_local_init(device::NumaNode, symbol::Symbol, type::Type)
Interface implementation, dispatched to from [`gen_local_init`](@ref).
"""
function _gen_local_init(::NumaNode, symbol::Symbol, type::Type)
#s = Symbol("data_$(symbol)")
quote_node = Expr(:local, symbol, :(::), Symbol(type))
return quote_node
end

0 comments on commit 68a7945

Please sign in to comment.