diff --git a/ext/devices/cuda/function.jl b/ext/devices/cuda/function.jl index ef7920a..1bf5a5d 100644 --- a/ext/devices/cuda/function.jl +++ b/ext/devices/cuda/function.jl @@ -1,4 +1,6 @@ -function ComputableDAGs.kernel(::Type{CUDAGPU}, graph::DAG, instance) +function ComputableDAGs.kernel( + ::Type{CUDAGPU}, graph::DAG, instance, context_module::Module +) machine = cpu_st() tape = ComputableDAGs.gen_tape(graph, instance, machine, context_module) diff --git a/ext/devices/rocm/function.jl b/ext/devices/rocm/function.jl index 6990ba1..2ecca3b 100644 --- a/ext/devices/rocm/function.jl +++ b/ext/devices/rocm/function.jl @@ -1,4 +1,6 @@ -function ComputableDAGs.kernel(::Type{ROCmGPU}, graph::DAG, instance) +function ComputableDAGs.kernel( + ::Type{ROCmGPU}, graph::DAG, instance, context_module::Module +) machine = cpu_st() tape = ComputableDAGs.gen_tape(graph, instance, machine, context_module)