Skip to content

Commit

Permalink
WIP fix input assignment code world age problem
Browse files Browse the repository at this point in the history
  • Loading branch information
AntonReinhard committed Nov 27, 2024
1 parent 0171a87 commit 4003edc
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 22 deletions.
48 changes: 34 additions & 14 deletions src/code_gen/tape_machine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ function call_fc(fc::FunctionCall{VectorT,M}, cache::Dict{Symbol,Any}) where {Ve
end
=#

function expr_from_fc(fc::FunctionCall{VAL_T}) where {VAL_T}
function expr_from_fc(fc::FunctionCall{VAL_T,F_T}) where {VAL_T,F_T<:Function}
if length(fc) == 1
func_call = Expr(
:call,
Expand All @@ -68,11 +68,17 @@ function expr_from_fc(fc::FunctionCall{VAL_T}) where {VAL_T}
return Expr(:(=), access_expr, func_call)
end

function expr_from_fc(fc::FunctionCall{VAL_T,Expr}) where {VAL_T}
@assert length(fc) == 1 && isempty(fc.arguments[1]) && isempty(fc.value_arguments[1]) "function call assigning an expression has an unallowed combination of arguments, which is not allowed\n$fc"
return Expr(:(=), gen_access_expr(fc), fc.func)
end

"""
gen_input_assignment_code(
input_symbols::Dict{String, Vector{Symbol}},
instance::AbstractProblemInstance,
machine::Machine,
input_type::Type,
context_module::Module
)
Expand All @@ -82,24 +88,37 @@ function gen_input_assignment_code(
input_symbols::Dict{String,Vector{Symbol}},
instance,
machine::Machine,
input_type::Type,
context_module::Module,
)
assign_inputs = Vector{FunctionCall}()
for (name, symbols) in input_symbols
# make a function for this, since we can't use anonymous functions in the FunctionCall

for symbol in symbols
device = entry_device(machine)

fc = FunctionCall(
context_module.eval(Expr(:->, :x, input_expr(instance, name, :x))),
f_id = Symbol(to_var_name(UUIDs.uuid1(rng[threadid()])))

fc_setup = FunctionCall(
Expr(:->, :x, input_expr(instance, name, :x)),
(),
[:input],
[symbol],
[Nothing],
Symbol[],
Symbol[f_id],
Type[Nothing],
device,
)

fc = FunctionCall(
_call, (), Symbol[f_id, :input], Symbol[symbol], Type[Nothing], device
)

ret_expr = Expr(
:call, Base.return_types, fc_setup.func, Expr(:tuple, input_type)
)
ret_type = context_module.eval(ret_expr)
@assert length(ret_type) == 1
fc.return_types = [ret_type[1]]

push!(assign_inputs, fc_setup)
push!(assign_inputs, fc)
end
end
Expand All @@ -125,7 +144,7 @@ function gen_function_body(tape::Tape, context_module::Module; closures_size::In
closures_size = ceil(Int, length(tape.schedule)^(1 / closures_depth))
end

@info "generating function body with closure size $closures_size"
@debug "generating function body with closure size $closures_size"

return _gen_function_body(
tape.schedule, types, tape.machine, context_module; closures_size=closures_size
Expand All @@ -139,7 +158,7 @@ function _gen_function_body(
context_module::Module;
closures_size=0,
)
@info "generating function body from $(length(fc_vec)) function calls with closure size $closures_size"
@debug "generating function body from $(length(fc_vec)) function calls with closure size $closures_size"
if closures_size <= 1 || closures_size >= length(fc_vec)
return Expr(:block, expr_from_fc.(fc_vec)...)
end
Expand Down Expand Up @@ -275,11 +294,12 @@ function gen_tape(
# get outSymbol
outSym = Symbol(to_var_name(get_exit_node(graph).id))

assign_inputs = gen_input_assignment_code(input_syms, instance, machine, context_module)

return Tape{input_type(instance)}(
assign_inputs, function_body, outSym, instance, machine
INPUT_T = input_type(instance)
assign_inputs = gen_input_assignment_code(
input_syms, instance, machine, INPUT_T, context_module
)

return Tape{INPUT_T}(assign_inputs, function_body, outSym, instance, machine)
end

"""
Expand Down
7 changes: 5 additions & 2 deletions src/code_gen/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@ function infer_types!(tape::Tape)
known_result_types[:input] = input_type(tape.instance)

for fc in tape.input_assign_code
res_types = result_types(fc, known_result_types)
fc.return_types = res_types
if typeof(fc.func) isa Expr
continue
end
# for input assign code, the return types are set on construction
res_types = fc.return_types
for (s, t) in Iterators.zip(
Iterators.flatten(fc.return_symbols),
Iterators.cycle(res_types, length(fc.return_symbols)),
Expand Down
8 changes: 4 additions & 4 deletions src/scheduler/type.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,17 @@ Type representing a function call. Contains the function to call, argument symbo
TODO: extend docs
"""
mutable struct FunctionCall{VAL_T<:Tuple}
func::Function
mutable struct FunctionCall{VAL_T<:Tuple,FUNC_T<:Union{Function,Expr}}
func::FUNC_T
value_arguments::Vector{VAL_T} # tuple of value arguments for the function call, will be prepended to the other arguments
arguments::Vector{Vector{Symbol}} # symbols of the inputs to the function call
return_symbols::Vector{Vector{Symbol}} # the return symbols
return_types::Vector{<:Type} # the return type of the function call(s); there can only be one return type since we require type stability
return_types::Vector{<:Type} # the return type of the function call(s); there can only be one return type since we require type stability
device::AbstractDevice
end

function FunctionCall(
func::Function,
func::Union{Function,Expr},
value_arguments::VAL_T,
arguments::Vector{Symbol},
return_symbol::Vector{Symbol},
Expand Down
12 changes: 10 additions & 2 deletions src/task/compute.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ function _argument_types(known_res_types::Dict{Symbol,Type}, fc::FunctionCall)
end

function result_types(
fc::FunctionCall{VAL_T}, known_res_types::Dict{Symbol,Type}
) where {VAL_T}
fc::FunctionCall{VAL_T,F_T}, known_res_types::Dict{Symbol,Type}
) where {VAL_T,F_T<:Function}
arg_types = (_value_argument_types(fc)..., _argument_types(known_res_types, fc)...)
types = Base.return_types(fc.func, arg_types)

Expand Down Expand Up @@ -95,3 +95,11 @@ function result_types(
end
return [types[1].parameters...]
end

function result_types(
fc::FunctionCall{VAL_T,Expr}, known_res_types::Dict{Symbol,Type}
) where {VAL_T}
# assume that the return type is already set
@assert length(fc.return_types) == 1
return [fc.return_types[1]]
end
4 changes: 4 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,7 @@ Return the given vector as single String without quotation marks or brackets.
function unroll_symbol_vector(vec::VEC) where {VEC<:Union{AbstractVector,Tuple}}
return Expr(:tuple, vec...)
end

@inline function _call(f, args::Vararg)
return f(args...)
end

0 comments on commit 4003edc

Please sign in to comment.