diff --git a/Project.toml b/Project.toml index d488855..b2dd12c 100644 --- a/Project.toml +++ b/Project.toml @@ -22,14 +22,14 @@ CUDAExt = "CUDA" oneAPIExt = "oneAPI" [compat] -julia = "1.10" AMDGPU = "1" CUDA = "5" DataStructures = "0.18" NumaAllocators = "0.2" -oneAPI = "1" RuntimeGeneratedFunctions = "0.5" StaticArrays = "1" +julia = "1.10" +oneAPI = "1" [extras] SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" diff --git a/ext/devices/cuda/function.jl b/ext/devices/cuda/function.jl index e7c392e..7f0f889 100644 --- a/ext/devices/cuda/function.jl +++ b/ext/devices/cuda/function.jl @@ -7,7 +7,7 @@ function ComputableDAGs.kernel( init_caches = Expr(:block, tape.initCachesCode...) assign_inputs = Expr(:block, ComputableDAGs.expr_from_fc.(tape.inputAssignCode)...) # TODO: use gen_function_body here - code = Expr(:block, ComputableDAGs.expr_from_fc.(tape.computeCode)...) + code = Expr(:block, ComputableDAGs.expr_from_fc.(tape.schedule)...) function_id = ComputableDAGs.to_var_name(UUIDs.uuid1(ComputableDAGs.rng[1])) res_sym = eval( diff --git a/ext/devices/rocm/function.jl b/ext/devices/rocm/function.jl index ae8572b..64cfcfe 100644 --- a/ext/devices/rocm/function.jl +++ b/ext/devices/rocm/function.jl @@ -8,7 +8,7 @@ function ComputableDAGs.kernel( assign_inputs = Expr(:block, ComputableDAGs.expr_from_fc.(tape.inputAssignCode)...) # TODO use gen_function_body here - code = Expr(:block, ComputableDAGs.expr_from_fc.(tape.computeCode)...) + code = Expr(:block, ComputableDAGs.expr_from_fc.(tape.schedule)...) function_id = ComputableDAGs.to_var_name(UUIDs.uuid1(ComputableDAGs.rng[1])) res_sym = eval( diff --git a/src/ComputableDAGs.jl b/src/ComputableDAGs.jl index e9597ba..291fc4f 100644 --- a/src/ComputableDAGs.jl +++ b/src/ComputableDAGs.jl @@ -130,6 +130,7 @@ include("scheduler/interface.jl") include("scheduler/greedy.jl") include("code_gen/type.jl") +include("code_gen/utils.jl") include("code_gen/tape_machine.jl") include("code_gen/function.jl") diff --git a/src/code_gen/function.jl b/src/code_gen/function.jl index c87a8f3..b33ee1b 100644 --- a/src/code_gen/function.jl +++ b/src/code_gen/function.jl @@ -26,7 +26,7 @@ function get_compute_function( initCaches = Expr(:block, tape.initCachesCode...) assignInputs = Expr(:block, expr_from_fc.(tape.inputAssignCode)...) - code = gen_function_body(tape.computeCode; closures_size=closures_size) + code = gen_function_body(tape; closures_size=closures_size) functionId = to_var_name(UUIDs.uuid1(rng[1])) resSym = eval( diff --git a/src/code_gen/tape_machine.jl b/src/code_gen/tape_machine.jl index c01a45a..8079903 100644 --- a/src/code_gen/tape_machine.jl +++ b/src/code_gen/tape_machine.jl @@ -121,14 +121,11 @@ function gen_input_assignment_code( device = entry_device(machine) fc = FunctionCall( - RuntimeGeneratedFunction( - @__MODULE__, - context_module, - Expr(:->, :x, input_expr(instance, name, :x)), - ), + context_module.eval(Expr(:->, :x, input_expr(instance, name, :x))), SVector{0,Any}(), SVector{1,Symbol}(:input), symbol, + Nothing, device, ) @@ -140,40 +137,95 @@ function gen_input_assignment_code( end """ - gen_function_body(fc_vec::Vector{FunctionCall}; closures_size) + gen_function_body(tape::Tape; closures_size) -Generate the function body from the given `Vector` of [`FunctionCall`](@ref)s. +Generate the function body from the given [`Tape`](@ref). ## Keyword Arguments `closures_size`: The size of closures to generate (in lines of code). Closures introduce function barriers in the function body, preventing some optimizations by the compiler and therefore greatly reducing compile time. A value of 1 or less will disable the use of closures entirely. """ -function gen_function_body(fc_vec::Vector{FunctionCall}; closures_size::Int) +function gen_function_body(tape::Tape; closures_size::Int) + if closures_size > 1 + # only need to annotate types later when using closures + infer_types!(tape) + end + + fc_vec = tape.schedule + if (closures_size <= 1) return Expr(:block, expr_from_fc.(fc_vec)...) end closures = Vector{Expr}() - for i in 1:closures_size:length(fc_vec) - code_block = fc_vec[i:min(i + closures_size, length(fc_vec))] + # iterate from end to beginning + # this helps because we can collect all undefined arguments to the closures that have to be returned somewhere earlier + undefined_argument_symbols = Set{Symbol}() + # the final return symbol is the return of the entire generated function, it always has to be returned + push!(undefined_argument_symbols, eval(gen_access_expr(fc_vec[end]))) + + for i in length(fc_vec):(-closures_size):1 + e = i + b = max(i - closures_size, 1) + code_block = fc_vec[b:e] # collect `local var` statements that need to exist before the closure starts - # since the return symbols are always unique, this has to happen for each fc and there will be no duplicates local_inits = gen_local_init.(code_block) - closure = Expr( # call to the following closure (no arguments) - :call, - Expr( # create the closure: () -> code block; return nothing - :->, - :(), - Expr(# # actual function body of the closure - :block, - expr_from_fc.(code_block)..., - Expr(:return, :nothing), + return_symbols = eval.(gen_access_expr.(code_block)) + argument_symbols = Set{Symbol}() + + ret_symbols_set = Set(return_symbols) + for fc in code_block + for arg in fc.arguments + symbol = eval(_gen_access_expr(fc.device, fc.device.cacheStrategy, arg)) + + # symbol won't be defined if it is first calculated in the closure + # so don't add it to the arguments in this case + if !(symbol in ret_symbols_set) + push!(argument_symbols, symbol) + end + end + end + union!(undefined_argument_symbols, argument_symbols) + + intersect!(ret_symbols_set, undefined_argument_symbols) + return_symbols = Symbol[ret_symbols_set...] + + argument_symbols = [argument_symbols...] # make sure there is an order (doesn't matter which) + + closure = Expr( + :block, + Expr( + :(=), + Expr(:tuple, return_symbols...), + Expr( + :call, # call to the following closure (no arguments) + Expr( # create the closure: (args) -> code block; return (locals) + :->, + Expr(:tuple, argument_symbols...), # arguments in the closure definition + Expr( # actual function body of the closure + :block, + local_inits..., # declare local variables with type information inside the closure + expr_from_fc.(code_block)..., + Expr(:return, Expr(:tuple, return_symbols...)), + ), + ), + argument_symbols..., # arguments to the closure call ), ), ) + + setdiff!(undefined_argument_symbols, ret_symbols_set) + + #=Expr( + :macrocall, + Symbol("@closure"), + @__LINE__, + Expr( ) + )=# + # combine to one closure call, including all the local inits and the actual call to the closure - push!(closures, Expr(:block, local_inits..., closure)) + pushfirst!(closures, closure) end return Expr(:block, closures...) @@ -200,25 +252,33 @@ function gen_tape( scheduler::AbstractScheduler=GreedyScheduler(), ) schedule = schedule_dag(scheduler, graph, machine) + function_body = lower(schedule, machine) - # get inSymbols - inputSyms = Dict{String,Vector{Symbol}}() + # get input symbols + input_syms = Dict{String,Vector{Symbol}}() for node in get_entry_nodes(graph) - if !haskey(inputSyms, node.name) - inputSyms[node.name] = Vector{Symbol}() + if !haskey(input_syms, node.name) + input_syms[node.name] = Vector{Symbol}() end - push!(inputSyms[node.name], Symbol("$(to_var_name(node.id))_in")) + push!(input_syms[node.name], Symbol("$(to_var_name(node.id))_in")) end # get outSymbol outSym = Symbol(to_var_name(get_exit_node(graph).id)) - initCaches = gen_cache_init_code(machine) - assign_inputs = gen_input_assignment_code(inputSyms, instance, machine, context_module) + init_caches = gen_cache_init_code(machine) + assign_inputs = gen_input_assignment_code(input_syms, instance, machine, context_module) return Tape{input_type(instance)}( - initCaches, assign_inputs, schedule, inputSyms, outSym, Dict(), instance, machine + init_caches, + assign_inputs, + function_body, + input_syms, + outSym, + Dict(), + instance, + machine, ) end @@ -228,6 +288,9 @@ end Execute the given tape with the given input. For implementation reasons, this disregards the set [`CacheStrategy`](@ref) of the devices and always uses a dictionary. + +!!! warning + This is very slow and might not work. This is to be majorly revamped. """ function execute_tape(tape::Tape, input) cache = Dict{Symbol,Any}() @@ -238,10 +301,12 @@ function execute_tape(tape::Tape, input) @eval $expr end + compute_code = tape.schedule + for function_call in tape.inputAssignCode call_fc(function_call, cache) end - for function_call in tape.computeCode + for function_call in compute_code call_fc(function_call, cache) end diff --git a/src/code_gen/type.jl b/src/code_gen/type.jl index 1b81166..06405b9 100644 --- a/src/code_gen/type.jl +++ b/src/code_gen/type.jl @@ -12,7 +12,7 @@ TODO: update docs struct Tape{INPUT} initCachesCode::Vector{Expr} inputAssignCode::Vector{FunctionCall} - computeCode::Vector{FunctionCall} + schedule::Vector{FunctionCall} inputSymbols::Dict{String,Vector{Symbol}} outputSymbol::Symbol cache::Dict{Symbol,Any} diff --git a/src/code_gen/utils.jl b/src/code_gen/utils.jl new file mode 100644 index 0000000..8379b16 --- /dev/null +++ b/src/code_gen/utils.jl @@ -0,0 +1,45 @@ +""" + infer_types!(schedule::Vector{FunctionCall}) + +Infer the result type of each function call in the given schedule. Returns a dictionary with the result type for each [`Node`](@ref). This assumes that each node has only one statically inferrable return type and will throw an exceptin otherwise. +This also assumes that the given `Vector` contains a topological ordering of its nodes, such as returned by a call to [`schedule_dag`](@ref). +""" +function infer_types!(tape::Tape) + known_result_types = Dict{Symbol,Type}() + + # the only initially known type + known_result_types[:input] = input_type(tape.instance) + + for fc in tape.inputAssignCode + res_type = result_type(fc, known_result_types) + fc.return_type = res_type + known_result_types[fc.return_symbol] = res_type + end + + for fc in tape.schedule + res_type = result_type(fc, known_result_types) + fc.return_type = res_type + known_result_types[fc.return_symbol] = res_type + end + + return nothing +end + +""" + lower(schedule::Vector{Node}, machine::Machine) + +After [`schedule_dag`](@ref) has made a schedule of nodes, this function lowers the vector of [`Node`](@ref)s into a vector of [`FunctionCall`](@ref)s. +""" +function lower(schedule::Vector{Node}, machine::Machine) + calls = Vector{FunctionCall}() + + for node in schedule + if (node isa DataTaskNode && length(children(node)) == 0) + push!(calls, get_init_function_call(node, entry_device(machine))) + else + push!(calls, get_function_call(node)...) + end + end + + return calls +end diff --git a/src/devices/numa/impl.jl b/src/devices/numa/impl.jl index 96a6504..0c61b75 100644 --- a/src/devices/numa/impl.jl +++ b/src/devices/numa/impl.jl @@ -94,7 +94,7 @@ Interface implementation, dispatched to from [`gen_local_init`](@ref). """ function _gen_local_init(fc::FunctionCall, ::NumaNode, ::LocalVariables) s = Symbol("data_$(fc.return_symbol)") - quote_node = Expr(:local, s) # TODO: figure out how to get type info for this local variable + quote_node = Expr(:local, s, :(::), Symbol(fc.return_type)) # TODO: figure out how to get type info for this local variable return quote_node end diff --git a/src/node/create.jl b/src/node/create.jl index ea255b8..6331d39 100644 --- a/src/node/create.jl +++ b/src/node/create.jl @@ -1,4 +1,3 @@ - function DataTaskNode(t::AbstractDataTask, name="") return DataTaskNode( t, diff --git a/src/scheduler/greedy.jl b/src/scheduler/greedy.jl index d00cd99..63795f1 100644 --- a/src/scheduler/greedy.jl +++ b/src/scheduler/greedy.jl @@ -7,46 +7,42 @@ A greedy implementation of a scheduler, creating a topological ordering of nodes struct GreedyScheduler <: AbstractScheduler end function schedule_dag(::GreedyScheduler, graph::DAG, machine::Machine) - nodeQueue = PriorityQueue{Node,Int}() + node_queue = PriorityQueue{Node,Int}() # use a priority equal to the number of unseen children -> 0 are nodes that can be added for node in get_entry_nodes(graph) - enqueue!(nodeQueue, node => 0) + enqueue!(node_queue, node => 0) end - schedule = Vector{FunctionCall}() + schedule = Vector{Node}() sizehint!(schedule, length(graph.nodes)) # keep an accumulated cost of things scheduled to this device so far - deviceAccCost = PriorityQueue{AbstractDevice,Float64}() + device_acc_cost = PriorityQueue{AbstractDevice,Float64}() for device in machine.devices - enqueue!(deviceAccCost, device => 0) + enqueue!(device_acc_cost, device => 0) end - node = nothing - while !isempty(nodeQueue) - @assert peek(nodeQueue)[2] == 0 - node = dequeue!(nodeQueue) + local node + while !isempty(node_queue) + @assert peek(node_queue)[2] == 0 + node = dequeue!(node_queue) # assign the device with lowest accumulated cost to the node (if it's a compute node) if (isa(node, ComputeTaskNode)) - lowestDevice = peek(deviceAccCost)[1] - node.device = lowestDevice - deviceAccCost[lowestDevice] = compute_effort(task(node)) + lowest_device = peek(device_acc_cost)[1] + node.device = lowest_device + device_acc_cost[lowest_device] = compute_effort(task(node)) end - if (node isa DataTaskNode && length(children(node)) == 0) - push!(schedule, get_init_function_call(node, entry_device(machine))) - else - push!(schedule, get_function_call(node)...) - end + push!(schedule, node) for parent in parents(node) # reduce the priority of all parents by one - if (!haskey(nodeQueue, parent)) - enqueue!(nodeQueue, parent => length(children(parent)) - 1) + if (!haskey(node_queue, parent)) + enqueue!(node_queue, parent => length(children(parent)) - 1) else - nodeQueue[parent] = nodeQueue[parent] - 1 + node_queue[parent] = node_queue[parent] - 1 end end end diff --git a/src/scheduler/interface.jl b/src/scheduler/interface.jl index b420788..1dcbfbf 100644 --- a/src/scheduler/interface.jl +++ b/src/scheduler/interface.jl @@ -15,6 +15,6 @@ The function assigns each [`ComputeTaskNode`](@ref) of the [`DAG`](@ref) to one [`DataTaskNode`](@ref)s are not scheduled to devices since they do not compute. Instead, a data node transfers data from the [`AbstractDevice`](@ref) of their child to all [`AbstractDevice`](@ref)s of its parents. -Return a `Vector{FunctionCall}`. See [`FunctionCall`](@ref) +The produced schedule can be converted to [`FunctionCall`](@ref)s using [`lower`](@ref). """ function schedule_dag end diff --git a/src/scheduler/type.jl b/src/scheduler/type.jl index 0f76d07..008f677 100644 --- a/src/scheduler/type.jl +++ b/src/scheduler/type.jl @@ -5,11 +5,12 @@ using StaticArrays Type representing a function call with `N` parameters. Contains the function to call, argument symbols, the return symbol and the device to execute on. """ -struct FunctionCall{VectorType<:AbstractVector,N} +mutable struct FunctionCall{VectorType<:AbstractVector,N} func::Function # TODO: this should be a tuple value_arguments::SVector{N,Any} # value arguments for the function call, will be prepended to the other arguments - arguments::VectorType # symbols of the inputs to the function call + arguments::VectorType # symbols of the inputs to the function call return_symbol::Symbol + return_type::Type device::AbstractDevice end diff --git a/src/task/compute.jl b/src/task/compute.jl index 45b353f..8593646 100644 --- a/src/task/compute.jl +++ b/src/task/compute.jl @@ -11,7 +11,7 @@ For ordinary compute or data tasks the vector will contain exactly one element. function get_function_call( t::CompTask, device::AbstractDevice, in_symbols::AbstractVector, out_symbol::Symbol ) where {CompTask<:AbstractComputeTask} - return [FunctionCall(compute, SVector{1,Any}(t), in_symbols, out_symbol, device)] + return [FunctionCall(compute, SVector{1,Any}(t), in_symbols, out_symbol, Any, device)] end function get_function_call(node::ComputeTaskNode) @@ -42,7 +42,7 @@ function get_function_call(node::ComputeTaskNode) end function get_function_call(node::DataTaskNode) - @assert length(children(node)) == 1 "trying to call get_expression on a data task node that has $(length(node.children)) children instead of 1" + @assert length(children(node)) == 1 "trying to call get_function_call on a data task node that has $(length(node.children)) children instead of 1" # TODO: dispatch to device implementations generating the copy commands return [ @@ -51,19 +51,36 @@ function get_function_call(node::DataTaskNode) SVector{0,Any}(), SVector{1,Symbol}(Symbol(to_var_name(first(children(node))[1].id))), Symbol(to_var_name(node.id)), + Any, first(children(node))[1].device, ), ] end function get_init_function_call(node::DataTaskNode, device::AbstractDevice) - @assert isempty(children(node)) "trying to call get_init_expression on a data task node that is not an entry node." + @assert isempty(children(node)) "trying to call get_init_function_call on a data task node that is not an entry node." return FunctionCall( unpack_identity, SVector{0,Any}(), SVector{1,Symbol}(Symbol("$(to_var_name(node.id))_in")), Symbol(to_var_name(node.id)), + Any, device, ) end + +function result_type(fc::FunctionCall, known_res_types::Dict{Symbol,Type}) + argument_types = ( + typeof.(fc.value_arguments)..., getindex.(Ref(known_res_types), fc.arguments)... + ) + types = Base.return_types(fc.func, argument_types) + + if length(types) > 1 + throw( + "failure during type inference: function call $fc is type unstable, possible return types: $types", + ) + end + + return types[1] +end