Skip to content

Commit

Permalink
Merge pull request #405 from JuliaHealth/multigpu
Browse files Browse the repository at this point in the history
Extend GPU support to Metal, ROCm, and oneAPI backends
  • Loading branch information
cncastillo authored Jun 12, 2024
2 parents 50e2e5b + 64897c8 commit 931fec8
Show file tree
Hide file tree
Showing 13 changed files with 334 additions and 174 deletions.
18 changes: 17 additions & 1 deletion KomaMRICore/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,35 @@ version = "0.8.3"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
KomaMRIBase = "d0bc0b20-b151-4d03-b2a4-6ca51751cb9c"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
ThreadsX = "ac1d9e8a-700a-412c-b207-f0111f4b6c0d"

[weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b"

[extensions]
KomaAMDGPUExt = "AMDGPU"
KomaCUDAExt = "CUDA"
KomaMetalExt = "Metal"
KomaoneAPIExt = "oneAPI"

[compat]
Adapt = "3, 4"
AMDGPU = "0.9"
CUDA = "3, 4, 5"
Functors = "0.4"
KernelAbstractions = "0.9"
KomaMRIBase = "0.8"
Metal = "1"
oneAPI = "1"
Pkg = "1.4"
ProgressMeter = "1"
Reexport = "1"
Expand Down
24 changes: 24 additions & 0 deletions KomaMRICore/ext/KomaAMDGPUExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
module KomaAMDGPUExt

using AMDGPU
import KomaMRICore

KomaMRICore.name(::ROCBackend) = "AMDGPU"
KomaMRICore.isfunctional(::ROCBackend) = AMDGPU.functional()
KomaMRICore.set_device!(::ROCBackend, dev_idx::Integer) = AMDGPU.device_id!(dev_idx)
KomaMRICore.set_device!(::ROCBackend, dev::AMDGPU.HIPDevice) = AMDGPU.device!(dev)
KomaMRICore.device_name(::ROCBackend) = AMDGPU.device().name

function KomaMRICore._print_devices(::ROCBackend)
devices = [
Symbol("($(i-1)$(i == 1 ? "*" : " "))") => d.name for
(i, d) in enumerate(AMDGPU.devices())
]
@info "$(length(AMDGPU.devices())) AMD capable device(s)." devices...
end

function __init__()
push!(KomaMRICore.LOADED_BACKENDS[], ROCBackend())
end

end
23 changes: 23 additions & 0 deletions KomaMRICore/ext/KomaCUDAExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
module KomaCUDAExt

using CUDA
import KomaMRICore

KomaMRICore.name(::CUDABackend) = "CUDA"
KomaMRICore.isfunctional(::CUDABackend) = CUDA.functional()
KomaMRICore.set_device!(::CUDABackend, val) = CUDA.device!(val)
KomaMRICore.device_name(::CUDABackend) = CUDA.name(CUDA.device())

function KomaMRICore._print_devices(::CUDABackend)
devices = [
Symbol("($(i-1)$(i == 1 ? "*" : " "))") => CUDA.name(d) for
(i, d) in enumerate(CUDA.devices())
]
@info "$(length(CUDA.devices())) CUDA capable device(s)." devices...
end

function __init__()
push!(KomaMRICore.LOADED_BACKENDS[], CUDABackend())
end

end
28 changes: 28 additions & 0 deletions KomaMRICore/ext/KomaMetalExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
module KomaMetalExt

using Metal
import KomaMRICore

KomaMRICore.name(::MetalBackend) = "Metal"
KomaMRICore.isfunctional(::MetalBackend) = Metal.functional()
KomaMRICore.set_device!(::MetalBackend, device_index::Integer) = device_index == 1 || @warn "Metal does not support multiple gpu devices. Ignoring the device setting."
KomaMRICore.set_device!(::MetalBackend, dev::Metal.MTLDevice) = Metal.device!(dev)
KomaMRICore.device_name(::MetalBackend) = String(Metal.current_device().name)

function KomaMRICore._print_devices(::MetalBackend)
@info "Metal device type: $(KomaMRICore.device_name(MetalBackend()))"
end

#Temporary workaround for https://github.com/JuliaGPU/Metal.jl/issues/348
#Once run_spin_excitation! and run_spin_precession! are kernel-based, this code
#can be removed
Base.cumsum(x::MtlVector) = convert(MtlVector, cumsum(KomaMRICore.cpu(x)))
Base.cumsum(x::MtlArray{T}; dims) where T = convert(MtlArray{T}, cumsum(KomaMRICore.cpu(x), dims=dims))
Base.findall(x::MtlVector{Bool}) = convert(MtlVector, findall(KomaMRICore.cpu(x)))

function __init__()
push!(KomaMRICore.LOADED_BACKENDS[], MetalBackend())
@warn "Due to https://github.com/JuliaGPU/Metal.jl/issues/348, some functions may need to run on the CPU. Performance may be impacted as a result."
end

end
23 changes: 23 additions & 0 deletions KomaMRICore/ext/KomaoneAPIExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
module KomaoneAPIExt

using oneAPI
import KomaMRICore

KomaMRICore.name(::oneAPIBackend) = "oneAPI"
KomaMRICore.isfunctional(::oneAPIBackend) = oneAPI.functional()
KomaMRICore.set_device!(::oneAPIBackend, val) = oneAPI.device!(val)
KomaMRICore.device_name(::oneAPIBackend) = oneAPI.properties(oneAPI.device()).name

function KomaMRICore._print_devices(::oneAPIBackend)
devices = [
Symbol("($(i-1)$(i == 1 ? "*" : " "))") => oneAPI.properties(d).name for
(i, d) in enumerate(oneAPI.devices())
]
@info "$(length(oneAPI.devices())) oneAPI capable device(s)." devices...
end

function __init__()
push!(KomaMRICore.LOADED_BACKENDS[], oneAPIBackend())
end

end
4 changes: 2 additions & 2 deletions KomaMRICore/src/KomaMRICore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@ module KomaMRICore

# General
import Base.*, Base.abs
import KernelAbstractions as KA
using Reexport
using ThreadsX
# Printing
using ProgressMeter
# Simulation
using CUDA

# KomaMRIBase
@reexport using KomaMRIBase
Expand All @@ -18,6 +17,7 @@ include("rawdata/ISMRMRD.jl")
include("datatypes/Spinor.jl")
include("other/DiffusionModel.jl")
# Simulator
include("simulation/Functors.jl")
include("simulation/GPUFunctions.jl")
include("simulation/SimulatorCore.jl")

Expand Down
94 changes: 94 additions & 0 deletions KomaMRICore/src/simulation/Functors.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import Adapt: adapt, adapt_storage
import Functors: @functor, functor, fmap, isleaf

#Aux. funcitons to check if the variable we want to convert to CuArray is numeric
_isleaf(x) = isleaf(x)
_isleaf(::AbstractArray{<:Number}) = true
_isleaf(::AbstractArray{T}) where T = isbitstype(T)
_isleaf(::AbstractRange) = true

"""
gpu(x)
Tries to move `x` to the GPU backend specified in the 'backend' parameter.
This works for functions, and any struct marked with `@functor`.
Use [`cpu`](@ref) to copy back to ordinary `Array`s.
See also [`f32`](@ref) and [`f64`](@ref) to change element type only.
# Examples
```julia
x = gpu(x, CUDABackend())
```
"""
function gpu(x, backend::KA.GPU)
return fmap(x -> adapt(backend, x), x; exclude=_isleaf)
end

# To CPU
"""
cpu(x)
Tries to move object to CPU. This works for functions, and any struct marked with `@functor`.
See also [`gpu`](@ref).
# Examples
```julia
x = x |> cpu
```
"""
cpu(x) = fmap(x -> adapt(KA.CPU(), x), x, exclude=_isleaf)

#Precision
paramtype(T::Type{<:Real}, m) = fmap(x -> adapt(T, x), m)
adapt_storage(T::Type{<:Real}, xs::Real) = convert(T, xs)
adapt_storage(T::Type{<:Real}, xs::AbstractArray{<:Real}) = convert.(T, xs)
adapt_storage(T::Type{<:Real}, xs::AbstractArray{<:Complex}) = convert.(Complex{T}, xs)
adapt_storage(T::Type{<:Real}, xs::AbstractArray{<:Bool}) = xs
adapt_storage(T::Type{<:Real}, xs::SimpleMotion) = SimpleMotion(paramtype(T, xs.types))
adapt_storage(T::Type{<:Real}, xs::NoMotion) = NoMotion{T}()
function adapt_storage(T::Type{<:Real}, xs::ArbitraryMotion)
fields = []
for field in fieldnames(ArbitraryMotion)
push!(fields, paramtype(T, getfield(xs, field)))
end
return ArbitraryMotion(fields...)
end

"""
f32(m)
Converts the `eltype` of model's parameters to `Float32`
Recurses into structs marked with `@functor`.
See also [`f64`](@ref).
"""
f32(m) = paramtype(Float32, m)

"""
f64(m)
Converts the `eltype` of model's parameters to `Float64` (which is Koma's default)..
Recurses into structs marked with `@functor`.
See also [`f32`](@ref).
"""
f64(m) = paramtype(Float64, m)

#The functor macro makes it easier to call a function in all the parameters
@functor Phantom

@functor Translation
@functor Rotation
@functor HeartBeat
@functor PeriodicTranslation
@functor PeriodicRotation
@functor PeriodicHeartBeat

@functor Spinor
@functor DiscreteSequence

export gpu, cpu, f32, f64
Loading

0 comments on commit 931fec8

Please sign in to comment.