Skip to content

Commit

Permalink
Try adding closure functionality in code gen
Browse files Browse the repository at this point in the history
  • Loading branch information
AntonReinhard committed Oct 3, 2024
1 parent bb191d6 commit 41a1bcd
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 30 deletions.
7 changes: 5 additions & 2 deletions ext/devices/cuda/function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@ function ComputableDAGs.kernel(

init_caches = Expr(:block, tape.initCachesCode...)
assign_inputs = Expr(:block, ComputableDAGs.expr_from_fc.(tape.inputAssignCode)...)
# TODO: use gen_function_body here
code = Expr(:block, ComputableDAGs.expr_from_fc.(tape.computeCode)...)

function_id = ComputableDAGs.to_var_name(UUIDs.uuid1(ComputableDAGs.rng[1]))
res_sym = eval(
ComputableDAGs.gen_access_expr(
ComputableDAGs.entry_device(tape.machine), tape.outputSymbol
ComputableDAGs._gen_access_expr(
ComputableDAGs.entry_device(tape.machine),
ComputableDAGs.entry_device(tape.machine).cacheStrategy,
tape.outputSymbol,
),
)
expr = Meta.parse(
Expand Down
10 changes: 7 additions & 3 deletions ext/devices/rocm/function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,22 @@ function ComputableDAGs.kernel(

init_caches = Expr(:block, tape.initCachesCode...)
assign_inputs = Expr(:block, ComputableDAGs.expr_from_fc.(tape.inputAssignCode)...)

# TODO use gen_function_body here
code = Expr(:block, ComputableDAGs.expr_from_fc.(tape.computeCode)...)

function_id = ComputableDAGs.to_var_name(UUIDs.uuid1(ComputableDAGs.rng[1]))
res_sym = eval(
ComputableDAGs.gen_access_expr(
ComputableDAGs.entry_device(tape.machine), tape.outputSymbol
ComputableDAGs._gen_access_expr(
ComputableDAGs.entry_device(tape.machine),
ComputableDAGs.entry_device(tape.machine).cacheStrategy,
tape.outputSymbol,
),
)
expr = Meta.parse(
"function compute_$(function_id)(input_vector, output_vector, n::Int64)
id = (workgroupIdx().x - 1) * workgroupDim().x + workgroupIdx().x
if (id > n)
if (id > n)
return
end
@inline data_input = input_vector[id]
Expand Down
16 changes: 13 additions & 3 deletions src/code_gen/function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,28 @@ using RuntimeGeneratedFunctions
RuntimeGeneratedFunctions.init(@__MODULE__)
```
in your top level.
## Keyword Arguments
`closures_size` (default=500): The size of closures to use in the main generated code. This specifies the size of code blocks across which the compiler cannot optimize. For sufficiently large functions, a larger value means longer compile times but potentially faster execution time.
"""
function get_compute_function(
graph::DAG, instance, machine::Machine, context_module::Module
graph::DAG, instance, machine::Machine, context_module::Module; closures_size=500
)
tape = gen_tape(graph, instance, machine, context_module)

initCaches = Expr(:block, tape.initCachesCode...)
assignInputs = Expr(:block, expr_from_fc.(tape.inputAssignCode)...)
code = Expr(:block, expr_from_fc.(tape.computeCode)...)
code = gen_function_body(tape.computeCode; closures_size=closures_size)

functionId = to_var_name(UUIDs.uuid1(rng[1]))
resSym = eval(gen_access_expr(entry_device(tape.machine), tape.outputSymbol))
resSym = eval(
_gen_access_expr(
entry_device(tape.machine),
entry_device(tape.machine).cacheStrategy,
tape.outputSymbol,
),
)
expr = #
Expr(
:function, # function definition
Expand Down
54 changes: 50 additions & 4 deletions src/code_gen/tape_machine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,13 @@ end

function expr_from_fc(fc::FunctionCall{VectorT,0}) where {VectorT}
func_call = Expr(
:call, fc.func, eval.(gen_access_expr.(Ref(fc.device), fc.arguments))...
:call,
fc.func,
eval.(
_gen_access_expr.(Ref(fc.device), Ref(fc.device.cacheStrategy), fc.arguments)
)...,
)
access_expr = eval(gen_access_expr(fc.device, fc.return_symbol))
access_expr = eval(gen_access_expr(fc))

return Expr(:(=), access_expr, func_call)
end
Expand All @@ -69,9 +73,11 @@ function expr_from_fc(fc::FunctionCall{VectorT,M}) where {VectorT,M}
:call,
fc.func,
fc.value_arguments...,
eval.(gen_access_expr.(Ref(fc.device), fc.arguments))...,
eval.(
_gen_access_expr.(Ref(fc.device), Ref(fc.device.cacheStrategy), fc.arguments)
)...,
)
access_expr = eval(gen_access_expr(fc.device, fc.return_symbol))
access_expr = eval(gen_access_expr(fc))

return Expr(:(=), access_expr, func_call)
end
Expand Down Expand Up @@ -133,6 +139,46 @@ function gen_input_assignment_code(
return assign_inputs
end

"""
gen_function_body(fc_vec::Vector{FunctionCall}; closures_size)
Generate the function body from the given `Vector` of [`FunctionCall`](@ref)s.
## Keyword Arguments
`closures_size`: The size of closures to generate (in lines of code). Closures introduce function barriers in the function body, preventing some optimizations by the compiler and therefore greatly reducing compile time. A value of 1 or less will disable the use of closures entirely.
"""
function gen_function_body(fc_vec::Vector{FunctionCall}; closures_size::Int)
if (closures_size <= 1)
return Expr(:block, expr_from_fc.(fc_vec)...)
end

closures = Vector{Expr}()
for i in 1:closures_size:length(fc_vec)
code_block = fc_vec[i:min(i + closures_size, length(fc_vec))]

# collect `local var` statements that need to exist before the closure starts
# since the return symbols are always unique, this has to happen for each fc and there will be no duplicates
local_inits = gen_local_init.(code_block)

closure = Expr( # call to the following closure (no arguments)
:call,
Expr( # create the closure: () -> code block; return nothing
:->,
:(),
Expr(# # actual function body of the closure
:block,
expr_from_fc.(code_block)...,
Expr(:return, :nothing),
),
),
)
# combine to one closure call, including all the local inits and the actual call to the closure
push!(closures, Expr(:block, local_inits..., closure))
end

return Expr(:block, closures...)
end

"""
gen_tape(
graph::DAG,
Expand Down
18 changes: 18 additions & 0 deletions src/devices/impl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,21 @@ function cpu_st()
[NumaNode(0, 1, default_strategy(NumaNode), -1.0, UUIDs.uuid1())], [-1.0;;]
)
end

"""
gen_access_expr(fc::FunctionCall)
Dispatch from the given [`FunctionCall`](@ref) to the interface function `_gen_access_expr`(@ref).
"""
function gen_access_expr(fc::FunctionCall)
return _gen_access_expr(fc.device, fc.device.cacheStrategy, fc.return_symbol)
end

"""
gen_local_init(fc::FunctionCall)
Dispatch from the given [`FunctionCall`](@ref) to the interface function `_gen_local_init`(@ref).
"""
function gen_local_init(fc::FunctionCall)
return _gen_local_init(fc, fc.device, fc.device.cacheStrategy)
end
13 changes: 11 additions & 2 deletions src/devices/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,21 @@ The strategy is a symbol
function gen_cache_init_code end

"""
gen_access_expr(device::AbstractDevice, symbol::Symbol)
_gen_access_expr(device::AbstractDevice, cache_strategy::CacheStrategy, symbol::Symbol)
Interface function that must be implemented for every subtype of [`AbstractDevice`](@ref) and at least one [`CacheStrategy`](@ref).
Return an `Expr` or `QuoteNode` accessing the variable identified by [`symbol`].
"""
function gen_access_expr end
function _gen_access_expr end

"""
_gen_local_init(fc::FunctionCall, device::AbstractDevice, cache_strategy::CacheStrategy)
Interface function that must be implemented for every subtype of [`AbstractDevice`](@ref) and at least one [`CacheStrategy`](@ref).
Return an `Expr` or `QuoteNode` that initializes the access expression returned by [`_gen_access_expr`](@ref) in the local scope.
This expression may be empty. For local variables it should be `local <variable_name>::<Type>`.
"""
function _gen_local_init end

"""
kernel(gpu_type::Type{<:AbstractGPU}, graph::DAG, instance)
Expand Down
45 changes: 29 additions & 16 deletions src/devices/numa/impl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,32 +64,45 @@ function gen_cache_init_code(device::NumaNode)
end

"""
gen_access_expr(device::NumaNode, symbol::Symbol)
_gen_access_expr(device::NumaNode, ::LocalVariables, symbol::Symbol)
Generate code to access the variable designated by `symbol` on a [`NumaNode`](@ref), using the [`CacheStrategy`](@ref) set in the device.
Interface implementation, dispatched to from [`gen_access_expr`](@ref).
"""
function gen_access_expr(device::NumaNode, symbol::Symbol)
return _gen_access_expr(device, device.cacheStrategy, symbol)
function _gen_access_expr(::NumaNode, ::LocalVariables, symbol::Symbol)
# TODO rewrite these with Expr instead of quote node
s = Symbol("data_$symbol")
quote_node = Meta.parse(":($s)")
return quote_node
end

"""
_gen_access_expr(device::NumaNode, ::LocalVariables, symbol::Symbol)
_gen_access_expr(device::NumaNode, ::Dictionary, symbol::Symbol)
Internal function for dispatch, used in [`gen_access_expr`](@ref).
Interface implementation, dispatched to from [`gen_access_expr`](@ref).
"""
function _gen_access_expr(device::NumaNode, ::LocalVariables, symbol::Symbol)
s = Symbol("data_$symbol")
quoteNode = Meta.parse(":($s)")
return quoteNode
function _gen_access_expr(device::NumaNode, ::Dictionary, symbol::Symbol)
# TODO rewrite these with Expr instead of quote node
access_str = ":(cache_$(to_var_name(device.id))[:$symbol])"
quote_node = Meta.parse(access_str)
return quote_node
end

"""
_gen_access_expr(device::NumaNode, ::Dictionary, symbol::Symbol)
_gen_local_init(fc::FunctionCall, device::NumaNode, cache_strategy::LocalVariables)
Internal function for dispatch, used in [`gen_access_expr`](@ref).
Interface implementation, dispatched to from [`gen_local_init`](@ref).
"""
function _gen_access_expr(device::NumaNode, ::Dictionary, symbol::Symbol)
accessStr = ":(cache_$(to_var_name(device.id))[:$symbol])"
quoteNode = Meta.parse(accessStr)
return quoteNode
function _gen_local_init(fc::FunctionCall, ::NumaNode, ::LocalVariables)
s = Symbol("data_$(fc.return_symbol)")
quote_node = Expr(:local, s) # TODO: figure out how to get type info for this local variable
return quote_node
end

"""
_gen_local_init(fc::FunctionCall, device::NumaNode, cache_strategy::Dictionary)
Interface implementation, dispatched to from [`gen_local_init`](@ref).
"""
function _gen_local_init(::FunctionCall, ::NumaNode, ::Dictionary)
return Exp()
end

0 comments on commit 41a1bcd

Please sign in to comment.