From c0c4692db3466ac9ca3c83e240fb6da2c0458c46 Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Wed, 16 Oct 2024 13:51:08 +0200 Subject: [PATCH] update --- src/fmri/operators/gradient.py | 45 +++++++++++++++++- src/fmri/operators/weighted.py | 2 - src/fmri/optimizer.py | 26 +++++------ src/fmri/reconstructors/frame_based.py | 64 ++++++++++---------------- 4 files changed, 81 insertions(+), 56 deletions(-) diff --git a/src/fmri/operators/gradient.py b/src/fmri/operators/gradient.py index 892bab7..bd6f0d5 100644 --- a/src/fmri/operators/gradient.py +++ b/src/fmri/operators/gradient.py @@ -3,10 +3,13 @@ Adapted from pysap-mri and Modopt libraries. """ +from functools import cached_property + import numpy as np +import cupy as cp from modopt.math.matrix import PowerMethod -from modopt.opt.gradient import GradBasic -from modopt.base.backend import get_backend +from modopt.opt.gradient import GradBasic, GradParent +from modopt.base.backend import get_backend, get_array_module def check_lipschitz_cst(f, x_shape, x_dtype, lipschitz_cst, max_nb_of_iter=10): @@ -224,3 +227,41 @@ def _op_method(self, data): def _trans_op_method(self, data): return self.linear_op.op(self.fourier_op.adj_op(data)) + + +class CustomGradAnalysis(GradParent): + """Custom Gradient Analysis Operator.""" + + def __init__(self, fourier_op, obs_data, obs_data_gpu=None, lazy=True): + self.fourier_op = fourier_op + self._grad_data_type = np.complex64 + self._obs_data = obs_data + if obs_data_gpu is None: + self.obs_data_gpu = cp.array(obs_data) + elif isinstance(obs_data_gpu, cp.ndarray): + self.obs_data_gpu = obs_data_gpu + else: + raise ValueError("Invalid data type for obs_data_gpu") + self.lazy = lazy + self.shape = fourier_op.shape + + def get_grad(self, x): + """Get the gradient value""" + if self.lazy: + self.obs_data_gpu.set(self.obs_data) + self.grad = self.fourier_op.data_consistency(x, self.obs_data_gpu) + return self.grad + + @cached_property + def spec_rad(self): + return self.fourier_op.get_lipschitz_cst() + + def inv_spec_rad(self): + return 1.0 / self.spec_rad + + def cost(self, x, *args, **kwargs): + xp = get_array_module(x) + cost = xp.linalg.norm(self.fourier_op.op(x) - self.obs_data) + if xp != np: + return cost.get() + return cost diff --git a/src/fmri/operators/weighted.py b/src/fmri/operators/weighted.py index c7896e7..2e2263f 100644 --- a/src/fmri/operators/weighted.py +++ b/src/fmri/operators/weighted.py @@ -430,8 +430,6 @@ def _auto_thresh(self, input_data): weights = self._thresh_scale(weights, self._n_op_calls) else: weights *= self._thresh_scale - xp = get_array_module(weights) - logger.info(xp.unique(weights)) return weights def _op_method(self, input_data, extra_factor=1.0): diff --git a/src/fmri/optimizer.py b/src/fmri/optimizer.py index 4644857..b1a5e7d 100644 --- a/src/fmri/optimizer.py +++ b/src/fmri/optimizer.py @@ -31,7 +31,7 @@ class AccProxSVRG(SetUp): def __init__( self, x, - grad_list, + fourier_op_list, prox, cost="auto", step_size=1.0, @@ -79,7 +79,7 @@ def _update(self): self._v_tld = self.xp.zeros_like(self._v_tld) # Compute the average gradient. for g in self._grad_ops: - self._v_tld += g.get_grad(self._x_old) + self._v_tld += g.get_grad(self._x_tld) self._v_tld /= len(self._grad_ops) self.xp.copyto(self._x_old, self._x_tld) @@ -89,16 +89,17 @@ def _update(self): self.xp.copyto(self._v, self._v_tld) self._v *= self.batch_size for g in gIk: - self._v += g.get_grad(self._x_tld) - self._v -= g.get_grad(self._y) + self._v -= g.get_grad(self._x_tld) + self._v += g.get_grad(self._y) self._v *= self.step_size / self.batch_size - self._x_new = self._y - self._v # Reuse the array + self.xp.copyto(self._x_new, self._y) + self._x_new -= self._v # Reuse the array self._x_new = self._prox.op(self._x_new, extra_factor=self.step_size) - self._v = self._x_new - self._x_old # Reuse the array - - self._y = self._x_new + self.beta * self._v + self.xp.copyto(self._v, self._x_new) + self._v -= self._x_old # Reuse the array + self.xp.copyto(self._y, self._x_new) + self._y += self.beta * self._v self.xp.copyto(self._x_old, self._x_new) - self.xp.copyto(self._x_tld, self._x_new) # Test cost function for convergence. @@ -184,14 +185,13 @@ def __init__( super().__init__(**kwargs) # Set the initial variable values - self._check_input_data(x) self.step_size = step_size self.update_frequency = update_frequency self.batch_size = batch_size self._grad_ops = grad_list - self._prox_op = prox + self._prox = prox self._rng = np.random.default_rng(seed) @@ -213,9 +213,9 @@ def _update(self): self._g += g.get_grad(self._x) self._g /= len(self._grad_ops) self.xp.copyto(self._y, self._x) - tk = self.rng.randint(1, self.update_frequency) + tk = self._rng.integers(1, self.update_frequency) for _ in range(tk): - Ak = self.rng.choices(self._grad_ops, k=self.batch_size) + Ak = self._rng.choice(self._grad_ops, size=self.batch_size, replace=False) self.xp.copyto(self._g_sto, self._g) self._g_sto *= self.batch_size for g in Ak: diff --git a/src/fmri/reconstructors/frame_based.py b/src/fmri/reconstructors/frame_based.py index 2da8bb5..72e5be0 100644 --- a/src/fmri/reconstructors/frame_based.py +++ b/src/fmri/reconstructors/frame_based.py @@ -5,6 +5,9 @@ """ +import cupy as cp +import logging + import gc from functools import cached_property @@ -13,14 +16,17 @@ import copy from tqdm.auto import tqdm, trange -from ..operators.gradient import GradAnalysis, GradSynthesis +from ..operators.gradient import GradAnalysis, GradSynthesis, CustomGradAnalysis from .base import BaseFMRIReconstructor from .utils import OPTIMIZERS, initialize_opt from modopt.opt.algorithms import POGM from modopt.opt.linear import Identity +from modopt.opt.gradient import GradParent from ..optimizer import AccProxSVRG, MS2GD +logger = logging.getLogger("pysap-fmri") + class SequentialReconstructor(BaseFMRIReconstructor): """Sequential Reconstruction of fMRI data. @@ -107,6 +113,8 @@ def reconstruct( final_estimate[i, ...] = x_iter # Progressbar update progbar.close() + + logger.info("final prox weight: %f ", xp.unique(self.space_prox_op.weights)) return final_estimate def _reconstruct_frame( @@ -219,34 +227,6 @@ def reconstruct( return final_estimate -class CustomGradAnalysis: - """Custom Gradient Analysis Operator.""" - - def __init__(self, fourier_op, obs_data): - self.fourier_op = fourier_op - self.obs_data = obs_data - self.shape = fourier_op.shape - - def get_grad(self, x): - """Get the gradient value""" - self.grad = self.fourier_op.data_consistency(x, self.obs_data) - return self.grad - - @cached_property - def spec_rad(self): - return self.fourier_op.get_lipschitz_cst() - - def inv_spec_rad(self): - return 1.0 / self.spec_rad - - def cost(self, x, *args, **kwargs): - xp = get_array_module(x) - cost = xp.linalg.norm(self.fourier_op.op(x) - self.obs_data) - if xp != np: - return cost.get() - return cost - - class StochasticSequentialReconstructor(BaseFMRIReconstructor): """Stochastic Sequential Reconstruction of fMRI data.""" @@ -255,12 +235,18 @@ def __init__( fourier_op, space_linear_op, space_prox_op, + space_prox_op_refine=None, progbar_disable=False, compute_backend="numpy", **kwargs, ): super().__init__(fourier_op, space_linear_op, space_prox_op, **kwargs) + if space_prox_op_refine is None: + self.space_prox_op_refine = space_prox_op + else: + self.space_prox_op_refine = space_prox_op_refine + self.progbar_disable = progbar_disable self.compute_backend = compute_backend @@ -269,6 +255,7 @@ def reconstruct( kspace_data, x_init=None, max_iter_per_frame=15, + max_iter_stochastic=20, grad_kwargs=None, algorithm="accproxsvrg", progbar_disable=False, @@ -283,6 +270,7 @@ def reconstruct( xp, _ = get_backend(self.compute_backend) # Create the gradients operators grad_list = [] + tmp_ksp = cp.zeros_like(kspace_data[0]) for i, fop in enumerate(self.fourier_op.fourier_ops): # L = fop.get_lipschitz_cst() @@ -296,7 +284,7 @@ def reconstruct( # input_data_writeable=True, # ) # g._obs_data = kspace_data[i, ...] - g = CustomGradAnalysis(fop, kspace_data[i, ...]) + g = CustomGradAnalysis(fop, kspace_data[i, ...], obs_data_gpu=tmp_ksp) grad_list.append(g) max_lip = max(g.spec_rad for g in grad_list) @@ -307,10 +295,9 @@ def reconstruct( x=xp.zeros(grad_list[0].shape, dtype="complex64"), grad_list=grad_list, prox=self.space_prox_op, - step_size=1.0 / max_lip, + step_size=1.0 / 2 * max_lip, auto_iterate=False, cost=None, - update_frequency=10, compute_backend=self.compute_backend, **algorithm_kwargs, ) @@ -323,12 +310,11 @@ def reconstruct( prox=self.space_prox_op, step_size=1.0 / max_lip, auto_iterate=False, - update_frequency=10, cost=None, **algorithm_kwargs, ) - opt.iterate(max_iter=20) + opt.iterate(max_iter=max_iter_stochastic) x_anat = opt.x_final.squeeze() @@ -348,7 +334,7 @@ def reconstruct( x_anat, x_anat, grad=grad_list[i], - prox=self.space_prox_op, + prox=self.space_prox_op_refine, linear=Identity(), beta=grad_list[i].inv_spec_rad, compute_backend=self.compute_backend, @@ -360,9 +346,9 @@ def reconstruct( progbar.reset(total=max_iter_per_frame) img = opt.x_final - if self.compute_backend == "cupy": - final_img[i] = img.get() - else: - final_img[i] = img + if self.compute_backend == "cupy": + final_img[i] = img.get().squeeze() + else: + final_img[i] = img return final_img, x_anat