diff --git a/src/fmri/operators/fourier.py b/src/fmri/operators/fourier.py index 65b88a4..368be60 100644 --- a/src/fmri/operators/fourier.py +++ b/src/fmri/operators/fourier.py @@ -10,6 +10,7 @@ import numpy as np from mrinufft import get_operator +from modopt.base.backend import get_array_module try: from mrinufft.operators.interfaces.gpunufft import make_pinned_smaps @@ -171,7 +172,8 @@ def op(self, images): def adj_op(self, coeffs): """Apply Adjoint Operator.""" c = 1 if self.uses_sense else self.n_coils - final_image = np.empty((self.n_frames, c, *self.shape), dtype=np.complex64) + xp = get_array_module(coeffs) + final_image = xp.empty((self.n_frames, c, *self.shape), dtype=np.complex64) for i in range(len(coeffs)): final_image[i] = self.fourier_ops[i].adj_op(coeffs[i]) return final_image.squeeze() diff --git a/src/fmri/operators/gradient.py b/src/fmri/operators/gradient.py index 392c969..892bab7 100644 --- a/src/fmri/operators/gradient.py +++ b/src/fmri/operators/gradient.py @@ -210,6 +210,7 @@ def __init__(self, linear_op, fourier_op, verbose=0, **kwargs): n_channels = fourier_op.n_coils if not fourier_op.uses_sense else 1 coef = linear_op.op(np.squeeze(np.zeros((n_channels, *fourier_op.shape)))) self.linear_op_coeffs_shape = coef.shape + self.shape = coef.shape super().__init__( self._op_method, self._trans_op_method, diff --git a/src/fmri/operators/weighted.py b/src/fmri/operators/weighted.py index 5f574eb..c7896e7 100644 --- a/src/fmri/operators/weighted.py +++ b/src/fmri/operators/weighted.py @@ -382,6 +382,7 @@ def __init__( thresh_range="global", threshold_estimation="sure", threshold_scaler=1.0, + synthesis=True, **kwargs ): if linear is None: @@ -389,6 +390,7 @@ def __init__( self._n_op_calls = 0 self.cf_shape = coeffs_shape self._update_period = update_period + self.synthesis = synthesis if thresh_range not in ["bands", "scale", "global"]: raise ValueError("Unsupported threshold range.") @@ -451,10 +453,16 @@ def _op_method(self, input_data, extra_factor=1.0): Thresholded data """ + if not self.synthesis: + input_data = self._linear.op(input_data) if self._update_period == 0 and self._n_op_calls == 0: self.weights = self._auto_thresh(input_data) if self._update_period != 0 and self._n_op_calls % self._update_period == 0: self.weights = self._auto_thresh(input_data) self._n_op_calls += 1 - return super()._op_method(input_data, extra_factor=extra_factor) + threshed = super()._op_method(input_data, extra_factor=extra_factor) + + if not self.synthesis: + return self._linear.adj_op(threshed) + return threshed diff --git a/src/fmri/reconstructors/frame_based.py b/src/fmri/reconstructors/frame_based.py index 757b6b8..2da8bb5 100644 --- a/src/fmri/reconstructors/frame_based.py +++ b/src/fmri/reconstructors/frame_based.py @@ -5,7 +5,10 @@ """ -from modopt.base.backend import get_backend +import gc +from functools import cached_property + +from modopt.base.backend import get_backend, get_array_module import numpy as np import copy from tqdm.auto import tqdm, trange @@ -14,6 +17,10 @@ from .base import BaseFMRIReconstructor from .utils import OPTIMIZERS, initialize_opt +from modopt.opt.algorithms import POGM +from modopt.opt.linear import Identity +from ..optimizer import AccProxSVRG, MS2GD + class SequentialReconstructor(BaseFMRIReconstructor): """Sequential Reconstruction of fMRI data. @@ -210,3 +217,152 @@ def reconstruct( progbar.close() return final_estimate + + +class CustomGradAnalysis: + """Custom Gradient Analysis Operator.""" + + def __init__(self, fourier_op, obs_data): + self.fourier_op = fourier_op + self.obs_data = obs_data + self.shape = fourier_op.shape + + def get_grad(self, x): + """Get the gradient value""" + self.grad = self.fourier_op.data_consistency(x, self.obs_data) + return self.grad + + @cached_property + def spec_rad(self): + return self.fourier_op.get_lipschitz_cst() + + def inv_spec_rad(self): + return 1.0 / self.spec_rad + + def cost(self, x, *args, **kwargs): + xp = get_array_module(x) + cost = xp.linalg.norm(self.fourier_op.op(x) - self.obs_data) + if xp != np: + return cost.get() + return cost + + +class StochasticSequentialReconstructor(BaseFMRIReconstructor): + """Stochastic Sequential Reconstruction of fMRI data.""" + + def __init__( + self, + fourier_op, + space_linear_op, + space_prox_op, + progbar_disable=False, + compute_backend="numpy", + **kwargs, + ): + super().__init__(fourier_op, space_linear_op, space_prox_op, **kwargs) + + self.progbar_disable = progbar_disable + self.compute_backend = compute_backend + + def reconstruct( + self, + kspace_data, + x_init=None, + max_iter_per_frame=15, + grad_kwargs=None, + algorithm="accproxsvrg", + progbar_disable=False, + algorithm_kwargs=None, + ): + """Reconstruct using sequential method.""" + self.progbar_disable = progbar_disable + + if algorithm_kwargs is None: + algorithm_kwargs = {} + + xp, _ = get_backend(self.compute_backend) + # Create the gradients operators + grad_list = [] + for i, fop in enumerate(self.fourier_op.fourier_ops): + # L = fop.get_lipschitz_cst() + + # g = GradSynthesis( + # linear_op=self.space_linear_op, + # fourier_op=fop, + # verbose=self.verbose, + # dtype=kspace_data.dtype, + # lipschitz_cst=L, + # num_check_lips=0, # trust me + # input_data_writeable=True, + # ) + # g._obs_data = kspace_data[i, ...] + g = CustomGradAnalysis(fop, kspace_data[i, ...]) + grad_list.append(g) + + max_lip = max(g.spec_rad for g in grad_list) + + if algorithm == "accproxsvrg": + + opt = AccProxSVRG( + x=xp.zeros(grad_list[0].shape, dtype="complex64"), + grad_list=grad_list, + prox=self.space_prox_op, + step_size=1.0 / max_lip, + auto_iterate=False, + cost=None, + update_frequency=10, + compute_backend=self.compute_backend, + **algorithm_kwargs, + ) + + elif algorithm == "m2sg": + + opt = MS2GD( + x=xp.zeros(self.fourier_op.shape, dtype="complex64"), + grad_list=grad_list, + prox=self.space_prox_op, + step_size=1.0 / max_lip, + auto_iterate=False, + update_frequency=10, + cost=None, + **algorithm_kwargs, + ) + + opt.iterate(max_iter=20) + + x_anat = opt.x_final.squeeze() + + progbar_main = trange(len(kspace_data), disable=self.progbar_disable) + progbar = tqdm(total=max_iter_per_frame, disable=self.progbar_disable) + final_img = np.zeros( + (len(kspace_data), *self.fourier_op.shape), + dtype=self.fourier_op.cpx_dtype, + ) + del opt + gc.collect() + for i in progbar_main: # Parallel + + opt = POGM( + x_anat, + x_anat, + x_anat, + x_anat, + grad=grad_list[i], + prox=self.space_prox_op, + linear=Identity(), + beta=grad_list[i].inv_spec_rad, + compute_backend=self.compute_backend, + auto_iterate=False, + cost=None, + ) + opt.iterate(progbar=progbar, max_iter=max_iter_per_frame) + + progbar.reset(total=max_iter_per_frame) + img = opt.x_final + + if self.compute_backend == "cupy": + final_img[i] = img.get() + else: + final_img[i] = img + + return final_img, x_anat