-
Notifications
You must be signed in to change notification settings - Fork 21
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #405 from JuliaHealth/multigpu
Extend GPU support to Metal, ROCm, and oneAPI backends
- Loading branch information
Showing
13 changed files
with
334 additions
and
174 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
module KomaAMDGPUExt | ||
|
||
using AMDGPU | ||
import KomaMRICore | ||
|
||
KomaMRICore.name(::ROCBackend) = "AMDGPU" | ||
KomaMRICore.isfunctional(::ROCBackend) = AMDGPU.functional() | ||
KomaMRICore.set_device!(::ROCBackend, dev_idx::Integer) = AMDGPU.device_id!(dev_idx) | ||
KomaMRICore.set_device!(::ROCBackend, dev::AMDGPU.HIPDevice) = AMDGPU.device!(dev) | ||
KomaMRICore.device_name(::ROCBackend) = AMDGPU.device().name | ||
|
||
function KomaMRICore._print_devices(::ROCBackend) | ||
devices = [ | ||
Symbol("($(i-1)$(i == 1 ? "*" : " "))") => d.name for | ||
(i, d) in enumerate(AMDGPU.devices()) | ||
] | ||
@info "$(length(AMDGPU.devices())) AMD capable device(s)." devices... | ||
end | ||
|
||
function __init__() | ||
push!(KomaMRICore.LOADED_BACKENDS[], ROCBackend()) | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
module KomaCUDAExt | ||
|
||
using CUDA | ||
import KomaMRICore | ||
|
||
KomaMRICore.name(::CUDABackend) = "CUDA" | ||
KomaMRICore.isfunctional(::CUDABackend) = CUDA.functional() | ||
KomaMRICore.set_device!(::CUDABackend, val) = CUDA.device!(val) | ||
KomaMRICore.device_name(::CUDABackend) = CUDA.name(CUDA.device()) | ||
|
||
function KomaMRICore._print_devices(::CUDABackend) | ||
devices = [ | ||
Symbol("($(i-1)$(i == 1 ? "*" : " "))") => CUDA.name(d) for | ||
(i, d) in enumerate(CUDA.devices()) | ||
] | ||
@info "$(length(CUDA.devices())) CUDA capable device(s)." devices... | ||
end | ||
|
||
function __init__() | ||
push!(KomaMRICore.LOADED_BACKENDS[], CUDABackend()) | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
module KomaMetalExt | ||
|
||
using Metal | ||
import KomaMRICore | ||
|
||
KomaMRICore.name(::MetalBackend) = "Metal" | ||
KomaMRICore.isfunctional(::MetalBackend) = Metal.functional() | ||
KomaMRICore.set_device!(::MetalBackend, device_index::Integer) = device_index == 1 || @warn "Metal does not support multiple gpu devices. Ignoring the device setting." | ||
KomaMRICore.set_device!(::MetalBackend, dev::Metal.MTLDevice) = Metal.device!(dev) | ||
KomaMRICore.device_name(::MetalBackend) = String(Metal.current_device().name) | ||
|
||
function KomaMRICore._print_devices(::MetalBackend) | ||
@info "Metal device type: $(KomaMRICore.device_name(MetalBackend()))" | ||
end | ||
|
||
#Temporary workaround for https://github.com/JuliaGPU/Metal.jl/issues/348 | ||
#Once run_spin_excitation! and run_spin_precession! are kernel-based, this code | ||
#can be removed | ||
Base.cumsum(x::MtlVector) = convert(MtlVector, cumsum(KomaMRICore.cpu(x))) | ||
Base.cumsum(x::MtlArray{T}; dims) where T = convert(MtlArray{T}, cumsum(KomaMRICore.cpu(x), dims=dims)) | ||
Base.findall(x::MtlVector{Bool}) = convert(MtlVector, findall(KomaMRICore.cpu(x))) | ||
|
||
function __init__() | ||
push!(KomaMRICore.LOADED_BACKENDS[], MetalBackend()) | ||
@warn "Due to https://github.com/JuliaGPU/Metal.jl/issues/348, some functions may need to run on the CPU. Performance may be impacted as a result." | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
module KomaoneAPIExt | ||
|
||
using oneAPI | ||
import KomaMRICore | ||
|
||
KomaMRICore.name(::oneAPIBackend) = "oneAPI" | ||
KomaMRICore.isfunctional(::oneAPIBackend) = oneAPI.functional() | ||
KomaMRICore.set_device!(::oneAPIBackend, val) = oneAPI.device!(val) | ||
KomaMRICore.device_name(::oneAPIBackend) = oneAPI.properties(oneAPI.device()).name | ||
|
||
function KomaMRICore._print_devices(::oneAPIBackend) | ||
devices = [ | ||
Symbol("($(i-1)$(i == 1 ? "*" : " "))") => oneAPI.properties(d).name for | ||
(i, d) in enumerate(oneAPI.devices()) | ||
] | ||
@info "$(length(oneAPI.devices())) oneAPI capable device(s)." devices... | ||
end | ||
|
||
function __init__() | ||
push!(KomaMRICore.LOADED_BACKENDS[], oneAPIBackend()) | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
import Adapt: adapt, adapt_storage | ||
import Functors: @functor, functor, fmap, isleaf | ||
|
||
#Aux. funcitons to check if the variable we want to convert to CuArray is numeric | ||
_isleaf(x) = isleaf(x) | ||
_isleaf(::AbstractArray{<:Number}) = true | ||
_isleaf(::AbstractArray{T}) where T = isbitstype(T) | ||
_isleaf(::AbstractRange) = true | ||
|
||
""" | ||
gpu(x) | ||
Tries to move `x` to the GPU backend specified in the 'backend' parameter. | ||
This works for functions, and any struct marked with `@functor`. | ||
Use [`cpu`](@ref) to copy back to ordinary `Array`s. | ||
See also [`f32`](@ref) and [`f64`](@ref) to change element type only. | ||
# Examples | ||
```julia | ||
x = gpu(x, CUDABackend()) | ||
``` | ||
""" | ||
function gpu(x, backend::KA.GPU) | ||
return fmap(x -> adapt(backend, x), x; exclude=_isleaf) | ||
end | ||
|
||
# To CPU | ||
""" | ||
cpu(x) | ||
Tries to move object to CPU. This works for functions, and any struct marked with `@functor`. | ||
See also [`gpu`](@ref). | ||
# Examples | ||
```julia | ||
x = x |> cpu | ||
``` | ||
""" | ||
cpu(x) = fmap(x -> adapt(KA.CPU(), x), x, exclude=_isleaf) | ||
|
||
#Precision | ||
paramtype(T::Type{<:Real}, m) = fmap(x -> adapt(T, x), m) | ||
adapt_storage(T::Type{<:Real}, xs::Real) = convert(T, xs) | ||
adapt_storage(T::Type{<:Real}, xs::AbstractArray{<:Real}) = convert.(T, xs) | ||
adapt_storage(T::Type{<:Real}, xs::AbstractArray{<:Complex}) = convert.(Complex{T}, xs) | ||
adapt_storage(T::Type{<:Real}, xs::AbstractArray{<:Bool}) = xs | ||
adapt_storage(T::Type{<:Real}, xs::SimpleMotion) = SimpleMotion(paramtype(T, xs.types)) | ||
adapt_storage(T::Type{<:Real}, xs::NoMotion) = NoMotion{T}() | ||
function adapt_storage(T::Type{<:Real}, xs::ArbitraryMotion) | ||
fields = [] | ||
for field in fieldnames(ArbitraryMotion) | ||
push!(fields, paramtype(T, getfield(xs, field))) | ||
end | ||
return ArbitraryMotion(fields...) | ||
end | ||
|
||
""" | ||
f32(m) | ||
Converts the `eltype` of model's parameters to `Float32` | ||
Recurses into structs marked with `@functor`. | ||
See also [`f64`](@ref). | ||
""" | ||
f32(m) = paramtype(Float32, m) | ||
|
||
""" | ||
f64(m) | ||
Converts the `eltype` of model's parameters to `Float64` (which is Koma's default).. | ||
Recurses into structs marked with `@functor`. | ||
See also [`f32`](@ref). | ||
""" | ||
f64(m) = paramtype(Float64, m) | ||
|
||
#The functor macro makes it easier to call a function in all the parameters | ||
@functor Phantom | ||
|
||
@functor Translation | ||
@functor Rotation | ||
@functor HeartBeat | ||
@functor PeriodicTranslation | ||
@functor PeriodicRotation | ||
@functor PeriodicHeartBeat | ||
|
||
@functor Spinor | ||
@functor DiscreteSequence | ||
|
||
export gpu, cpu, f32, f64 |
Oops, something went wrong.