diff --git a/src/code_gen/tape_machine.jl b/src/code_gen/tape_machine.jl index f1391b1..94b9f54 100644 --- a/src/code_gen/tape_machine.jl +++ b/src/code_gen/tape_machine.jl @@ -1,3 +1,4 @@ +#= # TODO: do this with macros function call_fc( fc::FunctionCall{VectorT,0}, cache::Dict{Symbol,Any} @@ -49,8 +50,9 @@ function call_fc(fc::FunctionCall{VectorT,M}, cache::Dict{Symbol,Any}) where {Ve ) return nothing end +=# -function expr_from_fc(fc::FunctionCall{VAL_T,N_ARG,N_RET}) where {VAL_T,N_ARG,N_RET} +function expr_from_fc(fc::FunctionCall{VAL_T}) where {VAL_T} if length(fc) == 1 func_call = Expr( :call, @@ -92,9 +94,9 @@ function gen_input_assignment_code( fc = FunctionCall( context_module.eval(Expr(:->, :x, input_expr(instance, name, :x))), (), - (:input,), - (symbol,), - (Nothing,), + [:input], + [symbol], + [Nothing], device, ) @@ -203,8 +205,8 @@ function _closure_fc( setdiff!(arg_symbols_set, ret_symbols_set) intersect!(ret_symbols_set, undefined_argument_symbols) - arg_symbols_t = (arg_symbols_set...,) - ret_symbols_t = (ret_symbols_set...,) + arg_symbols_t = [arg_symbols_set...] + ret_symbols_t = [ret_symbols_set...] closure = context_module.eval( Expr( # create the closure: () -> code block; return (locals) diff --git a/src/devices/impl.jl b/src/devices/impl.jl index 1c31f46..35236c2 100644 --- a/src/devices/impl.jl +++ b/src/devices/impl.jl @@ -33,19 +33,21 @@ end Dispatch from the given [`FunctionCall`](@ref) to the interface function [`_gen_access_expr`](@ref). """ -function gen_access_expr(fc::FunctionCall{VAL_T,N_ARG,N_RET}) where {VAL_T,N_ARG,N_RET} - vec = Expr[] - for ret_symbols in fc.return_symbols - push!(vec, unroll_symbol_vector(_gen_access_expr.(Ref(fc.device), ret_symbols))) +function gen_access_expr(fc::FunctionCall{VAL_T}) where {VAL_T} + if length(fc.return_types) != 1 + # general case + vec = Expr[] + for ret_symbols in fc.return_symbols + push!(vec, unroll_symbol_vector(_gen_access_expr.(Ref(fc.device), ret_symbols))) + end + if length(vec) > 1 + return unroll_symbol_vector(vec) + else + return vec[1] + end end - if length(vec) > 1 - return unroll_symbol_vector(vec) - else - return vec[1] - end -end -function gen_access_expr(fc::FunctionCall{VAL_T,N_ARG,1}) where {VAL_T,N_ARG} + # no vectorization case vec = Symbol[] for ret_symbols in fc.return_symbols push!(vec, _gen_access_expr.(Ref(fc.device), ret_symbols[1])) diff --git a/src/scheduler/type.jl b/src/scheduler/type.jl index 4a441c6..03fc918 100644 --- a/src/scheduler/type.jl +++ b/src/scheduler/type.jl @@ -5,24 +5,25 @@ Type representing a function call. Contains the function to call, argument symbo TODO: extend docs """ -mutable struct FunctionCall{VAL_T<:Tuple,N_ARG,N_RET} +mutable struct FunctionCall{VAL_T<:Tuple} func::Function - value_arguments::Vector{VAL_T} # tuple of value arguments for the function call, will be prepended to the other arguments - arguments::Vector{NTuple{N_ARG,Symbol}} # symbols of the inputs to the function call - return_symbols::Vector{NTuple{N_RET,Symbol}} # the return symbols - return_types::NTuple{N_RET,Type} # the return type of the function call(s); there can only be one return type since we require type stability + value_arguments::Vector{VAL_T} # tuple of value arguments for the function call, will be prepended to the other arguments + arguments::Vector{Vector{Symbol}} # symbols of the inputs to the function call + return_symbols::Vector{Vector{Symbol}} # the return symbols + return_types::Vector{<:Type} # the return type of the function call(s); there can only be one return type since we require type stability device::AbstractDevice end function FunctionCall( func::Function, value_arguments::VAL_T, - arguments::NTuple{N_ARG,Symbol}, - return_symbol::NTuple{N_RET,Symbol}, - return_types::NTuple{N_RET,Type}, + arguments::Vector{Symbol}, + return_symbol::Vector{Symbol}, + return_types::Vector{<:Type}, device::AbstractDevice, -) where {VAL_T<:Tuple,N_ARG,N_RET} +) where {VAL_T<:Tuple} # convenience constructor for function calls that do not use vectorization, which is most of the use cases + @assert length(return_types) == 0 || length(return_types) == length(return_symbol) "number of return types $(length(return_types)) does not match the number of return symbols $(length(return_symbol))" return FunctionCall( func, [value_arguments], [arguments], [return_symbol], return_types, device ) diff --git a/src/task/compute.jl b/src/task/compute.jl index 39c2092..07559d9 100644 --- a/src/task/compute.jl +++ b/src/task/compute.jl @@ -12,7 +12,7 @@ function get_function_call( in_symbols::NTuple{N,Symbol}, out_symbol::Symbol, ) where {N} - return FunctionCall(compute, (t,), in_symbols, (out_symbol,), (Any,), device) + return FunctionCall(compute, (t,), [in_symbols...], [out_symbol], [Any], device) end function get_function_call(node::ComputeTaskNode) @@ -37,9 +37,9 @@ function get_function_call(node::DataTaskNode) return FunctionCall( identity, (), - (Symbol(to_var_name(first(children(node))[1].id)),), - (Symbol(to_var_name(node.id)),), - (Any,), + [Symbol(to_var_name(first(children(node))[1].id))], + [Symbol(to_var_name(node.id))], + [Any], first(children(node))[1].device, ) end @@ -50,9 +50,9 @@ function get_init_function_call(node::DataTaskNode, device::AbstractDevice) return FunctionCall( identity, (), - (Symbol("$(to_var_name(node.id))_in"),), - (Symbol(to_var_name(node.id)),), - (Any,), + [Symbol("$(to_var_name(node.id))_in")], + [Symbol(to_var_name(node.id))], + [Any], device, ) end @@ -63,33 +63,12 @@ function _argument_types(known_res_types::Dict{Symbol,Type}, fc::FunctionCall) end function result_types( - fc::FunctionCall{VAL_T,N_ARG,1}, known_res_types::Dict{Symbol,Type} -) where {VAL_T,N_ARG} + fc::FunctionCall{VAL_T}, known_res_types::Dict{Symbol,Type} +) where {VAL_T} arg_types = (_value_argument_types(fc)..., _argument_types(known_res_types, fc)...) types = Base.return_types(fc.func, arg_types) - if length(types) > 1 - throw( - "failure during type inference: function call $fc with argument types $(arg_types) is type unstable, possible return types: $types", - ) - end - if isempty(types) - throw( - "failure during type inference: function call $fc with argument types $(arg_types) has no return types, this is likely because no method matches the arguments", - ) - end - if types[1] == Any - @warn "inferred return type 'Any' in task $fc with argument types $(arg_types)" - end - - return (types[1],) -end - -function result_types( - fc::FunctionCall{VAL_T,N_ARG,N_RET}, known_res_types::Dict{Symbol,Type} -) where {VAL_T,N_ARG,N_RET} - arg_types = (_value_argument_types(fc)..., _argument_types(known_res_types, fc)...) - types = Base.return_types(fc.func, arg_types) + N_RET = length(fc.return_types) if length(types) > 1 throw( @@ -104,11 +83,15 @@ function result_types( if types[1] == Any @warn "inferred return type 'Any' in task $fc with argument types $(arg_types)" end + + if (N_RET == 1) + return [types[1]] + end + if !(types[1] isa Tuple) || length(types[1].parameters) != N_RET throw( "failure durng type inference: function call $(fc.func) was expected to return a Tuple with $N_RET elements, but returns $(types[1])", ) end - - return (types[1].parameters...,) + return [types[1].parameters...] end