Skip to content

Commit

Permalink
renaming/docs/removing unused interfaces
Browse files Browse the repository at this point in the history
fix type stability in closures
  • Loading branch information
AntonReinhard committed Nov 28, 2024
1 parent ee213a3 commit e6c8fc0
Show file tree
Hide file tree
Showing 9 changed files with 217 additions and 170 deletions.
6 changes: 2 additions & 4 deletions src/ComputableDAGs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ export get_operations

# code generation related
export get_compute_function
export gen_tape

# estimator
export cost_type, graph_cost, operation_effect
Expand All @@ -48,8 +47,7 @@ export optimize_step!, optimize!
export fixpoint_reached, optimize_to_fixpoint!

# models
export AbstractModel, AbstractProblemInstance
export problem_instance, input_type, graph, input_expr
export graph, input_type, input_expr

# machine info
export Machine
Expand All @@ -67,6 +65,7 @@ include("properties/type.jl")
include("operation/type.jl")
include("graph/type.jl")
include("scheduler/type.jl")
include("code_gen/type.jl")

include("trie.jl")
include("utils.jl")
Expand Down Expand Up @@ -125,7 +124,6 @@ include("devices/ext.jl")
include("scheduler/interface.jl")
include("scheduler/greedy.jl")

include("code_gen/type.jl")
include("code_gen/utils.jl")
include("code_gen/tape_machine.jl")
include("code_gen/function.jl")
Expand Down
132 changes: 70 additions & 62 deletions src/code_gen/tape_machine.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function expr_from_fc(fc::FunctionCall{VAL_T,F_T}) where {VAL_T,F_T<:Function}
function expr_from_fc(fc::FunctionCall{VAL_T,<:Function}) where {VAL_T}
if length(fc) == 1
func_call = Expr(:call, fc.func, fc.value_arguments[1]..., fc.arguments[1]...)
else
Expand All @@ -10,14 +10,35 @@ function expr_from_fc(fc::FunctionCall{VAL_T,F_T}) where {VAL_T,F_T<:Function}
end

function expr_from_fc(fc::FunctionCall{VAL_T,Expr}) where {VAL_T}
@assert length(fc) == 1 && isempty(fc.arguments[1]) && isempty(fc.value_arguments[1]) "function call assigning an expression has an unallowed combination of arguments, which is not allowed\n$fc"
return Expr(:(=), gen_access_expr(fc), fc.func)
@assert length(fc) == 1 && isempty(fc.value_arguments[1]) "function call assigning an expression cannot be vectorized and cannot contain value arguments\n$fc"

fc_expr_in_let = Expr(
:let,
Expr(:block, fc.return_symbols[1]...),
fc.func, # anonymous function code block
)

func_call = Expr(
:call, # call
Expr(
:->, # anonymous function
Expr(
:tuple, # anonymous function arguments
fc.arguments[1]...,
),
fc_expr_in_let,
),
fc.arguments[1]..., # runtime arguments passed to the anonymous function
)

access_expr = gen_access_expr(fc)
return Expr(:(=), access_expr, func_call)
end

"""
gen_input_assignment_code(
input_symbols::Dict{String, Vector{Symbol}},
instance::AbstractProblemInstance,
instance::Any,
machine::Machine,
input_type::Type,
context_module::Module
Expand All @@ -26,40 +47,22 @@ end
Return a `Vector{Expr}` doing the input assignments from the given `problem_input` onto the `input_symbols`.
"""
function gen_input_assignment_code(
input_symbols::Dict{String,Vector{Symbol}},
instance,
machine::Machine,
input_type::Type,
context_module::Module,
input_symbols::Dict{String,Vector{Symbol}}, instance, machine::Machine
)
assign_inputs = Vector{FunctionCall}()
for (name, symbols) in input_symbols
for symbol in symbols
device = entry_device(machine)

f_id = Symbol(to_var_name(UUIDs.uuid1(rng[threadid()])))

fc_setup = FunctionCall(
Expr(:->, :x, input_expr(instance, name, :x)),
fc = FunctionCall(
input_expr(instance, name, :input),
(),
Symbol[],
Symbol[f_id],
Symbol[:input],
Symbol[symbol],
Type[Nothing],
device,
)

fc = FunctionCall(
_call, (), Symbol[f_id, :input], Symbol[symbol], Type[Nothing], device
)

ret_expr = Expr(
:call, Base.return_types, fc_setup.func, Expr(:tuple, input_type)
)
ret_type = context_module.eval(ret_expr)
@assert length(ret_type) == 1
fc.return_types = [ret_type[1]]

push!(assign_inputs, fc_setup)
push!(assign_inputs, fc)
end
end
Expand All @@ -73,13 +76,15 @@ end
Generate the function body from the given [`Tape`](@ref).
## Keyword Arguments
`closures_size`: The size of closures to generate (in lines of code). Closures introduce function barriers in the function body, preventing some optimizations by the compiler and therefore greatly reducing compile time. A value of 1 or less will disable the use of closures entirely.
`closures_size`: The size of closures to generate (in lines of code). Closures introduce function barriers
in the function body, preventing some optimizations by the compiler and therefore greatly reducing
compile time. A value of 0 will disable the use of closures entirely.
"""
function gen_function_body(tape::Tape, context_module::Module; closures_size::Int)
# only need to annotate types later when using closures
types = infer_types!(tape)
types = infer_types!(tape, context_module)

if closures_size >= 1
if closures_size > 1
s = log(closures_size, length(tape.schedule))
closures_depth = ceil(Int, s) # tend towards more levels/smaller closures
closures_size = ceil(Int, length(tape.schedule)^(1 / closures_depth))
Expand Down Expand Up @@ -113,15 +118,14 @@ function _gen_function_body(
closured_fc_vec = FunctionCall[]
for i in length(fc_vec):(-closures_size):1
e = i
b = max(i - closures_size, 1)
b = max(i - closures_size + 1, 1)
code_block = fc_vec[b:e]

pushfirst!(
closured_fc_vec,
_closure_fc(
code_block, type_dict, machine, undefined_argument_symbols, context_module
),
closure_fc = _closure_fc(
code_block, type_dict, machine, undefined_argument_symbols, context_module
)

pushfirst!(closured_fc_vec, closure_fc)
end

return _gen_function_body(
Expand All @@ -130,9 +134,16 @@ function _gen_function_body(
end

"""
_closure_fc()
_closure_fc(
code_block::AbstractVector{FunctionCall},
types::Dict{Symbol,Type},
machine::Machine,
undefined_argument_symbols::Set{Symbol},
context_module::Module,
)
From the given function calls, make and return a new function call representing all of them together.
From the given function calls, make and return 2 function calls representing all of them together. 2 function calls are necessary, one for setting up the anonymous
function and the second for calling it.
The undefined_argument_symbols is the set of all Symbols that need to be returned if available inside the code_block. They get updated inside this function.
"""
function _closure_fc(
Expand Down Expand Up @@ -168,29 +179,28 @@ function _closure_fc(
arg_symbols_t = [arg_symbols_set...]
ret_symbols_t = [ret_symbols_set...]

closure = context_module.eval(
Expr( # create the closure: () -> code block; return (locals)
:->,
Expr(:tuple, arg_symbols_t...), # closure arguments
Expr( # actual function body of the closure
:block,
expr_from_fc.(code_block)...,
Expr(
:return, # have to make sure to not return a tuple of length 1
if length(ret_symbols_t) == 1
ret_symbols_t[1]
else
Expr(:tuple, ret_symbols_t...)
end,
),
),
ret_types = (getindex.(Ref(types), ret_symbols_t))

fc_expr = Expr( # actual function body of the closure
:block,
expr_from_fc.(code_block)...,
Expr(
:return, # have to make sure to not return a tuple of length 1
if length(ret_symbols_t) == 1
ret_symbols_t[1]
else
Expr(:tuple, ret_symbols_t...)
end,
),
)

ret_types = (getindex.(Ref(types), ret_symbols_t))

fc = FunctionCall(
closure, (), arg_symbols_t, ret_symbols_t, ret_types, entry_device(machine)
fc_expr,
(),
Symbol[arg_symbols_t...],
ret_symbols_t,
ret_types,
entry_device(machine),
)

setdiff!(undefined_argument_symbols, ret_symbols_set)
Expand All @@ -201,7 +211,7 @@ end
"""
gen_tape(
graph::DAG,
instance::AbstractProblemInstance,
instance::Any,
machine::Machine,
context_module::Module,
scheduler::AbstractScheduler = GreedyScheduler()
Expand Down Expand Up @@ -233,10 +243,8 @@ function gen_tape(
# get outSymbol
outSym = Symbol(to_var_name(get_exit_node(graph).id))

INPUT_T = input_type(instance)
assign_inputs = gen_input_assignment_code(
input_syms, instance, machine, INPUT_T, context_module
)
assign_inputs = gen_input_assignment_code(input_syms, instance, machine)

INPUT_T = input_type(instance)
return Tape{INPUT_T}(assign_inputs, function_body, outSym, instance, machine)
end
72 changes: 68 additions & 4 deletions src/code_gen/type.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,76 @@
"""
FunctionCall{VAL_T<:Tuple,FUNC_T<:Union{Function,Expr}}
Representation of a function call. Contains the function to call (or an expression of a value to assign),
value arguments of type `VAL_T`, argument symbols, the return symbol(s) and type(s) and the device to execute on.
To support vectorization, i.e., calling the same function on multiple inputs (SIMD), the value arguments, arguments,
and return symbols are each vectors of the actual inputs. In the non-vectorized case, these `Vector`s simply always
have length 1. For this common case, a special constructor exists which automatically wraps each of these arguments
in a `Vector`.
## Type Arguments
- `VAL_T<:Tuple`: A tuple of all the value arguments that are passed to the function when it's called.
- `FUNC_T<:Union{Function, Expr}`: The type of the function. `Function` is the default, but in some cases, an `Expr`
of a value can be necessary to assign to the return symbol. In this case, no arguments are allowed.
## Fields
- `func::FUNC_T`: The function to be called, or an expression containing a value to assign to the return_symbol.
- `value_arguments::Vector{VAL_T}`: The value arguments for the function call. These are passed *first* to the
function, in the order given here. The `Vector` contains the tuple of value arguments for each vectorization
member.
- `arguments::Vector{Vector{Symbol}}`: The first vector represents the vectorization, the second layer represents the
symbols that will be passed as arguments to the function call.
- `return_symbols::Vector{Vector{Symbol}}`: As with the arguments, the first vector level represents the vectorization,
the second represents the symbols that the results of the function call are assigned to. For most function calls,
there is only one return symbol. When using closures when generating a function body for a [`Tape`](@ref), the
option to have multiple return symbols is necessary.
- `return_types::Vector{<:Type}`: The types of the function call with the arguments provided. This field only contains
one level of Vector, because it is required that a `FunctionCall` is type stable, and therefore, the types of the
return symbols have to be equal for all members of a vectorization. The return type is initially set to `Nothing`
and later inferred and assigned by [`infer_types!`](@ref).
- `device::AbstractDevice`: The device that this function call is scheduled on.
"""
mutable struct FunctionCall{VAL_T<:Tuple,FUNC_T<:Union{Function,Expr}}
func::FUNC_T
value_arguments::Vector{VAL_T} # tuple of value arguments for the function call, will be prepended to the other arguments
arguments::Vector{Vector{Symbol}} # symbols of the inputs to the function call
return_symbols::Vector{Vector{Symbol}} # the return symbols
return_types::Vector{<:Type} # the return type of the function call(s); there can only be one return type since we require type stability
device::AbstractDevice
end
function FunctionCall(
func::Union{Function,Expr},
value_arguments::VAL_T,
arguments::Vector{Symbol},
return_symbol::Vector{Symbol},
return_types::Vector{<:Type},
device::AbstractDevice,
) where {VAL_T<:Tuple}
# convenience constructor for function calls that do not use vectorization, which is most of the use cases
@assert length(return_types) == 0 || length(return_types) == length(return_symbol) "number of return types '$(length(return_types))' does not match the number of return symbols '$(length(return_symbol))'"
@assert func isa Function || length(value_arguments) == 0 "no value arguments are allowed for a an Expr FunctionCall, but got '$value_arguments'"
return FunctionCall(
func, [value_arguments], [arguments], [return_symbol], return_types, device
)
end

"""
Tape{INPUT}
TODO: update docs
- `INPUT` the input type of the problem instance
Lowered representation of a computation, generated from a [`DAG`](@ref) through [`gen_tape`](@ref).
- `INPUT` the input type of the problem instance, see also the interface function [`input_type`](@ref)
- `code::Vector{Expr}`: The julia expression containing the code for the whole graph.
- `output_symbol::Symbol`: The symbol of the final calculated value
## Fields
- `input_assign_code::Vector{FunctionCall}`: The [`FunctionCall`](@ref)s representing the input assignments,
mapping part of the input of the computation to each DAG entry node. These functions are generated using
the interface function [`input_expr`](@ref).
- `schedule::Vector{FunctionCall}`: The [`FunctionCall`](@ref)s representing the function body of the computation.
There is one function call for each node in the [`DAG`](@ref).
- `output_symbol::Symbol`: The symbol of the final calculated value, which is returned.
- `instance::Any`: The instance that this tape is generated for.
- `machine::Machine`: The [`Machine`](@ref) that this tape is generated for.
"""
struct Tape{INPUT}
input_assign_code::Vector{FunctionCall}
Expand Down
23 changes: 11 additions & 12 deletions src/code_gen/utils.jl
Original file line number Diff line number Diff line change
@@ -1,23 +1,22 @@
"""
infer_types!(schedule::Vector{FunctionCall})
function Base.length(fc::FunctionCall)
@assert length(fc.value_arguments) == length(fc.arguments) == length(fc.return_symbols) "function call length is undefined, got '$(length(fc.value_arguments))' tuples of value arguments, '$(length(fc.arguments))' tuples of arguments, and '$(length(return_symbols))' return symbols"
return length(fc.value_arguments)
end

Infer the result type of each function call in the given schedule. Returns a dictionary with the result type for each [`Node`](@ref). This assumes that each node has only one statically inferrable return type and will throw an exceptin otherwise.
This also assumes that the given `Vector` contains a topological ordering of its nodes, such as returned by a call to [`schedule_dag`](@ref).
"""
infer_types!(tape::Tape, context_module::Module)
Also returns the inferred types as a `Dict{Symbol, Type}`.
Infer the result type of each function call in the given tape. Returns a dictionary with the result type for each symbol and sets each function call's return_types.
This function assumes that each [`FunctionCall`](@ref) has only one statically inferrable return type and will throw an exception otherwise.
"""
function infer_types!(tape::Tape)
function infer_types!(tape::Tape, context_module::Module)
known_result_types = Dict{Symbol,Type}()

# the only initially known type
known_result_types[:input] = input_type(tape.instance)

for fc in tape.input_assign_code
if typeof(fc.func) isa Expr
continue
end
# for input assign code, the return types are set on construction
res_types = fc.return_types
res_types = result_types(fc, known_result_types, context_module)
for (s, t) in Iterators.zip(
Iterators.flatten(fc.return_symbols),
Iterators.cycle(res_types, length(fc.return_symbols)),
Expand All @@ -27,7 +26,7 @@ function infer_types!(tape::Tape)
end

for fc in tape.schedule
res_types = result_types(fc, known_result_types)
res_types = result_types(fc, known_result_types, context_module)
fc.return_types = res_types
for (s, t) in Iterators.zip(
Iterators.flatten(fc.return_symbols),
Expand Down
Loading

0 comments on commit e6c8fc0

Please sign in to comment.