Skip to content

Commit

Permalink
Fix type inference ability and make closures into anonymous functions
Browse files Browse the repository at this point in the history
  • Loading branch information
AntonReinhard committed Oct 4, 2024
1 parent 41a1bcd commit c8b9d79
Show file tree
Hide file tree
Showing 14 changed files with 188 additions and 64 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@ CUDAExt = "CUDA"
oneAPIExt = "oneAPI"

[compat]
julia = "1.10"
AMDGPU = "1"
CUDA = "5"
DataStructures = "0.18"
NumaAllocators = "0.2"
oneAPI = "1"
RuntimeGeneratedFunctions = "0.5"
StaticArrays = "1"
julia = "1.10"
oneAPI = "1"

[extras]
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Expand Down
2 changes: 1 addition & 1 deletion ext/devices/cuda/function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ function ComputableDAGs.kernel(
init_caches = Expr(:block, tape.initCachesCode...)
assign_inputs = Expr(:block, ComputableDAGs.expr_from_fc.(tape.inputAssignCode)...)
# TODO: use gen_function_body here
code = Expr(:block, ComputableDAGs.expr_from_fc.(tape.computeCode)...)
code = Expr(:block, ComputableDAGs.expr_from_fc.(tape.schedule)...)

function_id = ComputableDAGs.to_var_name(UUIDs.uuid1(ComputableDAGs.rng[1]))
res_sym = eval(
Expand Down
2 changes: 1 addition & 1 deletion ext/devices/rocm/function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ function ComputableDAGs.kernel(
assign_inputs = Expr(:block, ComputableDAGs.expr_from_fc.(tape.inputAssignCode)...)

# TODO use gen_function_body here
code = Expr(:block, ComputableDAGs.expr_from_fc.(tape.computeCode)...)
code = Expr(:block, ComputableDAGs.expr_from_fc.(tape.schedule)...)

function_id = ComputableDAGs.to_var_name(UUIDs.uuid1(ComputableDAGs.rng[1]))
res_sym = eval(
Expand Down
1 change: 1 addition & 0 deletions src/ComputableDAGs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ include("scheduler/interface.jl")
include("scheduler/greedy.jl")

include("code_gen/type.jl")
include("code_gen/utils.jl")
include("code_gen/tape_machine.jl")
include("code_gen/function.jl")

Expand Down
2 changes: 1 addition & 1 deletion src/code_gen/function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ function get_compute_function(

initCaches = Expr(:block, tape.initCachesCode...)
assignInputs = Expr(:block, expr_from_fc.(tape.inputAssignCode)...)
code = gen_function_body(tape.computeCode; closures_size=closures_size)
code = gen_function_body(tape; closures_size=closures_size)

functionId = to_var_name(UUIDs.uuid1(rng[1]))
resSym = eval(
Expand Down
125 changes: 95 additions & 30 deletions src/code_gen/tape_machine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,14 +121,11 @@ function gen_input_assignment_code(
device = entry_device(machine)

fc = FunctionCall(
RuntimeGeneratedFunction(
@__MODULE__,
context_module,
Expr(:->, :x, input_expr(instance, name, :x)),
),
context_module.eval(Expr(:->, :x, input_expr(instance, name, :x))),
SVector{0,Any}(),
SVector{1,Symbol}(:input),
symbol,
Nothing,
device,
)

Expand All @@ -140,40 +137,95 @@ function gen_input_assignment_code(
end

"""
gen_function_body(fc_vec::Vector{FunctionCall}; closures_size)
gen_function_body(tape::Tape; closures_size)
Generate the function body from the given `Vector` of [`FunctionCall`](@ref)s.
Generate the function body from the given [`Tape`](@ref).
## Keyword Arguments
`closures_size`: The size of closures to generate (in lines of code). Closures introduce function barriers in the function body, preventing some optimizations by the compiler and therefore greatly reducing compile time. A value of 1 or less will disable the use of closures entirely.
"""
function gen_function_body(fc_vec::Vector{FunctionCall}; closures_size::Int)
function gen_function_body(tape::Tape; closures_size::Int)
if closures_size > 1
# only need to annotate types later when using closures
infer_types!(tape)
end

fc_vec = tape.schedule

if (closures_size <= 1)
return Expr(:block, expr_from_fc.(fc_vec)...)
end

closures = Vector{Expr}()
for i in 1:closures_size:length(fc_vec)
code_block = fc_vec[i:min(i + closures_size, length(fc_vec))]
# iterate from end to beginning
# this helps because we can collect all undefined arguments to the closures that have to be returned somewhere earlier
undefined_argument_symbols = Set{Symbol}()
# the final return symbol is the return of the entire generated function, it always has to be returned
push!(undefined_argument_symbols, eval(gen_access_expr(fc_vec[end])))

for i in length(fc_vec):(-closures_size):1
e = i
b = max(i - closures_size, 1)
code_block = fc_vec[b:e]

# collect `local var` statements that need to exist before the closure starts
# since the return symbols are always unique, this has to happen for each fc and there will be no duplicates
local_inits = gen_local_init.(code_block)

closure = Expr( # call to the following closure (no arguments)
:call,
Expr( # create the closure: () -> code block; return nothing
:->,
:(),
Expr(# # actual function body of the closure
:block,
expr_from_fc.(code_block)...,
Expr(:return, :nothing),
return_symbols = eval.(gen_access_expr.(code_block))
argument_symbols = Set{Symbol}()

ret_symbols_set = Set(return_symbols)
for fc in code_block
for arg in fc.arguments
symbol = eval(_gen_access_expr(fc.device, fc.device.cacheStrategy, arg))

# symbol won't be defined if it is first calculated in the closure
# so don't add it to the arguments in this case
if !(symbol in ret_symbols_set)
push!(argument_symbols, symbol)
end
end
end
union!(undefined_argument_symbols, argument_symbols)

intersect!(ret_symbols_set, undefined_argument_symbols)
return_symbols = Symbol[ret_symbols_set...]

argument_symbols = [argument_symbols...] # make sure there is an order (doesn't matter which)

closure = Expr(
:block,
Expr(
:(=),
Expr(:tuple, return_symbols...),
Expr(
:call, # call to the following closure (no arguments)
Expr( # create the closure: (args) -> code block; return (locals)
:->,
Expr(:tuple, argument_symbols...), # arguments in the closure definition
Expr( # actual function body of the closure
:block,
local_inits..., # declare local variables with type information inside the closure
expr_from_fc.(code_block)...,
Expr(:return, Expr(:tuple, return_symbols...)),
),
),
argument_symbols..., # arguments to the closure call
),
),
)

setdiff!(undefined_argument_symbols, ret_symbols_set)

#=Expr(
:macrocall,
Symbol("@closure"),
@__LINE__,
Expr( <closure...> )
)=#

# combine to one closure call, including all the local inits and the actual call to the closure
push!(closures, Expr(:block, local_inits..., closure))
pushfirst!(closures, closure)
end

return Expr(:block, closures...)
Expand All @@ -200,25 +252,33 @@ function gen_tape(
scheduler::AbstractScheduler=GreedyScheduler(),
)
schedule = schedule_dag(scheduler, graph, machine)
function_body = lower(schedule, machine)

# get inSymbols
inputSyms = Dict{String,Vector{Symbol}}()
# get input symbols
input_syms = Dict{String,Vector{Symbol}}()
for node in get_entry_nodes(graph)
if !haskey(inputSyms, node.name)
inputSyms[node.name] = Vector{Symbol}()
if !haskey(input_syms, node.name)
input_syms[node.name] = Vector{Symbol}()
end

push!(inputSyms[node.name], Symbol("$(to_var_name(node.id))_in"))
push!(input_syms[node.name], Symbol("$(to_var_name(node.id))_in"))
end

# get outSymbol
outSym = Symbol(to_var_name(get_exit_node(graph).id))

initCaches = gen_cache_init_code(machine)
assign_inputs = gen_input_assignment_code(inputSyms, instance, machine, context_module)
init_caches = gen_cache_init_code(machine)
assign_inputs = gen_input_assignment_code(input_syms, instance, machine, context_module)

return Tape{input_type(instance)}(
initCaches, assign_inputs, schedule, inputSyms, outSym, Dict(), instance, machine
init_caches,
assign_inputs,
function_body,
input_syms,
outSym,
Dict(),
instance,
machine,
)
end

Expand All @@ -228,6 +288,9 @@ end
Execute the given tape with the given input.
For implementation reasons, this disregards the set [`CacheStrategy`](@ref) of the devices and always uses a dictionary.
!!! warning
This is very slow and might not work. This is to be majorly revamped.
"""
function execute_tape(tape::Tape, input)
cache = Dict{Symbol,Any}()
Expand All @@ -238,10 +301,12 @@ function execute_tape(tape::Tape, input)
@eval $expr
end

compute_code = tape.schedule

for function_call in tape.inputAssignCode
call_fc(function_call, cache)
end
for function_call in tape.computeCode
for function_call in compute_code
call_fc(function_call, cache)
end

Expand Down
2 changes: 1 addition & 1 deletion src/code_gen/type.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ TODO: update docs
struct Tape{INPUT}
initCachesCode::Vector{Expr}
inputAssignCode::Vector{FunctionCall}
computeCode::Vector{FunctionCall}
schedule::Vector{FunctionCall}
inputSymbols::Dict{String,Vector{Symbol}}
outputSymbol::Symbol
cache::Dict{Symbol,Any}
Expand Down
45 changes: 45 additions & 0 deletions src/code_gen/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""
infer_types!(schedule::Vector{FunctionCall})
Infer the result type of each function call in the given schedule. Returns a dictionary with the result type for each [`Node`](@ref). This assumes that each node has only one statically inferrable return type and will throw an exceptin otherwise.
This also assumes that the given `Vector` contains a topological ordering of its nodes, such as returned by a call to [`schedule_dag`](@ref).
"""
function infer_types!(tape::Tape)
known_result_types = Dict{Symbol,Type}()

# the only initially known type
known_result_types[:input] = input_type(tape.instance)

for fc in tape.inputAssignCode
res_type = result_type(fc, known_result_types)
fc.return_type = res_type
known_result_types[fc.return_symbol] = res_type
end

for fc in tape.schedule
res_type = result_type(fc, known_result_types)
fc.return_type = res_type
known_result_types[fc.return_symbol] = res_type
end

return nothing
end

"""
lower(schedule::Vector{Node}, machine::Machine)
After [`schedule_dag`](@ref) has made a schedule of nodes, this function lowers the vector of [`Node`](@ref)s into a vector of [`FunctionCall`](@ref)s.
"""
function lower(schedule::Vector{Node}, machine::Machine)
calls = Vector{FunctionCall}()

for node in schedule
if (node isa DataTaskNode && length(children(node)) == 0)
push!(calls, get_init_function_call(node, entry_device(machine)))
else
push!(calls, get_function_call(node)...)
end
end

return calls
end
2 changes: 1 addition & 1 deletion src/devices/numa/impl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ Interface implementation, dispatched to from [`gen_local_init`](@ref).
"""
function _gen_local_init(fc::FunctionCall, ::NumaNode, ::LocalVariables)
s = Symbol("data_$(fc.return_symbol)")
quote_node = Expr(:local, s) # TODO: figure out how to get type info for this local variable
quote_node = Expr(:local, s, :(::), Symbol(fc.return_type)) # TODO: figure out how to get type info for this local variable
return quote_node
end

Expand Down
1 change: 0 additions & 1 deletion src/node/create.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

function DataTaskNode(t::AbstractDataTask, name="")
return DataTaskNode(
t,
Expand Down
36 changes: 16 additions & 20 deletions src/scheduler/greedy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,46 +7,42 @@ A greedy implementation of a scheduler, creating a topological ordering of nodes
struct GreedyScheduler <: AbstractScheduler end

function schedule_dag(::GreedyScheduler, graph::DAG, machine::Machine)
nodeQueue = PriorityQueue{Node,Int}()
node_queue = PriorityQueue{Node,Int}()

# use a priority equal to the number of unseen children -> 0 are nodes that can be added
for node in get_entry_nodes(graph)
enqueue!(nodeQueue, node => 0)
enqueue!(node_queue, node => 0)
end

schedule = Vector{FunctionCall}()
schedule = Vector{Node}()
sizehint!(schedule, length(graph.nodes))

# keep an accumulated cost of things scheduled to this device so far
deviceAccCost = PriorityQueue{AbstractDevice,Float64}()
device_acc_cost = PriorityQueue{AbstractDevice,Float64}()
for device in machine.devices
enqueue!(deviceAccCost, device => 0)
enqueue!(device_acc_cost, device => 0)
end

node = nothing
while !isempty(nodeQueue)
@assert peek(nodeQueue)[2] == 0
node = dequeue!(nodeQueue)
local node
while !isempty(node_queue)
@assert peek(node_queue)[2] == 0
node = dequeue!(node_queue)

# assign the device with lowest accumulated cost to the node (if it's a compute node)
if (isa(node, ComputeTaskNode))
lowestDevice = peek(deviceAccCost)[1]
node.device = lowestDevice
deviceAccCost[lowestDevice] = compute_effort(task(node))
lowest_device = peek(device_acc_cost)[1]
node.device = lowest_device
device_acc_cost[lowest_device] = compute_effort(task(node))
end

if (node isa DataTaskNode && length(children(node)) == 0)
push!(schedule, get_init_function_call(node, entry_device(machine)))
else
push!(schedule, get_function_call(node)...)
end
push!(schedule, node)

for parent in parents(node)
# reduce the priority of all parents by one
if (!haskey(nodeQueue, parent))
enqueue!(nodeQueue, parent => length(children(parent)) - 1)
if (!haskey(node_queue, parent))
enqueue!(node_queue, parent => length(children(parent)) - 1)
else
nodeQueue[parent] = nodeQueue[parent] - 1
node_queue[parent] = node_queue[parent] - 1
end
end
end
Expand Down
2 changes: 1 addition & 1 deletion src/scheduler/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@ The function assigns each [`ComputeTaskNode`](@ref) of the [`DAG`](@ref) to one
[`DataTaskNode`](@ref)s are not scheduled to devices since they do not compute. Instead, a data node transfers data from the [`AbstractDevice`](@ref) of their child to all [`AbstractDevice`](@ref)s of its parents.
Return a `Vector{FunctionCall}`. See [`FunctionCall`](@ref)
The produced schedule can be converted to [`FunctionCall`](@ref)s using [`lower`](@ref).
"""
function schedule_dag end
Loading

0 comments on commit c8b9d79

Please sign in to comment.