From b15488061b9a7f9d77b68dcc74db8bd34bdb7d99 Mon Sep 17 00:00:00 2001 From: Anton Reinhard Date: Wed, 13 Nov 2024 18:38:55 +0100 Subject: [PATCH 01/12] Restructure FunctionCall --- ext/devices/cuda/function.jl | 4 +- ext/devices/rocm/function.jl | 4 +- src/ComputableDAGs.jl | 1 - src/code_gen/function.jl | 10 +-- src/code_gen/tape_machine.jl | 58 ++++++++--------- src/code_gen/type.jl | 8 +-- src/code_gen/utils.jl | 26 +++++--- src/devices/impl.jl | 37 +++++++++-- src/devices/interface.jl | 2 +- src/devices/numa/impl.jl | 10 +-- src/scheduler/type.jl | 38 ++++++++--- src/task/compute.jl | 118 ++++++++++++++++++++--------------- src/utils.jl | 23 +------ test/strassen_test.jl | 2 + 14 files changed, 194 insertions(+), 147 deletions(-) diff --git a/ext/devices/cuda/function.jl b/ext/devices/cuda/function.jl index 7ccdf94..e990d38 100644 --- a/ext/devices/cuda/function.jl +++ b/ext/devices/cuda/function.jl @@ -4,14 +4,14 @@ function ComputableDAGs.kernel( machine = cpu_st() tape = ComputableDAGs.gen_tape(graph, instance, machine, context_module) - assign_inputs = Expr(:block, ComputableDAGs.expr_from_fc.(tape.inputAssignCode)...) + assign_inputs = Expr(:block, ComputableDAGs.expr_from_fc.(tape.input_assign_code)...) # TODO: use gen_function_body here code = Expr(:block, ComputableDAGs.expr_from_fc.(tape.schedule)...) function_id = ComputableDAGs.to_var_name(UUIDs.uuid1(ComputableDAGs.rng[1])) res_sym = eval( ComputableDAGs._gen_access_expr( - ComputableDAGs.entry_device(tape.machine), tape.outputSymbol + ComputableDAGs.entry_device(tape.machine), tape.output_symbol ), ) expr = Meta.parse( diff --git a/ext/devices/rocm/function.jl b/ext/devices/rocm/function.jl index dc617e2..4b41723 100644 --- a/ext/devices/rocm/function.jl +++ b/ext/devices/rocm/function.jl @@ -4,7 +4,7 @@ function ComputableDAGs.kernel( machine = cpu_st() tape = ComputableDAGs.gen_tape(graph, instance, machine, context_module) - assign_inputs = Expr(:block, ComputableDAGs.expr_from_fc.(tape.inputAssignCode)...) + assign_inputs = Expr(:block, ComputableDAGs.expr_from_fc.(tape.input_assign_code)...) # TODO use gen_function_body here code = Expr(:block, ComputableDAGs.expr_from_fc.(tape.schedule)...) @@ -12,7 +12,7 @@ function ComputableDAGs.kernel( function_id = ComputableDAGs.to_var_name(UUIDs.uuid1(ComputableDAGs.rng[1])) res_sym = eval( ComputableDAGs._gen_access_expr( - ComputableDAGs.entry_device(tape.machine), tape.outputSymbol + ComputableDAGs.entry_device(tape.machine), tape.output_symbol ), ) expr = Meta.parse( diff --git a/src/ComputableDAGs.jl b/src/ComputableDAGs.jl index a852e15..095b8ca 100644 --- a/src/ComputableDAGs.jl +++ b/src/ComputableDAGs.jl @@ -37,7 +37,6 @@ export get_operations export execute export get_compute_function export gen_tape, execute_tape -export unpack_identity # estimator export cost_type, graph_cost, operation_effect diff --git a/src/code_gen/function.jl b/src/code_gen/function.jl index f8dbce0..412d265 100644 --- a/src/code_gen/function.jl +++ b/src/code_gen/function.jl @@ -24,20 +24,20 @@ function get_compute_function( ) tape = gen_tape(graph, instance, machine, context_module) - assignInputs = Expr(:block, expr_from_fc.(tape.inputAssignCode)...) + assign_inputs = Expr(:block, expr_from_fc.(tape.input_assign_code)...) code = gen_function_body(tape; closures_size=closures_size) - functionId = to_var_name(UUIDs.uuid1(rng[1])) - resSym = eval(_gen_access_expr(entry_device(tape.machine), tape.outputSymbol)) + function_id = to_var_name(UUIDs.uuid1(rng[1])) + res_sym = _gen_access_expr(entry_device(tape.machine), tape.output_symbol) expr = # Expr( :function, # function definition Expr( :call, - Symbol("compute_$functionId"), + Symbol("compute_$function_id"), Expr(:(::), :data_input, input_type(instance)), ), # function name and parameters - Expr(:block, assignInputs, code, Expr(:return, resSym)), # function body + Expr(:block, assign_inputs, code, Expr(:return, res_sym)), # function body ) return RuntimeGeneratedFunction(@__MODULE__, context_module, expr) diff --git a/src/code_gen/tape_machine.jl b/src/code_gen/tape_machine.jl index 7bdc584..be386b9 100644 --- a/src/code_gen/tape_machine.jl +++ b/src/code_gen/tape_machine.jl @@ -50,29 +50,21 @@ function call_fc(fc::FunctionCall{VectorT,M}, cache::Dict{Symbol,Any}) where {Ve return nothing end -function expr_from_fc(fc::FunctionCall{VectorT,0}) where {VectorT} - func_call = Expr( - :call, fc.func, eval.(_gen_access_expr.(Ref(fc.device), fc.arguments))... - ) - access_expr = eval(gen_access_expr(fc)) - - return Expr(:(=), access_expr, func_call) -end - -""" - expr_from_fc(fc::FunctionCall) - -For a given function call, return an expression evaluating it. -""" -function expr_from_fc(fc::FunctionCall{VectorT,M}) where {VectorT,M} - func_call = Expr( - :call, - fc.func, - fc.value_arguments..., - eval.(_gen_access_expr.(Ref(fc.device), fc.arguments))..., - ) - access_expr = eval(gen_access_expr(fc)) - +function expr_from_fc(fc::FunctionCall{VAL_T,N_ARG,N_RET}) where {VAL_T,N_ARG,N_RET} + if length(fc) == 1 + func_call = Expr( + :call, + fc.func, + ( + fc.value_arguments[1]..., + _gen_access_expr.(Ref(fc.device), fc.arguments[1])..., + )..., + ) + else + # TBW; dispatch to device specific vectorization + throw("unimplemented") + end + access_expr = gen_access_expr(fc) return Expr(:(=), access_expr, func_call) end @@ -101,10 +93,10 @@ function gen_input_assignment_code( fc = FunctionCall( context_module.eval(Expr(:->, :x, input_expr(instance, name, :x))), - SVector{0,Any}(), - SVector{1,Symbol}(:input), - symbol, - Nothing, + (), + (:input,), + (symbol,), + (Nothing,), device, ) @@ -140,7 +132,7 @@ function gen_function_body(tape::Tape; closures_size::Int) # this helps because we can collect all undefined arguments to the closures that have to be returned somewhere earlier undefined_argument_symbols = Set{Symbol}() # the final return symbol is the return of the entire generated function, it always has to be returned - push!(undefined_argument_symbols, eval(gen_access_expr(fc_vec[end]))) + push!(undefined_argument_symbols, gen_access_expr(fc_vec[end])) for i in length(fc_vec):(-closures_size):1 e = i @@ -150,12 +142,12 @@ function gen_function_body(tape::Tape; closures_size::Int) # collect `local var` statements that need to exist before the closure starts local_inits = gen_local_init.(code_block) - return_symbols = eval.(gen_access_expr.(code_block)) + return_symbols = gen_access_expr.(code_block) ret_symbols_set = Set(return_symbols) for fc in code_block for arg in fc.arguments - symbol = eval(_gen_access_expr(fc.device, arg)) + symbol = _gen_access_expr(fc.device, arg) # symbol won't be defined if it is first calculated in the closure # so don't add it to the arguments in this case @@ -237,7 +229,7 @@ function gen_tape( assign_inputs = gen_input_assignment_code(input_syms, instance, machine, context_module) return Tape{input_type(instance)}( - assign_inputs, function_body, input_syms, outSym, instance, machine + assign_inputs, function_body, outSym, instance, machine ) end @@ -257,12 +249,12 @@ function execute_tape(tape::Tape, input) compute_code = tape.schedule - for function_call in tape.inputAssignCode + for function_call in tape.input_assign_code call_fc(function_call, cache) end for function_call in compute_code call_fc(function_call, cache) end - return cache[tape.outputSymbol] + return cache[tape.output_symbol] end diff --git a/src/code_gen/type.jl b/src/code_gen/type.jl index d08bf46..4e2f004 100644 --- a/src/code_gen/type.jl +++ b/src/code_gen/type.jl @@ -6,14 +6,12 @@ TODO: update docs - `INPUT` the input type of the problem instance - `code::Vector{Expr}`: The julia expression containing the code for the whole graph. -- `inputSymbols::Dict{String, Vector{Symbol}}`: A dictionary of symbols mapping the names of the input nodes of the graph to the symbols their inputs should be provided on. -- `outputSymbol::Symbol`: The symbol of the final calculated value +- `output_symbol::Symbol`: The symbol of the final calculated value """ struct Tape{INPUT} - inputAssignCode::Vector{FunctionCall} + input_assign_code::Vector{FunctionCall} schedule::Vector{FunctionCall} - inputSymbols::Dict{String,Vector{Symbol}} - outputSymbol::Symbol + output_symbol::Symbol instance::Any machine::Machine end diff --git a/src/code_gen/utils.jl b/src/code_gen/utils.jl index 8379b16..ddcd2ff 100644 --- a/src/code_gen/utils.jl +++ b/src/code_gen/utils.jl @@ -10,16 +10,26 @@ function infer_types!(tape::Tape) # the only initially known type known_result_types[:input] = input_type(tape.instance) - for fc in tape.inputAssignCode - res_type = result_type(fc, known_result_types) - fc.return_type = res_type - known_result_types[fc.return_symbol] = res_type + for fc in tape.input_assign_code + res_types = result_types(fc, known_result_types) + fc.return_types = res_types + for (s, t) in Iterators.zip( + Iterators.flatten(fc.return_symbols), + Iterators.cycle(res_types, length(fc.return_symbols)), + ) + known_result_types[s] = t + end end for fc in tape.schedule - res_type = result_type(fc, known_result_types) - fc.return_type = res_type - known_result_types[fc.return_symbol] = res_type + res_types = result_types(fc, known_result_types) + fc.return_types = res_types + for (s, t) in Iterators.zip( + Iterators.flatten(fc.return_symbols), + Iterators.cycle(res_types, length(fc.return_symbols)), + ) + known_result_types[s] = t + end end return nothing @@ -37,7 +47,7 @@ function lower(schedule::Vector{Node}, machine::Machine) if (node isa DataTaskNode && length(children(node)) == 0) push!(calls, get_init_function_call(node, entry_device(machine))) else - push!(calls, get_function_call(node)...) + push!(calls, get_function_call(node)) end end diff --git a/src/devices/impl.jl b/src/devices/impl.jl index 05c0049..1c31f46 100644 --- a/src/devices/impl.jl +++ b/src/devices/impl.jl @@ -31,17 +31,44 @@ end """ gen_access_expr(fc::FunctionCall) -Dispatch from the given [`FunctionCall`](@ref) to the interface function `_gen_access_expr`(@ref). +Dispatch from the given [`FunctionCall`](@ref) to the interface function [`_gen_access_expr`](@ref). """ -function gen_access_expr(fc::FunctionCall) - return _gen_access_expr(fc.device, fc.return_symbol) +function gen_access_expr(fc::FunctionCall{VAL_T,N_ARG,N_RET}) where {VAL_T,N_ARG,N_RET} + vec = Expr[] + for ret_symbols in fc.return_symbols + push!(vec, unroll_symbol_vector(_gen_access_expr.(Ref(fc.device), ret_symbols))) + end + if length(vec) > 1 + return unroll_symbol_vector(vec) + else + return vec[1] + end +end + +function gen_access_expr(fc::FunctionCall{VAL_T,N_ARG,1}) where {VAL_T,N_ARG} + vec = Symbol[] + for ret_symbols in fc.return_symbols + push!(vec, _gen_access_expr.(Ref(fc.device), ret_symbols[1])) + end + if length(vec) > 1 + return unroll_symbol_vector(vec) + else + return vec[1] + end end """ gen_local_init(fc::FunctionCall) -Dispatch from the given [`FunctionCall`](@ref) to the interface function `_gen_local_init`(@ref). +Dispatch from the given [`FunctionCall`](@ref) to the interface function [`_gen_local_init`](@ref). """ function gen_local_init(fc::FunctionCall) - return _gen_local_init(fc, fc.device) + return Expr( + :block, + _gen_local_init.( + Ref(fc.device), + Iterators.flatten(fc.return_symbols), + Iterators.cycle(fc.return_types, length(fc.return_symbols)), + )..., + ) end diff --git a/src/devices/interface.jl b/src/devices/interface.jl index a4a08be..95acd28 100644 --- a/src/devices/interface.jl +++ b/src/devices/interface.jl @@ -55,7 +55,7 @@ Return an `Expr` or `QuoteNode` accessing the variable identified by [`symbol`]. function _gen_access_expr end """ - _gen_local_init(fc::FunctionCall, device::AbstractDevice) + _gen_local_init(device::AbstractDevice, symbol::Symbol, type::Type) Interface function that must be implemented for every subtype of [`AbstractDevice`](@ref). Return an `Expr` or `QuoteNode` that initializes the access expression returned by [`_gen_access_expr`](@ref) in the local scope. diff --git a/src/devices/numa/impl.jl b/src/devices/numa/impl.jl index e3a2ec8..e8b9ad9 100644 --- a/src/devices/numa/impl.jl +++ b/src/devices/numa/impl.jl @@ -48,16 +48,16 @@ function _gen_access_expr(::NumaNode, symbol::Symbol) # TODO rewrite these with Expr instead of quote node s = Symbol("data_$symbol") quote_node = Meta.parse(":($s)") - return quote_node + return eval(quote_node) end """ - _gen_local_init(fc::FunctionCall, device::NumaNode) + _gen_local_init(device::NumaNode, symbol::Symbol, type::Type) Interface implementation, dispatched to from [`gen_local_init`](@ref). """ -function _gen_local_init(fc::FunctionCall, ::NumaNode) - s = Symbol("data_$(fc.return_symbol)") - quote_node = Expr(:local, s, :(::), Symbol(fc.return_type)) # TODO: figure out how to get type info for this local variable +function _gen_local_init(::NumaNode, symbol::Symbol, type::Type) + s = Symbol("data_$(symbol)") + quote_node = Expr(:local, s, :(::), Symbol(type)) return quote_node end diff --git a/src/scheduler/type.jl b/src/scheduler/type.jl index 008f677..4a441c6 100644 --- a/src/scheduler/type.jl +++ b/src/scheduler/type.jl @@ -1,16 +1,34 @@ -using StaticArrays - """ - FunctionCall{N} + FunctionCall{VAL_TYPES} + +Type representing a function call. Contains the function to call, argument symbols, the return symbol and the device to execute on. -Type representing a function call with `N` parameters. Contains the function to call, argument symbols, the return symbol and the device to execute on. +TODO: extend docs """ -mutable struct FunctionCall{VectorType<:AbstractVector,N} +mutable struct FunctionCall{VAL_T<:Tuple,N_ARG,N_RET} func::Function - # TODO: this should be a tuple - value_arguments::SVector{N,Any} # value arguments for the function call, will be prepended to the other arguments - arguments::VectorType # symbols of the inputs to the function call - return_symbol::Symbol - return_type::Type + value_arguments::Vector{VAL_T} # tuple of value arguments for the function call, will be prepended to the other arguments + arguments::Vector{NTuple{N_ARG,Symbol}} # symbols of the inputs to the function call + return_symbols::Vector{NTuple{N_RET,Symbol}} # the return symbols + return_types::NTuple{N_RET,Type} # the return type of the function call(s); there can only be one return type since we require type stability device::AbstractDevice end + +function FunctionCall( + func::Function, + value_arguments::VAL_T, + arguments::NTuple{N_ARG,Symbol}, + return_symbol::NTuple{N_RET,Symbol}, + return_types::NTuple{N_RET,Type}, + device::AbstractDevice, +) where {VAL_T<:Tuple,N_ARG,N_RET} + # convenience constructor for function calls that do not use vectorization, which is most of the use cases + return FunctionCall( + func, [value_arguments], [arguments], [return_symbol], return_types, device + ) +end + +function Base.length(fc::FunctionCall) + @assert length(fc.value_arguments) == length(fc.arguments) == length(fc.return_symbols) "function call length is undefined, got $(length(fc.value_arguments)) tuples of value arguments, $(length(fc.arguments)) tuples of arguments, and $(length(return_symbols)) return symbols" + return length(fc.value_arguments) +end diff --git a/src/task/compute.jl b/src/task/compute.jl index 4355142..0e6ba4a 100644 --- a/src/task/compute.jl +++ b/src/task/compute.jl @@ -2,16 +2,17 @@ using StaticArrays """ get_function_call(n::Node) - get_function_call(t::AbstractTask, device::AbstractDevice, in_symbols::AbstractVector, out_symbol::Symbol) + get_function_call(t::AbstractTask, device::AbstractDevice, in_symbols::NTuple{}, out_symbol::Symbol) -For a node or a task together with necessary information, return a vector of [`FunctionCall`](@ref)s for the computation of the node or task. - -For ordinary compute or data tasks the vector will contain exactly one element. +For a node or a task together with necessary information, a [`FunctionCall`](@ref)s for the computation of the node or task. """ function get_function_call( - t::CompTask, device::AbstractDevice, in_symbols::AbstractVector, out_symbol::Symbol -) where {CompTask<:AbstractComputeTask} - return [FunctionCall(compute, SVector{1,Any}(t), in_symbols, out_symbol, Any, device)] + t::AbstractComputeTask, + device::AbstractDevice, + in_symbols::NTuple{N,Symbol}, + out_symbol::Symbol, +) where {N} + return FunctionCall(compute, (t,), in_symbols, (out_symbol,), (Any,), device) end function get_function_call(node::ComputeTaskNode) @@ -21,74 +22,93 @@ function get_function_call(node::ComputeTaskNode) # make sure the node is sorted so the arguments keep their order sort_node!(node) - if (length(node.children) <= 800) - #only use an SVector when there are few children - return get_function_call( - node.task, - node.device, - SVector{length(node.children),Symbol}( - Symbol.(to_var_name.(getfield.(getindex.(children(node), 1), :id)))... - ), - Symbol(to_var_name(node.id)), - ) - else - return get_function_call( - node.task, - node.device, - Symbol.(to_var_name.(getfield.(getindex.(children(node), 1), :id))), - Symbol(to_var_name(node.id)), - ) - end + return get_function_call( + node.task, + node.device, + (Symbol.(to_var_name.(getfield.(getindex.(children(node), 1), :id)))...,), + Symbol(to_var_name(node.id)), + ) end function get_function_call(node::DataTaskNode) @assert length(children(node)) == 1 "trying to call get_function_call on a data task node that has $(length(node.children)) children instead of 1\nchildren: $(node.children)" # TODO: dispatch to device implementations generating the copy commands - return [ - FunctionCall( - unpack_identity, - SVector{0,Any}(), - SVector{1,Symbol}(Symbol(to_var_name(first(children(node))[1].id))), - Symbol(to_var_name(node.id)), - Any, - first(children(node))[1].device, - ), - ] + return FunctionCall( + identity, + (), + (Symbol(to_var_name(first(children(node))[1].id)),), + (Symbol(to_var_name(node.id)),), + (Any,), + first(children(node))[1].device, + ) end function get_init_function_call(node::DataTaskNode, device::AbstractDevice) @assert isempty(children(node)) "trying to call get_init_function_call on a data task node that is not an entry node." return FunctionCall( - unpack_identity, - SVector{0,Any}(), - SVector{1,Symbol}(Symbol("$(to_var_name(node.id))_in")), - Symbol(to_var_name(node.id)), - Any, + identity, + (), + (Symbol("$(to_var_name(node.id))_in"),), + (Symbol(to_var_name(node.id)),), + (Any,), device, ) end -function result_type(fc::FunctionCall, known_res_types::Dict{Symbol,Type}) - argument_types = ( - typeof.(fc.value_arguments)..., getindex.(Ref(known_res_types), fc.arguments)... - ) - types = Base.return_types(fc.func, argument_types) +_value_argument_types(fc::FunctionCall) = typeof.(fc.value_arguments[1]) +function _argument_types(known_res_types::Dict{Symbol,Type}, fc::FunctionCall) + return getindex.(Ref(known_res_types), fc.arguments[1]) +end + +function result_types( + fc::FunctionCall{VAL_T,N_ARG,1}, known_res_types::Dict{Symbol,Type} +) where {VAL_T,N_ARG} + arg_types = (_value_argument_types(fc)..., _argument_types(known_res_types, fc)...) + types = Base.return_types(fc.func, arg_types) if length(types) > 1 throw( - "failure during type inference: function call $fc with argument types $(argument_types) is type unstable, possible return types: $types", + "failure during type inference: function call $fc with argument types $(arg_types) is type unstable, possible return types: $types", ) end if isempty(types) throw( - "failure during type inference: function call $fc with argument types $(argument_types) has no return types, this is likely because no method matches the arguments", + "failure during type inference: function call $fc with argument types $(arg_types) has no return types, this is likely because no method matches the arguments", ) end if types[1] == Any - @warn "inferred return type 'Any' in task $fc with argument types $(argument_types)" + @warn "inferred return type 'Any' in task $fc with argument types $(arg_types)" + end + + return (types[1],) +end + +function result_types( + fc::FunctionCall{VAL_T,N_ARG,N_RET}, known_res_types::Dict{Symbol,Type} +) where {VAL_T,N_ARG,N_RET} + arg_types = (_value_argument_types(fc)..., _argument_types(fc)...) + types = Base.return_types(fc.func, arg_types) + + if length(types) > 1 + throw( + "failure during type inference: function call $(fc.func) with argument types $(arg_types) is type unstable, possible return types: $types", + ) + end + if isempty(types) + throw( + "failure during type inference: function call $(fc.func) with argument types $(arg_types) has no return types, this is likely because no method matches the arguments", + ) + end + if types[1] == Any + @warn "inferred return type 'Any' in task $fc with argument types $(arg_types)" + end + if !(types[1] isa Tuple) || length(types[1].parameters) != N_RET + throw( + "failure durng type inference: function call $(fc.func) was expected to return a Tuple with $N_RET elements, but returns $(types[1])", + ) end - return types[1] + return (types[1].parameters...,) end diff --git a/src/utils.jl b/src/utils.jl index 862eda9..e3754f6 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -5,14 +5,6 @@ Function with no arguments, returns nothing, does nothing. Useful for noop [`Fun """ @inline noop() = nothing -""" - unpack_identity(x::SVector) - -Function taking an `SVector`, returning it unpacked. -""" -@inline unpack_identity(x::SVector{1,<:Any}) = x[1] -@inline unpack_identity(x) = x - """ bytes_to_human_readable(bytes) @@ -118,17 +110,6 @@ end Return the given vector as single String without quotation marks or brackets. """ -function unroll_symbol_vector(vec::Vector) - result = "" - for s in vec - if (result != "") - result *= ", " - end - result *= "$s" - end - return result -end - -function unroll_symbol_vector(vec::SVector) - return unroll_symbol_vector(Vector(vec)) +function unroll_symbol_vector(vec::VEC) where {VEC<:Union{AbstractVector,Tuple}} + return Expr(:tuple, vec...) end diff --git a/test/strassen_test.jl b/test/strassen_test.jl index e67491e..e5df096 100644 --- a/test/strassen_test.jl +++ b/test/strassen_test.jl @@ -53,10 +53,12 @@ EDGE_NUMBERS = (3, 96, 747, 5304) #, 37203 @test isapprox(f(input), input[1] * input[2]) end + #= TODO: reenable when closures work again @testset "Execution with closures" begin f_closures = get_compute_function(g, mm, cpu_st(), @__MODULE__; closures_size=100) @test Base.return_types(f_closures, (typeof(input),))[1] == typeof(input[1]) @test isapprox(f_closures(input), input[1] * input[2]) end + =# end From edcdb31799ce8443ee8a5c950774dd0ea0dd75cc Mon Sep 17 00:00:00 2001 From: AntonReinhard Date: Thu, 14 Nov 2024 19:10:04 +0100 Subject: [PATCH 02/12] Make closures work again --- ext/devices/cuda/function.jl | 2 +- ext/devices/rocm/function.jl | 2 +- src/code_gen/function.jl | 6 +- src/code_gen/tape_machine.jl | 131 +++++++++++++++++++++++------------ src/code_gen/utils.jl | 4 +- src/devices/numa/impl.jl | 10 +-- src/task/compute.jl | 2 +- test/strassen_test.jl | 2 - 8 files changed, 99 insertions(+), 60 deletions(-) diff --git a/ext/devices/cuda/function.jl b/ext/devices/cuda/function.jl index e990d38..5b41056 100644 --- a/ext/devices/cuda/function.jl +++ b/ext/devices/cuda/function.jl @@ -20,7 +20,7 @@ function ComputableDAGs.kernel( if (id > n) return end - @inline data_input = input_vector[id] + @inline input = input_vector[id] $(assign_inputs) $code @inline output_vector[id] = $res_sym diff --git a/ext/devices/rocm/function.jl b/ext/devices/rocm/function.jl index 4b41723..7a2da6d 100644 --- a/ext/devices/rocm/function.jl +++ b/ext/devices/rocm/function.jl @@ -21,7 +21,7 @@ function ComputableDAGs.kernel( if (id > n) return end - @inline data_input = input_vector[id] + @inline input = input_vector[id] $(assign_inputs) $code @inline output_vector[id] = $res_sym diff --git a/src/code_gen/function.jl b/src/code_gen/function.jl index 412d265..16a4d3c 100644 --- a/src/code_gen/function.jl +++ b/src/code_gen/function.jl @@ -25,7 +25,7 @@ function get_compute_function( tape = gen_tape(graph, instance, machine, context_module) assign_inputs = Expr(:block, expr_from_fc.(tape.input_assign_code)...) - code = gen_function_body(tape; closures_size=closures_size) + code = gen_function_body(tape, context_module; closures_size=closures_size) function_id = to_var_name(UUIDs.uuid1(rng[1])) res_sym = _gen_access_expr(entry_device(tape.machine), tape.output_symbol) @@ -33,9 +33,7 @@ function get_compute_function( Expr( :function, # function definition Expr( - :call, - Symbol("compute_$function_id"), - Expr(:(::), :data_input, input_type(instance)), + :call, Symbol("compute_$function_id"), Expr(:(::), :input, input_type(instance)) ), # function name and parameters Expr(:block, assign_inputs, code, Expr(:return, res_sym)), # function body ) diff --git a/src/code_gen/tape_machine.jl b/src/code_gen/tape_machine.jl index be386b9..2e76b72 100644 --- a/src/code_gen/tape_machine.jl +++ b/src/code_gen/tape_machine.jl @@ -108,86 +108,127 @@ function gen_input_assignment_code( end """ - gen_function_body(tape::Tape; closures_size) + gen_function_body(tape::Tape, context_module::Module; closures_size) Generate the function body from the given [`Tape`](@ref). ## Keyword Arguments `closures_size`: The size of closures to generate (in lines of code). Closures introduce function barriers in the function body, preventing some optimizations by the compiler and therefore greatly reducing compile time. A value of 1 or less will disable the use of closures entirely. """ -function gen_function_body(tape::Tape; closures_size::Int) - if closures_size > 1 - # only need to annotate types later when using closures - infer_types!(tape) - end +function gen_function_body(tape::Tape, context_module::Module; closures_size::Int) + # only need to annotate types later when using closures + types = infer_types!(tape) + + # TODO calculate closures size better - fc_vec = tape.schedule + return _gen_function_body( + tape.schedule, types, tape.machine, context_module; closures_size=closures_size + ) +end - if (closures_size <= 1) +function _gen_function_body( + fc_vec::AbstractVector{FunctionCall}, + type_dict::Dict{Symbol,Type}, + machine::Machine, + context_module::Module; + closures_size=0, +) + if closures_size <= 1 || closures_size >= length(fc_vec) return Expr(:block, expr_from_fc.(fc_vec)...) end - closures = Vector{Expr}() # iterate from end to beginning # this helps because we can collect all undefined arguments to the closures that have to be returned somewhere earlier undefined_argument_symbols = Set{Symbol}() # the final return symbol is the return of the entire generated function, it always has to be returned push!(undefined_argument_symbols, gen_access_expr(fc_vec[end])) + closured_fc_vec = FunctionCall[] for i in length(fc_vec):(-closures_size):1 e = i b = max(i - closures_size, 1) code_block = fc_vec[b:e] - # collect `local var` statements that need to exist before the closure starts - local_inits = gen_local_init.(code_block) + pushfirst!( + closured_fc_vec, + _closure_fc( + code_block, type_dict, machine, undefined_argument_symbols, context_module + ), + ) + end + + return _gen_function_body( + closured_fc_vec, type_dict, machine, context_module; closures_size=closures_size + ) +end + +""" + _closure_fc() - return_symbols = gen_access_expr.(code_block) +From the given function calls, make and return a new function call representing all of them together. +The undefined_argument_symbols is the set of all Symbols that need to be returned if available inside the code_block. They get updated inside this function. +""" +function _closure_fc( + code_block::AbstractVector{FunctionCall}, + types::Dict{Symbol,Type}, + machine::Machine, + undefined_argument_symbols::Set{Symbol}, + context_module::Module, +) + return_symbols = Symbol[] + for s in + Iterators.flatten(Iterators.flatten(getfield.(code_block, Ref(:return_symbols)))) + push!(return_symbols, s) + end - ret_symbols_set = Set(return_symbols) - for fc in code_block - for arg in fc.arguments - symbol = _gen_access_expr(fc.device, arg) + ret_symbols_set = Set(return_symbols) + arg_symbols_set = Set{Symbol}() + for fc in code_block + for symbol in Iterators.flatten(fc.arguments) + # symbol won't be defined if it is first calculated in the closure + # so don't add it to the arguments in this case + if !(symbol in ret_symbols_set) + push!(undefined_argument_symbols, symbol) - # symbol won't be defined if it is first calculated in the closure - # so don't add it to the arguments in this case - if !(symbol in ret_symbols_set) - push!(undefined_argument_symbols, symbol) - end + push!(arg_symbols_set, symbol) end end + end - intersect!(ret_symbols_set, undefined_argument_symbols) - return_symbols = Symbol[ret_symbols_set...] + setdiff!(arg_symbols_set, ret_symbols_set) + intersect!(ret_symbols_set, undefined_argument_symbols) - closure = Expr( - :block, - Expr( - :(=), - Expr(:tuple, return_symbols...), + arg_symbols_t = (arg_symbols_set...,) + ret_symbols_t = (ret_symbols_set...,) + + closure = context_module.eval( + Expr( # create the closure: () -> code block; return (locals) + :->, + Expr(:tuple, arg_symbols_t...), # closure arguments + Expr( # actual function body of the closure + :block, + expr_from_fc.(code_block)..., Expr( - :call, # call to the following closure (no arguments) - Expr( # create the closure: () -> code block; return (locals) - :->, - :(), # closure arguments (none) - Expr( # actual function body of the closure - :block, - local_inits..., # declare local variables with type information inside the closure - expr_from_fc.(code_block)..., - Expr(:return, Expr(:tuple, return_symbols...)), - ), - ), + :return, # have to make sure to not return a tuple of length 1 + if length(ret_symbols_t) == 1 + ret_symbols_t[1] + else + Expr(:tuple, ret_symbols_t...) + end, ), ), - ) + ), + ) - setdiff!(undefined_argument_symbols, ret_symbols_set) + ret_types = (getindex.(Ref(types), ret_symbols_t)) - # combine to one closure call, including all the local inits and the actual call to the closure - pushfirst!(closures, closure) - end + fc = FunctionCall( + closure, (), arg_symbols_t, ret_symbols_t, ret_types, entry_device(machine) + ) + + setdiff!(undefined_argument_symbols, ret_symbols_set) - return Expr(:block, closures...) + return fc end """ diff --git a/src/code_gen/utils.jl b/src/code_gen/utils.jl index ddcd2ff..d9b4750 100644 --- a/src/code_gen/utils.jl +++ b/src/code_gen/utils.jl @@ -3,6 +3,8 @@ Infer the result type of each function call in the given schedule. Returns a dictionary with the result type for each [`Node`](@ref). This assumes that each node has only one statically inferrable return type and will throw an exceptin otherwise. This also assumes that the given `Vector` contains a topological ordering of its nodes, such as returned by a call to [`schedule_dag`](@ref). + +Also returns the inferred types as a `Dict{Symbol, Type}`. """ function infer_types!(tape::Tape) known_result_types = Dict{Symbol,Type}() @@ -32,7 +34,7 @@ function infer_types!(tape::Tape) end end - return nothing + return known_result_types end """ diff --git a/src/devices/numa/impl.jl b/src/devices/numa/impl.jl index e8b9ad9..5cb80a3 100644 --- a/src/devices/numa/impl.jl +++ b/src/devices/numa/impl.jl @@ -46,9 +46,9 @@ Interface implementation, dispatched to from [`gen_access_expr`](@ref). """ function _gen_access_expr(::NumaNode, symbol::Symbol) # TODO rewrite these with Expr instead of quote node - s = Symbol("data_$symbol") - quote_node = Meta.parse(":($s)") - return eval(quote_node) + #=s = Symbol("data_$symbol") + quote_node = Meta.parse(":($s)")=# + return symbol end """ @@ -57,7 +57,7 @@ end Interface implementation, dispatched to from [`gen_local_init`](@ref). """ function _gen_local_init(::NumaNode, symbol::Symbol, type::Type) - s = Symbol("data_$(symbol)") - quote_node = Expr(:local, s, :(::), Symbol(type)) + #s = Symbol("data_$(symbol)") + quote_node = Expr(:local, symbol, :(::), Symbol(type)) return quote_node end diff --git a/src/task/compute.jl b/src/task/compute.jl index 0e6ba4a..39c2092 100644 --- a/src/task/compute.jl +++ b/src/task/compute.jl @@ -88,7 +88,7 @@ end function result_types( fc::FunctionCall{VAL_T,N_ARG,N_RET}, known_res_types::Dict{Symbol,Type} ) where {VAL_T,N_ARG,N_RET} - arg_types = (_value_argument_types(fc)..., _argument_types(fc)...) + arg_types = (_value_argument_types(fc)..., _argument_types(known_res_types, fc)...) types = Base.return_types(fc.func, arg_types) if length(types) > 1 diff --git a/test/strassen_test.jl b/test/strassen_test.jl index e5df096..e67491e 100644 --- a/test/strassen_test.jl +++ b/test/strassen_test.jl @@ -53,12 +53,10 @@ EDGE_NUMBERS = (3, 96, 747, 5304) #, 37203 @test isapprox(f(input), input[1] * input[2]) end - #= TODO: reenable when closures work again @testset "Execution with closures" begin f_closures = get_compute_function(g, mm, cpu_st(), @__MODULE__; closures_size=100) @test Base.return_types(f_closures, (typeof(input),))[1] == typeof(input[1]) @test isapprox(f_closures(input), input[1] * input[2]) end - =# end From 03cdf5871a6533d650d2b420e39058eeb0a2b4e7 Mon Sep 17 00:00:00 2001 From: AntonReinhard Date: Sun, 17 Nov 2024 03:45:21 +0100 Subject: [PATCH 03/12] Improve scheduler --- src/code_gen/tape_machine.jl | 16 +++++++++++----- src/scheduler/greedy.jl | 26 ++++++++++++++++---------- 2 files changed, 27 insertions(+), 15 deletions(-) diff --git a/src/code_gen/tape_machine.jl b/src/code_gen/tape_machine.jl index 2e76b72..f1391b1 100644 --- a/src/code_gen/tape_machine.jl +++ b/src/code_gen/tape_machine.jl @@ -55,10 +55,8 @@ function expr_from_fc(fc::FunctionCall{VAL_T,N_ARG,N_RET}) where {VAL_T,N_ARG,N_ func_call = Expr( :call, fc.func, - ( - fc.value_arguments[1]..., - _gen_access_expr.(Ref(fc.device), fc.arguments[1])..., - )..., + fc.value_arguments[1]..., + _gen_access_expr.(Ref(fc.device), fc.arguments[1])..., ) else # TBW; dispatch to device specific vectorization @@ -119,7 +117,13 @@ function gen_function_body(tape::Tape, context_module::Module; closures_size::In # only need to annotate types later when using closures types = infer_types!(tape) - # TODO calculate closures size better + if closures_size >= 1 + s = log(closures_size, length(tape.schedule)) + closures_depth = ceil(Int, s) # tend towards more levels/smaller closures + closures_size = ceil(Int, length(tape.schedule)^(1 / closures_depth)) + end + + @info "generating function body with closure size $closures_size" return _gen_function_body( tape.schedule, types, tape.machine, context_module; closures_size=closures_size @@ -133,6 +137,7 @@ function _gen_function_body( context_module::Module; closures_size=0, ) + @info "generating function body from $(length(fc_vec)) function calls with closure size $closures_size" if closures_size <= 1 || closures_size >= length(fc_vec) return Expr(:block, expr_from_fc.(fc_vec)...) end @@ -251,6 +256,7 @@ function gen_tape( context_module::Module, scheduler::AbstractScheduler=GreedyScheduler(), ) + @debug "generating tape" schedule = schedule_dag(scheduler, graph, machine) function_body = lower(schedule, machine) diff --git a/src/scheduler/greedy.jl b/src/scheduler/greedy.jl index 63795f1..61077b1 100644 --- a/src/scheduler/greedy.jl +++ b/src/scheduler/greedy.jl @@ -7,14 +7,16 @@ A greedy implementation of a scheduler, creating a topological ordering of nodes struct GreedyScheduler <: AbstractScheduler end function schedule_dag(::GreedyScheduler, graph::DAG, machine::Machine) - node_queue = PriorityQueue{Node,Int}() + node_dict = Dict{Node,Int}() # dictionary of nodes with the number of not-yet-scheduled children + node_stack = Stack{Node}() # stack of currently schedulable nodes, i.e., nodes with all of their children already scheduled + # the stack makes sure that closely related nodes will be scheduled one after another # use a priority equal to the number of unseen children -> 0 are nodes that can be added for node in get_entry_nodes(graph) - enqueue!(node_queue, node => 0) + push!(node_stack, node) end - schedule = Vector{Node}() + schedule = Node[] sizehint!(schedule, length(graph.nodes)) # keep an accumulated cost of things scheduled to this device so far @@ -24,9 +26,8 @@ function schedule_dag(::GreedyScheduler, graph::DAG, machine::Machine) end local node - while !isempty(node_queue) - @assert peek(node_queue)[2] == 0 - node = dequeue!(node_queue) + while !isempty(node_stack) + node = pop!(node_stack) # assign the device with lowest accumulated cost to the node (if it's a compute node) if (isa(node, ComputeTaskNode)) @@ -37,15 +38,20 @@ function schedule_dag(::GreedyScheduler, graph::DAG, machine::Machine) push!(schedule, node) + # find all parent's priority, reduce by one if in the node_dict + # if it reaches zero, push onto node_stack for parent in parents(node) - # reduce the priority of all parents by one - if (!haskey(node_queue, parent)) - enqueue!(node_queue, parent => length(children(parent)) - 1) + parents_prio = get(node_dict, parent, length(children(parent))) - 1 + if parents_prio == 0 + delete!(node_dict, parent) + push!(node_stack, parent) else - node_queue[parent] = node_queue[parent] - 1 + node_dict[parent] = parents_prio end end end + @assert isempty(node_dict) "found unschedulable nodes, this most likely means the graph has a cycle" + return schedule end From 0171a8768debebf12120d5a2b1204fde829cb429 Mon Sep 17 00:00:00 2001 From: Anton Reinhard Date: Mon, 18 Nov 2024 17:21:31 +0100 Subject: [PATCH 04/12] Change Function Calls to mostly use Vectors instead of Tuples This makes small function calls slower, but large function calls much much faster to compile --- src/code_gen/tape_machine.jl | 14 ++++++----- src/devices/impl.jl | 24 ++++++++++-------- src/scheduler/type.jl | 19 +++++++------- src/task/compute.jl | 49 ++++++++++++------------------------ 4 files changed, 47 insertions(+), 59 deletions(-) diff --git a/src/code_gen/tape_machine.jl b/src/code_gen/tape_machine.jl index f1391b1..94b9f54 100644 --- a/src/code_gen/tape_machine.jl +++ b/src/code_gen/tape_machine.jl @@ -1,3 +1,4 @@ +#= # TODO: do this with macros function call_fc( fc::FunctionCall{VectorT,0}, cache::Dict{Symbol,Any} @@ -49,8 +50,9 @@ function call_fc(fc::FunctionCall{VectorT,M}, cache::Dict{Symbol,Any}) where {Ve ) return nothing end +=# -function expr_from_fc(fc::FunctionCall{VAL_T,N_ARG,N_RET}) where {VAL_T,N_ARG,N_RET} +function expr_from_fc(fc::FunctionCall{VAL_T}) where {VAL_T} if length(fc) == 1 func_call = Expr( :call, @@ -92,9 +94,9 @@ function gen_input_assignment_code( fc = FunctionCall( context_module.eval(Expr(:->, :x, input_expr(instance, name, :x))), (), - (:input,), - (symbol,), - (Nothing,), + [:input], + [symbol], + [Nothing], device, ) @@ -203,8 +205,8 @@ function _closure_fc( setdiff!(arg_symbols_set, ret_symbols_set) intersect!(ret_symbols_set, undefined_argument_symbols) - arg_symbols_t = (arg_symbols_set...,) - ret_symbols_t = (ret_symbols_set...,) + arg_symbols_t = [arg_symbols_set...] + ret_symbols_t = [ret_symbols_set...] closure = context_module.eval( Expr( # create the closure: () -> code block; return (locals) diff --git a/src/devices/impl.jl b/src/devices/impl.jl index 1c31f46..35236c2 100644 --- a/src/devices/impl.jl +++ b/src/devices/impl.jl @@ -33,19 +33,21 @@ end Dispatch from the given [`FunctionCall`](@ref) to the interface function [`_gen_access_expr`](@ref). """ -function gen_access_expr(fc::FunctionCall{VAL_T,N_ARG,N_RET}) where {VAL_T,N_ARG,N_RET} - vec = Expr[] - for ret_symbols in fc.return_symbols - push!(vec, unroll_symbol_vector(_gen_access_expr.(Ref(fc.device), ret_symbols))) +function gen_access_expr(fc::FunctionCall{VAL_T}) where {VAL_T} + if length(fc.return_types) != 1 + # general case + vec = Expr[] + for ret_symbols in fc.return_symbols + push!(vec, unroll_symbol_vector(_gen_access_expr.(Ref(fc.device), ret_symbols))) + end + if length(vec) > 1 + return unroll_symbol_vector(vec) + else + return vec[1] + end end - if length(vec) > 1 - return unroll_symbol_vector(vec) - else - return vec[1] - end -end -function gen_access_expr(fc::FunctionCall{VAL_T,N_ARG,1}) where {VAL_T,N_ARG} + # no vectorization case vec = Symbol[] for ret_symbols in fc.return_symbols push!(vec, _gen_access_expr.(Ref(fc.device), ret_symbols[1])) diff --git a/src/scheduler/type.jl b/src/scheduler/type.jl index 4a441c6..03fc918 100644 --- a/src/scheduler/type.jl +++ b/src/scheduler/type.jl @@ -5,24 +5,25 @@ Type representing a function call. Contains the function to call, argument symbo TODO: extend docs """ -mutable struct FunctionCall{VAL_T<:Tuple,N_ARG,N_RET} +mutable struct FunctionCall{VAL_T<:Tuple} func::Function - value_arguments::Vector{VAL_T} # tuple of value arguments for the function call, will be prepended to the other arguments - arguments::Vector{NTuple{N_ARG,Symbol}} # symbols of the inputs to the function call - return_symbols::Vector{NTuple{N_RET,Symbol}} # the return symbols - return_types::NTuple{N_RET,Type} # the return type of the function call(s); there can only be one return type since we require type stability + value_arguments::Vector{VAL_T} # tuple of value arguments for the function call, will be prepended to the other arguments + arguments::Vector{Vector{Symbol}} # symbols of the inputs to the function call + return_symbols::Vector{Vector{Symbol}} # the return symbols + return_types::Vector{<:Type} # the return type of the function call(s); there can only be one return type since we require type stability device::AbstractDevice end function FunctionCall( func::Function, value_arguments::VAL_T, - arguments::NTuple{N_ARG,Symbol}, - return_symbol::NTuple{N_RET,Symbol}, - return_types::NTuple{N_RET,Type}, + arguments::Vector{Symbol}, + return_symbol::Vector{Symbol}, + return_types::Vector{<:Type}, device::AbstractDevice, -) where {VAL_T<:Tuple,N_ARG,N_RET} +) where {VAL_T<:Tuple} # convenience constructor for function calls that do not use vectorization, which is most of the use cases + @assert length(return_types) == 0 || length(return_types) == length(return_symbol) "number of return types $(length(return_types)) does not match the number of return symbols $(length(return_symbol))" return FunctionCall( func, [value_arguments], [arguments], [return_symbol], return_types, device ) diff --git a/src/task/compute.jl b/src/task/compute.jl index 39c2092..07559d9 100644 --- a/src/task/compute.jl +++ b/src/task/compute.jl @@ -12,7 +12,7 @@ function get_function_call( in_symbols::NTuple{N,Symbol}, out_symbol::Symbol, ) where {N} - return FunctionCall(compute, (t,), in_symbols, (out_symbol,), (Any,), device) + return FunctionCall(compute, (t,), [in_symbols...], [out_symbol], [Any], device) end function get_function_call(node::ComputeTaskNode) @@ -37,9 +37,9 @@ function get_function_call(node::DataTaskNode) return FunctionCall( identity, (), - (Symbol(to_var_name(first(children(node))[1].id)),), - (Symbol(to_var_name(node.id)),), - (Any,), + [Symbol(to_var_name(first(children(node))[1].id))], + [Symbol(to_var_name(node.id))], + [Any], first(children(node))[1].device, ) end @@ -50,9 +50,9 @@ function get_init_function_call(node::DataTaskNode, device::AbstractDevice) return FunctionCall( identity, (), - (Symbol("$(to_var_name(node.id))_in"),), - (Symbol(to_var_name(node.id)),), - (Any,), + [Symbol("$(to_var_name(node.id))_in")], + [Symbol(to_var_name(node.id))], + [Any], device, ) end @@ -63,33 +63,12 @@ function _argument_types(known_res_types::Dict{Symbol,Type}, fc::FunctionCall) end function result_types( - fc::FunctionCall{VAL_T,N_ARG,1}, known_res_types::Dict{Symbol,Type} -) where {VAL_T,N_ARG} + fc::FunctionCall{VAL_T}, known_res_types::Dict{Symbol,Type} +) where {VAL_T} arg_types = (_value_argument_types(fc)..., _argument_types(known_res_types, fc)...) types = Base.return_types(fc.func, arg_types) - if length(types) > 1 - throw( - "failure during type inference: function call $fc with argument types $(arg_types) is type unstable, possible return types: $types", - ) - end - if isempty(types) - throw( - "failure during type inference: function call $fc with argument types $(arg_types) has no return types, this is likely because no method matches the arguments", - ) - end - if types[1] == Any - @warn "inferred return type 'Any' in task $fc with argument types $(arg_types)" - end - - return (types[1],) -end - -function result_types( - fc::FunctionCall{VAL_T,N_ARG,N_RET}, known_res_types::Dict{Symbol,Type} -) where {VAL_T,N_ARG,N_RET} - arg_types = (_value_argument_types(fc)..., _argument_types(known_res_types, fc)...) - types = Base.return_types(fc.func, arg_types) + N_RET = length(fc.return_types) if length(types) > 1 throw( @@ -104,11 +83,15 @@ function result_types( if types[1] == Any @warn "inferred return type 'Any' in task $fc with argument types $(arg_types)" end + + if (N_RET == 1) + return [types[1]] + end + if !(types[1] isa Tuple) || length(types[1].parameters) != N_RET throw( "failure durng type inference: function call $(fc.func) was expected to return a Tuple with $N_RET elements, but returns $(types[1])", ) end - - return (types[1].parameters...,) + return [types[1].parameters...] end From 4003edc9730140261eeb21237804a25b04c1b3fa Mon Sep 17 00:00:00 2001 From: AntonReinhard Date: Tue, 26 Nov 2024 23:48:30 +0100 Subject: [PATCH 05/12] WIP fix input assignment code world age problem --- src/code_gen/tape_machine.jl | 48 +++++++++++++++++++++++++----------- src/code_gen/utils.jl | 7 ++++-- src/scheduler/type.jl | 8 +++--- src/task/compute.jl | 12 +++++++-- src/utils.jl | 4 +++ 5 files changed, 57 insertions(+), 22 deletions(-) diff --git a/src/code_gen/tape_machine.jl b/src/code_gen/tape_machine.jl index 94b9f54..1065714 100644 --- a/src/code_gen/tape_machine.jl +++ b/src/code_gen/tape_machine.jl @@ -52,7 +52,7 @@ function call_fc(fc::FunctionCall{VectorT,M}, cache::Dict{Symbol,Any}) where {Ve end =# -function expr_from_fc(fc::FunctionCall{VAL_T}) where {VAL_T} +function expr_from_fc(fc::FunctionCall{VAL_T,F_T}) where {VAL_T,F_T<:Function} if length(fc) == 1 func_call = Expr( :call, @@ -68,11 +68,17 @@ function expr_from_fc(fc::FunctionCall{VAL_T}) where {VAL_T} return Expr(:(=), access_expr, func_call) end +function expr_from_fc(fc::FunctionCall{VAL_T,Expr}) where {VAL_T} + @assert length(fc) == 1 && isempty(fc.arguments[1]) && isempty(fc.value_arguments[1]) "function call assigning an expression has an unallowed combination of arguments, which is not allowed\n$fc" + return Expr(:(=), gen_access_expr(fc), fc.func) +end + """ gen_input_assignment_code( input_symbols::Dict{String, Vector{Symbol}}, instance::AbstractProblemInstance, machine::Machine, + input_type::Type, context_module::Module ) @@ -82,24 +88,37 @@ function gen_input_assignment_code( input_symbols::Dict{String,Vector{Symbol}}, instance, machine::Machine, + input_type::Type, context_module::Module, ) assign_inputs = Vector{FunctionCall}() for (name, symbols) in input_symbols - # make a function for this, since we can't use anonymous functions in the FunctionCall - for symbol in symbols device = entry_device(machine) - fc = FunctionCall( - context_module.eval(Expr(:->, :x, input_expr(instance, name, :x))), + f_id = Symbol(to_var_name(UUIDs.uuid1(rng[threadid()]))) + + fc_setup = FunctionCall( + Expr(:->, :x, input_expr(instance, name, :x)), (), - [:input], - [symbol], - [Nothing], + Symbol[], + Symbol[f_id], + Type[Nothing], device, ) + fc = FunctionCall( + _call, (), Symbol[f_id, :input], Symbol[symbol], Type[Nothing], device + ) + + ret_expr = Expr( + :call, Base.return_types, fc_setup.func, Expr(:tuple, input_type) + ) + ret_type = context_module.eval(ret_expr) + @assert length(ret_type) == 1 + fc.return_types = [ret_type[1]] + + push!(assign_inputs, fc_setup) push!(assign_inputs, fc) end end @@ -125,7 +144,7 @@ function gen_function_body(tape::Tape, context_module::Module; closures_size::In closures_size = ceil(Int, length(tape.schedule)^(1 / closures_depth)) end - @info "generating function body with closure size $closures_size" + @debug "generating function body with closure size $closures_size" return _gen_function_body( tape.schedule, types, tape.machine, context_module; closures_size=closures_size @@ -139,7 +158,7 @@ function _gen_function_body( context_module::Module; closures_size=0, ) - @info "generating function body from $(length(fc_vec)) function calls with closure size $closures_size" + @debug "generating function body from $(length(fc_vec)) function calls with closure size $closures_size" if closures_size <= 1 || closures_size >= length(fc_vec) return Expr(:block, expr_from_fc.(fc_vec)...) end @@ -275,11 +294,12 @@ function gen_tape( # get outSymbol outSym = Symbol(to_var_name(get_exit_node(graph).id)) - assign_inputs = gen_input_assignment_code(input_syms, instance, machine, context_module) - - return Tape{input_type(instance)}( - assign_inputs, function_body, outSym, instance, machine + INPUT_T = input_type(instance) + assign_inputs = gen_input_assignment_code( + input_syms, instance, machine, INPUT_T, context_module ) + + return Tape{INPUT_T}(assign_inputs, function_body, outSym, instance, machine) end """ diff --git a/src/code_gen/utils.jl b/src/code_gen/utils.jl index d9b4750..19eb56c 100644 --- a/src/code_gen/utils.jl +++ b/src/code_gen/utils.jl @@ -13,8 +13,11 @@ function infer_types!(tape::Tape) known_result_types[:input] = input_type(tape.instance) for fc in tape.input_assign_code - res_types = result_types(fc, known_result_types) - fc.return_types = res_types + if typeof(fc.func) isa Expr + continue + end + # for input assign code, the return types are set on construction + res_types = fc.return_types for (s, t) in Iterators.zip( Iterators.flatten(fc.return_symbols), Iterators.cycle(res_types, length(fc.return_symbols)), diff --git a/src/scheduler/type.jl b/src/scheduler/type.jl index 03fc918..28452a8 100644 --- a/src/scheduler/type.jl +++ b/src/scheduler/type.jl @@ -5,17 +5,17 @@ Type representing a function call. Contains the function to call, argument symbo TODO: extend docs """ -mutable struct FunctionCall{VAL_T<:Tuple} - func::Function +mutable struct FunctionCall{VAL_T<:Tuple,FUNC_T<:Union{Function,Expr}} + func::FUNC_T value_arguments::Vector{VAL_T} # tuple of value arguments for the function call, will be prepended to the other arguments arguments::Vector{Vector{Symbol}} # symbols of the inputs to the function call return_symbols::Vector{Vector{Symbol}} # the return symbols - return_types::Vector{<:Type} # the return type of the function call(s); there can only be one return type since we require type stability + return_types::Vector{<:Type} # the return type of the function call(s); there can only be one return type since we require type stability device::AbstractDevice end function FunctionCall( - func::Function, + func::Union{Function,Expr}, value_arguments::VAL_T, arguments::Vector{Symbol}, return_symbol::Vector{Symbol}, diff --git a/src/task/compute.jl b/src/task/compute.jl index 07559d9..c2e9fdd 100644 --- a/src/task/compute.jl +++ b/src/task/compute.jl @@ -63,8 +63,8 @@ function _argument_types(known_res_types::Dict{Symbol,Type}, fc::FunctionCall) end function result_types( - fc::FunctionCall{VAL_T}, known_res_types::Dict{Symbol,Type} -) where {VAL_T} + fc::FunctionCall{VAL_T,F_T}, known_res_types::Dict{Symbol,Type} +) where {VAL_T,F_T<:Function} arg_types = (_value_argument_types(fc)..., _argument_types(known_res_types, fc)...) types = Base.return_types(fc.func, arg_types) @@ -95,3 +95,11 @@ function result_types( end return [types[1].parameters...] end + +function result_types( + fc::FunctionCall{VAL_T,Expr}, known_res_types::Dict{Symbol,Type} +) where {VAL_T} + # assume that the return type is already set + @assert length(fc.return_types) == 1 + return [fc.return_types[1]] +end diff --git a/src/utils.jl b/src/utils.jl index e3754f6..d320ea5 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -113,3 +113,7 @@ Return the given vector as single String without quotation marks or brackets. function unroll_symbol_vector(vec::VEC) where {VEC<:Union{AbstractVector,Tuple}} return Expr(:tuple, vec...) end + +@inline function _call(f, args::Vararg) + return f(args...) +end From 428ce7b0f2ee56847624699efab3b51750adef40 Mon Sep 17 00:00:00 2001 From: Anton Reinhard Date: Wed, 27 Nov 2024 12:59:54 +0100 Subject: [PATCH 06/12] Remove trie workaround --- src/trie.jl | 32 ++------------------------------ 1 file changed, 2 insertions(+), 30 deletions(-) diff --git a/src/trie.jl b/src/trie.jl index 23cf7ee..e949124 100644 --- a/src/trie.jl +++ b/src/trie.jl @@ -46,24 +46,7 @@ Insert the given node into the trie. The depth is used to iterate through the tr """ function insert_helper!( trie::NodeIdTrie{NodeType}, node::NodeType, depth::Int -) where {TaskType<:AbstractDataTask,NodeType<:DataTaskNode{TaskType}} - if (length(children(node)) == depth) - push!(trie.value, node) - return nothing - end - - depth = depth + 1 - id = node.children[depth][1].id - - if (!haskey(trie.children, id)) - trie.children[id] = NodeIdTrie{NodeType}() - end - return insert_helper!(trie.children[id], node, depth) -end -# TODO: Remove this workaround once https://github.com/JuliaLang/julia/issues/54404 is fixed in julia 1.10+ -function insert_helper!( - trie::NodeIdTrie{NodeType}, node::NodeType, depth::Int -) where {TaskType<:AbstractComputeTask,NodeType<:ComputeTaskNode{TaskType}} +) where {NodeType<:Node} if (length(children(node)) == depth) push!(trie.value, node) return nothing @@ -83,18 +66,7 @@ end Insert the given node into the trie. It's sorted by its type in the first layer, then by its children in the following layers. """ -function Base.insert!( - trie::NodeTrie, node::NodeType -) where {TaskType<:AbstractDataTask,NodeType<:DataTaskNode{TaskType}} - if (!haskey(trie.children, NodeType)) - trie.children[NodeType] = NodeIdTrie{NodeType}() - end - return insert_helper!(trie.children[NodeType], node, 0) -end -# TODO: Remove this workaround once https://github.com/JuliaLang/julia/issues/54404 is fixed in julia 1.10+ -function Base.insert!( - trie::NodeTrie, node::NodeType -) where {TaskType<:AbstractComputeTask,NodeType<:ComputeTaskNode{TaskType}} +function Base.insert!(trie::NodeTrie, node::NodeType) where {NodeType<:Node} if (!haskey(trie.children, NodeType)) trie.children[NodeType] = NodeIdTrie{NodeType}() end From e76461fddfc44b7db2b4eaee554fb565f9600442 Mon Sep 17 00:00:00 2001 From: Anton Reinhard Date: Wed, 27 Nov 2024 13:10:20 +0100 Subject: [PATCH 07/12] Remove call_fc and execute functions --- src/ComputableDAGs.jl | 3 +- src/code_gen/function.jl | 22 ---------- src/code_gen/tape_machine.jl | 82 ------------------------------------ 3 files changed, 1 insertion(+), 106 deletions(-) diff --git a/src/ComputableDAGs.jl b/src/ComputableDAGs.jl index 095b8ca..63bc257 100644 --- a/src/ComputableDAGs.jl +++ b/src/ComputableDAGs.jl @@ -34,9 +34,8 @@ export reset_graph! export get_operations # code generation related -export execute export get_compute_function -export gen_tape, execute_tape +export gen_tape # estimator export cost_type, graph_cost, operation_effect diff --git a/src/code_gen/function.jl b/src/code_gen/function.jl index 16a4d3c..e6a6c53 100644 --- a/src/code_gen/function.jl +++ b/src/code_gen/function.jl @@ -40,25 +40,3 @@ function get_compute_function( return RuntimeGeneratedFunction(@__MODULE__, context_module, expr) end - -""" - execute( - graph::DAG, - instance, - machine::Machine, - input, - context_module::Module - ) - -Execute the code of the given `graph` on the given input values. - -This is essentially shorthand for -```julia -tape = gen_tape(graph, instance, machine, context_module) -return execute_tape(tape, input) -``` -""" -function execute(graph::DAG, instance, machine::Machine, input, context_module::Module) - tape = gen_tape(graph, instance, machine, context_module) - return execute_tape(tape, input) -end diff --git a/src/code_gen/tape_machine.jl b/src/code_gen/tape_machine.jl index 1065714..f70ee61 100644 --- a/src/code_gen/tape_machine.jl +++ b/src/code_gen/tape_machine.jl @@ -1,57 +1,3 @@ -#= -# TODO: do this with macros -function call_fc( - fc::FunctionCall{VectorT,0}, cache::Dict{Symbol,Any} -) where {VectorT<:SVector{1}} - cache[fc.return_symbol] = fc.func(cache[fc.arguments[1]]) - return nothing -end - -function call_fc( - fc::FunctionCall{VectorT,1}, cache::Dict{Symbol,Any} -) where {VectorT<:SVector{1}} - cache[fc.return_symbol] = fc.func(fc.value_arguments[1], cache[fc.arguments[1]]) - return nothing -end - -function call_fc( - fc::FunctionCall{VectorT,0}, cache::Dict{Symbol,Any} -) where {VectorT<:SVector{2}} - cache[fc.return_symbol] = fc.func(cache[fc.arguments[1]], cache[fc.arguments[2]]) - return nothing -end - -function call_fc( - fc::FunctionCall{VectorT,1}, cache::Dict{Symbol,Any} -) where {VectorT<:SVector{2}} - cache[fc.return_symbol] = fc.func( - fc.value_arguments[1], cache[fc.arguments[1]], cache[fc.arguments[2]] - ) - return nothing -end - -function call_fc(fc::FunctionCall{VectorT,1}, cache::Dict{Symbol,Any}) where {VectorT} - cache[fc.return_symbol] = fc.func( - fc.value_arguments[1], getindex.(Ref(cache), fc.arguments)... - ) - return nothing -end - -""" - call_fc(fc::FunctionCall, cache::Dict{Symbol, Any}) - -Execute the given [`FunctionCall`](@ref) on the dictionary. - -Several more specialized versions of this function exist to reduce vector unrolling work for common cases. -""" -function call_fc(fc::FunctionCall{VectorT,M}, cache::Dict{Symbol,Any}) where {VectorT,M} - cache[fc.return_symbol] = fc.func( - fc.value_arguments..., getindex.(Ref(cache), fc.arguments)... - ) - return nothing -end -=# - function expr_from_fc(fc::FunctionCall{VAL_T,F_T}) where {VAL_T,F_T<:Function} if length(fc) == 1 func_call = Expr( @@ -267,8 +213,6 @@ end ) Generate the code for a given graph. The return value is a [`Tape`](@ref). - -See also: [`execute`](@ref), [`execute_tape`](@ref) """ function gen_tape( graph::DAG, @@ -301,29 +245,3 @@ function gen_tape( return Tape{INPUT_T}(assign_inputs, function_body, outSym, instance, machine) end - -""" - execute_tape(tape::Tape, input::Input) where {Input} - -Execute the given tape with the given input. - -!!! warning - This is very slow and might not work. This is to be majorly revamped. -""" -function execute_tape(tape::Tape, input) - cache = Dict{Symbol,Any}() - cache[:input] = input - # simply execute all the code snippets here - @assert typeof(input) <: input_type(tape.instance) "expected tape input type to fit $(input_type(tape.instance)) but got $(typeof(input))" - - compute_code = tape.schedule - - for function_call in tape.input_assign_code - call_fc(function_call, cache) - end - for function_call in compute_code - call_fc(function_call, cache) - end - - return cache[tape.output_symbol] -end From 68a79454e06ae3140a1f3cf37b0f9f9311b422e3 Mon Sep 17 00:00:00 2001 From: Anton Reinhard Date: Wed, 27 Nov 2024 13:36:06 +0100 Subject: [PATCH 08/12] Remove superfluous _gen_access_expr and simplify _gen_local_init --- ext/devices/cuda/function.jl | 7 +------ ext/devices/rocm/function.jl | 7 +------ src/code_gen/function.jl | 7 +++++-- src/code_gen/tape_machine.jl | 7 +------ src/devices/impl.jl | 27 +++++++++++++++++++++------ src/devices/interface.jl | 17 ----------------- src/devices/numa/impl.jl | 23 ----------------------- 7 files changed, 29 insertions(+), 66 deletions(-) diff --git a/ext/devices/cuda/function.jl b/ext/devices/cuda/function.jl index 5b41056..4c9caf8 100644 --- a/ext/devices/cuda/function.jl +++ b/ext/devices/cuda/function.jl @@ -9,11 +9,6 @@ function ComputableDAGs.kernel( code = Expr(:block, ComputableDAGs.expr_from_fc.(tape.schedule)...) function_id = ComputableDAGs.to_var_name(UUIDs.uuid1(ComputableDAGs.rng[1])) - res_sym = eval( - ComputableDAGs._gen_access_expr( - ComputableDAGs.entry_device(tape.machine), tape.output_symbol - ), - ) expr = Meta.parse( "function compute_$(function_id)(input_vector, output_vector, n::Int64) id = (blockIdx().x - 1) * blockDim().x + threadIdx().x @@ -23,7 +18,7 @@ function ComputableDAGs.kernel( @inline input = input_vector[id] $(assign_inputs) $code - @inline output_vector[id] = $res_sym + @inline output_vector[id] = $(tape.output_symbol) return nothing end" ) diff --git a/ext/devices/rocm/function.jl b/ext/devices/rocm/function.jl index 7a2da6d..1567ce3 100644 --- a/ext/devices/rocm/function.jl +++ b/ext/devices/rocm/function.jl @@ -10,11 +10,6 @@ function ComputableDAGs.kernel( code = Expr(:block, ComputableDAGs.expr_from_fc.(tape.schedule)...) function_id = ComputableDAGs.to_var_name(UUIDs.uuid1(ComputableDAGs.rng[1])) - res_sym = eval( - ComputableDAGs._gen_access_expr( - ComputableDAGs.entry_device(tape.machine), tape.output_symbol - ), - ) expr = Meta.parse( "function compute_$(function_id)(input_vector, output_vector, n::Int64) id = (workgroupIdx().x - 1) * workgroupDim().x + workgroupIdx().x @@ -24,7 +19,7 @@ function ComputableDAGs.kernel( @inline input = input_vector[id] $(assign_inputs) $code - @inline output_vector[id] = $res_sym + @inline output_vector[id] = $(tape.output_symbol) return nothing end" ) diff --git a/src/code_gen/function.jl b/src/code_gen/function.jl index e6a6c53..c73c99c 100644 --- a/src/code_gen/function.jl +++ b/src/code_gen/function.jl @@ -17,7 +17,10 @@ in your top level. ## Keyword Arguments -`closures_size` (default=0 (off)): The size of closures to use in the main generated code. This specifies the size of code blocks across which the compiler cannot optimize. For sufficiently large functions, a larger value means longer compile times but potentially faster execution time. +`closures_size` (default=0 (off)): The size of closures to use in the main generated code. This specifies the size of code blocks across which the + compiler cannot optimize. For sufficiently large functions, a larger value means longer compile times but potentially faster execution time. + **Note** that the actually used closure size might be different than the one passed here, since the function automatically chooses a size that + is close to a n-th root of the total number of loc, based off the given size. """ function get_compute_function( graph::DAG, instance, machine::Machine, context_module::Module; closures_size=0 @@ -28,7 +31,7 @@ function get_compute_function( code = gen_function_body(tape, context_module; closures_size=closures_size) function_id = to_var_name(UUIDs.uuid1(rng[1])) - res_sym = _gen_access_expr(entry_device(tape.machine), tape.output_symbol) + res_sym = tape.output_symbol expr = # Expr( :function, # function definition diff --git a/src/code_gen/tape_machine.jl b/src/code_gen/tape_machine.jl index f70ee61..dc7c795 100644 --- a/src/code_gen/tape_machine.jl +++ b/src/code_gen/tape_machine.jl @@ -1,11 +1,6 @@ function expr_from_fc(fc::FunctionCall{VAL_T,F_T}) where {VAL_T,F_T<:Function} if length(fc) == 1 - func_call = Expr( - :call, - fc.func, - fc.value_arguments[1]..., - _gen_access_expr.(Ref(fc.device), fc.arguments[1])..., - ) + func_call = Expr(:call, fc.func, fc.value_arguments[1]..., fc.arguments[1]...) else # TBW; dispatch to device specific vectorization throw("unimplemented") diff --git a/src/devices/impl.jl b/src/devices/impl.jl index 35236c2..172160f 100644 --- a/src/devices/impl.jl +++ b/src/devices/impl.jl @@ -31,14 +31,14 @@ end """ gen_access_expr(fc::FunctionCall) -Dispatch from the given [`FunctionCall`](@ref) to the interface function [`_gen_access_expr`](@ref). +Return the """ function gen_access_expr(fc::FunctionCall{VAL_T}) where {VAL_T} if length(fc.return_types) != 1 # general case vec = Expr[] for ret_symbols in fc.return_symbols - push!(vec, unroll_symbol_vector(_gen_access_expr.(Ref(fc.device), ret_symbols))) + push!(vec, unroll_symbol_vector(ret_symbols)) end if length(vec) > 1 return unroll_symbol_vector(vec) @@ -47,10 +47,10 @@ function gen_access_expr(fc::FunctionCall{VAL_T}) where {VAL_T} end end - # no vectorization case + # single return value per function vec = Symbol[] for ret_symbols in fc.return_symbols - push!(vec, _gen_access_expr.(Ref(fc.device), ret_symbols[1])) + push!(vec, ret_symbols[1]) end if length(vec) > 1 return unroll_symbol_vector(vec) @@ -62,15 +62,30 @@ end """ gen_local_init(fc::FunctionCall) -Dispatch from the given [`FunctionCall`](@ref) to the interface function [`_gen_local_init`](@ref). +Dispatch from the given [`FunctionCall`](@ref) to the lower-level function [`_gen_local_init`](@ref). + +!!! note + This is currently unused, but may become useful in the future again. """ function gen_local_init(fc::FunctionCall) return Expr( :block, _gen_local_init.( - Ref(fc.device), Iterators.flatten(fc.return_symbols), Iterators.cycle(fc.return_types, length(fc.return_symbols)), )..., ) end + +""" + _gen_local_init(symbol::Symbol, type::Type) + +Return an `Expr` that initializes the symbol in the local scope. +The result looks like `local ::`. + +!!! note + This is currently unused, but may become useful in the future again. +""" +function _gen_local_init(symbol::Symbol, type::Type) + return Expr(:local, symbol, :(::), Symbol(type)) +end diff --git a/src/devices/interface.jl b/src/devices/interface.jl index 95acd28..04ab6e0 100644 --- a/src/devices/interface.jl +++ b/src/devices/interface.jl @@ -46,23 +46,6 @@ Interface function that must be implemented for every subtype of [`AbstractDevic """ function measure_device! end -""" - _gen_access_expr(device::AbstractDevice, symbol::Symbol) - -Interface function that must be implemented for every subtype of [`AbstractDevice`](@ref). -Return an `Expr` or `QuoteNode` accessing the variable identified by [`symbol`]. -""" -function _gen_access_expr end - -""" - _gen_local_init(device::AbstractDevice, symbol::Symbol, type::Type) - -Interface function that must be implemented for every subtype of [`AbstractDevice`](@ref). -Return an `Expr` or `QuoteNode` that initializes the access expression returned by [`_gen_access_expr`](@ref) in the local scope. -This expression may be empty. For local variables it should be `local ::`. -""" -function _gen_local_init end - """ kernel(gpu_type::Type{<:AbstractGPU}, graph::DAG, instance) diff --git a/src/devices/numa/impl.jl b/src/devices/numa/impl.jl index 5cb80a3..9e65828 100644 --- a/src/devices/numa/impl.jl +++ b/src/devices/numa/impl.jl @@ -38,26 +38,3 @@ function get_devices(deviceType::Type{T}; verbose::Bool=false) where {T<:NumaNod return devices end - -""" - _gen_access_expr(device::NumaNode, symbol::Symbol) - -Interface implementation, dispatched to from [`gen_access_expr`](@ref). -""" -function _gen_access_expr(::NumaNode, symbol::Symbol) - # TODO rewrite these with Expr instead of quote node - #=s = Symbol("data_$symbol") - quote_node = Meta.parse(":($s)")=# - return symbol -end - -""" - _gen_local_init(device::NumaNode, symbol::Symbol, type::Type) - -Interface implementation, dispatched to from [`gen_local_init`](@ref). -""" -function _gen_local_init(::NumaNode, symbol::Symbol, type::Type) - #s = Symbol("data_$(symbol)") - quote_node = Expr(:local, symbol, :(::), Symbol(type)) - return quote_node -end From ee213a30b9721279338b865604d2d732d8b90d47 Mon Sep 17 00:00:00 2001 From: Anton Reinhard Date: Wed, 27 Nov 2024 14:20:16 +0100 Subject: [PATCH 09/12] Renaming of member variables --- src/estimator/global_metric.jl | 32 +++++++++++++------------- src/graph/interface.jl | 13 ++++++----- src/graph/mute.jl | 16 ++++++------- src/graph/print.jl | 10 ++++----- src/graph/properties.jl | 4 ++-- src/graph/type.jl | 12 +++++----- src/graph/validate.jl | 8 +++---- src/operation/apply.jl | 8 +++---- src/operation/clean.jl | 6 ++--- src/operation/find.jl | 40 ++++++++++++++++----------------- src/operation/get.jl | 8 +++---- src/operation/iterate.jl | 10 ++++----- src/operation/print.jl | 8 +++---- src/operation/utility.jl | 10 ++++----- src/optimization/random_walk.jl | 10 ++++----- src/optimization/reduce.jl | 4 ++-- src/optimization/split.jl | 4 ++-- src/properties/create.jl | 22 ++++++++++-------- src/properties/type.jl | 10 ++++----- src/properties/utility.jl | 30 ++++++++++++------------- src/utils.jl | 12 +++++----- test/strassen_test.jl | 4 ++-- 22 files changed, 143 insertions(+), 138 deletions(-) diff --git a/src/estimator/global_metric.jl b/src/estimator/global_metric.jl index b9e817f..ac81d2d 100644 --- a/src/estimator/global_metric.jl +++ b/src/estimator/global_metric.jl @@ -5,40 +5,40 @@ Representation of a [`DAG`](@ref)'s cost as estimated by the [`GlobalMetricEstim # Fields: `.data`: The total data transfer.\\ -`.computeEffort`: The total compute effort.\\ -`.computeIntensity`: The compute intensity, will always equal `.computeEffort / .data`. +`.compute_effort`: The total compute effort.\\ +`.compute_intensity`: The compute intensity, will always equal `.compute_effort / .data`. !!! note - Note that the `computeIntensity` doesn't necessarily make sense in the context of only operation costs. + Note that the `compute_intensity` doesn't necessarily make sense in the context of only operation costs. It will still work as intended when adding/subtracting to/from a `graph_cost` estimate. """ const CDCost = NamedTuple{ - (:data, :computeEffort, :computeIntensity),Tuple{Float64,Float64,Float64} + (:data, :compute_effort, :compute_intensity),Tuple{Float64,Float64,Float64} } function Base.:+(cost1::CDCost, cost2::CDCost)::CDCost d = cost1.data + cost2.data - ce = computeEffort = cost1.computeEffort + cost2.computeEffort - return (data=d, computeEffort=ce, computeIntensity=ce / d)::CDCost + ce = compute_effort = cost1.compute_effort + cost2.compute_effort + return (data=d, compute_effort=ce, compute_intensity=ce / d)::CDCost end function Base.:-(cost1::CDCost, cost2::CDCost)::CDCost d = cost1.data - cost2.data - ce = computeEffort = cost1.computeEffort - cost2.computeEffort - return (data=d, computeEffort=ce, computeIntensity=ce / d)::CDCost + ce = compute_effort = cost1.compute_effort - cost2.compute_effort + return (data=d, compute_effort=ce, compute_intensity=ce / d)::CDCost end function Base.isless(cost1::CDCost, cost2::CDCost)::Bool - return cost1.data + cost1.computeEffort < cost2.data + cost2.computeEffort + return cost1.data + cost1.compute_effort < cost2.data + cost2.compute_effort end function Base.zero(type::Type{CDCost}) - return (data=0.0, computeEffort=0.0, computeIntensity=0.0)::CDCost + return (data=0.0, compute_effort=0.0, compute_intensity=0.0)::CDCost end function Base.typemax(type::Type{CDCost}) - return (data=Inf, computeEffort=Inf, computeIntensity=0.0)::CDCost + return (data=Inf, compute_effort=Inf, compute_intensity=0.0)::CDCost end """ @@ -56,8 +56,8 @@ function graph_cost(estimator::GlobalMetricEstimator, graph::DAG) properties = get_properties(graph) return ( data=properties.data, - computeEffort=properties.computeEffort, - computeIntensity=properties.computeIntensity, + compute_effort=properties.compute_effort, + compute_intensity=properties.compute_intensity, )::CDCost end @@ -67,8 +67,8 @@ function operation_effect( s = length(operation.input) - 1 return ( data=s * -data(task(operation.input[1])), - computeEffort=s * -compute_effort(task(operation.input[1])), - computeIntensity=typeof(operation.input) <: DataTaskNode ? 0.0 : Inf, + compute_effort=s * -compute_effort(task(operation.input[1])), + compute_intensity=typeof(operation.input) <: DataTaskNode ? 0.0 : Inf, )::CDCost end @@ -78,7 +78,7 @@ function operation_effect( s::Float64 = length(parents(operation.input)) - 1 d::Float64 = s * data(task(operation.input)) ce::Float64 = s * compute_effort(task(operation.input)) - return (data=d, computeEffort=ce, computeIntensity=ce / d)::CDCost + return (data=d, compute_effort=ce, compute_intensity=ce / d)::CDCost end function String(::GlobalMetricEstimator) diff --git a/src/graph/interface.jl b/src/graph/interface.jl index 0fa74cc..97d092d 100644 --- a/src/graph/interface.jl +++ b/src/graph/interface.jl @@ -7,7 +7,7 @@ See also: [`DAG`](@ref), [`pop_operation!`](@ref) """ function push_operation!(graph::DAG, operation::Operation) # 1.: Add the operation to the DAG - push!(graph.operationsToApply, operation) + push!(graph.operations_to_apply, operation) return nothing end @@ -21,10 +21,10 @@ See also: [`DAG`](@ref), [`push_operation!`](@ref) """ function pop_operation!(graph::DAG) # 1.: Remove the operation from the appliedChain of the DAG - if !isempty(graph.operationsToApply) - pop!(graph.operationsToApply) - elseif !isempty(graph.appliedOperations) - appliedOp = pop!(graph.appliedOperations) + if !isempty(graph.operations_to_apply) + pop!(graph.operations_to_apply) + elseif !isempty(graph.applied_operations) + appliedOp = pop!(graph.applied_operations) revert_operation!(graph, appliedOp) else error("No more operations to pop!") @@ -38,7 +38,8 @@ end Return `true` if [`pop_operation!`](@ref) is possible, `false` otherwise. """ -can_pop(graph::DAG) = !isempty(graph.operationsToApply) || !isempty(graph.appliedOperations) +can_pop(graph::DAG) = + !isempty(graph.operations_to_apply) || !isempty(graph.applied_operations) """ reset_graph!(graph::DAG) diff --git a/src/graph/mute.jl b/src/graph/mute.jl index f927967..4dbe93d 100644 --- a/src/graph/mute.jl +++ b/src/graph/mute.jl @@ -56,7 +56,7 @@ function _insert_node!(graph::DAG, node::Node; track=true, invalidate_cache=true if (!invalidate_cache) return node end - push!(graph.dirtyNodes, node) + push!(graph.dirty_nodes, node) return node end @@ -110,8 +110,8 @@ function _insert_edge!( invalidate_operation_caches!(graph, node1) invalidate_operation_caches!(graph, node2) - push!(graph.dirtyNodes, node1) - push!(graph.dirtyNodes, node2) + push!(graph.dirty_nodes, node1) + push!(graph.dirty_nodes, node2) return nothing end @@ -145,7 +145,7 @@ function _remove_node!(graph::DAG, node::Node; track=true, invalidate_cache=true end invalidate_operation_caches!(graph, node) - delete!(graph.dirtyNodes, node) + delete!(graph.dirty_nodes, node) return nothing end @@ -207,10 +207,10 @@ function _remove_edge!( invalidate_operation_caches!(graph, node1) invalidate_operation_caches!(graph, node2) if (node1 in graph) - push!(graph.dirtyNodes, node1) + push!(graph.dirty_nodes, node1) end if (node2 in graph) - push!(graph.dirtyNodes, node2) + push!(graph.dirty_nodes, node2) end return removed_node_index @@ -235,7 +235,7 @@ Invalidate the operation caches for a given [`NodeReduction`](@ref). This deletes the operation from the graph's possible operations and from the involved nodes' own operation caches. """ function invalidate_caches!(graph::DAG, operation::NodeReduction) - delete!(graph.possibleOperations, operation) + delete!(graph.possible_operations, operation) for node in operation.input node.nodeReduction = missing @@ -252,7 +252,7 @@ Invalidate the operation caches for a given [`NodeSplit`](@ref). This deletes the operation from the graph's possible operations and from the involved nodes' own operation caches. """ function invalidate_caches!(graph::DAG, operation::NodeSplit) - delete!(graph.possibleOperations, operation) + delete!(graph.possible_operations, operation) # delete the operation from all caches of nodes involved in the operation # for node split there is only one node diff --git a/src/graph/print.jl b/src/graph/print.jl index 118679c..d1fc5af 100644 --- a/src/graph/print.jl +++ b/src/graph/print.jl @@ -28,14 +28,14 @@ function Base.show(io::IO, graph::DAG) print(io, " Nodes: ") nodeDict = Dict{Type,Int64}() - noEdges = 0 + number_of_edges = 0 for node in graph.nodes if haskey(nodeDict, typeof(task(node))) nodeDict[typeof(task(node))] = nodeDict[typeof(task(node))] + 1 else nodeDict[typeof(task(node))] = 1 end - noEdges += length(parents(node)) + number_of_edges += length(parents(node)) end if length(graph.nodes) <= 20 @@ -58,9 +58,9 @@ function Base.show(io::IO, graph::DAG) end end println(io) - println(io, " Edges: ", noEdges) + println(io, " Edges: ", number_of_edges) properties = get_properties(graph) - println(io, " Total Compute Effort: ", properties.computeEffort) + println(io, " Total Compute Effort: ", properties.compute_effort) println(io, " Total Data Transfer: ", properties.data) - return println(io, " Total Compute Intensity: ", properties.computeIntensity) + return println(io, " Total Compute Intensity: ", properties.compute_intensity) end diff --git a/src/graph/properties.jl b/src/graph/properties.jl index 9515820..0f52b12 100644 --- a/src/graph/properties.jl +++ b/src/graph/properties.jl @@ -8,7 +8,7 @@ function get_properties(graph::DAG) apply_all!(graph) # TODO: tests stop working without the if condition, which means there is probably a bug in the lazy evaluation and in the tests - if (graph.properties.computeEffort <= 0.0) + if (graph.properties.compute_effort <= 0.0) graph.properties = GraphProperties(graph) end @@ -51,5 +51,5 @@ end Return the number of operations applied to the graph. """ function operation_stack_length(graph::DAG) - return length(graph.appliedOperations) + length(graph.operationsToApply) + return length(graph.applied_operations) + length(graph.operations_to_apply) end diff --git a/src/graph/type.jl b/src/graph/type.jl index b5ceeb4..a5f8715 100644 --- a/src/graph/type.jl +++ b/src/graph/type.jl @@ -7,8 +7,8 @@ A struct storing all possible operations on a [`DAG`](@ref). To get the [`PossibleOperations`](@ref) on a [`DAG`](@ref), use [`get_operations`](@ref). """ mutable struct PossibleOperations - nodeReductions::Set{NodeReduction} - nodeSplits::Set{NodeSplit} + node_reductions::Set{NodeReduction} + node_splits::Set{NodeSplit} end """ @@ -24,16 +24,16 @@ mutable struct DAG nodes::Set{Union{DataTaskNode,ComputeTaskNode}} # The operations currently applied to the set of nodes - appliedOperations::Stack{AppliedOperation} + applied_operations::Stack{AppliedOperation} # The operations not currently applied but part of the current state of the DAG - operationsToApply::Deque{Operation} + operations_to_apply::Deque{Operation} # The possible operations at the current state of the DAG - possibleOperations::PossibleOperations + possible_operations::PossibleOperations # The set of nodes whose possible operations need to be reevaluated - dirtyNodes::Set{Union{DataTaskNode,ComputeTaskNode}} + dirty_nodes::Set{Union{DataTaskNode,ComputeTaskNode}} # "snapshot" system: keep track of added/removed nodes/edges since last snapshot # these are muted in insert_node! etc. diff --git a/src/graph/validate.jl b/src/graph/validate.jl index ad95d19..9caef4b 100644 --- a/src/graph/validate.jl +++ b/src/graph/validate.jl @@ -30,18 +30,18 @@ function is_valid(graph::DAG) @assert is_valid(graph, node) end - for op in graph.operationsToApply + for op in graph.operations_to_apply @assert is_valid(graph, op) end - for nr in graph.possibleOperations.nodeReductions + for nr in graph.possible_operations.node_reductions @assert is_valid(graph, nr) end - for ns in graph.possibleOperations.nodeSplits + for ns in graph.possible_operations.node_splits @assert is_valid(graph, ns) end - for node in graph.dirtyNodes + for node in graph.dirty_nodes @assert node in graph "Dirty Node is not part of the graph!" @assert ismissing(node.nodeReduction) "Dirty Node has a NodeReduction!" @assert ismissing(node.nodeSplit) "Dirty Node has a NodeSplit!" diff --git a/src/operation/apply.jl b/src/operation/apply.jl index 0de5eea..99dce26 100644 --- a/src/operation/apply.jl +++ b/src/operation/apply.jl @@ -4,15 +4,15 @@ Apply all unapplied operations in the DAG. Is automatically called in all functions that require the latest state of the [`DAG`](@ref). """ function apply_all!(graph::DAG) - while !isempty(graph.operationsToApply) + while !isempty(graph.operations_to_apply) # get next operation to apply from front of the deque - op = popfirst!(graph.operationsToApply) + op = popfirst!(graph.operations_to_apply) # apply it appliedOp = apply_operation!(graph, op) - # push to the end of the appliedOperations deque - push!(graph.appliedOperations, appliedOp) + # push to the end of the applied_operations deque + push!(graph.applied_operations, appliedOp) end return nothing end diff --git a/src/operation/clean.jl b/src/operation/clean.jl index eec9d3f..9e0b997 100644 --- a/src/operation/clean.jl +++ b/src/operation/clean.jl @@ -30,11 +30,11 @@ function find_reductions!(graph::DAG, node::Node) if reductionVector !== nothing nr = NodeReduction(reductionVector) - push!(graph.possibleOperations.nodeReductions, nr) + push!(graph.possible_operations.node_reductions, nr) for node in reductionVector if !ismissing(node.nodeReduction) # it can happen that the dirty node becomes part of an existing NodeReduction and overrides those ones now - # this is only a problem insofar the existing NodeReduction has to be deleted and replaced also in the possibleOperations + # this is only a problem insofar the existing NodeReduction has to be deleted and replaced also in the possible_operations invalidate_caches!(graph, node.nodeReduction) end node.nodeReduction = nr @@ -56,7 +56,7 @@ function find_splits!(graph::DAG, node::Node) if (can_split(node)) ns = NodeSplit(node) - push!(graph.possibleOperations.nodeSplits, ns) + push!(graph.possible_operations.node_splits, ns) node.nodeSplit = ns end diff --git a/src/operation/find.jl b/src/operation/find.jl index 6179cec..20e9a2c 100644 --- a/src/operation/find.jl +++ b/src/operation/find.jl @@ -25,25 +25,25 @@ function insert_operation!(ns::NodeSplit) end """ - nr_insertion!(operations::PossibleOperations, nodeReductions::Vector{Vector{NodeReduction}}) + nr_insertion!(operations::PossibleOperations, node_reductions::Vector{Vector{NodeReduction}}) Insert the node reductions into the graph and the nodes' caches. Employs multithreading for speedup. """ function nr_insertion!( - operations::PossibleOperations, nodeReductions::Vector{Vector{NodeReduction}} + operations::PossibleOperations, node_reductions::Vector{Vector{NodeReduction}} ) total_len = 0 - for vec in nodeReductions + for vec in node_reductions total_len += length(vec) end - sizehint!(operations.nodeReductions, total_len) + sizehint!(operations.node_reductions, total_len) - t = @task for vec in nodeReductions - union!(operations.nodeReductions, Set(vec)) + t = @task for vec in node_reductions + union!(operations.node_reductions, Set(vec)) end schedule(t) - @threads for vec in nodeReductions + @threads for vec in node_reductions for op in vec insert_operation!(op) end @@ -55,25 +55,25 @@ function nr_insertion!( end """ - ns_insertion!(operations::PossibleOperations, nodeSplits::Vector{Vector{NodeSplits}}) + ns_insertion!(operations::PossibleOperations, node_splits::Vector{Vector{NodeSplits}}) Insert the node splits into the graph and the nodes' caches. Employs multithreading for speedup. """ function ns_insertion!( - operations::PossibleOperations, nodeSplits::Vector{Vector{NodeSplit}} + operations::PossibleOperations, node_splits::Vector{Vector{NodeSplit}} ) total_len = 0 - for vec in nodeSplits + for vec in node_splits total_len += length(vec) end - sizehint!(operations.nodeSplits, total_len) + sizehint!(operations.node_splits, total_len) - t = @task for vec in nodeSplits - union!(operations.nodeSplits, Set(vec)) + t = @task for vec in node_splits + union!(operations.node_splits, Set(vec)) end schedule(t) - @threads for vec in nodeSplits + @threads for vec in node_splits for op in vec insert_operation!(op) end @@ -124,11 +124,11 @@ function generate_operations(graph::DAG) insert!(trie, candidate) end - nodeReductions = collect(trie) + node_reductions = collect(trie) - for nrVec in nodeReductions + for nrVec in node_reductions # parent sets are ordered and any node can only be part of one nodeReduction, so a NodeReduction is uniquely identifiable by its first element - # this prevents duplicate nodeReductions being generated + # this prevents duplicate node_reductions being generated lock(checkedNodesLock) if (nrVec[1] in checkedNodes) unlock(checkedNodesLock) @@ -144,7 +144,7 @@ function generate_operations(graph::DAG) # launch thread for node reduction insertion # remove duplicates - nr_task = @spawn nr_insertion!(graph.possibleOperations, generatedReductions) + nr_task = @spawn nr_insertion!(graph.possible_operations, generatedReductions) # find possible node splits @threads for node in nodeArray @@ -154,9 +154,9 @@ function generate_operations(graph::DAG) end # launch thread for node split insertion - ns_task = @spawn ns_insertion!(graph.possibleOperations, generatedSplits) + ns_task = @spawn ns_insertion!(graph.possible_operations, generatedSplits) - empty!(graph.dirtyNodes) + empty!(graph.dirty_nodes) wait(nr_task) wait(ns_task) diff --git a/src/operation/get.jl b/src/operation/get.jl index 3294459..ff13398 100644 --- a/src/operation/get.jl +++ b/src/operation/get.jl @@ -10,12 +10,12 @@ Return the [`PossibleOperations`](@ref) of the graph at the current state. function get_operations(graph::DAG) apply_all!(graph) - if isempty(graph.possibleOperations) + if isempty(graph.possible_operations) generate_operations(graph) end - clean_node!.(Ref(graph), graph.dirtyNodes) - empty!(graph.dirtyNodes) + clean_node!.(Ref(graph), graph.dirty_nodes) + empty!(graph.dirty_nodes) - return graph.possibleOperations + return graph.possible_operations end diff --git a/src/operation/iterate.jl b/src/operation/iterate.jl index 1a9ece4..4eb6727 100644 --- a/src/operation/iterate.jl +++ b/src/operation/iterate.jl @@ -5,10 +5,10 @@ _POIteratorStateType = NamedTuple{ } @inline function Base.iterate( - possibleOperations::PossibleOperations + possible_operations::PossibleOperations )::Union{Nothing,_POIteratorStateType} for fieldname in _POSSIBLE_OPERATIONS_FIELDS - iterator = iterate(getfield(possibleOperations, fieldname)) + iterator = iterate(getfield(possible_operations, fieldname)) if (!isnothing(iterator)) return (result=iterator[1], state=(fieldname, iterator[2])) end @@ -18,10 +18,10 @@ _POIteratorStateType = NamedTuple{ end @inline function Base.iterate( - possibleOperations::PossibleOperations, state + possible_operations::PossibleOperations, state )::Union{Nothing,_POIteratorStateType} newStateSym = state[1] - newStateIt = iterate(getfield(possibleOperations, newStateSym), state[2]) + newStateIt = iterate(getfield(possible_operations, newStateSym), state[2]) if !isnothing(newStateIt) return (result=newStateIt[1], state=(newStateSym, newStateIt[2])) end @@ -31,7 +31,7 @@ end while index <= length(_POSSIBLE_OPERATIONS_FIELDS) newStateSym = _POSSIBLE_OPERATIONS_FIELDS[index] - newStateIt = iterate(getfield(possibleOperations, newStateSym)) + newStateIt = iterate(getfield(possible_operations, newStateSym)) if !isnothing(newStateIt) return (result=newStateIt[1], state=(newStateSym, newStateIt[2])) end diff --git a/src/operation/print.jl b/src/operation/print.jl index 274531a..400e32d 100644 --- a/src/operation/print.jl +++ b/src/operation/print.jl @@ -4,14 +4,14 @@ Print a string representation of the set of possible operations to io. """ function Base.show(io::IO, ops::PossibleOperations) - print(io, length(ops.nodeReductions)) + print(io, length(ops.node_reductions)) println(io, " Node Reductions: ") - for nr in ops.nodeReductions + for nr in ops.node_reductions println(io, " - ", nr) end - print(io, length(ops.nodeSplits)) + print(io, length(ops.node_splits)) println(io, " Node Splits: ") - for ns in ops.nodeSplits + for ns in ops.node_splits println(io, " - ", ns) end end diff --git a/src/operation/utility.jl b/src/operation/utility.jl index 0424798..263c926 100644 --- a/src/operation/utility.jl +++ b/src/operation/utility.jl @@ -4,7 +4,7 @@ Return whether `operations` is empty, i.e. all of its fields are empty. """ function Base.isempty(operations::PossibleOperations) - return isempty(operations.nodeReductions) && isempty(operations.nodeSplits) + return isempty(operations.node_reductions) && isempty(operations.node_splits) end """ @@ -14,8 +14,8 @@ Return a named tuple with the number of each of the operation types as a named t """ function Base.length(operations::PossibleOperations) return ( - nodeReductions=length(operations.nodeReductions), - nodeSplits=length(operations.nodeSplits), + node_reductions=length(operations.node_reductions), + node_splits=length(operations.node_splits), ) end @@ -25,7 +25,7 @@ end Delete the given node reduction from the possible operations. """ function Base.delete!(operations::PossibleOperations, op::NodeReduction) - delete!(operations.nodeReductions, op) + delete!(operations.node_reductions, op) return operations end @@ -35,7 +35,7 @@ end Delete the given node split from the possible operations. """ function Base.delete!(operations::PossibleOperations, op::NodeSplit) - delete!(operations.nodeSplits, op) + delete!(operations.node_splits, op) return operations end diff --git a/src/optimization/random_walk.jl b/src/optimization/random_walk.jl index 62d02c2..ab381a7 100644 --- a/src/optimization/random_walk.jl +++ b/src/optimization/random_walk.jl @@ -15,7 +15,7 @@ function optimize_step!(optimizer::RandomWalkOptimizer, graph::DAG) operations = get_operations(graph) if sum(length(operations)) == 0 && - length(graph.appliedOperations) + length(graph.operationsToApply) == 0 + length(graph.applied_operations) + length(graph.operations_to_apply) == 0 # in case there are zero operations possible at all on the graph return false end @@ -29,11 +29,11 @@ function optimize_step!(optimizer::RandomWalkOptimizer, graph::DAG) # choose one of split/reduce option = rand(r, 1:2) - if option == 1 && !isempty(operations.nodeReductions) - push_operation!(graph, rand(r, collect(operations.nodeReductions))) + if option == 1 && !isempty(operations.node_reductions) + push_operation!(graph, rand(r, collect(operations.node_reductions))) return true - elseif option == 2 && !isempty(operations.nodeSplits) - push_operation!(graph, rand(r, collect(operations.nodeSplits))) + elseif option == 2 && !isempty(operations.node_splits) + push_operation!(graph, rand(r, collect(operations.node_splits))) return true end else diff --git a/src/optimization/reduce.jl b/src/optimization/reduce.jl index edd49d2..36bdb0e 100644 --- a/src/optimization/reduce.jl +++ b/src/optimization/reduce.jl @@ -14,14 +14,14 @@ function optimize_step!(optimizer::ReductionOptimizer, graph::DAG) return false end - push_operation!(graph, first(operations.nodeReductions)) + push_operation!(graph, first(operations.node_reductions)) return true end function fixpoint_reached(optimizer::ReductionOptimizer, graph::DAG) operations = get_operations(graph) - return isempty(operations.nodeReductions) + return isempty(operations.node_reductions) end function optimize_to_fixpoint!(optimizer::ReductionOptimizer, graph::DAG) diff --git a/src/optimization/split.jl b/src/optimization/split.jl index 0d38dae..b610494 100644 --- a/src/optimization/split.jl +++ b/src/optimization/split.jl @@ -14,14 +14,14 @@ function optimize_step!(optimizer::SplitOptimizer, graph::DAG) return false end - push_operation!(graph, first(operations.nodeSplits)) + push_operation!(graph, first(operations.node_splits)) return true end function fixpoint_reached(optimizer::SplitOptimizer, graph::DAG) operations = get_operations(graph) - return isempty(operations.nodeSplits) + return isempty(operations.node_splits) end function optimize_to_fixpoint!(optimizer::SplitOptimizer, graph::DAG) diff --git a/src/properties/create.jl b/src/properties/create.jl index 72f7b9a..5d0de61 100644 --- a/src/properties/create.jl +++ b/src/properties/create.jl @@ -5,7 +5,11 @@ Create an empty [`GraphProperties`](@ref) object. """ function GraphProperties() return ( - data=0.0, computeEffort=0.0, computeIntensity=0.0, noNodes=0, noEdges=0 + data=0.0, + compute_effort=0.0, + compute_intensity=0.0, + number_of_nodes=0, + number_of_edges=0, )::GraphProperties end @@ -41,10 +45,10 @@ function GraphProperties(graph::DAG) return ( data=d, - computeEffort=ce, - computeIntensity=(d == 0) ? 0.0 : ce / d, - noNodes=length(graph.nodes), - noEdges=ed, + compute_effort=ce, + compute_intensity=(d == 0) ? 0.0 : ce / d, + number_of_nodes=length(graph.nodes), + number_of_edges=ed, )::GraphProperties end @@ -66,9 +70,9 @@ function GraphProperties(diff::Diff) return ( data=d, - computeEffort=ce, - computeIntensity=(d == 0) ? 0.0 : ce / d, - noNodes=length(diff.addedNodes) - length(diff.removedNodes), - noEdges=length(diff.addedEdges) - length(diff.removedEdges), + compute_effort=ce, + compute_intensity=(d == 0) ? 0.0 : ce / d, + number_of_nodes=length(diff.addedNodes) - length(diff.removedNodes), + number_of_edges=length(diff.addedEdges) - length(diff.removedEdges), )::GraphProperties end diff --git a/src/properties/type.jl b/src/properties/type.jl index 9fea9c6..7f7e001 100644 --- a/src/properties/type.jl +++ b/src/properties/type.jl @@ -5,12 +5,12 @@ Representation of a [`DAG`](@ref)'s properties. # Fields: `.data`: The total data transfer.\\ -`.computeEffort`: The total compute effort.\\ -`.computeIntensity`: The compute intensity, will always equal `.computeEffort / .data`.\\ -`.noNodes`: Number of [`Node`](@ref)s.\\ -`.noEdges`: Number of [`Edge`](@ref)s. +`.compute_effort`: The total compute effort.\\ +`.compute_intensity`: The compute intensity, will always equal `.compute_effort / .data`.\\ +`.number_of_nodes`: Number of [`Node`](@ref)s.\\ +`.number_of_edges`: Number of [`Edge`](@ref)s. """ const GraphProperties = NamedTuple{ - (:data, :computeEffort, :computeIntensity, :noNodes, :noEdges), + (:data, :compute_effort, :compute_intensity, :number_of_nodes, :number_of_edges), Tuple{Float64,Float64,Float64,Int,Int}, } diff --git a/src/properties/utility.jl b/src/properties/utility.jl index 1760d2a..ceddc76 100644 --- a/src/properties/utility.jl +++ b/src/properties/utility.jl @@ -7,14 +7,14 @@ Also take care to keep consistent compute intensity. function Base.:-(prop1::GraphProperties, prop2::GraphProperties) return ( data=prop1.data - prop2.data, - computeEffort=prop1.computeEffort - prop2.computeEffort, - computeIntensity=if (prop1.data - prop2.data == 0) + compute_effort=prop1.compute_effort - prop2.compute_effort, + compute_intensity=if (prop1.data - prop2.data == 0) 0.0 else - (prop1.computeEffort - prop2.computeEffort) / (prop1.data - prop2.data) + (prop1.compute_effort - prop2.compute_effort) / (prop1.data - prop2.data) end, - noNodes=prop1.noNodes - prop2.noNodes, - noEdges=prop1.noEdges - prop2.noEdges, + number_of_nodes=prop1.number_of_nodes - prop2.number_of_nodes, + number_of_edges=prop1.number_of_edges - prop2.number_of_edges, )::GraphProperties end @@ -27,28 +27,28 @@ Also take care to keep consistent compute intensity. function Base.:+(prop1::GraphProperties, prop2::GraphProperties) return ( data=prop1.data + prop2.data, - computeEffort=prop1.computeEffort + prop2.computeEffort, - computeIntensity=if (prop1.data + prop2.data == 0) + compute_effort=prop1.compute_effort + prop2.compute_effort, + compute_intensity=if (prop1.data + prop2.data == 0) 0.0 else - (prop1.computeEffort + prop2.computeEffort) / (prop1.data + prop2.data) + (prop1.compute_effort + prop2.compute_effort) / (prop1.data + prop2.data) end, - noNodes=prop1.noNodes + prop2.noNodes, - noEdges=prop1.noEdges + prop2.noEdges, + number_of_nodes=prop1.number_of_nodes + prop2.number_of_nodes, + number_of_edges=prop1.number_of_edges + prop2.number_of_edges, )::GraphProperties end """ -(prop::GraphProperties) -Unary negation of the graph properties. `.computeIntensity` will not be negated because `.data` and `.computeEffort` both are. +Unary negation of the graph properties. `.compute_intensity` will not be negated because `.data` and `.compute_effort` both are. """ function Base.:-(prop::GraphProperties) return ( data=-prop.data, - computeEffort=-prop.computeEffort, - computeIntensity=prop.computeIntensity, # no negation here! - noNodes=-prop.noNodes, - noEdges=-prop.noEdges, + compute_effort=-prop.compute_effort, + compute_intensity=prop.compute_intensity, # no negation here! + number_of_nodes=-prop.number_of_nodes, + number_of_edges=-prop.number_of_edges, )::GraphProperties end diff --git a/src/utils.jl b/src/utils.jl index d320ea5..b8593e0 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -72,18 +72,18 @@ function mem(graph::DAG) size += mem(n) end - size += sizeof(graph.appliedOperations) - size += sizeof(graph.operationsToApply) + size += sizeof(graph.applied_operations) + size += sizeof(graph.operations_to_apply) - size += sizeof(graph.possibleOperations) - for op in graph.possibleOperations.nodeReductions + size += sizeof(graph.possible_operations) + for op in graph.possible_operations.node_reductions size += mem(op) end - for op in graph.possibleOperations.nodeSplits + for op in graph.possible_operations.node_splits size += mem(op) end - size += Base.summarysize(graph.dirtyNodes; exclude=Union{Node}) + size += Base.summarysize(graph.dirty_nodes; exclude=Union{Node}) return size += sizeof(diff) end diff --git a/test/strassen_test.jl b/test/strassen_test.jl index e67491e..61c927d 100644 --- a/test/strassen_test.jl +++ b/test/strassen_test.jl @@ -38,8 +38,8 @@ EDGE_NUMBERS = (3, 96, 747, 5304) #, 37203 @test get_exit_node(g) isa DataTaskNode props = get_properties(g) - @test NODE_NUM_EXPECTED == props.noNodes - @test EDGE_NUM_EXPECTED == props.noEdges + @test NODE_NUM_EXPECTED == props.number_of_nodes + @test EDGE_NUM_EXPECTED == props.number_of_edges end f = get_compute_function(g, mm, cpu_st(), @__MODULE__) From e6c8fc00179f65eaf182d33b233faa8f358bab48 Mon Sep 17 00:00:00 2001 From: Anton Reinhard Date: Wed, 27 Nov 2024 16:49:58 +0100 Subject: [PATCH 10/12] renaming/docs/removing unused interfaces fix type stability in closures --- src/ComputableDAGs.jl | 6 +- src/code_gen/tape_machine.jl | 132 +++++++++++++++++++---------------- src/code_gen/type.jl | 72 +++++++++++++++++-- src/code_gen/utils.jl | 23 +++--- src/models/interface.jl | 40 +++-------- src/properties/type.jl | 12 ++-- src/scheduler/interface.jl | 8 --- src/scheduler/type.jl | 35 +--------- src/task/compute.jl | 59 ++++++++++++---- 9 files changed, 217 insertions(+), 170 deletions(-) diff --git a/src/ComputableDAGs.jl b/src/ComputableDAGs.jl index 63bc257..c0b8b7c 100644 --- a/src/ComputableDAGs.jl +++ b/src/ComputableDAGs.jl @@ -35,7 +35,6 @@ export get_operations # code generation related export get_compute_function -export gen_tape # estimator export cost_type, graph_cost, operation_effect @@ -48,8 +47,7 @@ export optimize_step!, optimize! export fixpoint_reached, optimize_to_fixpoint! # models -export AbstractModel, AbstractProblemInstance -export problem_instance, input_type, graph, input_expr +export graph, input_type, input_expr # machine info export Machine @@ -67,6 +65,7 @@ include("properties/type.jl") include("operation/type.jl") include("graph/type.jl") include("scheduler/type.jl") +include("code_gen/type.jl") include("trie.jl") include("utils.jl") @@ -125,7 +124,6 @@ include("devices/ext.jl") include("scheduler/interface.jl") include("scheduler/greedy.jl") -include("code_gen/type.jl") include("code_gen/utils.jl") include("code_gen/tape_machine.jl") include("code_gen/function.jl") diff --git a/src/code_gen/tape_machine.jl b/src/code_gen/tape_machine.jl index dc7c795..767dc9b 100644 --- a/src/code_gen/tape_machine.jl +++ b/src/code_gen/tape_machine.jl @@ -1,4 +1,4 @@ -function expr_from_fc(fc::FunctionCall{VAL_T,F_T}) where {VAL_T,F_T<:Function} +function expr_from_fc(fc::FunctionCall{VAL_T,<:Function}) where {VAL_T} if length(fc) == 1 func_call = Expr(:call, fc.func, fc.value_arguments[1]..., fc.arguments[1]...) else @@ -10,14 +10,35 @@ function expr_from_fc(fc::FunctionCall{VAL_T,F_T}) where {VAL_T,F_T<:Function} end function expr_from_fc(fc::FunctionCall{VAL_T,Expr}) where {VAL_T} - @assert length(fc) == 1 && isempty(fc.arguments[1]) && isempty(fc.value_arguments[1]) "function call assigning an expression has an unallowed combination of arguments, which is not allowed\n$fc" - return Expr(:(=), gen_access_expr(fc), fc.func) + @assert length(fc) == 1 && isempty(fc.value_arguments[1]) "function call assigning an expression cannot be vectorized and cannot contain value arguments\n$fc" + + fc_expr_in_let = Expr( + :let, + Expr(:block, fc.return_symbols[1]...), + fc.func, # anonymous function code block + ) + + func_call = Expr( + :call, # call + Expr( + :->, # anonymous function + Expr( + :tuple, # anonymous function arguments + fc.arguments[1]..., + ), + fc_expr_in_let, + ), + fc.arguments[1]..., # runtime arguments passed to the anonymous function + ) + + access_expr = gen_access_expr(fc) + return Expr(:(=), access_expr, func_call) end """ gen_input_assignment_code( input_symbols::Dict{String, Vector{Symbol}}, - instance::AbstractProblemInstance, + instance::Any, machine::Machine, input_type::Type, context_module::Module @@ -26,40 +47,22 @@ end Return a `Vector{Expr}` doing the input assignments from the given `problem_input` onto the `input_symbols`. """ function gen_input_assignment_code( - input_symbols::Dict{String,Vector{Symbol}}, - instance, - machine::Machine, - input_type::Type, - context_module::Module, + input_symbols::Dict{String,Vector{Symbol}}, instance, machine::Machine ) assign_inputs = Vector{FunctionCall}() for (name, symbols) in input_symbols for symbol in symbols device = entry_device(machine) - f_id = Symbol(to_var_name(UUIDs.uuid1(rng[threadid()]))) - - fc_setup = FunctionCall( - Expr(:->, :x, input_expr(instance, name, :x)), + fc = FunctionCall( + input_expr(instance, name, :input), (), - Symbol[], - Symbol[f_id], + Symbol[:input], + Symbol[symbol], Type[Nothing], device, ) - fc = FunctionCall( - _call, (), Symbol[f_id, :input], Symbol[symbol], Type[Nothing], device - ) - - ret_expr = Expr( - :call, Base.return_types, fc_setup.func, Expr(:tuple, input_type) - ) - ret_type = context_module.eval(ret_expr) - @assert length(ret_type) == 1 - fc.return_types = [ret_type[1]] - - push!(assign_inputs, fc_setup) push!(assign_inputs, fc) end end @@ -73,13 +76,15 @@ end Generate the function body from the given [`Tape`](@ref). ## Keyword Arguments -`closures_size`: The size of closures to generate (in lines of code). Closures introduce function barriers in the function body, preventing some optimizations by the compiler and therefore greatly reducing compile time. A value of 1 or less will disable the use of closures entirely. +`closures_size`: The size of closures to generate (in lines of code). Closures introduce function barriers + in the function body, preventing some optimizations by the compiler and therefore greatly reducing + compile time. A value of 0 will disable the use of closures entirely. """ function gen_function_body(tape::Tape, context_module::Module; closures_size::Int) # only need to annotate types later when using closures - types = infer_types!(tape) + types = infer_types!(tape, context_module) - if closures_size >= 1 + if closures_size > 1 s = log(closures_size, length(tape.schedule)) closures_depth = ceil(Int, s) # tend towards more levels/smaller closures closures_size = ceil(Int, length(tape.schedule)^(1 / closures_depth)) @@ -113,15 +118,14 @@ function _gen_function_body( closured_fc_vec = FunctionCall[] for i in length(fc_vec):(-closures_size):1 e = i - b = max(i - closures_size, 1) + b = max(i - closures_size + 1, 1) code_block = fc_vec[b:e] - pushfirst!( - closured_fc_vec, - _closure_fc( - code_block, type_dict, machine, undefined_argument_symbols, context_module - ), + closure_fc = _closure_fc( + code_block, type_dict, machine, undefined_argument_symbols, context_module ) + + pushfirst!(closured_fc_vec, closure_fc) end return _gen_function_body( @@ -130,9 +134,16 @@ function _gen_function_body( end """ - _closure_fc() + _closure_fc( + code_block::AbstractVector{FunctionCall}, + types::Dict{Symbol,Type}, + machine::Machine, + undefined_argument_symbols::Set{Symbol}, + context_module::Module, + ) -From the given function calls, make and return a new function call representing all of them together. +From the given function calls, make and return 2 function calls representing all of them together. 2 function calls are necessary, one for setting up the anonymous +function and the second for calling it. The undefined_argument_symbols is the set of all Symbols that need to be returned if available inside the code_block. They get updated inside this function. """ function _closure_fc( @@ -168,29 +179,28 @@ function _closure_fc( arg_symbols_t = [arg_symbols_set...] ret_symbols_t = [ret_symbols_set...] - closure = context_module.eval( - Expr( # create the closure: () -> code block; return (locals) - :->, - Expr(:tuple, arg_symbols_t...), # closure arguments - Expr( # actual function body of the closure - :block, - expr_from_fc.(code_block)..., - Expr( - :return, # have to make sure to not return a tuple of length 1 - if length(ret_symbols_t) == 1 - ret_symbols_t[1] - else - Expr(:tuple, ret_symbols_t...) - end, - ), - ), + ret_types = (getindex.(Ref(types), ret_symbols_t)) + + fc_expr = Expr( # actual function body of the closure + :block, + expr_from_fc.(code_block)..., + Expr( + :return, # have to make sure to not return a tuple of length 1 + if length(ret_symbols_t) == 1 + ret_symbols_t[1] + else + Expr(:tuple, ret_symbols_t...) + end, ), ) - ret_types = (getindex.(Ref(types), ret_symbols_t)) - fc = FunctionCall( - closure, (), arg_symbols_t, ret_symbols_t, ret_types, entry_device(machine) + fc_expr, + (), + Symbol[arg_symbols_t...], + ret_symbols_t, + ret_types, + entry_device(machine), ) setdiff!(undefined_argument_symbols, ret_symbols_set) @@ -201,7 +211,7 @@ end """ gen_tape( graph::DAG, - instance::AbstractProblemInstance, + instance::Any, machine::Machine, context_module::Module, scheduler::AbstractScheduler = GreedyScheduler() @@ -233,10 +243,8 @@ function gen_tape( # get outSymbol outSym = Symbol(to_var_name(get_exit_node(graph).id)) - INPUT_T = input_type(instance) - assign_inputs = gen_input_assignment_code( - input_syms, instance, machine, INPUT_T, context_module - ) + assign_inputs = gen_input_assignment_code(input_syms, instance, machine) + INPUT_T = input_type(instance) return Tape{INPUT_T}(assign_inputs, function_body, outSym, instance, machine) end diff --git a/src/code_gen/type.jl b/src/code_gen/type.jl index 4e2f004..45e653a 100644 --- a/src/code_gen/type.jl +++ b/src/code_gen/type.jl @@ -1,12 +1,76 @@ +""" + FunctionCall{VAL_T<:Tuple,FUNC_T<:Union{Function,Expr}} + +Representation of a function call. Contains the function to call (or an expression of a value to assign), +value arguments of type `VAL_T`, argument symbols, the return symbol(s) and type(s) and the device to execute on. + +To support vectorization, i.e., calling the same function on multiple inputs (SIMD), the value arguments, arguments, +and return symbols are each vectors of the actual inputs. In the non-vectorized case, these `Vector`s simply always +have length 1. For this common case, a special constructor exists which automatically wraps each of these arguments +in a `Vector`. + +## Type Arguments +- `VAL_T<:Tuple`: A tuple of all the value arguments that are passed to the function when it's called. +- `FUNC_T<:Union{Function, Expr}`: The type of the function. `Function` is the default, but in some cases, an `Expr` + of a value can be necessary to assign to the return symbol. In this case, no arguments are allowed. + +## Fields +- `func::FUNC_T`: The function to be called, or an expression containing a value to assign to the return_symbol. +- `value_arguments::Vector{VAL_T}`: The value arguments for the function call. These are passed *first* to the + function, in the order given here. The `Vector` contains the tuple of value arguments for each vectorization + member. +- `arguments::Vector{Vector{Symbol}}`: The first vector represents the vectorization, the second layer represents the + symbols that will be passed as arguments to the function call. +- `return_symbols::Vector{Vector{Symbol}}`: As with the arguments, the first vector level represents the vectorization, + the second represents the symbols that the results of the function call are assigned to. For most function calls, + there is only one return symbol. When using closures when generating a function body for a [`Tape`](@ref), the + option to have multiple return symbols is necessary. +- `return_types::Vector{<:Type}`: The types of the function call with the arguments provided. This field only contains + one level of Vector, because it is required that a `FunctionCall` is type stable, and therefore, the types of the + return symbols have to be equal for all members of a vectorization. The return type is initially set to `Nothing` + and later inferred and assigned by [`infer_types!`](@ref). +- `device::AbstractDevice`: The device that this function call is scheduled on. +""" +mutable struct FunctionCall{VAL_T<:Tuple,FUNC_T<:Union{Function,Expr}} + func::FUNC_T + value_arguments::Vector{VAL_T} # tuple of value arguments for the function call, will be prepended to the other arguments + arguments::Vector{Vector{Symbol}} # symbols of the inputs to the function call + return_symbols::Vector{Vector{Symbol}} # the return symbols + return_types::Vector{<:Type} # the return type of the function call(s); there can only be one return type since we require type stability + device::AbstractDevice +end +function FunctionCall( + func::Union{Function,Expr}, + value_arguments::VAL_T, + arguments::Vector{Symbol}, + return_symbol::Vector{Symbol}, + return_types::Vector{<:Type}, + device::AbstractDevice, +) where {VAL_T<:Tuple} + # convenience constructor for function calls that do not use vectorization, which is most of the use cases + @assert length(return_types) == 0 || length(return_types) == length(return_symbol) "number of return types '$(length(return_types))' does not match the number of return symbols '$(length(return_symbol))'" + @assert func isa Function || length(value_arguments) == 0 "no value arguments are allowed for a an Expr FunctionCall, but got '$value_arguments'" + return FunctionCall( + func, [value_arguments], [arguments], [return_symbol], return_types, device + ) +end """ Tape{INPUT} -TODO: update docs -- `INPUT` the input type of the problem instance +Lowered representation of a computation, generated from a [`DAG`](@ref) through [`gen_tape`](@ref). + +- `INPUT` the input type of the problem instance, see also the interface function [`input_type`](@ref) -- `code::Vector{Expr}`: The julia expression containing the code for the whole graph. -- `output_symbol::Symbol`: The symbol of the final calculated value +## Fields +- `input_assign_code::Vector{FunctionCall}`: The [`FunctionCall`](@ref)s representing the input assignments, + mapping part of the input of the computation to each DAG entry node. These functions are generated using + the interface function [`input_expr`](@ref). +- `schedule::Vector{FunctionCall}`: The [`FunctionCall`](@ref)s representing the function body of the computation. + There is one function call for each node in the [`DAG`](@ref). +- `output_symbol::Symbol`: The symbol of the final calculated value, which is returned. +- `instance::Any`: The instance that this tape is generated for. +- `machine::Machine`: The [`Machine`](@ref) that this tape is generated for. """ struct Tape{INPUT} input_assign_code::Vector{FunctionCall} diff --git a/src/code_gen/utils.jl b/src/code_gen/utils.jl index 19eb56c..8207406 100644 --- a/src/code_gen/utils.jl +++ b/src/code_gen/utils.jl @@ -1,23 +1,22 @@ -""" - infer_types!(schedule::Vector{FunctionCall}) +function Base.length(fc::FunctionCall) + @assert length(fc.value_arguments) == length(fc.arguments) == length(fc.return_symbols) "function call length is undefined, got '$(length(fc.value_arguments))' tuples of value arguments, '$(length(fc.arguments))' tuples of arguments, and '$(length(return_symbols))' return symbols" + return length(fc.value_arguments) +end -Infer the result type of each function call in the given schedule. Returns a dictionary with the result type for each [`Node`](@ref). This assumes that each node has only one statically inferrable return type and will throw an exceptin otherwise. -This also assumes that the given `Vector` contains a topological ordering of its nodes, such as returned by a call to [`schedule_dag`](@ref). +""" + infer_types!(tape::Tape, context_module::Module) -Also returns the inferred types as a `Dict{Symbol, Type}`. +Infer the result type of each function call in the given tape. Returns a dictionary with the result type for each symbol and sets each function call's return_types. +This function assumes that each [`FunctionCall`](@ref) has only one statically inferrable return type and will throw an exception otherwise. """ -function infer_types!(tape::Tape) +function infer_types!(tape::Tape, context_module::Module) known_result_types = Dict{Symbol,Type}() # the only initially known type known_result_types[:input] = input_type(tape.instance) for fc in tape.input_assign_code - if typeof(fc.func) isa Expr - continue - end - # for input assign code, the return types are set on construction - res_types = fc.return_types + res_types = result_types(fc, known_result_types, context_module) for (s, t) in Iterators.zip( Iterators.flatten(fc.return_symbols), Iterators.cycle(res_types, length(fc.return_symbols)), @@ -27,7 +26,7 @@ function infer_types!(tape::Tape) end for fc in tape.schedule - res_types = result_types(fc, known_result_types) + res_types = result_types(fc, known_result_types, context_module) fc.return_types = res_types for (s, t) in Iterators.zip( Iterators.flatten(fc.return_symbols), diff --git a/src/models/interface.jl b/src/models/interface.jl index 1672e40..81bf46c 100644 --- a/src/models/interface.jl +++ b/src/models/interface.jl @@ -1,44 +1,26 @@ - -""" - AbstractModel - -Base type for all models. From this, [`AbstractProblemInstance`](@ref)s can be constructed. - -See also: [`problem_instance`](@ref) """ -abstract type AbstractModel end + input_type(problem_instance) -""" - problem_instance(::AbstractModel, ::Vararg) - -Interface function that must be implemented for any implementation of [`AbstractModel`](@ref). This function should return a specific [`AbstractProblemInstance`](@ref) given some parameters. -""" -function problem_instance end +Return the input type for a specific `problem_instance`. This can be a specific type or a supertype for which all child types are expected to be implemented. -""" - AbstractProblemInstance - -Base type for problem instances. An object of this type of a corresponding [`AbstractModel`](@ref) should uniquely identify a problem instance of that model. -""" -abstract type AbstractProblemInstance end - -""" - input_type(problem::AbstractProblemInstance) - -Return the input type for a specific [`AbstractProblemInstance`](@ref). This can be a specific type or a supertype for which all child types are expected to work. +For more details on the `problem_instance`, please refer to the documentation. """ function input_type end """ - graph(::AbstractProblemInstance) + graph(problem_instance) + +Generate the [`DAG`](@ref) for the given `problem_instance`. Every entry node (see [`get_entry_nodes`](@ref)) to the graph must have a name set. Implement [`input_expr`](@ref) to return a valid expression for each of those names. -Generate the [`DAG`](@ref) for the given [`AbstractProblemInstance`](@ref). Every entry node (see [`get_entry_nodes`](@ref)) to the graph must have a name set. Implement [`input_expr`](@ref) to return a valid expression for each of those names. +For more details on the `problem_instance`, please refer to the documentation. """ function graph end """ - input_expr(instance::AbstractProblemInstance, name::String, input_symbol::Symbol) + input_expr(problem_instance, name::String, input_symbol::Symbol) + +For the given `problem_instance`, the entry node name, and the symbol of the problem input (where a variable of type `input_type(...)` will exist), return an `Expr` that gets that specific input value from the input symbol. -For the given [`AbstractProblemInstance`](@ref), the entry node name, and the symbol of the problem input (where a variable of type `input_type(...)` will exist), return an `Expr` that gets that specific input value from the input symbol. +For more details on the `problem_instance`, please refer to the documentation. """ function input_expr end diff --git a/src/properties/type.jl b/src/properties/type.jl index 7f7e001..afa9f19 100644 --- a/src/properties/type.jl +++ b/src/properties/type.jl @@ -3,12 +3,12 @@ Representation of a [`DAG`](@ref)'s properties. -# Fields: -`.data`: The total data transfer.\\ -`.compute_effort`: The total compute effort.\\ -`.compute_intensity`: The compute intensity, will always equal `.compute_effort / .data`.\\ -`.number_of_nodes`: Number of [`Node`](@ref)s.\\ -`.number_of_edges`: Number of [`Edge`](@ref)s. +## Fields: +- `data::Float64`: The total data transfer. +- `compute_effort::Float64`: The total compute effort. +- `compute_intensity::Float64`: The compute intensity, will always equal `compute_effort / data`. +- `number_of_nodes::Int`: Number of [`Node`](@ref)s. +- `number_of_edges::Int`: Number of [`Edge`](@ref)s. """ const GraphProperties = NamedTuple{ (:data, :compute_effort, :compute_intensity, :number_of_nodes, :number_of_edges), diff --git a/src/scheduler/interface.jl b/src/scheduler/interface.jl index 1dcbfbf..b452a36 100644 --- a/src/scheduler/interface.jl +++ b/src/scheduler/interface.jl @@ -1,11 +1,3 @@ - -""" - AbstractScheduler - -Abstract base type for scheduler implementations. The scheduler is used to assign each node to a device and create a topological ordering of tasks. -""" -abstract type AbstractScheduler end - """ schedule_dag(::Scheduler, ::DAG, ::Machine) diff --git a/src/scheduler/type.jl b/src/scheduler/type.jl index 28452a8..7eb4062 100644 --- a/src/scheduler/type.jl +++ b/src/scheduler/type.jl @@ -1,35 +1,6 @@ """ - FunctionCall{VAL_TYPES} + AbstractScheduler -Type representing a function call. Contains the function to call, argument symbols, the return symbol and the device to execute on. - -TODO: extend docs +Abstract base type for scheduler implementations. The scheduler is used to assign each node to a device and create a topological ordering of tasks. """ -mutable struct FunctionCall{VAL_T<:Tuple,FUNC_T<:Union{Function,Expr}} - func::FUNC_T - value_arguments::Vector{VAL_T} # tuple of value arguments for the function call, will be prepended to the other arguments - arguments::Vector{Vector{Symbol}} # symbols of the inputs to the function call - return_symbols::Vector{Vector{Symbol}} # the return symbols - return_types::Vector{<:Type} # the return type of the function call(s); there can only be one return type since we require type stability - device::AbstractDevice -end - -function FunctionCall( - func::Union{Function,Expr}, - value_arguments::VAL_T, - arguments::Vector{Symbol}, - return_symbol::Vector{Symbol}, - return_types::Vector{<:Type}, - device::AbstractDevice, -) where {VAL_T<:Tuple} - # convenience constructor for function calls that do not use vectorization, which is most of the use cases - @assert length(return_types) == 0 || length(return_types) == length(return_symbol) "number of return types $(length(return_types)) does not match the number of return symbols $(length(return_symbol))" - return FunctionCall( - func, [value_arguments], [arguments], [return_symbol], return_types, device - ) -end - -function Base.length(fc::FunctionCall) - @assert length(fc.value_arguments) == length(fc.arguments) == length(fc.return_symbols) "function call length is undefined, got $(length(fc.value_arguments)) tuples of value arguments, $(length(fc.arguments)) tuples of arguments, and $(length(return_symbols)) return symbols" - return length(fc.value_arguments) -end +abstract type AbstractScheduler end diff --git a/src/task/compute.jl b/src/task/compute.jl index c2e9fdd..b76e238 100644 --- a/src/task/compute.jl +++ b/src/task/compute.jl @@ -62,14 +62,8 @@ function _argument_types(known_res_types::Dict{Symbol,Type}, fc::FunctionCall) return getindex.(Ref(known_res_types), fc.arguments[1]) end -function result_types( - fc::FunctionCall{VAL_T,F_T}, known_res_types::Dict{Symbol,Type} -) where {VAL_T,F_T<:Function} - arg_types = (_value_argument_types(fc)..., _argument_types(known_res_types, fc)...) - types = Base.return_types(fc.func, arg_types) - +function _validate_result_types(fc::FunctionCall, types, arg_types) N_RET = length(fc.return_types) - if length(types) > 1 throw( "failure during type inference: function call $(fc.func) with argument types $(arg_types) is type unstable, possible return types: $types", @@ -85,21 +79,60 @@ function result_types( end if (N_RET == 1) - return [types[1]] + return nothing end if !(types[1] isa Tuple) || length(types[1].parameters) != N_RET throw( - "failure durng type inference: function call $(fc.func) was expected to return a Tuple with $N_RET elements, but returns $(types[1])", + "failure during type inference: function call $(fc.func) was expected to return a Tuple with $N_RET elements, but returns $(types[1])", ) end + return nothing +end + +function result_types( + fc::FunctionCall{VAL_T,F_T}, known_res_types::Dict{Symbol,Type}, context_module::Module +) where {VAL_T,F_T<:Function} + arg_types = (_value_argument_types(fc)..., _argument_types(known_res_types, fc)...) + types = Base.return_types(fc.func, arg_types) + + _validate_result_types(fc, types, arg_types) + + N_RET = length(fc.return_types) + if (N_RET == 1) + return [types[1]] + end + return [types[1].parameters...] end function result_types( - fc::FunctionCall{VAL_T,Expr}, known_res_types::Dict{Symbol,Type} + fc::FunctionCall{VAL_T,Expr}, known_res_types::Dict{Symbol,Type}, context_module::Module ) where {VAL_T} - # assume that the return type is already set - @assert length(fc.return_types) == 1 - return [fc.return_types[1]] + arg_types = _argument_types(known_res_types, fc) + ret_expr = Expr( + :call, + Base.return_types, # return types call + Expr( # function argument to return_types + :->, # anonymous function + Expr( + :tuple, # anonymous function arguments + fc.arguments[1]..., + ), + fc.func, # anonymous function code block + ), + Expr(:tuple, arg_types...), # types arguments to return_types + ) + types = context_module.eval(ret_expr) + + #@info "evaluation of expression\n$ret_expr\ngives\n$types" + + _validate_result_types(fc, types, arg_types) + + N_RET = length(fc.return_types) + if (N_RET == 1) + return [types[1]] + end + + return [types[1].parameters...] end From 9290549572f8971ece3b84f57a0a5666fb2839dc Mon Sep 17 00:00:00 2001 From: AntonReinhard Date: Fri, 29 Nov 2024 15:21:53 +0100 Subject: [PATCH 11/12] Work on performance problems with closures (again...) --- src/code_gen/function.jl | 5 ++--- src/code_gen/tape_machine.jl | 41 ++++++++++++++++++++---------------- src/code_gen/utils.jl | 1 + src/devices/impl.jl | 25 ++++++++++++++++------ 4 files changed, 45 insertions(+), 27 deletions(-) diff --git a/src/code_gen/function.jl b/src/code_gen/function.jl index c73c99c..6a957e8 100644 --- a/src/code_gen/function.jl +++ b/src/code_gen/function.jl @@ -27,13 +27,12 @@ function get_compute_function( ) tape = gen_tape(graph, instance, machine, context_module) - assign_inputs = Expr(:block, expr_from_fc.(tape.input_assign_code)...) code = gen_function_body(tape, context_module; closures_size=closures_size) + assign_inputs = Expr(:block, expr_from_fc.(tape.input_assign_code)...) function_id = to_var_name(UUIDs.uuid1(rng[1])) res_sym = tape.output_symbol - expr = # - Expr( + expr = Expr( :function, # function definition Expr( :call, Symbol("compute_$function_id"), Expr(:(::), :input, input_type(instance)) diff --git a/src/code_gen/tape_machine.jl b/src/code_gen/tape_machine.jl index 767dc9b..4609dfb 100644 --- a/src/code_gen/tape_machine.jl +++ b/src/code_gen/tape_machine.jl @@ -12,23 +12,36 @@ end function expr_from_fc(fc::FunctionCall{VAL_T,Expr}) where {VAL_T} @assert length(fc) == 1 && isempty(fc.value_arguments[1]) "function call assigning an expression cannot be vectorized and cannot contain value arguments\n$fc" - fc_expr_in_let = Expr( - :let, - Expr(:block, fc.return_symbols[1]...), - fc.func, # anonymous function code block + fc_expr = Expr( + :block, + gen_local_init(fc), + fc.func, # anonymous function code block + Expr( # return the symbols + :return, + ( + if length(fc.return_symbols[1]) == 1 + fc.return_symbols[1][1] + else + Expr(:tuple, fc.return_symbols[1]...) + end + ), + ), ) func_call = Expr( :call, # call + #wrap_in_let_statement( # wrap in let statement to prevent boxing of local variables Expr( :->, # anonymous function Expr( :tuple, # anonymous function arguments - fc.arguments[1]..., + #fc.arguments[1]..., ), - fc_expr_in_let, + fc_expr, # anonymous function code block ), - fc.arguments[1]..., # runtime arguments passed to the anonymous function + #fc.arguments[1], + #), + #fc.arguments[1]..., # runtime arguments passed to the anonymous function call ) access_expr = gen_access_expr(fc) @@ -55,7 +68,7 @@ function gen_input_assignment_code( device = entry_device(machine) fc = FunctionCall( - input_expr(instance, name, :input), + Expr(:(=), symbol, input_expr(instance, name, :input)), (), Symbol[:input], Symbol[symbol], @@ -110,7 +123,7 @@ function _gen_function_body( end # iterate from end to beginning - # this helps because we can collect all undefined arguments to the closures that have to be returned somewhere earlier + # this helps because we can collect all undefined arguments to the closures that have to be defined somewhere earlier undefined_argument_symbols = Set{Symbol}() # the final return symbol is the return of the entire generated function, it always has to be returned push!(undefined_argument_symbols, gen_access_expr(fc_vec[end])) @@ -183,15 +196,7 @@ function _closure_fc( fc_expr = Expr( # actual function body of the closure :block, - expr_from_fc.(code_block)..., - Expr( - :return, # have to make sure to not return a tuple of length 1 - if length(ret_symbols_t) == 1 - ret_symbols_t[1] - else - Expr(:tuple, ret_symbols_t...) - end, - ), + expr_from_fc.(code_block)..., # no return statement necessary, will be done via capture and local init ) fc = FunctionCall( diff --git a/src/code_gen/utils.jl b/src/code_gen/utils.jl index 8207406..e477c68 100644 --- a/src/code_gen/utils.jl +++ b/src/code_gen/utils.jl @@ -17,6 +17,7 @@ function infer_types!(tape::Tape, context_module::Module) for fc in tape.input_assign_code res_types = result_types(fc, known_result_types, context_module) + fc.return_types = res_types for (s, t) in Iterators.zip( Iterators.flatten(fc.return_symbols), Iterators.cycle(res_types, length(fc.return_symbols)), diff --git a/src/devices/impl.jl b/src/devices/impl.jl index 172160f..7392074 100644 --- a/src/devices/impl.jl +++ b/src/devices/impl.jl @@ -63,9 +63,6 @@ end gen_local_init(fc::FunctionCall) Dispatch from the given [`FunctionCall`](@ref) to the lower-level function [`_gen_local_init`](@ref). - -!!! note - This is currently unused, but may become useful in the future again. """ function gen_local_init(fc::FunctionCall) return Expr( @@ -82,10 +79,26 @@ end Return an `Expr` that initializes the symbol in the local scope. The result looks like `local ::`. - -!!! note - This is currently unused, but may become useful in the future again. """ function _gen_local_init(symbol::Symbol, type::Type) return Expr(:local, symbol, :(::), Symbol(type)) end + +""" + wrap_in_let_statement(expr, symbols) + +For a given expression and a collection of symbols, generate a let statement that wraps the expression in a let statement with all the symbols, like +`let =, ..., = end` +""" +@inline function wrap_in_let_statement(expr, symbols) + return Expr(:let, Expr(:block, _gen_let_statement.(symbols)...), expr) +end + +""" + _gen_let_statement(symbol::Symbol) + +Return a let-`Expr` like ` = `. +""" +function _gen_let_statement(symbol::Symbol) + return Expr(:(=), symbol, symbol) +end From f895f8d869ff1774171ec7843f7a4b2b2d6e6e4d Mon Sep 17 00:00:00 2001 From: Anton Reinhard Date: Wed, 4 Dec 2024 15:47:28 +0100 Subject: [PATCH 12/12] Fix NodeSplit --- src/graph/mute.jl | 20 ++++++++++---------- src/graph/validate.jl | 4 ++-- src/node/type.jl | 16 ++++++++-------- src/node/validate.jl | 8 ++++---- src/operation/apply.jl | 9 +++++---- src/operation/clean.jl | 12 ++++++------ src/operation/find.jl | 6 +++--- 7 files changed, 38 insertions(+), 37 deletions(-) diff --git a/src/graph/mute.jl b/src/graph/mute.jl index 4dbe93d..2018274 100644 --- a/src/graph/mute.jl +++ b/src/graph/mute.jl @@ -238,7 +238,7 @@ function invalidate_caches!(graph::DAG, operation::NodeReduction) delete!(graph.possible_operations, operation) for node in operation.input - node.nodeReduction = missing + node.node_reduction = missing end return nothing @@ -256,7 +256,7 @@ function invalidate_caches!(graph::DAG, operation::NodeSplit) # delete the operation from all caches of nodes involved in the operation # for node split there is only one node - operation.input.nodeSplit = missing + operation.input.node_split = missing return nothing end @@ -267,11 +267,11 @@ end Invalidate the operation caches of the given node through calls to the respective [`invalidate_caches!`](@ref) functions. """ function invalidate_operation_caches!(graph::DAG, node::ComputeTaskNode) - if !ismissing(node.nodeReduction) - invalidate_caches!(graph, node.nodeReduction) + if !ismissing(node.node_reduction) + invalidate_caches!(graph, node.node_reduction) end - if !ismissing(node.nodeSplit) - invalidate_caches!(graph, node.nodeSplit) + if !ismissing(node.node_split) + invalidate_caches!(graph, node.node_split) end return nothing end @@ -282,11 +282,11 @@ end Invalidate the operation caches of the given node through calls to the respective [`invalidate_caches!`](@ref) functions. """ function invalidate_operation_caches!(graph::DAG, node::DataTaskNode) - if !ismissing(node.nodeReduction) - invalidate_caches!(graph, node.nodeReduction) + if !ismissing(node.node_reduction) + invalidate_caches!(graph, node.node_reduction) end - if !ismissing(node.nodeSplit) - invalidate_caches!(graph, node.nodeSplit) + if !ismissing(node.node_split) + invalidate_caches!(graph, node.node_split) end return nothing end diff --git a/src/graph/validate.jl b/src/graph/validate.jl index 9caef4b..c32a138 100644 --- a/src/graph/validate.jl +++ b/src/graph/validate.jl @@ -43,8 +43,8 @@ function is_valid(graph::DAG) for node in graph.dirty_nodes @assert node in graph "Dirty Node is not part of the graph!" - @assert ismissing(node.nodeReduction) "Dirty Node has a NodeReduction!" - @assert ismissing(node.nodeSplit) "Dirty Node has a NodeSplit!" + @assert ismissing(node.node_reduction) "Dirty Node has a NodeReduction!" + @assert ismissing(node.node_split) "Dirty Node has a NodeSplit!" end @assert is_connected(graph) "Graph is not connected!" diff --git a/src/node/type.jl b/src/node/type.jl index fa0b400..8e608dd 100644 --- a/src/node/type.jl +++ b/src/node/type.jl @@ -27,8 +27,8 @@ Any node that transfers data and does no computation. `.parents`: A vector of the node's parents (i.e. nodes that depend on this one).\\ `.children`: A vector of tuples of the node's children (i.e. nodes that this one depends on) and their indices, indicating their order in the resulting function call passed to the task.\\ `.id`: The node's id. Improves the speed of comparisons and is used as a unique identifier.\\ -`.nodeReduction`: Either this node's [`NodeReduction`](@ref) or `missing`, if none. There can only be at most one.\\ -`.nodeSplit`: Either this node's [`NodeSplit`](@ref) or `missing`, if none. There can only be at most one.\\ +`.node_reduction`: Either this node's [`NodeReduction`](@ref) or `missing`, if none. There can only be at most one.\\ +`.node_split`: Either this node's [`NodeSplit`](@ref) or `missing`, if none. There can only be at most one.\\ `.name`: The name of this node for entry nodes into the graph ([`is_entry_node`](@ref)) to reliably assign the inputs to the correct nodes when executing.\\ """ mutable struct DataTaskNode{TaskType<:AbstractDataTask} <: Node @@ -44,10 +44,10 @@ mutable struct DataTaskNode{TaskType<:AbstractDataTask} <: Node # the NodeReduction involving this node, if it exists # Can't use the NodeReduction type here because it's not yet defined - nodeReduction::Union{Operation,Missing} + node_reduction::Union{Operation,Missing} # the NodeSplit involving this node, if it exists - nodeSplit::Union{Operation,Missing} + node_split::Union{Operation,Missing} # for input nodes we need a name for the node to distinguish between them name::String @@ -63,8 +63,8 @@ Any node that computes a result from inputs using an [`AbstractComputeTask`](@re `.parents`: A vector of the node's parents (i.e. nodes that depend on this one).\\ `.children`: A vector of tuples with the node's children (i.e. nodes that this one depends on) and their index, used to order the arguments for the [`AbstractComputeTask`](@ref).\\ `.id`: The node's id. Improves the speed of comparisons and is used as a unique identifier.\\ -`.nodeReduction`: Either this node's [`NodeReduction`](@ref) or `missing`, if none. There can only be at most one.\\ -`.nodeSplit`: Either this node's [`NodeSplit`](@ref) or `missing`, if none. There can only be at most one.\\ +`.node_reduction`: Either this node's [`NodeReduction`](@ref) or `missing`, if none. There can only be at most one.\\ +`.node_split`: Either this node's [`NodeSplit`](@ref) or `missing`, if none. There can only be at most one.\\ `.device`: The Device this node has been scheduled on by a [`Scheduler`](@ref). """ mutable struct ComputeTaskNode{TaskType<:AbstractComputeTask} <: Node @@ -73,8 +73,8 @@ mutable struct ComputeTaskNode{TaskType<:AbstractComputeTask} <: Node children::Vector{Tuple{Node,Int}} id::Base.UUID - nodeReduction::Union{Operation,Missing} - nodeSplit::Union{Operation,Missing} + node_reduction::Union{Operation,Missing} + node_split::Union{Operation,Missing} # the device this node is assigned to execute on device::Union{AbstractDevice,Missing} diff --git a/src/node/validate.jl b/src/node/validate.jl index 04bf03f..8676f0b 100644 --- a/src/node/validate.jl +++ b/src/node/validate.jl @@ -22,11 +22,11 @@ function is_valid_node(graph::DAG, node::Node) @assert node in child.parents "Node is not a parent of its child!" end - #=if !ismissing(node.nodeReduction) - @assert is_valid(graph, node.nodeReduction) + #=if !ismissing(node.node_reduction) + @assert is_valid(graph, node.node_reduction) end - if !ismissing(node.nodeSplit) - @assert is_valid(graph, node.nodeSplit) + if !ismissing(node.node_split) + @assert is_valid(graph, node.node_split) end=# return true diff --git a/src/operation/apply.jl b/src/operation/apply.jl index 99dce26..2a17d19 100644 --- a/src/operation/apply.jl +++ b/src/operation/apply.jl @@ -182,13 +182,14 @@ function node_split!( get_snapshot_diff(graph) n1_parents = copy(parents(n1)) + local parent_indices = Dict() n1_children = copy(children(n1)) for parent in n1_parents - _remove_edge!(graph, n1, parent) + parent_indices[parent] = _remove_edge!(graph, n1, parent) end for (child, index) in n1_children - _remove_edge!(graph, child, n1) + @assert index == _remove_edge!(graph, child, n1) end _remove_node!(graph, n1) @@ -196,10 +197,10 @@ function node_split!( n_copy = copy(n1) _insert_node!(graph, n_copy) - _insert_edge!(graph, n_copy, parent) + _insert_edge!(graph, n_copy, parent, parent_indices[parent]) for (child, index) in n1_children - _insert_edge!(graph, child, n_copy) + _insert_edge!(graph, child, n_copy, index) end end diff --git a/src/operation/clean.jl b/src/operation/clean.jl index 9e0b997..8c415d8 100644 --- a/src/operation/clean.jl +++ b/src/operation/clean.jl @@ -7,7 +7,7 @@ Find node reductions involving the given node. The function pushes the found [`N """ function find_reductions!(graph::DAG, node::Node) # there can only be one reduction per node, avoid adding duplicates - if !ismissing(node.nodeReduction) + if !ismissing(node.node_reduction) return nothing end @@ -32,12 +32,12 @@ function find_reductions!(graph::DAG, node::Node) nr = NodeReduction(reductionVector) push!(graph.possible_operations.node_reductions, nr) for node in reductionVector - if !ismissing(node.nodeReduction) + if !ismissing(node.node_reduction) # it can happen that the dirty node becomes part of an existing NodeReduction and overrides those ones now # this is only a problem insofar the existing NodeReduction has to be deleted and replaced also in the possible_operations - invalidate_caches!(graph, node.nodeReduction) + invalidate_caches!(graph, node.node_reduction) end - node.nodeReduction = nr + node.node_reduction = nr end end @@ -50,14 +50,14 @@ end Find the node split of the given node. The function pushes the found [`NodeSplit`](@ref) (if any) everywhere it needs to be and returns nothing. """ function find_splits!(graph::DAG, node::Node) - if !ismissing(node.nodeSplit) + if !ismissing(node.node_split) return nothing end if (can_split(node)) ns = NodeSplit(node) push!(graph.possible_operations.node_splits, ns) - node.nodeSplit = ns + node.node_split = ns end return nothing diff --git a/src/operation/find.jl b/src/operation/find.jl index 20e9a2c..977a86b 100644 --- a/src/operation/find.jl +++ b/src/operation/find.jl @@ -9,7 +9,7 @@ Insert the given node reduction into its input nodes' operation caches. This is """ function insert_operation!(nr::NodeReduction) for n in nr.input - n.nodeReduction = nr + n.node_reduction = nr end return nothing end @@ -20,7 +20,7 @@ end Insert the given node split into its input node's operation cache. This is thread-safe. """ function insert_operation!(ns::NodeSplit) - ns.input.nodeSplit = ns + ns.input.node_split = ns return nothing end @@ -127,7 +127,7 @@ function generate_operations(graph::DAG) node_reductions = collect(trie) for nrVec in node_reductions - # parent sets are ordered and any node can only be part of one nodeReduction, so a NodeReduction is uniquely identifiable by its first element + # parent sets are ordered and any node can only be part of one node_reduction, so a NodeReduction is uniquely identifiable by its first element # this prevents duplicate node_reductions being generated lock(checkedNodesLock) if (nrVec[1] in checkedNodes)