Skip to content

Commit

Permalink
Improve scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
AntonReinhard committed Nov 17, 2024
1 parent edcdb31 commit 03cdf58
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 15 deletions.
16 changes: 11 additions & 5 deletions src/code_gen/tape_machine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,8 @@ function expr_from_fc(fc::FunctionCall{VAL_T,N_ARG,N_RET}) where {VAL_T,N_ARG,N_
func_call = Expr(
:call,
fc.func,
(
fc.value_arguments[1]...,
_gen_access_expr.(Ref(fc.device), fc.arguments[1])...,
)...,
fc.value_arguments[1]...,
_gen_access_expr.(Ref(fc.device), fc.arguments[1])...,
)
else
# TBW; dispatch to device specific vectorization
Expand Down Expand Up @@ -119,7 +117,13 @@ function gen_function_body(tape::Tape, context_module::Module; closures_size::In
# only need to annotate types later when using closures
types = infer_types!(tape)

# TODO calculate closures size better
if closures_size >= 1
s = log(closures_size, length(tape.schedule))
closures_depth = ceil(Int, s) # tend towards more levels/smaller closures
closures_size = ceil(Int, length(tape.schedule)^(1 / closures_depth))
end

@info "generating function body with closure size $closures_size"

return _gen_function_body(
tape.schedule, types, tape.machine, context_module; closures_size=closures_size
Expand All @@ -133,6 +137,7 @@ function _gen_function_body(
context_module::Module;
closures_size=0,
)
@info "generating function body from $(length(fc_vec)) function calls with closure size $closures_size"
if closures_size <= 1 || closures_size >= length(fc_vec)
return Expr(:block, expr_from_fc.(fc_vec)...)
end
Expand Down Expand Up @@ -251,6 +256,7 @@ function gen_tape(
context_module::Module,
scheduler::AbstractScheduler=GreedyScheduler(),
)
@debug "generating tape"
schedule = schedule_dag(scheduler, graph, machine)
function_body = lower(schedule, machine)

Expand Down
26 changes: 16 additions & 10 deletions src/scheduler/greedy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@ A greedy implementation of a scheduler, creating a topological ordering of nodes
struct GreedyScheduler <: AbstractScheduler end

function schedule_dag(::GreedyScheduler, graph::DAG, machine::Machine)
node_queue = PriorityQueue{Node,Int}()
node_dict = Dict{Node,Int}() # dictionary of nodes with the number of not-yet-scheduled children
node_stack = Stack{Node}() # stack of currently schedulable nodes, i.e., nodes with all of their children already scheduled
# the stack makes sure that closely related nodes will be scheduled one after another

# use a priority equal to the number of unseen children -> 0 are nodes that can be added
for node in get_entry_nodes(graph)
enqueue!(node_queue, node => 0)
push!(node_stack, node)
end

schedule = Vector{Node}()
schedule = Node[]
sizehint!(schedule, length(graph.nodes))

# keep an accumulated cost of things scheduled to this device so far
Expand All @@ -24,9 +26,8 @@ function schedule_dag(::GreedyScheduler, graph::DAG, machine::Machine)
end

local node
while !isempty(node_queue)
@assert peek(node_queue)[2] == 0
node = dequeue!(node_queue)
while !isempty(node_stack)
node = pop!(node_stack)

# assign the device with lowest accumulated cost to the node (if it's a compute node)
if (isa(node, ComputeTaskNode))
Expand All @@ -37,15 +38,20 @@ function schedule_dag(::GreedyScheduler, graph::DAG, machine::Machine)

push!(schedule, node)

# find all parent's priority, reduce by one if in the node_dict
# if it reaches zero, push onto node_stack
for parent in parents(node)
# reduce the priority of all parents by one
if (!haskey(node_queue, parent))
enqueue!(node_queue, parent => length(children(parent)) - 1)
parents_prio = get(node_dict, parent, length(children(parent))) - 1
if parents_prio == 0
delete!(node_dict, parent)
push!(node_stack, parent)
else
node_queue[parent] = node_queue[parent] - 1
node_dict[parent] = parents_prio
end
end
end

@assert isempty(node_dict) "found unschedulable nodes, this most likely means the graph has a cycle"

return schedule
end

0 comments on commit 03cdf58

Please sign in to comment.