From 03cdf5871a6533d650d2b420e39058eeb0a2b4e7 Mon Sep 17 00:00:00 2001 From: AntonReinhard Date: Sun, 17 Nov 2024 03:45:21 +0100 Subject: [PATCH] Improve scheduler --- src/code_gen/tape_machine.jl | 16 +++++++++++----- src/scheduler/greedy.jl | 26 ++++++++++++++++---------- 2 files changed, 27 insertions(+), 15 deletions(-) diff --git a/src/code_gen/tape_machine.jl b/src/code_gen/tape_machine.jl index 2e76b72..f1391b1 100644 --- a/src/code_gen/tape_machine.jl +++ b/src/code_gen/tape_machine.jl @@ -55,10 +55,8 @@ function expr_from_fc(fc::FunctionCall{VAL_T,N_ARG,N_RET}) where {VAL_T,N_ARG,N_ func_call = Expr( :call, fc.func, - ( - fc.value_arguments[1]..., - _gen_access_expr.(Ref(fc.device), fc.arguments[1])..., - )..., + fc.value_arguments[1]..., + _gen_access_expr.(Ref(fc.device), fc.arguments[1])..., ) else # TBW; dispatch to device specific vectorization @@ -119,7 +117,13 @@ function gen_function_body(tape::Tape, context_module::Module; closures_size::In # only need to annotate types later when using closures types = infer_types!(tape) - # TODO calculate closures size better + if closures_size >= 1 + s = log(closures_size, length(tape.schedule)) + closures_depth = ceil(Int, s) # tend towards more levels/smaller closures + closures_size = ceil(Int, length(tape.schedule)^(1 / closures_depth)) + end + + @info "generating function body with closure size $closures_size" return _gen_function_body( tape.schedule, types, tape.machine, context_module; closures_size=closures_size @@ -133,6 +137,7 @@ function _gen_function_body( context_module::Module; closures_size=0, ) + @info "generating function body from $(length(fc_vec)) function calls with closure size $closures_size" if closures_size <= 1 || closures_size >= length(fc_vec) return Expr(:block, expr_from_fc.(fc_vec)...) end @@ -251,6 +256,7 @@ function gen_tape( context_module::Module, scheduler::AbstractScheduler=GreedyScheduler(), ) + @debug "generating tape" schedule = schedule_dag(scheduler, graph, machine) function_body = lower(schedule, machine) diff --git a/src/scheduler/greedy.jl b/src/scheduler/greedy.jl index 63795f1..61077b1 100644 --- a/src/scheduler/greedy.jl +++ b/src/scheduler/greedy.jl @@ -7,14 +7,16 @@ A greedy implementation of a scheduler, creating a topological ordering of nodes struct GreedyScheduler <: AbstractScheduler end function schedule_dag(::GreedyScheduler, graph::DAG, machine::Machine) - node_queue = PriorityQueue{Node,Int}() + node_dict = Dict{Node,Int}() # dictionary of nodes with the number of not-yet-scheduled children + node_stack = Stack{Node}() # stack of currently schedulable nodes, i.e., nodes with all of their children already scheduled + # the stack makes sure that closely related nodes will be scheduled one after another # use a priority equal to the number of unseen children -> 0 are nodes that can be added for node in get_entry_nodes(graph) - enqueue!(node_queue, node => 0) + push!(node_stack, node) end - schedule = Vector{Node}() + schedule = Node[] sizehint!(schedule, length(graph.nodes)) # keep an accumulated cost of things scheduled to this device so far @@ -24,9 +26,8 @@ function schedule_dag(::GreedyScheduler, graph::DAG, machine::Machine) end local node - while !isempty(node_queue) - @assert peek(node_queue)[2] == 0 - node = dequeue!(node_queue) + while !isempty(node_stack) + node = pop!(node_stack) # assign the device with lowest accumulated cost to the node (if it's a compute node) if (isa(node, ComputeTaskNode)) @@ -37,15 +38,20 @@ function schedule_dag(::GreedyScheduler, graph::DAG, machine::Machine) push!(schedule, node) + # find all parent's priority, reduce by one if in the node_dict + # if it reaches zero, push onto node_stack for parent in parents(node) - # reduce the priority of all parents by one - if (!haskey(node_queue, parent)) - enqueue!(node_queue, parent => length(children(parent)) - 1) + parents_prio = get(node_dict, parent, length(children(parent))) - 1 + if parents_prio == 0 + delete!(node_dict, parent) + push!(node_stack, parent) else - node_queue[parent] = node_queue[parent] - 1 + node_dict[parent] = parents_prio end end end + @assert isempty(node_dict) "found unschedulable nodes, this most likely means the graph has a cycle" + return schedule end