Skip to content

Commit

Permalink
Simplify adapt functions
Browse files Browse the repository at this point in the history
  • Loading branch information
pvillacorta committed Jun 4, 2024
1 parent 60e5743 commit 9b2a9c2
Showing 1 changed file with 4 additions and 17 deletions.
21 changes: 4 additions & 17 deletions KomaMRICore/src/simulation/GPUFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,7 @@ _isleaf(x) = _isbitsarray(x) || isleaf(x)
# GPU adaptor
struct KomaCUDAAdaptor end
adapt_storage(to::KomaCUDAAdaptor, x) = CUDA.cu(x)
adapt_storage(to::KomaCUDAAdaptor, x::NoMotion) = NoMotion{Float32}()
adapt_storage(to::KomaCUDAAdaptor, x::SimpleMotion) = f32(x)
function adapt_storage(to::KomaCUDAAdaptor, x::ArbitraryMotion)
fields = []
for field in fieldnames(ArbitraryMotion)
push!(fields, f32(getfield(x, field)))
end
return ArbitraryMotion(fields...)
end
adapt_storage(to::KomaCUDAAdaptor, x::MotionModel) = f32(x) # Motion models are not passed to GPU

Check warning on line 48 in KomaMRICore/src/simulation/GPUFunctions.jl

View check run for this annotation

Codecov / codecov/patch

KomaMRICore/src/simulation/GPUFunctions.jl#L48

Added line #L48 was not covered by tests


"""
Expand Down Expand Up @@ -105,15 +97,10 @@ adapt_storage(T::Type{<:Real}, xs::Real) = convert(T, xs)
adapt_storage(T::Type{<:Real}, xs::AbstractArray{<:Real}) = convert.(T, xs)
adapt_storage(T::Type{<:Real}, xs::AbstractArray{<:Complex}) = convert.(Complex{T}, xs)
adapt_storage(T::Type{<:Real}, xs::AbstractArray{<:Bool}) = xs
adapt_storage(T::Type{<:Real}, xs::SimpleMotion) = SimpleMotion(paramtype(T, xs.types))

adapt_storage(T::Type{<:Real}, xs::NoMotion) = NoMotion{T}()
function adapt_storage(T::Type{<:Real}, xs::ArbitraryMotion)
fields = []
for field in fieldnames(ArbitraryMotion)
push!(fields, paramtype(T, getfield(xs, field)))
end
return ArbitraryMotion(fields...)
end
adapt_storage(T::Type{<:Real}, xs::SimpleMotion) = SimpleMotion(paramtype(T, xs.types))
adapt_storage(T::Type{<:Real}, xs::ArbitraryMotion) = ArbitraryMotion( (paramtype.(Ref(T), getfield.(Ref(xs), fieldnames(ArbitraryMotion))))... )

"""
f32(m)
Expand Down

0 comments on commit 9b2a9c2

Please sign in to comment.