Skip to content

Commit

Permalink
Change Function Calls to mostly use Vectors instead of Tuples
Browse files Browse the repository at this point in the history
This makes small function calls slower, but large function calls much much faster to compile
  • Loading branch information
AntonReinhard committed Nov 18, 2024
1 parent 03cdf58 commit 0171a87
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 59 deletions.
14 changes: 8 additions & 6 deletions src/code_gen/tape_machine.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#=
# TODO: do this with macros
function call_fc(
fc::FunctionCall{VectorT,0}, cache::Dict{Symbol,Any}
Expand Down Expand Up @@ -49,8 +50,9 @@ function call_fc(fc::FunctionCall{VectorT,M}, cache::Dict{Symbol,Any}) where {Ve
)
return nothing
end
=#

function expr_from_fc(fc::FunctionCall{VAL_T,N_ARG,N_RET}) where {VAL_T,N_ARG,N_RET}
function expr_from_fc(fc::FunctionCall{VAL_T}) where {VAL_T}
if length(fc) == 1
func_call = Expr(
:call,
Expand Down Expand Up @@ -92,9 +94,9 @@ function gen_input_assignment_code(
fc = FunctionCall(
context_module.eval(Expr(:->, :x, input_expr(instance, name, :x))),
(),
(:input,),
(symbol,),
(Nothing,),
[:input],
[symbol],
[Nothing],
device,
)

Expand Down Expand Up @@ -203,8 +205,8 @@ function _closure_fc(
setdiff!(arg_symbols_set, ret_symbols_set)
intersect!(ret_symbols_set, undefined_argument_symbols)

arg_symbols_t = (arg_symbols_set...,)
ret_symbols_t = (ret_symbols_set...,)
arg_symbols_t = [arg_symbols_set...]
ret_symbols_t = [ret_symbols_set...]

closure = context_module.eval(
Expr( # create the closure: () -> code block; return (locals)
Expand Down
24 changes: 13 additions & 11 deletions src/devices/impl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,21 @@ end
Dispatch from the given [`FunctionCall`](@ref) to the interface function [`_gen_access_expr`](@ref).
"""
function gen_access_expr(fc::FunctionCall{VAL_T,N_ARG,N_RET}) where {VAL_T,N_ARG,N_RET}
vec = Expr[]
for ret_symbols in fc.return_symbols
push!(vec, unroll_symbol_vector(_gen_access_expr.(Ref(fc.device), ret_symbols)))
function gen_access_expr(fc::FunctionCall{VAL_T}) where {VAL_T}
if length(fc.return_types) != 1
# general case
vec = Expr[]
for ret_symbols in fc.return_symbols
push!(vec, unroll_symbol_vector(_gen_access_expr.(Ref(fc.device), ret_symbols)))
end
if length(vec) > 1
return unroll_symbol_vector(vec)
else
return vec[1]
end
end
if length(vec) > 1
return unroll_symbol_vector(vec)
else
return vec[1]
end
end

function gen_access_expr(fc::FunctionCall{VAL_T,N_ARG,1}) where {VAL_T,N_ARG}
# no vectorization case
vec = Symbol[]
for ret_symbols in fc.return_symbols
push!(vec, _gen_access_expr.(Ref(fc.device), ret_symbols[1]))
Expand Down
19 changes: 10 additions & 9 deletions src/scheduler/type.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,25 @@ Type representing a function call. Contains the function to call, argument symbo
TODO: extend docs
"""
mutable struct FunctionCall{VAL_T<:Tuple,N_ARG,N_RET}
mutable struct FunctionCall{VAL_T<:Tuple}
func::Function
value_arguments::Vector{VAL_T} # tuple of value arguments for the function call, will be prepended to the other arguments
arguments::Vector{NTuple{N_ARG,Symbol}} # symbols of the inputs to the function call
return_symbols::Vector{NTuple{N_RET,Symbol}} # the return symbols
return_types::NTuple{N_RET,Type} # the return type of the function call(s); there can only be one return type since we require type stability
value_arguments::Vector{VAL_T} # tuple of value arguments for the function call, will be prepended to the other arguments
arguments::Vector{Vector{Symbol}} # symbols of the inputs to the function call
return_symbols::Vector{Vector{Symbol}} # the return symbols
return_types::Vector{<:Type} # the return type of the function call(s); there can only be one return type since we require type stability
device::AbstractDevice
end

function FunctionCall(
func::Function,
value_arguments::VAL_T,
arguments::NTuple{N_ARG,Symbol},
return_symbol::NTuple{N_RET,Symbol},
return_types::NTuple{N_RET,Type},
arguments::Vector{Symbol},
return_symbol::Vector{Symbol},
return_types::Vector{<:Type},
device::AbstractDevice,
) where {VAL_T<:Tuple,N_ARG,N_RET}
) where {VAL_T<:Tuple}
# convenience constructor for function calls that do not use vectorization, which is most of the use cases
@assert length(return_types) == 0 || length(return_types) == length(return_symbol) "number of return types $(length(return_types)) does not match the number of return symbols $(length(return_symbol))"
return FunctionCall(
func, [value_arguments], [arguments], [return_symbol], return_types, device
)
Expand Down
49 changes: 16 additions & 33 deletions src/task/compute.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ function get_function_call(
in_symbols::NTuple{N,Symbol},
out_symbol::Symbol,
) where {N}
return FunctionCall(compute, (t,), in_symbols, (out_symbol,), (Any,), device)
return FunctionCall(compute, (t,), [in_symbols...], [out_symbol], [Any], device)
end

function get_function_call(node::ComputeTaskNode)
Expand All @@ -37,9 +37,9 @@ function get_function_call(node::DataTaskNode)
return FunctionCall(
identity,
(),
(Symbol(to_var_name(first(children(node))[1].id)),),
(Symbol(to_var_name(node.id)),),
(Any,),
[Symbol(to_var_name(first(children(node))[1].id))],
[Symbol(to_var_name(node.id))],
[Any],
first(children(node))[1].device,
)
end
Expand All @@ -50,9 +50,9 @@ function get_init_function_call(node::DataTaskNode, device::AbstractDevice)
return FunctionCall(
identity,
(),
(Symbol("$(to_var_name(node.id))_in"),),
(Symbol(to_var_name(node.id)),),
(Any,),
[Symbol("$(to_var_name(node.id))_in")],
[Symbol(to_var_name(node.id))],
[Any],
device,
)
end
Expand All @@ -63,33 +63,12 @@ function _argument_types(known_res_types::Dict{Symbol,Type}, fc::FunctionCall)
end

function result_types(
fc::FunctionCall{VAL_T,N_ARG,1}, known_res_types::Dict{Symbol,Type}
) where {VAL_T,N_ARG}
fc::FunctionCall{VAL_T}, known_res_types::Dict{Symbol,Type}
) where {VAL_T}
arg_types = (_value_argument_types(fc)..., _argument_types(known_res_types, fc)...)
types = Base.return_types(fc.func, arg_types)

if length(types) > 1
throw(
"failure during type inference: function call $fc with argument types $(arg_types) is type unstable, possible return types: $types",
)
end
if isempty(types)
throw(
"failure during type inference: function call $fc with argument types $(arg_types) has no return types, this is likely because no method matches the arguments",
)
end
if types[1] == Any
@warn "inferred return type 'Any' in task $fc with argument types $(arg_types)"
end

return (types[1],)
end

function result_types(
fc::FunctionCall{VAL_T,N_ARG,N_RET}, known_res_types::Dict{Symbol,Type}
) where {VAL_T,N_ARG,N_RET}
arg_types = (_value_argument_types(fc)..., _argument_types(known_res_types, fc)...)
types = Base.return_types(fc.func, arg_types)
N_RET = length(fc.return_types)

if length(types) > 1
throw(
Expand All @@ -104,11 +83,15 @@ function result_types(
if types[1] == Any
@warn "inferred return type 'Any' in task $fc with argument types $(arg_types)"
end

if (N_RET == 1)
return [types[1]]
end

if !(types[1] isa Tuple) || length(types[1].parameters) != N_RET
throw(
"failure durng type inference: function call $(fc.func) was expected to return a Tuple with $N_RET elements, but returns $(types[1])",
)
end

return (types[1].parameters...,)
return [types[1].parameters...]
end

0 comments on commit 0171a87

Please sign in to comment.