Skip to content

Commit

Permalink
Add kernel-based matrix cumsum
Browse files Browse the repository at this point in the history
  • Loading branch information
rkierulf committed Jun 21, 2024
1 parent 3175e71 commit 987aef1
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ function run_spin_precession!(
sig::AbstractArray{Complex{T}},
M::Mag{T},
sim_method::BlochDict,
backend::KA.Backend
) where {T<:Real}
#Simulation
#Motion
Expand All @@ -43,7 +44,7 @@ function run_spin_precession!(
Bz = x .* seq.Gx' .+ y .* seq.Gy' .+ z .* seq.Gz' .+ p.Δw / T(2π * γ)
#Rotation
if is_ADC_on(seq)
ϕ = T(-2π * γ) .* cumtrapz(seq.Δt', Bz)
ϕ = T(-2π * γ) .* KomaMRIBase.cumtrapz(seq.Δt', Bz, backend)
else
ϕ = T(-2π * γ) .* trapz(seq.Δt', Bz)
end
Expand Down
3 changes: 2 additions & 1 deletion KomaMRICore/src/simulation/Bloch/BlochSimulationMethod.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ function run_spin_precession!(
sig::AbstractArray{Complex{T}},
M::Mag{T},
sim_method::SimulationMethod,
backend::KA.Backend
) where {T<:Real}
#Simulation
#Motion
Expand All @@ -52,7 +53,7 @@ function run_spin_precession!(
Bz = x .* seq.Gx' .+ y .* seq.Gy' .+ z .* seq.Gz' .+ p.Δw / T(2π * γ)
#Rotation
if is_ADC_on(seq)
ϕ = T(-2π * γ) .* cumtrapz(seq.Δt', Bz)
ϕ = T(-2π * γ) .* KomaMRIBase.cumtrapz(seq.Δt', Bz, backend)
else
ϕ = T(-2π * γ) .* trapz(seq.Δt', Bz)
end
Expand Down
44 changes: 44 additions & 0 deletions KomaMRICore/src/simulation/Bloch/KernelFunctions.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
using KernelAbstractions: @index, @kernel

"""
cumsum2_kernel
Simple kernel function, computes the cumulative sum of each row of a matrix. Operates
in-place on the input matrix without allocating additional memory.
# Arguments
- 'A': matrix to compute cumsum on
"""
@kernel function cumsum_matrix_rows_kernel!(A)
i = @index(Global)

Check warning on line 13 in KomaMRICore/src/simulation/Bloch/KernelFunctions.jl

View check run for this annotation

Codecov / codecov/patch

KomaMRICore/src/simulation/Bloch/KernelFunctions.jl#L12-L13

Added lines #L12 - L13 were not covered by tests

for k 2:size(A, 2)
@inbounds A[i, k] += A[i, k-1]
end

Check warning on line 17 in KomaMRICore/src/simulation/Bloch/KernelFunctions.jl

View check run for this annotation

Codecov / codecov/patch

KomaMRICore/src/simulation/Bloch/KernelFunctions.jl#L15-L17

Added lines #L15 - L17 were not covered by tests
end

"""
cumtrapz
A more efficient GPU implementation of the cumtrapz method defined in TrapezoidalIntegration.jl.
Uses a kernel to compute cumsum along the second dimension.
# Arguments
- `Δt`: (`1 x NΔt ::Matrix{Float64}`, `[s]`) delta time 1-row array
- `x`: (`Ns x (NΔt+1) ::Matrix{Float64}`, `[T]`) magnitude of the field Gx * x + Gy * y +
Gz * z
# Returns
- `y`: (`Ns x NΔt ::Matrix{Float64}`, `[T*s]`) matrix where every column is the
cumulative integration over time of (Gx * x + Gy * y + Gz * z) * Δt for every spin of a
phantom
"""
function KomaMRIBase.cumtrapz(Δt::AbstractArray{T}, x::AbstractArray{T}, backend::KA.GPU) where {T<:Real}
y = (x[:, 2:end] .+ x[:, 1:end-1]) .* (Δt / 2)
cumsum_matrix_rows_kernel!(backend)(y, ndrange=size(y,1))
KA.synchronize(backend)
return y
end

#If on CPU, forward call to cumtrapz in KomaMRIBase
KomaMRIBase.cumtrapz(Δt::AbstractArray{T}, x::AbstractArray{T}, backend::KA.CPU) where {T<:Real} = KomaMRIBase.cumtrapz(Δt, x)
14 changes: 9 additions & 5 deletions KomaMRICore/src/simulation/SimulatorCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ abstract type SimulationMethod end #get all available types by using subtypes(Ko
abstract type SpinStateRepresentation{T<:Real} end #get all available types by using subtypes(KomaMRI.SpinStateRepresentation)

#Defined methods:
include("Bloch/KernelFunctions.jl")
include("Bloch/BlochSimulationMethod.jl") #Defines Bloch simulation method
include("Bloch/BlochDictSimulationMethod.jl") #Defines BlochDict simulation method

Expand Down Expand Up @@ -82,15 +83,16 @@ function run_spin_precession_parallel!(
seq::DiscreteSequence{T},
sig::AbstractArray{Complex{T}},
Xt::SpinStateRepresentation{T},
sim_method::SimulationMethod;
sim_method::SimulationMethod,
backend::KA.Backend;
Nthreads=Threads.nthreads(),
) where {T<:Real}
parts = kfoldperm(length(obj), Nthreads)
dims = [Colon() for i in 1:(ndims(sig) - 1)] # :,:,:,... Ndim times

ThreadsX.foreach(enumerate(parts)) do (i, p)
run_spin_precession!(
@view(obj[p]), seq, @view(sig[dims..., i]), @view(Xt[p]), sim_method
@view(obj[p]), seq, @view(sig[dims..., i]), @view(Xt[p]), sim_method, backend
)
end

Expand Down Expand Up @@ -166,7 +168,8 @@ function run_sim_time_iter!(
seq::DiscreteSequence,
sig::AbstractArray{Complex{T}},
Xt::SpinStateRepresentation{T},
sim_method::SimulationMethod;
sim_method::SimulationMethod,
backend::KA.Backend;
Nblocks=1,
Nthreads=Threads.nthreads(),
parts=[1:length(seq)],
Expand All @@ -193,7 +196,7 @@ function run_sim_time_iter!(
rfs += 1
else
run_spin_precession_parallel!(
obj, seq_block, @view(sig[acq_samples, dims...]), Xt, sim_method; Nthreads
obj, seq_block, @view(sig[acq_samples, dims...]), Xt, sim_method, backend; Nthreads
)
end
samples += Nadc
Expand Down Expand Up @@ -373,7 +376,8 @@ function simulate(
seqd,
sig,
Xt,
sim_params["sim_method"];
sim_params["sim_method"],
backend;
Nblocks=length(parts),
Nthreads=sim_params["Nthreads"],
parts,
Expand Down

0 comments on commit 987aef1

Please sign in to comment.