diff --git a/src/ComputableDAGs.jl b/src/ComputableDAGs.jl index 63bc257..c0b8b7c 100644 --- a/src/ComputableDAGs.jl +++ b/src/ComputableDAGs.jl @@ -35,7 +35,6 @@ export get_operations # code generation related export get_compute_function -export gen_tape # estimator export cost_type, graph_cost, operation_effect @@ -48,8 +47,7 @@ export optimize_step!, optimize! export fixpoint_reached, optimize_to_fixpoint! # models -export AbstractModel, AbstractProblemInstance -export problem_instance, input_type, graph, input_expr +export graph, input_type, input_expr # machine info export Machine @@ -67,6 +65,7 @@ include("properties/type.jl") include("operation/type.jl") include("graph/type.jl") include("scheduler/type.jl") +include("code_gen/type.jl") include("trie.jl") include("utils.jl") @@ -125,7 +124,6 @@ include("devices/ext.jl") include("scheduler/interface.jl") include("scheduler/greedy.jl") -include("code_gen/type.jl") include("code_gen/utils.jl") include("code_gen/tape_machine.jl") include("code_gen/function.jl") diff --git a/src/code_gen/tape_machine.jl b/src/code_gen/tape_machine.jl index dc7c795..767dc9b 100644 --- a/src/code_gen/tape_machine.jl +++ b/src/code_gen/tape_machine.jl @@ -1,4 +1,4 @@ -function expr_from_fc(fc::FunctionCall{VAL_T,F_T}) where {VAL_T,F_T<:Function} +function expr_from_fc(fc::FunctionCall{VAL_T,<:Function}) where {VAL_T} if length(fc) == 1 func_call = Expr(:call, fc.func, fc.value_arguments[1]..., fc.arguments[1]...) else @@ -10,14 +10,35 @@ function expr_from_fc(fc::FunctionCall{VAL_T,F_T}) where {VAL_T,F_T<:Function} end function expr_from_fc(fc::FunctionCall{VAL_T,Expr}) where {VAL_T} - @assert length(fc) == 1 && isempty(fc.arguments[1]) && isempty(fc.value_arguments[1]) "function call assigning an expression has an unallowed combination of arguments, which is not allowed\n$fc" - return Expr(:(=), gen_access_expr(fc), fc.func) + @assert length(fc) == 1 && isempty(fc.value_arguments[1]) "function call assigning an expression cannot be vectorized and cannot contain value arguments\n$fc" + + fc_expr_in_let = Expr( + :let, + Expr(:block, fc.return_symbols[1]...), + fc.func, # anonymous function code block + ) + + func_call = Expr( + :call, # call + Expr( + :->, # anonymous function + Expr( + :tuple, # anonymous function arguments + fc.arguments[1]..., + ), + fc_expr_in_let, + ), + fc.arguments[1]..., # runtime arguments passed to the anonymous function + ) + + access_expr = gen_access_expr(fc) + return Expr(:(=), access_expr, func_call) end """ gen_input_assignment_code( input_symbols::Dict{String, Vector{Symbol}}, - instance::AbstractProblemInstance, + instance::Any, machine::Machine, input_type::Type, context_module::Module @@ -26,40 +47,22 @@ end Return a `Vector{Expr}` doing the input assignments from the given `problem_input` onto the `input_symbols`. """ function gen_input_assignment_code( - input_symbols::Dict{String,Vector{Symbol}}, - instance, - machine::Machine, - input_type::Type, - context_module::Module, + input_symbols::Dict{String,Vector{Symbol}}, instance, machine::Machine ) assign_inputs = Vector{FunctionCall}() for (name, symbols) in input_symbols for symbol in symbols device = entry_device(machine) - f_id = Symbol(to_var_name(UUIDs.uuid1(rng[threadid()]))) - - fc_setup = FunctionCall( - Expr(:->, :x, input_expr(instance, name, :x)), + fc = FunctionCall( + input_expr(instance, name, :input), (), - Symbol[], - Symbol[f_id], + Symbol[:input], + Symbol[symbol], Type[Nothing], device, ) - fc = FunctionCall( - _call, (), Symbol[f_id, :input], Symbol[symbol], Type[Nothing], device - ) - - ret_expr = Expr( - :call, Base.return_types, fc_setup.func, Expr(:tuple, input_type) - ) - ret_type = context_module.eval(ret_expr) - @assert length(ret_type) == 1 - fc.return_types = [ret_type[1]] - - push!(assign_inputs, fc_setup) push!(assign_inputs, fc) end end @@ -73,13 +76,15 @@ end Generate the function body from the given [`Tape`](@ref). ## Keyword Arguments -`closures_size`: The size of closures to generate (in lines of code). Closures introduce function barriers in the function body, preventing some optimizations by the compiler and therefore greatly reducing compile time. A value of 1 or less will disable the use of closures entirely. +`closures_size`: The size of closures to generate (in lines of code). Closures introduce function barriers + in the function body, preventing some optimizations by the compiler and therefore greatly reducing + compile time. A value of 0 will disable the use of closures entirely. """ function gen_function_body(tape::Tape, context_module::Module; closures_size::Int) # only need to annotate types later when using closures - types = infer_types!(tape) + types = infer_types!(tape, context_module) - if closures_size >= 1 + if closures_size > 1 s = log(closures_size, length(tape.schedule)) closures_depth = ceil(Int, s) # tend towards more levels/smaller closures closures_size = ceil(Int, length(tape.schedule)^(1 / closures_depth)) @@ -113,15 +118,14 @@ function _gen_function_body( closured_fc_vec = FunctionCall[] for i in length(fc_vec):(-closures_size):1 e = i - b = max(i - closures_size, 1) + b = max(i - closures_size + 1, 1) code_block = fc_vec[b:e] - pushfirst!( - closured_fc_vec, - _closure_fc( - code_block, type_dict, machine, undefined_argument_symbols, context_module - ), + closure_fc = _closure_fc( + code_block, type_dict, machine, undefined_argument_symbols, context_module ) + + pushfirst!(closured_fc_vec, closure_fc) end return _gen_function_body( @@ -130,9 +134,16 @@ function _gen_function_body( end """ - _closure_fc() + _closure_fc( + code_block::AbstractVector{FunctionCall}, + types::Dict{Symbol,Type}, + machine::Machine, + undefined_argument_symbols::Set{Symbol}, + context_module::Module, + ) -From the given function calls, make and return a new function call representing all of them together. +From the given function calls, make and return 2 function calls representing all of them together. 2 function calls are necessary, one for setting up the anonymous +function and the second for calling it. The undefined_argument_symbols is the set of all Symbols that need to be returned if available inside the code_block. They get updated inside this function. """ function _closure_fc( @@ -168,29 +179,28 @@ function _closure_fc( arg_symbols_t = [arg_symbols_set...] ret_symbols_t = [ret_symbols_set...] - closure = context_module.eval( - Expr( # create the closure: () -> code block; return (locals) - :->, - Expr(:tuple, arg_symbols_t...), # closure arguments - Expr( # actual function body of the closure - :block, - expr_from_fc.(code_block)..., - Expr( - :return, # have to make sure to not return a tuple of length 1 - if length(ret_symbols_t) == 1 - ret_symbols_t[1] - else - Expr(:tuple, ret_symbols_t...) - end, - ), - ), + ret_types = (getindex.(Ref(types), ret_symbols_t)) + + fc_expr = Expr( # actual function body of the closure + :block, + expr_from_fc.(code_block)..., + Expr( + :return, # have to make sure to not return a tuple of length 1 + if length(ret_symbols_t) == 1 + ret_symbols_t[1] + else + Expr(:tuple, ret_symbols_t...) + end, ), ) - ret_types = (getindex.(Ref(types), ret_symbols_t)) - fc = FunctionCall( - closure, (), arg_symbols_t, ret_symbols_t, ret_types, entry_device(machine) + fc_expr, + (), + Symbol[arg_symbols_t...], + ret_symbols_t, + ret_types, + entry_device(machine), ) setdiff!(undefined_argument_symbols, ret_symbols_set) @@ -201,7 +211,7 @@ end """ gen_tape( graph::DAG, - instance::AbstractProblemInstance, + instance::Any, machine::Machine, context_module::Module, scheduler::AbstractScheduler = GreedyScheduler() @@ -233,10 +243,8 @@ function gen_tape( # get outSymbol outSym = Symbol(to_var_name(get_exit_node(graph).id)) - INPUT_T = input_type(instance) - assign_inputs = gen_input_assignment_code( - input_syms, instance, machine, INPUT_T, context_module - ) + assign_inputs = gen_input_assignment_code(input_syms, instance, machine) + INPUT_T = input_type(instance) return Tape{INPUT_T}(assign_inputs, function_body, outSym, instance, machine) end diff --git a/src/code_gen/type.jl b/src/code_gen/type.jl index 4e2f004..45e653a 100644 --- a/src/code_gen/type.jl +++ b/src/code_gen/type.jl @@ -1,12 +1,76 @@ +""" + FunctionCall{VAL_T<:Tuple,FUNC_T<:Union{Function,Expr}} + +Representation of a function call. Contains the function to call (or an expression of a value to assign), +value arguments of type `VAL_T`, argument symbols, the return symbol(s) and type(s) and the device to execute on. + +To support vectorization, i.e., calling the same function on multiple inputs (SIMD), the value arguments, arguments, +and return symbols are each vectors of the actual inputs. In the non-vectorized case, these `Vector`s simply always +have length 1. For this common case, a special constructor exists which automatically wraps each of these arguments +in a `Vector`. + +## Type Arguments +- `VAL_T<:Tuple`: A tuple of all the value arguments that are passed to the function when it's called. +- `FUNC_T<:Union{Function, Expr}`: The type of the function. `Function` is the default, but in some cases, an `Expr` + of a value can be necessary to assign to the return symbol. In this case, no arguments are allowed. + +## Fields +- `func::FUNC_T`: The function to be called, or an expression containing a value to assign to the return_symbol. +- `value_arguments::Vector{VAL_T}`: The value arguments for the function call. These are passed *first* to the + function, in the order given here. The `Vector` contains the tuple of value arguments for each vectorization + member. +- `arguments::Vector{Vector{Symbol}}`: The first vector represents the vectorization, the second layer represents the + symbols that will be passed as arguments to the function call. +- `return_symbols::Vector{Vector{Symbol}}`: As with the arguments, the first vector level represents the vectorization, + the second represents the symbols that the results of the function call are assigned to. For most function calls, + there is only one return symbol. When using closures when generating a function body for a [`Tape`](@ref), the + option to have multiple return symbols is necessary. +- `return_types::Vector{<:Type}`: The types of the function call with the arguments provided. This field only contains + one level of Vector, because it is required that a `FunctionCall` is type stable, and therefore, the types of the + return symbols have to be equal for all members of a vectorization. The return type is initially set to `Nothing` + and later inferred and assigned by [`infer_types!`](@ref). +- `device::AbstractDevice`: The device that this function call is scheduled on. +""" +mutable struct FunctionCall{VAL_T<:Tuple,FUNC_T<:Union{Function,Expr}} + func::FUNC_T + value_arguments::Vector{VAL_T} # tuple of value arguments for the function call, will be prepended to the other arguments + arguments::Vector{Vector{Symbol}} # symbols of the inputs to the function call + return_symbols::Vector{Vector{Symbol}} # the return symbols + return_types::Vector{<:Type} # the return type of the function call(s); there can only be one return type since we require type stability + device::AbstractDevice +end +function FunctionCall( + func::Union{Function,Expr}, + value_arguments::VAL_T, + arguments::Vector{Symbol}, + return_symbol::Vector{Symbol}, + return_types::Vector{<:Type}, + device::AbstractDevice, +) where {VAL_T<:Tuple} + # convenience constructor for function calls that do not use vectorization, which is most of the use cases + @assert length(return_types) == 0 || length(return_types) == length(return_symbol) "number of return types '$(length(return_types))' does not match the number of return symbols '$(length(return_symbol))'" + @assert func isa Function || length(value_arguments) == 0 "no value arguments are allowed for a an Expr FunctionCall, but got '$value_arguments'" + return FunctionCall( + func, [value_arguments], [arguments], [return_symbol], return_types, device + ) +end """ Tape{INPUT} -TODO: update docs -- `INPUT` the input type of the problem instance +Lowered representation of a computation, generated from a [`DAG`](@ref) through [`gen_tape`](@ref). + +- `INPUT` the input type of the problem instance, see also the interface function [`input_type`](@ref) -- `code::Vector{Expr}`: The julia expression containing the code for the whole graph. -- `output_symbol::Symbol`: The symbol of the final calculated value +## Fields +- `input_assign_code::Vector{FunctionCall}`: The [`FunctionCall`](@ref)s representing the input assignments, + mapping part of the input of the computation to each DAG entry node. These functions are generated using + the interface function [`input_expr`](@ref). +- `schedule::Vector{FunctionCall}`: The [`FunctionCall`](@ref)s representing the function body of the computation. + There is one function call for each node in the [`DAG`](@ref). +- `output_symbol::Symbol`: The symbol of the final calculated value, which is returned. +- `instance::Any`: The instance that this tape is generated for. +- `machine::Machine`: The [`Machine`](@ref) that this tape is generated for. """ struct Tape{INPUT} input_assign_code::Vector{FunctionCall} diff --git a/src/code_gen/utils.jl b/src/code_gen/utils.jl index 19eb56c..8207406 100644 --- a/src/code_gen/utils.jl +++ b/src/code_gen/utils.jl @@ -1,23 +1,22 @@ -""" - infer_types!(schedule::Vector{FunctionCall}) +function Base.length(fc::FunctionCall) + @assert length(fc.value_arguments) == length(fc.arguments) == length(fc.return_symbols) "function call length is undefined, got '$(length(fc.value_arguments))' tuples of value arguments, '$(length(fc.arguments))' tuples of arguments, and '$(length(return_symbols))' return symbols" + return length(fc.value_arguments) +end -Infer the result type of each function call in the given schedule. Returns a dictionary with the result type for each [`Node`](@ref). This assumes that each node has only one statically inferrable return type and will throw an exceptin otherwise. -This also assumes that the given `Vector` contains a topological ordering of its nodes, such as returned by a call to [`schedule_dag`](@ref). +""" + infer_types!(tape::Tape, context_module::Module) -Also returns the inferred types as a `Dict{Symbol, Type}`. +Infer the result type of each function call in the given tape. Returns a dictionary with the result type for each symbol and sets each function call's return_types. +This function assumes that each [`FunctionCall`](@ref) has only one statically inferrable return type and will throw an exception otherwise. """ -function infer_types!(tape::Tape) +function infer_types!(tape::Tape, context_module::Module) known_result_types = Dict{Symbol,Type}() # the only initially known type known_result_types[:input] = input_type(tape.instance) for fc in tape.input_assign_code - if typeof(fc.func) isa Expr - continue - end - # for input assign code, the return types are set on construction - res_types = fc.return_types + res_types = result_types(fc, known_result_types, context_module) for (s, t) in Iterators.zip( Iterators.flatten(fc.return_symbols), Iterators.cycle(res_types, length(fc.return_symbols)), @@ -27,7 +26,7 @@ function infer_types!(tape::Tape) end for fc in tape.schedule - res_types = result_types(fc, known_result_types) + res_types = result_types(fc, known_result_types, context_module) fc.return_types = res_types for (s, t) in Iterators.zip( Iterators.flatten(fc.return_symbols), diff --git a/src/models/interface.jl b/src/models/interface.jl index 1672e40..81bf46c 100644 --- a/src/models/interface.jl +++ b/src/models/interface.jl @@ -1,44 +1,26 @@ - -""" - AbstractModel - -Base type for all models. From this, [`AbstractProblemInstance`](@ref)s can be constructed. - -See also: [`problem_instance`](@ref) """ -abstract type AbstractModel end + input_type(problem_instance) -""" - problem_instance(::AbstractModel, ::Vararg) - -Interface function that must be implemented for any implementation of [`AbstractModel`](@ref). This function should return a specific [`AbstractProblemInstance`](@ref) given some parameters. -""" -function problem_instance end +Return the input type for a specific `problem_instance`. This can be a specific type or a supertype for which all child types are expected to be implemented. -""" - AbstractProblemInstance - -Base type for problem instances. An object of this type of a corresponding [`AbstractModel`](@ref) should uniquely identify a problem instance of that model. -""" -abstract type AbstractProblemInstance end - -""" - input_type(problem::AbstractProblemInstance) - -Return the input type for a specific [`AbstractProblemInstance`](@ref). This can be a specific type or a supertype for which all child types are expected to work. +For more details on the `problem_instance`, please refer to the documentation. """ function input_type end """ - graph(::AbstractProblemInstance) + graph(problem_instance) + +Generate the [`DAG`](@ref) for the given `problem_instance`. Every entry node (see [`get_entry_nodes`](@ref)) to the graph must have a name set. Implement [`input_expr`](@ref) to return a valid expression for each of those names. -Generate the [`DAG`](@ref) for the given [`AbstractProblemInstance`](@ref). Every entry node (see [`get_entry_nodes`](@ref)) to the graph must have a name set. Implement [`input_expr`](@ref) to return a valid expression for each of those names. +For more details on the `problem_instance`, please refer to the documentation. """ function graph end """ - input_expr(instance::AbstractProblemInstance, name::String, input_symbol::Symbol) + input_expr(problem_instance, name::String, input_symbol::Symbol) + +For the given `problem_instance`, the entry node name, and the symbol of the problem input (where a variable of type `input_type(...)` will exist), return an `Expr` that gets that specific input value from the input symbol. -For the given [`AbstractProblemInstance`](@ref), the entry node name, and the symbol of the problem input (where a variable of type `input_type(...)` will exist), return an `Expr` that gets that specific input value from the input symbol. +For more details on the `problem_instance`, please refer to the documentation. """ function input_expr end diff --git a/src/properties/type.jl b/src/properties/type.jl index 7f7e001..afa9f19 100644 --- a/src/properties/type.jl +++ b/src/properties/type.jl @@ -3,12 +3,12 @@ Representation of a [`DAG`](@ref)'s properties. -# Fields: -`.data`: The total data transfer.\\ -`.compute_effort`: The total compute effort.\\ -`.compute_intensity`: The compute intensity, will always equal `.compute_effort / .data`.\\ -`.number_of_nodes`: Number of [`Node`](@ref)s.\\ -`.number_of_edges`: Number of [`Edge`](@ref)s. +## Fields: +- `data::Float64`: The total data transfer. +- `compute_effort::Float64`: The total compute effort. +- `compute_intensity::Float64`: The compute intensity, will always equal `compute_effort / data`. +- `number_of_nodes::Int`: Number of [`Node`](@ref)s. +- `number_of_edges::Int`: Number of [`Edge`](@ref)s. """ const GraphProperties = NamedTuple{ (:data, :compute_effort, :compute_intensity, :number_of_nodes, :number_of_edges), diff --git a/src/scheduler/interface.jl b/src/scheduler/interface.jl index 1dcbfbf..b452a36 100644 --- a/src/scheduler/interface.jl +++ b/src/scheduler/interface.jl @@ -1,11 +1,3 @@ - -""" - AbstractScheduler - -Abstract base type for scheduler implementations. The scheduler is used to assign each node to a device and create a topological ordering of tasks. -""" -abstract type AbstractScheduler end - """ schedule_dag(::Scheduler, ::DAG, ::Machine) diff --git a/src/scheduler/type.jl b/src/scheduler/type.jl index 28452a8..7eb4062 100644 --- a/src/scheduler/type.jl +++ b/src/scheduler/type.jl @@ -1,35 +1,6 @@ """ - FunctionCall{VAL_TYPES} + AbstractScheduler -Type representing a function call. Contains the function to call, argument symbols, the return symbol and the device to execute on. - -TODO: extend docs +Abstract base type for scheduler implementations. The scheduler is used to assign each node to a device and create a topological ordering of tasks. """ -mutable struct FunctionCall{VAL_T<:Tuple,FUNC_T<:Union{Function,Expr}} - func::FUNC_T - value_arguments::Vector{VAL_T} # tuple of value arguments for the function call, will be prepended to the other arguments - arguments::Vector{Vector{Symbol}} # symbols of the inputs to the function call - return_symbols::Vector{Vector{Symbol}} # the return symbols - return_types::Vector{<:Type} # the return type of the function call(s); there can only be one return type since we require type stability - device::AbstractDevice -end - -function FunctionCall( - func::Union{Function,Expr}, - value_arguments::VAL_T, - arguments::Vector{Symbol}, - return_symbol::Vector{Symbol}, - return_types::Vector{<:Type}, - device::AbstractDevice, -) where {VAL_T<:Tuple} - # convenience constructor for function calls that do not use vectorization, which is most of the use cases - @assert length(return_types) == 0 || length(return_types) == length(return_symbol) "number of return types $(length(return_types)) does not match the number of return symbols $(length(return_symbol))" - return FunctionCall( - func, [value_arguments], [arguments], [return_symbol], return_types, device - ) -end - -function Base.length(fc::FunctionCall) - @assert length(fc.value_arguments) == length(fc.arguments) == length(fc.return_symbols) "function call length is undefined, got $(length(fc.value_arguments)) tuples of value arguments, $(length(fc.arguments)) tuples of arguments, and $(length(return_symbols)) return symbols" - return length(fc.value_arguments) -end +abstract type AbstractScheduler end diff --git a/src/task/compute.jl b/src/task/compute.jl index c2e9fdd..b76e238 100644 --- a/src/task/compute.jl +++ b/src/task/compute.jl @@ -62,14 +62,8 @@ function _argument_types(known_res_types::Dict{Symbol,Type}, fc::FunctionCall) return getindex.(Ref(known_res_types), fc.arguments[1]) end -function result_types( - fc::FunctionCall{VAL_T,F_T}, known_res_types::Dict{Symbol,Type} -) where {VAL_T,F_T<:Function} - arg_types = (_value_argument_types(fc)..., _argument_types(known_res_types, fc)...) - types = Base.return_types(fc.func, arg_types) - +function _validate_result_types(fc::FunctionCall, types, arg_types) N_RET = length(fc.return_types) - if length(types) > 1 throw( "failure during type inference: function call $(fc.func) with argument types $(arg_types) is type unstable, possible return types: $types", @@ -85,21 +79,60 @@ function result_types( end if (N_RET == 1) - return [types[1]] + return nothing end if !(types[1] isa Tuple) || length(types[1].parameters) != N_RET throw( - "failure durng type inference: function call $(fc.func) was expected to return a Tuple with $N_RET elements, but returns $(types[1])", + "failure during type inference: function call $(fc.func) was expected to return a Tuple with $N_RET elements, but returns $(types[1])", ) end + return nothing +end + +function result_types( + fc::FunctionCall{VAL_T,F_T}, known_res_types::Dict{Symbol,Type}, context_module::Module +) where {VAL_T,F_T<:Function} + arg_types = (_value_argument_types(fc)..., _argument_types(known_res_types, fc)...) + types = Base.return_types(fc.func, arg_types) + + _validate_result_types(fc, types, arg_types) + + N_RET = length(fc.return_types) + if (N_RET == 1) + return [types[1]] + end + return [types[1].parameters...] end function result_types( - fc::FunctionCall{VAL_T,Expr}, known_res_types::Dict{Symbol,Type} + fc::FunctionCall{VAL_T,Expr}, known_res_types::Dict{Symbol,Type}, context_module::Module ) where {VAL_T} - # assume that the return type is already set - @assert length(fc.return_types) == 1 - return [fc.return_types[1]] + arg_types = _argument_types(known_res_types, fc) + ret_expr = Expr( + :call, + Base.return_types, # return types call + Expr( # function argument to return_types + :->, # anonymous function + Expr( + :tuple, # anonymous function arguments + fc.arguments[1]..., + ), + fc.func, # anonymous function code block + ), + Expr(:tuple, arg_types...), # types arguments to return_types + ) + types = context_module.eval(ret_expr) + + #@info "evaluation of expression\n$ret_expr\ngives\n$types" + + _validate_result_types(fc, types, arg_types) + + N_RET = length(fc.return_types) + if (N_RET == 1) + return [types[1]] + end + + return [types[1].parameters...] end