Skip to content

Commit

Permalink
WIP and renaming/docs/removing unused interfaces
Browse files Browse the repository at this point in the history
  • Loading branch information
AntonReinhard committed Nov 27, 2024
1 parent ee213a3 commit 5e799e9
Show file tree
Hide file tree
Showing 8 changed files with 130 additions and 97 deletions.
6 changes: 2 additions & 4 deletions src/ComputableDAGs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ export get_operations

# code generation related
export get_compute_function
export gen_tape

# estimator
export cost_type, graph_cost, operation_effect
Expand All @@ -48,8 +47,7 @@ export optimize_step!, optimize!
export fixpoint_reached, optimize_to_fixpoint!

# models
export AbstractModel, AbstractProblemInstance
export problem_instance, input_type, graph, input_expr
export graph, input_type, input_expr

# machine info
export Machine
Expand All @@ -67,6 +65,7 @@ include("properties/type.jl")
include("operation/type.jl")
include("graph/type.jl")
include("scheduler/type.jl")
include("code_gen/type.jl")

include("trie.jl")
include("utils.jl")
Expand Down Expand Up @@ -125,7 +124,6 @@ include("devices/ext.jl")
include("scheduler/interface.jl")
include("scheduler/greedy.jl")

include("code_gen/type.jl")
include("code_gen/utils.jl")
include("code_gen/tape_machine.jl")
include("code_gen/function.jl")
Expand Down
49 changes: 35 additions & 14 deletions src/code_gen/tape_machine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ end
"""
gen_input_assignment_code(
input_symbols::Dict{String, Vector{Symbol}},
instance::AbstractProblemInstance,
instance::Any,
machine::Machine,
input_type::Type,
context_module::Module
Expand Down Expand Up @@ -73,13 +73,15 @@ end
Generate the function body from the given [`Tape`](@ref).
## Keyword Arguments
`closures_size`: The size of closures to generate (in lines of code). Closures introduce function barriers in the function body, preventing some optimizations by the compiler and therefore greatly reducing compile time. A value of 1 or less will disable the use of closures entirely.
`closures_size`: The size of closures to generate (in lines of code). Closures introduce function barriers
in the function body, preventing some optimizations by the compiler and therefore greatly reducing
compile time. A value of 0 will disable the use of closures entirely.
"""
function gen_function_body(tape::Tape, context_module::Module; closures_size::Int)
# only need to annotate types later when using closures
types = infer_types!(tape)

if closures_size >= 1
if closures_size > 1
s = log(closures_size, length(tape.schedule))
closures_depth = ceil(Int, s) # tend towards more levels/smaller closures
closures_size = ceil(Int, length(tape.schedule)^(1 / closures_depth))
Expand Down Expand Up @@ -116,12 +118,12 @@ function _gen_function_body(
b = max(i - closures_size, 1)
code_block = fc_vec[b:e]

pushfirst!(
closured_fc_vec,
_closure_fc(
code_block, type_dict, machine, undefined_argument_symbols, context_module
),
closure_fcs = _closure_fc(
code_block, type_dict, machine, undefined_argument_symbols, context_module
)

pushfirst!(closured_fc_vec, closure_fcs[2])
pushfirst!(closured_fc_vec, closure_fcs[1])
end

return _gen_function_body(
Expand All @@ -130,9 +132,16 @@ function _gen_function_body(
end

"""
_closure_fc()
_closure_fc(
code_block::AbstractVector{FunctionCall},
types::Dict{Symbol,Type},
machine::Machine,
undefined_argument_symbols::Set{Symbol},
context_module::Module,
)
From the given function calls, make and return a new function call representing all of them together.
From the given function calls, make and return 2 function calls representing all of them together. 2 function calls are necessary, one for setting up the anonymous
function and the second for calling it.
The undefined_argument_symbols is the set of all Symbols that need to be returned if available inside the code_block. They get updated inside this function.
"""
function _closure_fc(
Expand Down Expand Up @@ -168,7 +177,9 @@ function _closure_fc(
arg_symbols_t = [arg_symbols_set...]
ret_symbols_t = [ret_symbols_set...]

closure = context_module.eval(
f_id = Symbol(to_var_name(UUIDs.uuid1(rng[threadid()])))

closure_setup = FunctionCall(
Expr( # create the closure: () -> code block; return (locals)
:->,
Expr(:tuple, arg_symbols_t...), # closure arguments
Expand All @@ -185,23 +196,33 @@ function _closure_fc(
),
),
),
(),
Symbol[],
Symbol[f_id],
Type[Nothing],
entry_device(machine),
)

ret_types = (getindex.(Ref(types), ret_symbols_t))

fc = FunctionCall(
closure, (), arg_symbols_t, ret_symbols_t, ret_types, entry_device(machine)
_call,
(),
Symbol[f_id, arg_symbols_t...],
ret_symbols_t,
ret_types,
entry_device(machine),
)

setdiff!(undefined_argument_symbols, ret_symbols_set)

return fc
return (closure_setup, fc)
end

"""
gen_tape(
graph::DAG,
instance::AbstractProblemInstance,
instance::Any,
machine::Machine,
context_module::Module,
scheduler::AbstractScheduler = GreedyScheduler()
Expand Down
72 changes: 68 additions & 4 deletions src/code_gen/type.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,76 @@
"""
FunctionCall{VAL_T<:Tuple,FUNC_T<:Union{Function,Expr}}
Representation of a function call. Contains the function to call (or an expression of a value to assign),
value arguments of type `VAL_T`, argument symbols, the return symbol(s) and type(s) and the device to execute on.
To support vectorization, i.e., calling the same function on multiple inputs (SIMD), the value arguments, arguments,
and return symbols are each vectors of the actual inputs. In the non-vectorized case, these `Vector`s simply always
have length 1. For this common case, a special constructor exists which automatically wraps each of these arguments
in a `Vector`.
## Type Arguments
- `VAL_T<:Tuple`: A tuple of all the value arguments that are passed to the function when it's called.
- `FUNC_T<:Union{Function, Expr}`: The type of the function. `Function` is the default, but in some cases, an `Expr`
of a value can be necessary to assign to the return symbol. In this case, no arguments are allowed.
## Fields
- `func::FUNC_T`: The function to be called, or an expression containing a value to assign to the return_symbol.
- `value_arguments::Vector{VAL_T}`: The value arguments for the function call. These are passed *first* to the
function, in the order given here. The `Vector` contains the tuple of value arguments for each vectorization
member.
- `arguments::Vector{Vector{Symbol}}`: The first vector represents the vectorization, the second layer represents the
symbols that will be passed as arguments to the function call.
- `return_symbols::Vector{Vector{Symbol}}`: As with the arguments, the first vector level represents the vectorization,
the second represents the symbols that the results of the function call are assigned to. For most function calls,
there is only one return symbol. When using closures when generating a function body for a [`Tape`](@ref), the
option to have multiple return symbols is necessary.
- `return_types::Vector{<:Type}`: The types of the function call with the arguments provided. This field only contains
one level of Vector, because it is required that a `FunctionCall` is type stable, and therefore, the types of the
return symbols have to be equal for all members of a vectorization. The return type is initially set to `Nothing`
and later inferred and assigned by [`infer_types!`](@ref).
- `device::AbstractDevice`: The device that this function call is scheduled on.
"""
mutable struct FunctionCall{VAL_T<:Tuple,FUNC_T<:Union{Function,Expr}}
func::FUNC_T
value_arguments::Vector{VAL_T} # tuple of value arguments for the function call, will be prepended to the other arguments
arguments::Vector{Vector{Symbol}} # symbols of the inputs to the function call
return_symbols::Vector{Vector{Symbol}} # the return symbols
return_types::Vector{<:Type} # the return type of the function call(s); there can only be one return type since we require type stability
device::AbstractDevice
end
function FunctionCall(
func::Union{Function,Expr},
value_arguments::VAL_T,
arguments::Vector{Symbol},
return_symbol::Vector{Symbol},
return_types::Vector{<:Type},
device::AbstractDevice,
) where {VAL_T<:Tuple}
# convenience constructor for function calls that do not use vectorization, which is most of the use cases
@assert length(return_types) == 0 || length(return_types) == length(return_symbol) "number of return types '$(length(return_types))' does not match the number of return symbols '$(length(return_symbol))'"
@assert func isa Function || length(value_arguments) == 0 && isempty(arguments) "no arguments are allowed for a an Expr FunctionCall, but got '$value_arguments' and '$arguments'"
return FunctionCall(
func, [value_arguments], [arguments], [return_symbol], return_types, device
)
end

"""
Tape{INPUT}
TODO: update docs
- `INPUT` the input type of the problem instance
Lowered representation of a computation, generated from a [`DAG`](@ref) through [`gen_tape`](@ref).
- `INPUT` the input type of the problem instance, see also the interface function [`input_type`](@ref)
- `code::Vector{Expr}`: The julia expression containing the code for the whole graph.
- `output_symbol::Symbol`: The symbol of the final calculated value
## Fields
- `input_assign_code::Vector{FunctionCall}`: The [`FunctionCall`](@ref)s representing the input assignments,
mapping part of the input of the computation to each DAG entry node. These functions are generated using
the interface function [`input_expr`](@ref).
- `schedule::Vector{FunctionCall}`: The [`FunctionCall`](@ref)s representing the function body of the computation.
There is one function call for each node in the [`DAG`](@ref).
- `output_symbol::Symbol`: The symbol of the final calculated value, which is returned.
- `instance::Any`: The instance that this tape is generated for.
- `machine::Machine`: The [`Machine`](@ref) that this tape is generated for.
"""
struct Tape{INPUT}
input_assign_code::Vector{FunctionCall}
Expand Down
5 changes: 5 additions & 0 deletions src/code_gen/utils.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
function Base.length(fc::FunctionCall)
@assert length(fc.value_arguments) == length(fc.arguments) == length(fc.return_symbols) "function call length is undefined, got '$(length(fc.value_arguments))' tuples of value arguments, '$(length(fc.arguments))' tuples of arguments, and '$(length(return_symbols))' return symbols"
return length(fc.value_arguments)
end

"""
infer_types!(schedule::Vector{FunctionCall})
Expand Down
40 changes: 11 additions & 29 deletions src/models/interface.jl
Original file line number Diff line number Diff line change
@@ -1,44 +1,26 @@

"""
AbstractModel
Base type for all models. From this, [`AbstractProblemInstance`](@ref)s can be constructed.
See also: [`problem_instance`](@ref)
"""
abstract type AbstractModel end
input_type(problem_instance)
"""
problem_instance(::AbstractModel, ::Vararg)
Interface function that must be implemented for any implementation of [`AbstractModel`](@ref). This function should return a specific [`AbstractProblemInstance`](@ref) given some parameters.
"""
function problem_instance end
Return the input type for a specific `problem_instance`. This can be a specific type or a supertype for which all child types are expected to be implemented.
"""
AbstractProblemInstance
Base type for problem instances. An object of this type of a corresponding [`AbstractModel`](@ref) should uniquely identify a problem instance of that model.
"""
abstract type AbstractProblemInstance end

"""
input_type(problem::AbstractProblemInstance)
Return the input type for a specific [`AbstractProblemInstance`](@ref). This can be a specific type or a supertype for which all child types are expected to work.
For more details on the `problem_instance`, please refer to the documentation.
"""
function input_type end

"""
graph(::AbstractProblemInstance)
graph(problem_instance)
Generate the [`DAG`](@ref) for the given `problem_instance`. Every entry node (see [`get_entry_nodes`](@ref)) to the graph must have a name set. Implement [`input_expr`](@ref) to return a valid expression for each of those names.
Generate the [`DAG`](@ref) for the given [`AbstractProblemInstance`](@ref). Every entry node (see [`get_entry_nodes`](@ref)) to the graph must have a name set. Implement [`input_expr`](@ref) to return a valid expression for each of those names.
For more details on the `problem_instance`, please refer to the documentation.
"""
function graph end

"""
input_expr(instance::AbstractProblemInstance, name::String, input_symbol::Symbol)
input_expr(problem_instance, name::String, input_symbol::Symbol)
For the given `problem_instance`, the entry node name, and the symbol of the problem input (where a variable of type `input_type(...)` will exist), return an `Expr` that gets that specific input value from the input symbol.
For the given [`AbstractProblemInstance`](@ref), the entry node name, and the symbol of the problem input (where a variable of type `input_type(...)` will exist), return an `Expr` that gets that specific input value from the input symbol.
For more details on the `problem_instance`, please refer to the documentation.
"""
function input_expr end
12 changes: 6 additions & 6 deletions src/properties/type.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
Representation of a [`DAG`](@ref)'s properties.
# Fields:
`.data`: The total data transfer.\\
`.compute_effort`: The total compute effort.\\
`.compute_intensity`: The compute intensity, will always equal `.compute_effort / .data`.\\
`.number_of_nodes`: Number of [`Node`](@ref)s.\\
`.number_of_edges`: Number of [`Edge`](@ref)s.
## Fields:
- `data::Float64`: The total data transfer.
- `compute_effort::Float64`: The total compute effort.
- `compute_intensity::Float64`: The compute intensity, will always equal `compute_effort / data`.
- `number_of_nodes::Int`: Number of [`Node`](@ref)s.
- `number_of_edges::Int`: Number of [`Edge`](@ref)s.
"""
const GraphProperties = NamedTuple{
(:data, :compute_effort, :compute_intensity, :number_of_nodes, :number_of_edges),
Expand Down
8 changes: 0 additions & 8 deletions src/scheduler/interface.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,3 @@

"""
AbstractScheduler
Abstract base type for scheduler implementations. The scheduler is used to assign each node to a device and create a topological ordering of tasks.
"""
abstract type AbstractScheduler end

"""
schedule_dag(::Scheduler, ::DAG, ::Machine)
Expand Down
35 changes: 3 additions & 32 deletions src/scheduler/type.jl
Original file line number Diff line number Diff line change
@@ -1,35 +1,6 @@
"""
FunctionCall{VAL_TYPES}
AbstractScheduler
Type representing a function call. Contains the function to call, argument symbols, the return symbol and the device to execute on.
TODO: extend docs
Abstract base type for scheduler implementations. The scheduler is used to assign each node to a device and create a topological ordering of tasks.
"""
mutable struct FunctionCall{VAL_T<:Tuple,FUNC_T<:Union{Function,Expr}}
func::FUNC_T
value_arguments::Vector{VAL_T} # tuple of value arguments for the function call, will be prepended to the other arguments
arguments::Vector{Vector{Symbol}} # symbols of the inputs to the function call
return_symbols::Vector{Vector{Symbol}} # the return symbols
return_types::Vector{<:Type} # the return type of the function call(s); there can only be one return type since we require type stability
device::AbstractDevice
end

function FunctionCall(
func::Union{Function,Expr},
value_arguments::VAL_T,
arguments::Vector{Symbol},
return_symbol::Vector{Symbol},
return_types::Vector{<:Type},
device::AbstractDevice,
) where {VAL_T<:Tuple}
# convenience constructor for function calls that do not use vectorization, which is most of the use cases
@assert length(return_types) == 0 || length(return_types) == length(return_symbol) "number of return types $(length(return_types)) does not match the number of return symbols $(length(return_symbol))"
return FunctionCall(
func, [value_arguments], [arguments], [return_symbol], return_types, device
)
end

function Base.length(fc::FunctionCall)
@assert length(fc.value_arguments) == length(fc.arguments) == length(fc.return_symbols) "function call length is undefined, got $(length(fc.value_arguments)) tuples of value arguments, $(length(fc.arguments)) tuples of arguments, and $(length(return_symbols)) return symbols"
return length(fc.value_arguments)
end
abstract type AbstractScheduler end

0 comments on commit 5e799e9

Please sign in to comment.