From 9b2a9c220c772a2e22662427b7871a28612f7e07 Mon Sep 17 00:00:00 2001 From: Pablo Villacorta Aylagas Date: Tue, 4 Jun 2024 11:39:31 +0200 Subject: [PATCH] Simplify adapt functions --- KomaMRICore/src/simulation/GPUFunctions.jl | 21 ++++----------------- 1 file changed, 4 insertions(+), 17 deletions(-) diff --git a/KomaMRICore/src/simulation/GPUFunctions.jl b/KomaMRICore/src/simulation/GPUFunctions.jl index 1a38c16eb..0a05eeb23 100644 --- a/KomaMRICore/src/simulation/GPUFunctions.jl +++ b/KomaMRICore/src/simulation/GPUFunctions.jl @@ -45,15 +45,7 @@ _isleaf(x) = _isbitsarray(x) || isleaf(x) # GPU adaptor struct KomaCUDAAdaptor end adapt_storage(to::KomaCUDAAdaptor, x) = CUDA.cu(x) -adapt_storage(to::KomaCUDAAdaptor, x::NoMotion) = NoMotion{Float32}() -adapt_storage(to::KomaCUDAAdaptor, x::SimpleMotion) = f32(x) -function adapt_storage(to::KomaCUDAAdaptor, x::ArbitraryMotion) - fields = [] - for field in fieldnames(ArbitraryMotion) - push!(fields, f32(getfield(x, field))) - end - return ArbitraryMotion(fields...) -end +adapt_storage(to::KomaCUDAAdaptor, x::MotionModel) = f32(x) # Motion models are not passed to GPU """ @@ -105,15 +97,10 @@ adapt_storage(T::Type{<:Real}, xs::Real) = convert(T, xs) adapt_storage(T::Type{<:Real}, xs::AbstractArray{<:Real}) = convert.(T, xs) adapt_storage(T::Type{<:Real}, xs::AbstractArray{<:Complex}) = convert.(Complex{T}, xs) adapt_storage(T::Type{<:Real}, xs::AbstractArray{<:Bool}) = xs -adapt_storage(T::Type{<:Real}, xs::SimpleMotion) = SimpleMotion(paramtype(T, xs.types)) + adapt_storage(T::Type{<:Real}, xs::NoMotion) = NoMotion{T}() -function adapt_storage(T::Type{<:Real}, xs::ArbitraryMotion) - fields = [] - for field in fieldnames(ArbitraryMotion) - push!(fields, paramtype(T, getfield(xs, field))) - end - return ArbitraryMotion(fields...) -end +adapt_storage(T::Type{<:Real}, xs::SimpleMotion) = SimpleMotion(paramtype(T, xs.types)) +adapt_storage(T::Type{<:Real}, xs::ArbitraryMotion) = ArbitraryMotion( (paramtype.(Ref(T), getfield.(Ref(xs), fieldnames(ArbitraryMotion))))... ) """ f32(m)