Skip to content

Commit

Permalink
Work on performance problems with closures
Browse files Browse the repository at this point in the history
(again...)
  • Loading branch information
AntonReinhard committed Nov 29, 2024
1 parent e6c8fc0 commit 9290549
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 27 deletions.
5 changes: 2 additions & 3 deletions src/code_gen/function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,12 @@ function get_compute_function(
)
tape = gen_tape(graph, instance, machine, context_module)

assign_inputs = Expr(:block, expr_from_fc.(tape.input_assign_code)...)
code = gen_function_body(tape, context_module; closures_size=closures_size)
assign_inputs = Expr(:block, expr_from_fc.(tape.input_assign_code)...)

function_id = to_var_name(UUIDs.uuid1(rng[1]))
res_sym = tape.output_symbol
expr = #
Expr(
expr = Expr(
:function, # function definition
Expr(
:call, Symbol("compute_$function_id"), Expr(:(::), :input, input_type(instance))
Expand Down
41 changes: 23 additions & 18 deletions src/code_gen/tape_machine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,36 @@ end
function expr_from_fc(fc::FunctionCall{VAL_T,Expr}) where {VAL_T}
@assert length(fc) == 1 && isempty(fc.value_arguments[1]) "function call assigning an expression cannot be vectorized and cannot contain value arguments\n$fc"

fc_expr_in_let = Expr(
:let,
Expr(:block, fc.return_symbols[1]...),
fc.func, # anonymous function code block
fc_expr = Expr(
:block,
gen_local_init(fc),
fc.func, # anonymous function code block
Expr( # return the symbols
:return,
(
if length(fc.return_symbols[1]) == 1
fc.return_symbols[1][1]
else
Expr(:tuple, fc.return_symbols[1]...)
end
),
),
)

func_call = Expr(
:call, # call
#wrap_in_let_statement( # wrap in let statement to prevent boxing of local variables
Expr(
:->, # anonymous function
Expr(
:tuple, # anonymous function arguments
fc.arguments[1]...,
#fc.arguments[1]...,
),
fc_expr_in_let,
fc_expr, # anonymous function code block
),
fc.arguments[1]..., # runtime arguments passed to the anonymous function
#fc.arguments[1],
#),
#fc.arguments[1]..., # runtime arguments passed to the anonymous function call
)

access_expr = gen_access_expr(fc)
Expand All @@ -55,7 +68,7 @@ function gen_input_assignment_code(
device = entry_device(machine)

fc = FunctionCall(
input_expr(instance, name, :input),
Expr(:(=), symbol, input_expr(instance, name, :input)),
(),
Symbol[:input],
Symbol[symbol],
Expand Down Expand Up @@ -110,7 +123,7 @@ function _gen_function_body(
end

# iterate from end to beginning
# this helps because we can collect all undefined arguments to the closures that have to be returned somewhere earlier
# this helps because we can collect all undefined arguments to the closures that have to be defined somewhere earlier
undefined_argument_symbols = Set{Symbol}()
# the final return symbol is the return of the entire generated function, it always has to be returned
push!(undefined_argument_symbols, gen_access_expr(fc_vec[end]))
Expand Down Expand Up @@ -183,15 +196,7 @@ function _closure_fc(

fc_expr = Expr( # actual function body of the closure
:block,
expr_from_fc.(code_block)...,
Expr(
:return, # have to make sure to not return a tuple of length 1
if length(ret_symbols_t) == 1
ret_symbols_t[1]
else
Expr(:tuple, ret_symbols_t...)
end,
),
expr_from_fc.(code_block)..., # no return statement necessary, will be done via capture and local init
)

fc = FunctionCall(
Expand Down
1 change: 1 addition & 0 deletions src/code_gen/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ function infer_types!(tape::Tape, context_module::Module)

for fc in tape.input_assign_code
res_types = result_types(fc, known_result_types, context_module)
fc.return_types = res_types
for (s, t) in Iterators.zip(
Iterators.flatten(fc.return_symbols),
Iterators.cycle(res_types, length(fc.return_symbols)),
Expand Down
25 changes: 19 additions & 6 deletions src/devices/impl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,6 @@ end
gen_local_init(fc::FunctionCall)
Dispatch from the given [`FunctionCall`](@ref) to the lower-level function [`_gen_local_init`](@ref).
!!! note
This is currently unused, but may become useful in the future again.
"""
function gen_local_init(fc::FunctionCall)
return Expr(
Expand All @@ -82,10 +79,26 @@ end
Return an `Expr` that initializes the symbol in the local scope.
The result looks like `local <symbol>::<type>`.
!!! note
This is currently unused, but may become useful in the future again.
"""
function _gen_local_init(symbol::Symbol, type::Type)
return Expr(:local, symbol, :(::), Symbol(type))
end

"""
wrap_in_let_statement(expr, symbols)
For a given expression and a collection of symbols, generate a let statement that wraps the expression in a let statement with all the symbols, like
`let <symbol[1]>=<symbol[1]>, ..., <symbol[end]>=<symbol[end]> <expr> end`
"""
@inline function wrap_in_let_statement(expr, symbols)
return Expr(:let, Expr(:block, _gen_let_statement.(symbols)...), expr)
end

"""
_gen_let_statement(symbol::Symbol)
Return a let-`Expr` like `<symbol> = <symbol>`.
"""
function _gen_let_statement(symbol::Symbol)
return Expr(:(=), symbol, symbol)
end

0 comments on commit 9290549

Please sign in to comment.