From 9290549572f8971ece3b84f57a0a5666fb2839dc Mon Sep 17 00:00:00 2001 From: AntonReinhard Date: Fri, 29 Nov 2024 15:21:53 +0100 Subject: [PATCH] Work on performance problems with closures (again...) --- src/code_gen/function.jl | 5 ++--- src/code_gen/tape_machine.jl | 41 ++++++++++++++++++++---------------- src/code_gen/utils.jl | 1 + src/devices/impl.jl | 25 ++++++++++++++++------ 4 files changed, 45 insertions(+), 27 deletions(-) diff --git a/src/code_gen/function.jl b/src/code_gen/function.jl index c73c99c..6a957e8 100644 --- a/src/code_gen/function.jl +++ b/src/code_gen/function.jl @@ -27,13 +27,12 @@ function get_compute_function( ) tape = gen_tape(graph, instance, machine, context_module) - assign_inputs = Expr(:block, expr_from_fc.(tape.input_assign_code)...) code = gen_function_body(tape, context_module; closures_size=closures_size) + assign_inputs = Expr(:block, expr_from_fc.(tape.input_assign_code)...) function_id = to_var_name(UUIDs.uuid1(rng[1])) res_sym = tape.output_symbol - expr = # - Expr( + expr = Expr( :function, # function definition Expr( :call, Symbol("compute_$function_id"), Expr(:(::), :input, input_type(instance)) diff --git a/src/code_gen/tape_machine.jl b/src/code_gen/tape_machine.jl index 767dc9b..4609dfb 100644 --- a/src/code_gen/tape_machine.jl +++ b/src/code_gen/tape_machine.jl @@ -12,23 +12,36 @@ end function expr_from_fc(fc::FunctionCall{VAL_T,Expr}) where {VAL_T} @assert length(fc) == 1 && isempty(fc.value_arguments[1]) "function call assigning an expression cannot be vectorized and cannot contain value arguments\n$fc" - fc_expr_in_let = Expr( - :let, - Expr(:block, fc.return_symbols[1]...), - fc.func, # anonymous function code block + fc_expr = Expr( + :block, + gen_local_init(fc), + fc.func, # anonymous function code block + Expr( # return the symbols + :return, + ( + if length(fc.return_symbols[1]) == 1 + fc.return_symbols[1][1] + else + Expr(:tuple, fc.return_symbols[1]...) + end + ), + ), ) func_call = Expr( :call, # call + #wrap_in_let_statement( # wrap in let statement to prevent boxing of local variables Expr( :->, # anonymous function Expr( :tuple, # anonymous function arguments - fc.arguments[1]..., + #fc.arguments[1]..., ), - fc_expr_in_let, + fc_expr, # anonymous function code block ), - fc.arguments[1]..., # runtime arguments passed to the anonymous function + #fc.arguments[1], + #), + #fc.arguments[1]..., # runtime arguments passed to the anonymous function call ) access_expr = gen_access_expr(fc) @@ -55,7 +68,7 @@ function gen_input_assignment_code( device = entry_device(machine) fc = FunctionCall( - input_expr(instance, name, :input), + Expr(:(=), symbol, input_expr(instance, name, :input)), (), Symbol[:input], Symbol[symbol], @@ -110,7 +123,7 @@ function _gen_function_body( end # iterate from end to beginning - # this helps because we can collect all undefined arguments to the closures that have to be returned somewhere earlier + # this helps because we can collect all undefined arguments to the closures that have to be defined somewhere earlier undefined_argument_symbols = Set{Symbol}() # the final return symbol is the return of the entire generated function, it always has to be returned push!(undefined_argument_symbols, gen_access_expr(fc_vec[end])) @@ -183,15 +196,7 @@ function _closure_fc( fc_expr = Expr( # actual function body of the closure :block, - expr_from_fc.(code_block)..., - Expr( - :return, # have to make sure to not return a tuple of length 1 - if length(ret_symbols_t) == 1 - ret_symbols_t[1] - else - Expr(:tuple, ret_symbols_t...) - end, - ), + expr_from_fc.(code_block)..., # no return statement necessary, will be done via capture and local init ) fc = FunctionCall( diff --git a/src/code_gen/utils.jl b/src/code_gen/utils.jl index 8207406..e477c68 100644 --- a/src/code_gen/utils.jl +++ b/src/code_gen/utils.jl @@ -17,6 +17,7 @@ function infer_types!(tape::Tape, context_module::Module) for fc in tape.input_assign_code res_types = result_types(fc, known_result_types, context_module) + fc.return_types = res_types for (s, t) in Iterators.zip( Iterators.flatten(fc.return_symbols), Iterators.cycle(res_types, length(fc.return_symbols)), diff --git a/src/devices/impl.jl b/src/devices/impl.jl index 172160f..7392074 100644 --- a/src/devices/impl.jl +++ b/src/devices/impl.jl @@ -63,9 +63,6 @@ end gen_local_init(fc::FunctionCall) Dispatch from the given [`FunctionCall`](@ref) to the lower-level function [`_gen_local_init`](@ref). - -!!! note - This is currently unused, but may become useful in the future again. """ function gen_local_init(fc::FunctionCall) return Expr( @@ -82,10 +79,26 @@ end Return an `Expr` that initializes the symbol in the local scope. The result looks like `local ::`. - -!!! note - This is currently unused, but may become useful in the future again. """ function _gen_local_init(symbol::Symbol, type::Type) return Expr(:local, symbol, :(::), Symbol(type)) end + +""" + wrap_in_let_statement(expr, symbols) + +For a given expression and a collection of symbols, generate a let statement that wraps the expression in a let statement with all the symbols, like +`let =, ..., = end` +""" +@inline function wrap_in_let_statement(expr, symbols) + return Expr(:let, Expr(:block, _gen_let_statement.(symbols)...), expr) +end + +""" + _gen_let_statement(symbol::Symbol) + +Return a let-`Expr` like ` = `. +""" +function _gen_let_statement(symbol::Symbol) + return Expr(:(=), symbol, symbol) +end