Skip to content

Commit

Permalink
Remove occurrences of CacheStrategy everywhere (#41)
Browse files Browse the repository at this point in the history
  • Loading branch information
AntonReinhard authored Nov 1, 2024
1 parent 9a93eb8 commit 67a05d7
Show file tree
Hide file tree
Showing 16 changed files with 24 additions and 222 deletions.
1 change: 0 additions & 1 deletion ext/AMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ function __init__()
@debug "Loading AMDGPUExt"

push!(ComputableDAGs.DEVICE_TYPES, ROCmGPU)
ComputableDAGs.CACHE_STRATEGIES[ROCmGPU] = [LocalVariables()]

return nothing
end
Expand Down
1 change: 0 additions & 1 deletion ext/CUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ function __init__()
@debug "Loading CUDAExt"

push!(ComputableDAGs.DEVICE_TYPES, CUDAGPU)
ComputableDAGs.CACHE_STRATEGIES[CUDAGPU] = [LocalVariables()]

return nothing
end
Expand Down
6 changes: 1 addition & 5 deletions ext/devices/cuda/function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,14 @@ function ComputableDAGs.kernel(
machine = cpu_st()
tape = ComputableDAGs.gen_tape(graph, instance, machine, context_module)

init_caches = Expr(:block, tape.initCachesCode...)
assign_inputs = Expr(:block, ComputableDAGs.expr_from_fc.(tape.inputAssignCode)...)
# TODO: use gen_function_body here
code = Expr(:block, ComputableDAGs.expr_from_fc.(tape.schedule)...)

function_id = ComputableDAGs.to_var_name(UUIDs.uuid1(ComputableDAGs.rng[1]))
res_sym = eval(
ComputableDAGs._gen_access_expr(
ComputableDAGs.entry_device(tape.machine),
ComputableDAGs.entry_device(tape.machine).cacheStrategy,
tape.outputSymbol,
ComputableDAGs.entry_device(tape.machine), tape.outputSymbol
),
)
expr = Meta.parse(
Expand All @@ -24,7 +21,6 @@ function ComputableDAGs.kernel(
return
end
@inline data_input = input_vector[id]
$(init_caches)
$(assign_inputs)
$code
@inline output_vector[id] = $res_sym
Expand Down
4 changes: 1 addition & 3 deletions ext/devices/cuda/impl.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
ComputableDAGs.default_strategy(::Type{CUDAGPU}) = LocalVariables()

function ComputableDAGs.measure_device!(device::CUDAGPU; verbose::Bool)
verbose && @info "Measuring CUDA GPU $(device.device)"

Expand All @@ -23,7 +21,7 @@ function ComputableDAGs.get_devices(::Type{CUDAGPU}; verbose::Bool=false)
CUDADevices = CUDA.devices()
verbose && @info "Found $(length(CUDADevices)) CUDA devices"
for device in CUDADevices
push!(devices, CUDAGPU(device, default_strategy(CUDAGPU), -1))
push!(devices, CUDAGPU(device, -1))
end

return devices
Expand Down
4 changes: 1 addition & 3 deletions ext/devices/oneapi/impl.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
ComputableDAGs.default_strategy(::Type{oneAPIGPU}) = LocalVariables()

function ComputableDAGs.measure_device!(device::oneAPIGPU; verbose::Bool)
verbose && @info "Measuring oneAPI GPU $(device.device)"

Expand All @@ -23,7 +21,7 @@ function ComputableDAGs.get_devices(::Type{oneAPIGPU}; verbose::Bool=false)
oneAPIDevices = oneAPI.devices()
verbose && @info "Found $(length(oneAPIDevices)) oneAPI devices"
for device in oneAPIDevices
push!(devices, oneAPIGPU(device, default_strategy(oneAPIGPU), -1))
push!(devices, oneAPIGPU(device, -1))
end

return devices
Expand Down
6 changes: 1 addition & 5 deletions ext/devices/rocm/function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ function ComputableDAGs.kernel(
machine = cpu_st()
tape = ComputableDAGs.gen_tape(graph, instance, machine, context_module)

init_caches = Expr(:block, tape.initCachesCode...)
assign_inputs = Expr(:block, ComputableDAGs.expr_from_fc.(tape.inputAssignCode)...)

# TODO use gen_function_body here
Expand All @@ -13,9 +12,7 @@ function ComputableDAGs.kernel(
function_id = ComputableDAGs.to_var_name(UUIDs.uuid1(ComputableDAGs.rng[1]))
res_sym = eval(
ComputableDAGs._gen_access_expr(
ComputableDAGs.entry_device(tape.machine),
ComputableDAGs.entry_device(tape.machine).cacheStrategy,
tape.outputSymbol,
ComputableDAGs.entry_device(tape.machine), tape.outputSymbol
),
)
expr = Meta.parse(
Expand All @@ -25,7 +22,6 @@ function ComputableDAGs.kernel(
return
end
@inline data_input = input_vector[id]
$(init_caches)
$(assign_inputs)
$code
@inline output_vector[id] = $res_sym
Expand Down
4 changes: 1 addition & 3 deletions ext/devices/rocm/impl.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
ComputableDAGs.default_strategy(::Type{ROCmGPU}) = LocalVariables()

function ComputableDAGs.measure_device!(device::ROCmGPU; verbose::Bool)
verbose && @info "Measuring ROCm GPU $(device.device)"

Expand All @@ -23,7 +21,7 @@ function ComputableDAGs.get_devices(::Type{ROCmGPU}; verbose::Bool=false)
AMDDevices = AMDGPU.devices()
verbose && @info "Found $(length(AMDDevices)) AMD devices"
for device in AMDDevices
push!(devices, ROCmGPU(device, default_strategy(ROCmGPU), -1))
push!(devices, ROCmGPU(device, -1))
end

return devices
Expand Down
1 change: 0 additions & 1 deletion ext/oneAPIExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ function __init__()
@debug "Loading oneAPIExt"

push!(ComputableDAGs.DEVICE_TYPES, oneAPIGPU)
ComputableDAGs.CACHE_STRATEGIES[oneAPIGPU] = [LocalVariables()]

return nothing
end
Expand Down
2 changes: 0 additions & 2 deletions src/ComputableDAGs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@ export problem_instance, input_type, graph, input_expr
export Machine
export NumaNode
export get_machine_info, cpu_st
export CacheStrategy, default_strategy
export LocalVariables, Dictionary

# GPU Extensions
export kernel, CUDAGPU, ROCmGPU, oneAPIGPU
Expand Down
11 changes: 2 additions & 9 deletions src/code_gen/function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,11 @@ function get_compute_function(
)
tape = gen_tape(graph, instance, machine, context_module)

initCaches = Expr(:block, tape.initCachesCode...)
assignInputs = Expr(:block, expr_from_fc.(tape.inputAssignCode)...)
code = gen_function_body(tape; closures_size=closures_size)

functionId = to_var_name(UUIDs.uuid1(rng[1]))
resSym = eval(
_gen_access_expr(
entry_device(tape.machine),
entry_device(tape.machine).cacheStrategy,
tape.outputSymbol,
),
)
resSym = eval(_gen_access_expr(entry_device(tape.machine), tape.outputSymbol))
expr = #
Expr(
:function, # function definition
Expand All @@ -44,7 +37,7 @@ function get_compute_function(
Symbol("compute_$functionId"),
Expr(:(::), :data_input, input_type(instance)),
), # function name and parameters
Expr(:block, initCaches, assignInputs, code, Expr(:return, resSym)), # function body
Expr(:block, assignInputs, code, Expr(:return, resSym)), # function body
)

return RuntimeGeneratedFunction(@__MODULE__, context_module, expr)
Expand Down
42 changes: 4 additions & 38 deletions src/code_gen/tape_machine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,7 @@ end

function expr_from_fc(fc::FunctionCall{VectorT,0}) where {VectorT}
func_call = Expr(
:call,
fc.func,
eval.(
_gen_access_expr.(Ref(fc.device), Ref(fc.device.cacheStrategy), fc.arguments)
)...,
:call, fc.func, eval.(_gen_access_expr.(Ref(fc.device), fc.arguments))...
)
access_expr = eval(gen_access_expr(fc))

Expand All @@ -73,30 +69,13 @@ function expr_from_fc(fc::FunctionCall{VectorT,M}) where {VectorT,M}
:call,
fc.func,
fc.value_arguments...,
eval.(
_gen_access_expr.(Ref(fc.device), Ref(fc.device.cacheStrategy), fc.arguments)
)...,
eval.(_gen_access_expr.(Ref(fc.device), fc.arguments))...,
)
access_expr = eval(gen_access_expr(fc))

return Expr(:(=), access_expr, func_call)
end

"""
gen_cache_init_code(machine::Machine)
For each [`AbstractDevice`](@ref) in the given [`Machine`](@ref), returning a `Vector{Expr}` doing the initialization.
"""
function gen_cache_init_code(machine::Machine)
initialize_caches = Vector{Expr}()

for device in machine.devices
push!(initialize_caches, gen_cache_init_code(device))
end

return initialize_caches
end

"""
gen_input_assignment_code(
input_symbols::Dict{String, Vector{Symbol}},
Expand Down Expand Up @@ -176,7 +155,7 @@ function gen_function_body(tape::Tape; closures_size::Int)
ret_symbols_set = Set(return_symbols)
for fc in code_block
for arg in fc.arguments
symbol = eval(_gen_access_expr(fc.device, fc.device.cacheStrategy, arg))
symbol = eval(_gen_access_expr(fc.device, arg))

# symbol won't be defined if it is first calculated in the closure
# so don't add it to the arguments in this case
Expand Down Expand Up @@ -255,18 +234,10 @@ function gen_tape(
# get outSymbol
outSym = Symbol(to_var_name(get_exit_node(graph).id))

init_caches = gen_cache_init_code(machine)
assign_inputs = gen_input_assignment_code(input_syms, instance, machine, context_module)

return Tape{input_type(instance)}(
init_caches,
assign_inputs,
function_body,
input_syms,
outSym,
Dict(),
instance,
machine,
assign_inputs, function_body, input_syms, outSym, instance, machine
)
end

Expand All @@ -275,8 +246,6 @@ end
Execute the given tape with the given input.
For implementation reasons, this disregards the set [`CacheStrategy`](@ref) of the devices and always uses a dictionary.
!!! warning
This is very slow and might not work. This is to be majorly revamped.
"""
Expand All @@ -285,9 +254,6 @@ function execute_tape(tape::Tape, input)
cache[:input] = input
# simply execute all the code snippets here
@assert typeof(input) <: input_type(tape.instance) "expected tape input type to fit $(input_type(tape.instance)) but got $(typeof(input))"
for expr in tape.initCachesCode
@eval $expr
end

compute_code = tape.schedule

Expand Down
2 changes: 0 additions & 2 deletions src/code_gen/type.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,10 @@ TODO: update docs
- `outputSymbol::Symbol`: The symbol of the final calculated value
"""
struct Tape{INPUT}
initCachesCode::Vector{Expr}
inputAssignCode::Vector{FunctionCall}
schedule::Vector{FunctionCall}
inputSymbols::Dict{String,Vector{Symbol}}
outputSymbol::Symbol
cache::Dict{Symbol,Any}
instance::Any
machine::Machine
end
3 changes: 0 additions & 3 deletions src/devices/ext.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ Representation of a specific CUDA GPU that code can run on. Implements the [`Abs
"""
mutable struct CUDAGPU <: AbstractGPU
device::Any # CuDevice
cacheStrategy::CacheStrategy
FLOPS::Float64
end

Expand All @@ -25,7 +24,6 @@ Representation of a specific Intel GPU that code can run on. Implements the [`Ab
"""
mutable struct oneAPIGPU <: AbstractGPU
device::Any # oneAPI.oneL0.ZeDevice
cacheStrategy::CacheStrategy
FLOPS::Float64
end

Expand All @@ -39,6 +37,5 @@ Representation of a specific AMD GPU that code can run on. Implements the [`Abst
"""
mutable struct ROCmGPU <: AbstractGPU
device::Any # HIPDevice
cacheStrategy::CacheStrategy
FLOPS::Float64
end
41 changes: 3 additions & 38 deletions src/devices/impl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,49 +18,14 @@ function entry_device(machine::Machine)
return machine.devices[1]
end

"""
strategies(t::Type{T}) where {T <: AbstractDevice}
Return a vector of available [`CacheStrategy`](@ref)s for the given [`AbstractDevice`](@ref).
The caching strategies are used in code generation.
"""
function strategies(t::Type{T}) where {T<:AbstractDevice}
if !haskey(CACHE_STRATEGIES, t)
error("Trying to get strategies for $T, but it has no strategies defined!")
end

return CACHE_STRATEGIES[t]
end

"""
cache_strategy(device::AbstractDevice)
Returns the cache strategy set for this device.
"""
function cache_strategy(device::AbstractDevice)
return device.cacheStrategy
end

"""
set_cache_strategy(device::AbstractDevice, cacheStrategy::CacheStrategy)
Sets the device's cache strategy. After this call, [`cache_strategy`](@ref) should return `cacheStrategy` on the given device.
"""
function set_cache_strategy(device::AbstractDevice, cacheStrategy::CacheStrategy)
device.cacheStrategy = cacheStrategy
return nothing
end

"""
cpu_st()
A function returning a [`Machine`](@ref) that only has a single thread of one CPU.
It is the simplest machine definition possible and produces a simple function when used with [`get_compute_function`](@ref).
"""
function cpu_st()
return Machine(
[NumaNode(0, 1, default_strategy(NumaNode), -1.0, UUIDs.uuid1())], [-1.0;;]
)
return Machine([NumaNode(0, 1, -1.0, UUIDs.uuid1())], [-1.0;;])
end

"""
Expand All @@ -69,7 +34,7 @@ end
Dispatch from the given [`FunctionCall`](@ref) to the interface function `_gen_access_expr`(@ref).
"""
function gen_access_expr(fc::FunctionCall)
return _gen_access_expr(fc.device, fc.device.cacheStrategy, fc.return_symbol)
return _gen_access_expr(fc.device, fc.return_symbol)
end

"""
Expand All @@ -78,5 +43,5 @@ end
Dispatch from the given [`FunctionCall`](@ref) to the interface function `_gen_local_init`(@ref).
"""
function gen_local_init(fc::FunctionCall)
return _gen_local_init(fc, fc.device, fc.device.cacheStrategy)
return _gen_local_init(fc, fc.device)
end
Loading

0 comments on commit 67a05d7

Please sign in to comment.