From 987aef1504352c8788f6f1b6cd028b68d0e41ba2 Mon Sep 17 00:00:00 2001 From: Ryan Kierulf Date: Thu, 20 Jun 2024 21:20:53 -0500 Subject: [PATCH] Add kernel-based matrix cumsum --- .../Bloch/BlochDictSimulationMethod.jl | 3 +- .../simulation/Bloch/BlochSimulationMethod.jl | 3 +- .../src/simulation/Bloch/KernelFunctions.jl | 44 +++++++++++++++++++ KomaMRICore/src/simulation/SimulatorCore.jl | 14 +++--- 4 files changed, 57 insertions(+), 7 deletions(-) create mode 100644 KomaMRICore/src/simulation/Bloch/KernelFunctions.jl diff --git a/KomaMRICore/src/simulation/Bloch/BlochDictSimulationMethod.jl b/KomaMRICore/src/simulation/Bloch/BlochDictSimulationMethod.jl index c9c89549c..3cf7abb6c 100644 --- a/KomaMRICore/src/simulation/Bloch/BlochDictSimulationMethod.jl +++ b/KomaMRICore/src/simulation/Bloch/BlochDictSimulationMethod.jl @@ -35,6 +35,7 @@ function run_spin_precession!( sig::AbstractArray{Complex{T}}, M::Mag{T}, sim_method::BlochDict, + backend::KA.Backend ) where {T<:Real} #Simulation #Motion @@ -43,7 +44,7 @@ function run_spin_precession!( Bz = x .* seq.Gx' .+ y .* seq.Gy' .+ z .* seq.Gz' .+ p.Δw / T(2π * γ) #Rotation if is_ADC_on(seq) - ϕ = T(-2π * γ) .* cumtrapz(seq.Δt', Bz) + ϕ = T(-2π * γ) .* KomaMRIBase.cumtrapz(seq.Δt', Bz, backend) else ϕ = T(-2π * γ) .* trapz(seq.Δt', Bz) end diff --git a/KomaMRICore/src/simulation/Bloch/BlochSimulationMethod.jl b/KomaMRICore/src/simulation/Bloch/BlochSimulationMethod.jl index 284da0c13..f60a8b278 100644 --- a/KomaMRICore/src/simulation/Bloch/BlochSimulationMethod.jl +++ b/KomaMRICore/src/simulation/Bloch/BlochSimulationMethod.jl @@ -44,6 +44,7 @@ function run_spin_precession!( sig::AbstractArray{Complex{T}}, M::Mag{T}, sim_method::SimulationMethod, + backend::KA.Backend ) where {T<:Real} #Simulation #Motion @@ -52,7 +53,7 @@ function run_spin_precession!( Bz = x .* seq.Gx' .+ y .* seq.Gy' .+ z .* seq.Gz' .+ p.Δw / T(2π * γ) #Rotation if is_ADC_on(seq) - ϕ = T(-2π * γ) .* cumtrapz(seq.Δt', Bz) + ϕ = T(-2π * γ) .* KomaMRIBase.cumtrapz(seq.Δt', Bz, backend) else ϕ = T(-2π * γ) .* trapz(seq.Δt', Bz) end diff --git a/KomaMRICore/src/simulation/Bloch/KernelFunctions.jl b/KomaMRICore/src/simulation/Bloch/KernelFunctions.jl new file mode 100644 index 000000000..78d4cafa2 --- /dev/null +++ b/KomaMRICore/src/simulation/Bloch/KernelFunctions.jl @@ -0,0 +1,44 @@ +using KernelAbstractions: @index, @kernel + +""" + cumsum2_kernel + +Simple kernel function, computes the cumulative sum of each row of a matrix. Operates +in-place on the input matrix without allocating additional memory. + +# Arguments +- 'A': matrix to compute cumsum on +""" +@kernel function cumsum_matrix_rows_kernel!(A) + i = @index(Global) + + for k ∈ 2:size(A, 2) + @inbounds A[i, k] += A[i, k-1] + end +end + +""" + cumtrapz + +A more efficient GPU implementation of the cumtrapz method defined in TrapezoidalIntegration.jl. +Uses a kernel to compute cumsum along the second dimension. + +# Arguments +- `Δt`: (`1 x NΔt ::Matrix{Float64}`, `[s]`) delta time 1-row array +- `x`: (`Ns x (NΔt+1) ::Matrix{Float64}`, `[T]`) magnitude of the field Gx * x + Gy * y + + Gz * z + +# Returns +- `y`: (`Ns x NΔt ::Matrix{Float64}`, `[T*s]`) matrix where every column is the + cumulative integration over time of (Gx * x + Gy * y + Gz * z) * Δt for every spin of a + phantom +""" +function KomaMRIBase.cumtrapz(Δt::AbstractArray{T}, x::AbstractArray{T}, backend::KA.GPU) where {T<:Real} + y = (x[:, 2:end] .+ x[:, 1:end-1]) .* (Δt / 2) + cumsum_matrix_rows_kernel!(backend)(y, ndrange=size(y,1)) + KA.synchronize(backend) + return y +end + +#If on CPU, forward call to cumtrapz in KomaMRIBase +KomaMRIBase.cumtrapz(Δt::AbstractArray{T}, x::AbstractArray{T}, backend::KA.CPU) where {T<:Real} = KomaMRIBase.cumtrapz(Δt, x) \ No newline at end of file diff --git a/KomaMRICore/src/simulation/SimulatorCore.jl b/KomaMRICore/src/simulation/SimulatorCore.jl index d52797a33..689ad3cdd 100644 --- a/KomaMRICore/src/simulation/SimulatorCore.jl +++ b/KomaMRICore/src/simulation/SimulatorCore.jl @@ -2,6 +2,7 @@ abstract type SimulationMethod end #get all available types by using subtypes(Ko abstract type SpinStateRepresentation{T<:Real} end #get all available types by using subtypes(KomaMRI.SpinStateRepresentation) #Defined methods: +include("Bloch/KernelFunctions.jl") include("Bloch/BlochSimulationMethod.jl") #Defines Bloch simulation method include("Bloch/BlochDictSimulationMethod.jl") #Defines BlochDict simulation method @@ -82,7 +83,8 @@ function run_spin_precession_parallel!( seq::DiscreteSequence{T}, sig::AbstractArray{Complex{T}}, Xt::SpinStateRepresentation{T}, - sim_method::SimulationMethod; + sim_method::SimulationMethod, + backend::KA.Backend; Nthreads=Threads.nthreads(), ) where {T<:Real} parts = kfoldperm(length(obj), Nthreads) @@ -90,7 +92,7 @@ function run_spin_precession_parallel!( ThreadsX.foreach(enumerate(parts)) do (i, p) run_spin_precession!( - @view(obj[p]), seq, @view(sig[dims..., i]), @view(Xt[p]), sim_method + @view(obj[p]), seq, @view(sig[dims..., i]), @view(Xt[p]), sim_method, backend ) end @@ -166,7 +168,8 @@ function run_sim_time_iter!( seq::DiscreteSequence, sig::AbstractArray{Complex{T}}, Xt::SpinStateRepresentation{T}, - sim_method::SimulationMethod; + sim_method::SimulationMethod, + backend::KA.Backend; Nblocks=1, Nthreads=Threads.nthreads(), parts=[1:length(seq)], @@ -193,7 +196,7 @@ function run_sim_time_iter!( rfs += 1 else run_spin_precession_parallel!( - obj, seq_block, @view(sig[acq_samples, dims...]), Xt, sim_method; Nthreads + obj, seq_block, @view(sig[acq_samples, dims...]), Xt, sim_method, backend; Nthreads ) end samples += Nadc @@ -373,7 +376,8 @@ function simulate( seqd, sig, Xt, - sim_params["sim_method"]; + sim_params["sim_method"], + backend; Nblocks=length(parts), Nthreads=sim_params["Nthreads"], parts,