From 59e43801bf2fde06b3816f6c671c486afde7a7e8 Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Thu, 26 Jan 2023 11:38:01 +0100 Subject: [PATCH 01/20] add auto threshold operator. --- mri/operators/proximity/weighted.py | 122 ++++++++++++++++++++++++++++ 1 file changed, 122 insertions(+) diff --git a/mri/operators/proximity/weighted.py b/mri/operators/proximity/weighted.py index a84cc5d2..9b932e17 100644 --- a/mri/operators/proximity/weighted.py +++ b/mri/operators/proximity/weighted.py @@ -1,4 +1,5 @@ import numpy as np +import scipy as sp from modopt.opt.proximity import SparseThreshold from modopt.opt.linear import Identity @@ -78,3 +79,124 @@ def mu(self, w): if self.zero_weight_coarse: weights_init[:np.prod(self.cf_shape[0])] = 0 self.weights = weights_init + + +class AutoWeightedSparseThreshold(WeightedSparseThreshold): + """This WeightedSparseThreshold uses the universal threshold rules to """ + def __init__(self, coeffs_shape, linear=Identity(), update_period=0, + sigma_estimation="global", threshold_estimation="sure", **kwargs): + self._n_op_calls = 0 + self._sigma_estimation = sigma_estimation + self._update_period = update_period + + self._thresh_estimation = threshold_estimation + + weights_init = np.zeros(np.sum(np.prod(self.cf_shape, axis=-1))) + super().__init__(weights=weights_init, + weight_type="scale_based", + **kwargs) + + def _auto_thresh_scale(self, input_data, sigma=None): + """Determine a threshold value adapted to denoise the data of a specific scale. + + Parameters + ---------- + input_data: numpy.ndarray + data that should be thresholded. + sigma: float + Estimation of the noise standard deviation. + Returns + ------- + float + The estimated threshold. + + Raises + ------ + ValueError is method is not supported. + Notes + ----- + The choice of the threshold makes the assumptions of a white additive gaussian noise. + """ + + #tmp = np.sort(input_data.flatten()) + tmp = input_data.flatten() + # use the robust estimator to estimate the noise variance. + med = np.median(tmp) + if sigma is None: + sigma = np.median(np.abs(tmp-med)) / 0.6745 + N = len(input_data) + j = np.log2(N) + + uni_threshold = np.sqrt(2*np.log(N)) + + if self._thresh_estimation == "universal": + return sigma * uni_threshold, sigma + elif self._thresh_estimation == "sure": + tmp = tmp **2 + eps2 = (sigma ** 2) /N + # TODO: optimize the estimation + def _sure(t): + e2t2 = eps2 * (t**2) + return (N * eps2 + + np.sum(np.minimum(tmp, e2t2)) + - 2*eps2*np.sum(tmp <= e2t2) + ) + + thresh = sp.optimize.minimize_scalar( + _sure, + method="bounded", + bounds=[0, uni_threshold]).x + sj2 = np.sum(tmp/eps2 - 1)/N + if sj2 >= 3*j/np.sqrt(2*N): + return thresh, sigma + else: + return uni_threshold, sigma + + else: + raise ValueError("Unknown method name") + + def _auto_thresh(self, input_data): + """Determines the threshold for every scale (except the coarse one) using the provided method. + + Parameters + ---------- + input_data: list of numpy.ndarray + + Returns + ------- + thresh_list: list + list of threshold for every scale + """ + sigma = None + thresh_list = [] + for band_idx in range(len(input_data)-1, 1, -1): + thresh_value, sigma_est = self._auto_thresh_scale(input_data[band_idx], sigma=sigma) + if self._sigma_estimation == "global" and band_idx == len(input_data)-1: + sigma = sigma_est + thresh_list.append(thresh_value) + return thresh_list + + + def _op_method(self, input_data, extra_factor=1.0): + """Operator. + + This method returns the input data thresholded by the weights. + The weights are computed using the universal threshold rule. + + Parameters + ---------- + input_data : numpy.ndarray + Input data array + extra_factor : float + Additional multiplication factor (default is ``1.0``) + + Returns + ------- + numpy.ndarray + Thresholded data + + """ + if (self._update_period == 0 and self._n_op_calls == 0) or (self._n_op_calls % self._update_period == 0) : + self.mu , sigma = self._auto_thresh(input_data) + self._n_op_calls += 1 + return super()._op_method(input_data, extra_factor=extra_factor) From 0e91cec2eb6431fe41dada2723889bcbeab59e74 Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Thu, 26 Jan 2023 22:09:28 +0100 Subject: [PATCH 02/20] fix deprecated dtype (error with numpy 1.24) --- mri/optimizers/forward_backward.py | 2 +- mri/optimizers/primal_dual.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mri/optimizers/forward_backward.py b/mri/optimizers/forward_backward.py index 67f14645..0ac23488 100644 --- a/mri/optimizers/forward_backward.py +++ b/mri/optimizers/forward_backward.py @@ -69,7 +69,7 @@ def fista(gradient_op, linear_op, prox_op, cost_op, kspace_generator=None, estim if x_init is None: x_init = np.squeeze(np.zeros((gradient_op.linear_op.n_coils, *gradient_op.fourier_op.shape), - dtype=np.complex)) + dtype=np.complex64)) alpha_init = linear_op.op(x_init) # Welcome message diff --git a/mri/optimizers/primal_dual.py b/mri/optimizers/primal_dual.py index 4b9d27fc..9bd38f91 100644 --- a/mri/optimizers/primal_dual.py +++ b/mri/optimizers/primal_dual.py @@ -101,7 +101,7 @@ def condatvu(gradient_op, linear_op, dual_regularizer, cost_op, kspace_generator if x_init is None: x_init = np.squeeze(np.zeros((linear_op.n_coils, *gradient_op.fourier_op.shape), - dtype=np.complex)) + dtype=np.complex64)) primal = x_init dual = linear_op.op(primal) weights = dual From e50a818e84f0838eeca1a5513616dfddb728a121 Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Thu, 26 Jan 2023 22:33:54 +0100 Subject: [PATCH 03/20] fix bugs. --- mri/operators/proximity/weighted.py | 30 ++++++++++++++++++++++------- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/mri/operators/proximity/weighted.py b/mri/operators/proximity/weighted.py index 9b932e17..e6cc5164 100644 --- a/mri/operators/proximity/weighted.py +++ b/mri/operators/proximity/weighted.py @@ -91,9 +91,10 @@ def __init__(self, coeffs_shape, linear=Identity(), update_period=0, self._thresh_estimation = threshold_estimation - weights_init = np.zeros(np.sum(np.prod(self.cf_shape, axis=-1))) + weights_init = np.zeros(np.sum(np.prod(coeffs_shape, axis=-1))) super().__init__(weights=weights_init, - weight_type="scale_based", + coeffs_shape=coeffs_shape, + weight_type="custom", **kwargs) def _auto_thresh_scale(self, input_data, sigma=None): @@ -168,13 +169,28 @@ def _auto_thresh(self, input_data): list of threshold for every scale """ sigma = None - thresh_list = [] - for band_idx in range(len(input_data)-1, 1, -1): - thresh_value, sigma_est = self._auto_thresh_scale(input_data[band_idx], sigma=sigma) + thresh_list = np.zeros(len(self.cf_shape)) + + # reverse order to get the finest scale first. + end=len(input_data) + for band_idx in range(len(self.cf_shape)-1, 1, -1): + band_size = np.prod(self.cf_shape[band_idx]) + thresh_value, sigma_est = self._auto_thresh_scale(input_data[end-band_size:end], sigma=sigma) + end= end-band_size if self._sigma_estimation == "global" and band_idx == len(input_data)-1: sigma = sigma_est - thresh_list.append(thresh_value) - return thresh_list + thresh_list[band_idx] = thresh_value + + # replicate the threshold for every subband + weights = np.zeros(np.sum(np.prod(self.cf_shape, axis=-1))) + + start=0 + for thresh, shape in zip(thresh_list, self.cf_shape): + size = np.prod(shape) + weights[start:start+size] = thresh + start += size + return weights, sigma + def _op_method(self, input_data, extra_factor=1.0): From f03605424dff3d4361c6a9e674b9abdc42c1817f Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Thu, 26 Jan 2023 22:34:01 +0100 Subject: [PATCH 04/20] add auto sure example. --- ...cartesian_reconstruction_auto_threshold.py | 101 ++++++++++++++++++ 1 file changed, 101 insertions(+) create mode 100644 examples/cartesian_reconstruction_auto_threshold.py diff --git a/examples/cartesian_reconstruction_auto_threshold.py b/examples/cartesian_reconstruction_auto_threshold.py new file mode 100644 index 00000000..b4a4d77b --- /dev/null +++ b/examples/cartesian_reconstruction_auto_threshold.py @@ -0,0 +1,101 @@ +""" +Neuroimaging cartesian reconstruction +===================================== + +Author: Pierre-Antoine Comby / Chaithya G R + +In this tutorial we will reconstruct an MRI image from the sparse kspace +measurements. +Moreover we will see the benefit of automatic tuning of the regularisation parameters. + +Import neuroimaging data +------------------------ + +We use the toy datasets available in pysap, more specifically a 2D brain slice +and the cartesian acquisition scheme. +""" + +# Package import +from mri.operators import FFT, WaveletN +from mri.operators.utils import convert_mask_to_locations +from mri.reconstructors import SingleChannelReconstructor +from mri.operators.proximity.weighted import AutoWeightedSparseThreshold +import pysap +from pysap.data import get_sample_data + +# Third party import +from modopt.opt.proximity import SparseThreshold +from modopt.opt.linear import Identity +from modopt.math.metrics import ssim +import numpy as np + +# Loading input data +image = get_sample_data('2d-mri') + +# Obtain K-Space Cartesian Mask +mask = get_sample_data("cartesian-mri-mask") + +# View Input +# image.show() +# mask.show() + +#%% +# Generate the kspace +# ------------------- +# +# From the 2D brain slice and the acquisition mask, we retrospectively +# undersample the k-space using a cartesian acquisition mask +# We then reconstruct the zero order solution as a baseline + + +# Get the locations of the kspace samples +kspace_loc = convert_mask_to_locations(mask.data) +# Generate the subsampled kspace +fourier_op = FFT(samples=kspace_loc, shape=image.shape) +kspace_data = fourier_op.op(image) + +# Zero order solution +image_rec0 = pysap.Image(data=fourier_op.adj_op(kspace_data), + metadata=image.metadata) +# image_rec0.show() + +# Calculate SSIM +base_ssim = ssim(image_rec0, image) +print(base_ssim) + +#%% +# FISTA optimization +# ------------------ +# +# We now want to refine the zero order solution using a FISTA optimization. +# The cost function is set to Proximity Cost + Gradient Cost + +# Setup the operators +linear_op = WaveletN(wavelet_name="sym8", nb_scales=4) +coeffs = linear_op.op(image_rec0) +regularizer_op = AutoWeightedSparseThreshold( + coeffs.shape, linear=Identity(), + update_period=5, + sigma_estimation="global", + threshold_estimation="sure", + thresh_type="soft" +) +# Setup Reconstructor +reconstructor = SingleChannelReconstructor( + fourier_op=fourier_op, + linear_op=linear_op, + regularizer_op=regularizer_op, + gradient_formulation='synthesis', + verbose=1, +) +# Start Reconstruction +x_final, costs, metrics = reconstructor.reconstruct( + kspace_data=kspace_data, + optimization_alg='fista', + num_iterations=200, +) +image_rec = pysap.Image(data=np.abs(x_final)) +# image_rec.show() +# Calculate SSIM +recon_ssim = ssim(image_rec, image) +print('The Reconstruction SSIM is : ' + str(recon_ssim)) From 8cb19e193a2aa2c08aa385e5d2eceab2dd234f78 Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Tue, 31 Jan 2023 12:02:33 +0100 Subject: [PATCH 05/20] rework of the eestimation function. --- mri/operators/proximity/weighted.py | 271 ++++++++++++++++++---------- 1 file changed, 172 insertions(+), 99 deletions(-) diff --git a/mri/operators/proximity/weighted.py b/mri/operators/proximity/weighted.py index e6cc5164..d2b32f40 100644 --- a/mri/operators/proximity/weighted.py +++ b/mri/operators/proximity/weighted.py @@ -1,9 +1,10 @@ import numpy as np -import scipy as sp from modopt.opt.proximity import SparseThreshold from modopt.opt.linear import Identity +from pysap.base.utils import flatten, unflatten + class WeightedSparseThreshold(SparseThreshold): """This is a weighted version of `SparseThreshold` in ModOpt. @@ -81,117 +82,186 @@ def mu(self, w): self.weights = weights_init -class AutoWeightedSparseThreshold(WeightedSparseThreshold): - """This WeightedSparseThreshold uses the universal threshold rules to """ +def _sigma_mad(data): + """Return a robust estimation of the variance. + + It assums that is a sparse vector polluted by gaussian noise. + """ + return np.median(np.abs(data - np.median(data)))/0.6745 + # return np.median(np.abs(data))/0.6745 + +def _sure_est(data): + """Return an estimation of the threshold computed using the SURE method.""" + dataf = data.flatten() + n = dataf.size + data_sorted = np.sort(np.abs(dataf))**2 + idx = np.arange(n-1, -1, -1) + tmp = np.cumsum(data_sorted) + idx * data_sorted + + risk = (n - (2 * np.arange(n)) + tmp) / n + ibest = np.argmin(risk) + + return np.sqrt(data_sorted[ibest]) + +def _thresh_select(data, thresh_est): + """ + Threshold selection for denoising. + + It assumes that data has a white noise of N(0,1) + """ + n, j = data.size, np.ceil(np.log2(data.size)) + universal_thr = np.sqrt(2*np.log(n)) + + if thresh_est == "sure": + thr = _sure_est(data) + if thresh_est == "universal": + thr = universal_thr + if thresh_est == "hybrid-sure": + eta = np.linalg.norm(data.flatten()) ** 2 /n - 1 + if eta < j ** (1.5) / np.sqrt(n): + thr = universal_thr + else: + test_th = _sure_est(data) + thr = min(test_th, universal_thr) + return thr + +def _wavelet_noise_estimate(wavelet_coefs, coeffs_shape, sigma_est): + r"""Return an estimate of the noise variance in each band. + + Parameters + ---------- + wavelet_bands: list + list of array + sigma_est: str + Estimation method, available are "band", "level", "level-shared", "global" + Returns + ------- + numpy.ndarray + Estimation of the variance for each wavelet bands. + + Notes + ----- + This methods makes several assumptions: + + - The wavelet coefficient are ordered by scale, and the scale are ordered by size. + - At each scale, the subbands should have the same shape. + + The variance estimation can be done: + + - On each band (eg LH, HL and HH band of each level) + - On each level, using the HH band. + - On each level, using all the available coefficient (estimating jointly on LH, HL and HH) + + For the selected data band(s) the variance is estimated using the MAD estimator: + + .. math:: + \hat{\sigma} = \textrm{median}(|x|) / 0.6745 + + """ + sigma_ret = np.ones(len(coeffs_shape)) + sigma_ret[0] = np.NaN + start = 0 + stop = 0 + if sigma_est is None: + return sigma_ret + if sigma_est == "band": + for i in range(1, len(coeffs_shape)): + stop += np.prod(coeffs_shape[i]) + sigma_ret[i] = _sigma_mad(wavelet_coefs[start:stop]) + start = stop + if sigma_est == "level": + # use the diagonal coefficient to estimate the variance of the level. + # it assumes that the band of the same level have the same shape. + start = np.prod(coeffs_shape[0]) + for i, scale_shape in enumerate(np.unique(coeffs_shape[1:], axis=0)): + scale_sz = np.prod(scale_shape) + matched_bands = np.all(scale_shape == coeffs_shape[1:], axis=1) + band_per_level = np.sum(matched_bands) + start = start + scale_sz * (band_per_level-1) + stop = start + scale_sz * band_per_level + sigma_ret[1+i*(band_per_level):1+(i+1)*band_per_level] = _sigma_mad(wavelet_coefs[start:stop]) + start = stop + if sigma_est == "level-shared": + start = np.prod(coeffs_shape[0]) + for i, scale_shape in enumerate(np.unique(coeffs_shape[1:], axis=0)): + scale_sz = np.prod(scale_shape) + band_per_level = np.sum(scale_shape == coeffs_shape) + stop = start + scale_sz * band_per_level + sigma_ret[i:i+band_per_level] = _sigma_mad(wavelet_coefs[start:stop]) + start = stop + if sigma_est == "global": + sigma_ret *= _sigma_mad(wavelet_coefs[-np.prod(coeffs_shape[-1]):]) + sigma_ret[0] = np.NaN + return sigma_ret + +class AutoWeightedSparseThreshold(SparseThreshold): + """Automatic Weighting of Sparse coefficient. + + This proximty automatically determines the threshold for Sparse (e.g. Wavelet based) + coefficients. + + The weight are computed on first call, and updated every ``update_period`` calls. + Note that the coarse/approximation scale will not be thresholded. + + Parameters + ---------- + coeffs_shape: list of tuple + list of shape for the subbands. + linear: LinearOperator + Required for cost estimation. + update_period: int + Estimation of the weight update period. + threshold_estimation: str + threshold estimation method. Available are "sure", "hybrid-sure" and "universal" + sigma_estimation: str + noise std estimation method. Available are "global", "level" and "level_shared" + thresh_type: str + "hard" or "soft" thresholding. + """ def __init__(self, coeffs_shape, linear=Identity(), update_period=0, sigma_estimation="global", threshold_estimation="sure", **kwargs): self._n_op_calls = 0 - self._sigma_estimation = sigma_estimation + self.cf_shape = coeffs_shape self._update_period = update_period + + if sigma_estimation not in ["bands", "level", "global"]: + raise ValueError("Unsupported sigma estimation method") + if threshold_estimation not in ["sure", "hybrid-sure", "universal"]: + raise ValueError("Unsupported threshold estimation method.") + + self._sigma_estimation = sigma_estimation self._thresh_estimation = threshold_estimation + weights_init = np.zeros(np.sum(np.prod(coeffs_shape, axis=-1))) super().__init__(weights=weights_init, - coeffs_shape=coeffs_shape, - weight_type="custom", + linear=linear, **kwargs) - def _auto_thresh_scale(self, input_data, sigma=None): - """Determine a threshold value adapted to denoise the data of a specific scale. - - Parameters - ---------- - input_data: numpy.ndarray - data that should be thresholded. - sigma: float - Estimation of the noise standard deviation. - Returns - ------- - float - The estimated threshold. - - Raises - ------ - ValueError is method is not supported. - Notes - ----- - The choice of the threshold makes the assumptions of a white additive gaussian noise. - """ - - #tmp = np.sort(input_data.flatten()) - tmp = input_data.flatten() - # use the robust estimator to estimate the noise variance. - med = np.median(tmp) - if sigma is None: - sigma = np.median(np.abs(tmp-med)) / 0.6745 - N = len(input_data) - j = np.log2(N) - - uni_threshold = np.sqrt(2*np.log(N)) - - if self._thresh_estimation == "universal": - return sigma * uni_threshold, sigma - elif self._thresh_estimation == "sure": - tmp = tmp **2 - eps2 = (sigma ** 2) /N - # TODO: optimize the estimation - def _sure(t): - e2t2 = eps2 * (t**2) - return (N * eps2 - + np.sum(np.minimum(tmp, e2t2)) - - 2*eps2*np.sum(tmp <= e2t2) - ) - - thresh = sp.optimize.minimize_scalar( - _sure, - method="bounded", - bounds=[0, uni_threshold]).x - sj2 = np.sum(tmp/eps2 - 1)/N - if sj2 >= 3*j/np.sqrt(2*N): - return thresh, sigma - else: - return uni_threshold, sigma - - else: - raise ValueError("Unknown method name") - def _auto_thresh(self, input_data): - """Determines the threshold for every scale (except the coarse one) using the provided method. + """Compute the best weights for the input_data.""" - Parameters - ---------- - input_data: list of numpy.ndarray + # Estimate the noise std for each band. - Returns - ------- - thresh_list: list - list of threshold for every scale - """ - sigma = None - thresh_list = np.zeros(len(self.cf_shape)) - - # reverse order to get the finest scale first. - end=len(input_data) - for band_idx in range(len(self.cf_shape)-1, 1, -1): - band_size = np.prod(self.cf_shape[band_idx]) - thresh_value, sigma_est = self._auto_thresh_scale(input_data[end-band_size:end], sigma=sigma) - end= end-band_size - if self._sigma_estimation == "global" and band_idx == len(input_data)-1: - sigma = sigma_est - thresh_list[band_idx] = thresh_value - - # replicate the threshold for every subband - weights = np.zeros(np.sum(np.prod(self.cf_shape, axis=-1))) - - start=0 - for thresh, shape in zip(thresh_list, self.cf_shape): - size = np.prod(shape) - weights[start:start+size] = thresh - start += size - return weights, sigma - + sigma_bands = _wavelet_noise_estimate(input_data, self.cf_shape, self._sigma_estimation) + weights = np.zeros_like(input_data) + + # compute the threshold for each subband + start = np.prod(self.cf_shape[0]) + stop = start + ts = [] + for i in range(1, len(self.cf_shape)): + stop = start + np.prod(self.cf_shape[i]) + t = sigma_bands[i] * _thresh_select( + input_data[start:stop] / sigma_bands[i], + self._thresh_estimation + ) + ts.append(t) + weights[start:stop] = t + start = stop + return weights def _op_method(self, input_data, extra_factor=1.0): """Operator. @@ -212,7 +282,10 @@ def _op_method(self, input_data, extra_factor=1.0): Thresholded data """ - if (self._update_period == 0 and self._n_op_calls == 0) or (self._n_op_calls % self._update_period == 0) : - self.mu , sigma = self._auto_thresh(input_data) + if ( + (self._update_period == 0 and self._n_op_calls == 0) + or (self._n_op_calls % self._update_period == 0) + ): + self.weights = self._auto_thresh(input_data) self._n_op_calls += 1 return super()._op_method(input_data, extra_factor=extra_factor) From be1f667b27dfdb73f06226949d44864d99fe7f42 Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Fri, 3 Feb 2023 10:41:47 +0100 Subject: [PATCH 06/20] cleaning. --- mri/operators/proximity/weighted.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/mri/operators/proximity/weighted.py b/mri/operators/proximity/weighted.py index d2b32f40..177163ee 100644 --- a/mri/operators/proximity/weighted.py +++ b/mri/operators/proximity/weighted.py @@ -3,8 +3,6 @@ from modopt.opt.proximity import SparseThreshold from modopt.opt.linear import Identity -from pysap.base.utils import flatten, unflatten - class WeightedSparseThreshold(SparseThreshold): """This is a weighted version of `SparseThreshold` in ModOpt. @@ -87,8 +85,8 @@ def _sigma_mad(data): It assums that is a sparse vector polluted by gaussian noise. """ - return np.median(np.abs(data - np.median(data)))/0.6745 - # return np.median(np.abs(data))/0.6745 + # return np.median(np.abs(data - np.median(data)))/0.6745 + return np.median(np.abs(data))/0.6745 def _sure_est(data): """Return an estimation of the threshold computed using the SURE method.""" @@ -109,7 +107,7 @@ def _thresh_select(data, thresh_est): It assumes that data has a white noise of N(0,1) """ - n, j = data.size, np.ceil(np.log2(data.size)) + n = data.size universal_thr = np.sqrt(2*np.log(n)) if thresh_est == "sure": @@ -117,8 +115,8 @@ def _thresh_select(data, thresh_est): if thresh_est == "universal": thr = universal_thr if thresh_est == "hybrid-sure": - eta = np.linalg.norm(data.flatten()) ** 2 /n - 1 - if eta < j ** (1.5) / np.sqrt(n): + eta = np.sum(data ** 2) /n - 1 + if eta < (np.log2(n) ** 1.5) / np.sqrt(n): thr = universal_thr else: test_th = _sure_est(data) @@ -150,7 +148,7 @@ def _wavelet_noise_estimate(wavelet_coefs, coeffs_shape, sigma_est): - On each band (eg LH, HL and HH band of each level) - On each level, using the HH band. - - On each level, using all the available coefficient (estimating jointly on LH, HL and HH) + - Only with the latest band (global) For the selected data band(s) the variance is estimated using the MAD estimator: @@ -282,10 +280,10 @@ def _op_method(self, input_data, extra_factor=1.0): Thresholded data """ - if ( - (self._update_period == 0 and self._n_op_calls == 0) - or (self._n_op_calls % self._update_period == 0) - ): + if self._update_period == 0 and self._n_op_calls == 0: + self.weights = self._auto_thresh(input_data) + if self._update_period != 0 and self._n_op_calls % self._update_period == 0: self.weights = self._auto_thresh(input_data) + self._n_op_calls += 1 return super()._op_method(input_data, extra_factor=extra_factor) From fdc657aba28eef6338512d03a47b0e7b5cd05ee5 Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Fri, 3 Feb 2023 10:41:53 +0100 Subject: [PATCH 07/20] remove level-shared method. --- mri/operators/proximity/weighted.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/mri/operators/proximity/weighted.py b/mri/operators/proximity/weighted.py index 177163ee..e6ad0921 100644 --- a/mri/operators/proximity/weighted.py +++ b/mri/operators/proximity/weighted.py @@ -179,14 +179,6 @@ def _wavelet_noise_estimate(wavelet_coefs, coeffs_shape, sigma_est): stop = start + scale_sz * band_per_level sigma_ret[1+i*(band_per_level):1+(i+1)*band_per_level] = _sigma_mad(wavelet_coefs[start:stop]) start = stop - if sigma_est == "level-shared": - start = np.prod(coeffs_shape[0]) - for i, scale_shape in enumerate(np.unique(coeffs_shape[1:], axis=0)): - scale_sz = np.prod(scale_shape) - band_per_level = np.sum(scale_shape == coeffs_shape) - stop = start + scale_sz * band_per_level - sigma_ret[i:i+band_per_level] = _sigma_mad(wavelet_coefs[start:stop]) - start = stop if sigma_est == "global": sigma_ret *= _sigma_mad(wavelet_coefs[-np.prod(coeffs_shape[-1]):]) sigma_ret[0] = np.NaN From dd6e24ed01134bbdbf5001443e3694d7e0e43b29 Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Fri, 3 Feb 2023 10:42:08 +0100 Subject: [PATCH 08/20] add thresh range argument. --- mri/operators/proximity/weighted.py | 78 ++++++++++++++++++++++++----- 1 file changed, 65 insertions(+), 13 deletions(-) diff --git a/mri/operators/proximity/weighted.py b/mri/operators/proximity/weighted.py index e6ad0921..c39e401f 100644 --- a/mri/operators/proximity/weighted.py +++ b/mri/operators/proximity/weighted.py @@ -209,19 +209,27 @@ class AutoWeightedSparseThreshold(SparseThreshold): "hard" or "soft" thresholding. """ def __init__(self, coeffs_shape, linear=Identity(), update_period=0, - sigma_estimation="global", threshold_estimation="sure", **kwargs): + sigma_range="global", + thresh_range="global", + threshold_estimation="sure", + threshold_scaler=1.0, + **kwargs): self._n_op_calls = 0 self.cf_shape = coeffs_shape self._update_period = update_period - if sigma_estimation not in ["bands", "level", "global"]: + if thresh_range not in ["bands", "level", "global"]: + raise ValueError("Unsupported threshold range") + if sigma_range not in ["bands", "level", "global"]: raise ValueError("Unsupported sigma estimation method") - if threshold_estimation not in ["sure", "hybrid-sure", "universal"]: + if threshold_estimation not in ["sure", "hybrid-sure", "universal", "bayes"]: raise ValueError("Unsupported threshold estimation method.") - self._sigma_estimation = sigma_estimation + self._sigma_range = sigma_range + self._thresh_range = thresh_range self._thresh_estimation = threshold_estimation + self._thresh_scale = threshold_scaler weights_init = np.zeros(np.sum(np.prod(coeffs_shape, axis=-1))) @@ -232,25 +240,69 @@ def __init__(self, coeffs_shape, linear=Identity(), update_period=0, def _auto_thresh(self, input_data): """Compute the best weights for the input_data.""" + weights = np.ones(input_data.shape) + weights[:np.prod(self.cf_shape[0])] = 0 + # special case for bayes shrink + if self._thresh_estimation == "bayes": + sigma_noise = _sigma_mad(input_data[-np.prod(self.cf_shape[-1]):]) + start = np.prod(self.cf_shape[0]) + for i in range(1, len(self.cf_shape)): + stop = start + np.prod(self.cf_shape[i]) + band = input_data[start:stop] + sigma_y2 = np.mean(band ** 2) + denom = np.sqrt(np.max(sigma_y2 - sigma_noise, 0)) + if denom == 0: + thr = np.max(abs(band)) + else: + thr = sigma_noise ** 2 / denom + weights[start:stop] = thr + start = stop + return weights + # Estimate the noise std for each band. - sigma_bands = _wavelet_noise_estimate(input_data, self.cf_shape, self._sigma_estimation) - weights = np.zeros_like(input_data) + sigma_bands = _wavelet_noise_estimate(input_data, self.cf_shape, self._sigma_range) # compute the threshold for each subband start = np.prod(self.cf_shape[0]) stop = start ts = [] - for i in range(1, len(self.cf_shape)): - stop = start + np.prod(self.cf_shape[i]) - t = sigma_bands[i] * _thresh_select( - input_data[start:stop] / sigma_bands[i], + if self._thresh_range == "global": + weights =sigma_bands[-1] * _thresh_select( + input_data[-np.prod(self.cf_shape[-1]):] / sigma_bands[-1], self._thresh_estimation ) - ts.append(t) - weights[start:stop] = t - start = stop + elif self._thresh_range == "band": + for i in range(1, len(self.cf_shape)): + stop = start + np.prod(self.cf_shape[i]) + t = sigma_bands[i] * _thresh_select( + input_data[start:stop] / sigma_bands[i], + self._thresh_estimation + ) + ts.append(t) + weights[start:stop] = t + start = stop + elif self._thresh_range == "level": + start = np.prod(self.cf_shape[0]) + start_hh = start + for i, scale_shape in enumerate(np.unique(self.cf_shape[1:], axis=0)): + scale_sz = np.prod(scale_shape) + matched_bands = np.all(scale_shape == self.cf_shape[1:], axis=1) + band_per_level = np.sum(matched_bands) + start_hh = start + scale_sz * (band_per_level-1) + stop = start + scale_sz * band_per_level + t = sigma_bands[i+1] * _thresh_select( + input_data[start_hh:stop] / sigma_bands[i+1], + self._thresh_estimation + ) + ts.append(t) + weights[start:stop] = t + start = stop + if callable(self._thresh_scale): + weights = self._thresh_scale(weights, self._n_op_calls) + else: + weights *= self._thresh_scale return weights def _op_method(self, input_data, extra_factor=1.0): From 589cdab5366d954e5470e81f635a02653e13adfe Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Fri, 3 Feb 2023 10:42:18 +0100 Subject: [PATCH 09/20] someone left a bug here. --- mri/optimizers/utils/cost.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mri/optimizers/utils/cost.py b/mri/optimizers/utils/cost.py index 36a3d4e6..077a957e 100644 --- a/mri/optimizers/utils/cost.py +++ b/mri/optimizers/utils/cost.py @@ -152,7 +152,7 @@ def _calc_cost(self, x_new, *args, **kwargs): the cost function defined by the operators (gradient + prox_op). """ if self.optimizer_type == 'forward_backward': - if not hasattr(self.grad_op, 'linear_op') and self.linear_op is not None: + if not hasattr(self.gradient_op, 'linear_op') and self.linear_op is not None: y_new = self.linear_op.op(x_new) else: # synthesis case, y_new is already in the linear_op (sparse) domain. From 9450ec0ca09e8cddd819b67fe3735287fab83ecf Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Mon, 6 Feb 2023 16:56:03 +0100 Subject: [PATCH 10/20] refactor: externalize the computation of wavelet threshold. --- mri/operators/proximity/weighted.py | 184 +++++++++++++++++----------- 1 file changed, 112 insertions(+), 72 deletions(-) diff --git a/mri/operators/proximity/weighted.py b/mri/operators/proximity/weighted.py index c39e401f..9d0e0560 100644 --- a/mri/operators/proximity/weighted.py +++ b/mri/operators/proximity/weighted.py @@ -123,15 +123,18 @@ def _thresh_select(data, thresh_est): thr = min(test_th, universal_thr) return thr -def _wavelet_noise_estimate(wavelet_coefs, coeffs_shape, sigma_est): +def wavelet_noise_estimate(wavelet_coeffs, coeffs_shape, sigma_est): r"""Return an estimate of the noise variance in each band. Parameters ---------- - wavelet_bands: list - list of array + wavelet_coeffs: numpy.ndarray + flatten array of wavelet coefficient, typically returned by ``WaveletN.op`` + coeffs_shape: + list of tuple representing the shape of each subbands. + Typically accessible by WaveletN.coeffs_shape sigma_est: str - Estimation method, available are "band", "level", "level-shared", "global" + Estimation method, available are "band", "level", and "global" Returns ------- numpy.ndarray @@ -146,9 +149,9 @@ def _wavelet_noise_estimate(wavelet_coefs, coeffs_shape, sigma_est): The variance estimation can be done: - - On each band (eg LH, HL and HH band of each level) + - On each band - On each level, using the HH band. - - Only with the latest band (global) + - Only with the largest, most detailled HH band (global) For the selected data band(s) the variance is estimated using the MAD estimator: @@ -165,7 +168,7 @@ def _wavelet_noise_estimate(wavelet_coefs, coeffs_shape, sigma_est): if sigma_est == "band": for i in range(1, len(coeffs_shape)): stop += np.prod(coeffs_shape[i]) - sigma_ret[i] = _sigma_mad(wavelet_coefs[start:stop]) + sigma_ret[i] = _sigma_mad(wavelet_coeffs[start:stop]) start = stop if sigma_est == "level": # use the diagonal coefficient to estimate the variance of the level. @@ -174,16 +177,97 @@ def _wavelet_noise_estimate(wavelet_coefs, coeffs_shape, sigma_est): for i, scale_shape in enumerate(np.unique(coeffs_shape[1:], axis=0)): scale_sz = np.prod(scale_shape) matched_bands = np.all(scale_shape == coeffs_shape[1:], axis=1) - band_per_level = np.sum(matched_bands) - start = start + scale_sz * (band_per_level-1) - stop = start + scale_sz * band_per_level - sigma_ret[1+i*(band_per_level):1+(i+1)*band_per_level] = _sigma_mad(wavelet_coefs[start:stop]) + bpl = np.sum(matched_bands) + start = start + scale_sz * (bpl-1) + stop = start + scale_sz * bpl + sigma_ret[1+i*(bpl):1+(i+1)*bpl] = _sigma_mad(wavelet_coeffs[start:stop]) start = stop if sigma_est == "global": - sigma_ret *= _sigma_mad(wavelet_coefs[-np.prod(coeffs_shape[-1]):]) + sigma_ret *= _sigma_mad(wavelet_coeffs[-np.prod(coeffs_shape[-1]):]) sigma_ret[0] = np.NaN return sigma_ret + +def wavelet_threshold_estimate( + wavelet_coeffs, + coeffs_shape, + thresh_range="global", + sigma_range="global", + thresh_estimation="hybrid-sure" +): + """Estimate wavelet coefficients thresholds. + + Notes that no threshold will be estimate for the coarse scale. + Parameters + ---------- + wavelet_coeffs: numpy.ndarray + flatten array of wavelet coefficient, typically returned by ``WaveletN.op`` + coeffs_shape: list + List of tuple representing the shape of each subbands. + Typically accessible by WaveletN.coeffs_shape + thresh_range: str. default "global" + Defines on which data range to estimate thresholds. + Either "band", "level", or "global" + sigma_range: str, default "global" + Defines on which data range to estimate thresholds. + Either "band", "level", or "global" + thresh_estimation: str, default "hybrid-sure" + Name of the threshold estimation method. + Available are "sure", "hybrid-sure", "universal" + + Returns + ------- + numpy.ndarray + array of threshold for each wavelet coefficient. + """ + + weights = np.ones(wavelet_coeffs.shape) + weights[:np.prod(coeffs_shape[0])] = 0 + + # Estimate the noise std for each band. + + sigma_bands = wavelet_noise_estimate(wavelet_coeffs, coeffs_shape, sigma_range) + + # compute the threshold for each subband + + start = np.prod(coeffs_shape[0]) + stop = start + ts = [] + if thresh_range == "global": + weights =sigma_bands[-1] * _thresh_select( + wavelet_coeffs[-np.prod(coeffs_shape[-1]):] / sigma_bands[-1], + thresh_estimation + ) + elif thresh_range == "band": + for i in range(1, len(coeffs_shape)): + stop = start + np.prod(coeffs_shape[i]) + t = sigma_bands[i] * _thresh_select( + wavelet_coeffs[start:stop] / sigma_bands[i], + thresh_estimation + ) + ts.append(t) + weights[start:stop] = t + start = stop + elif thresh_range == "level": + start = np.prod(coeffs_shape[0]) + start_hh = start + for i, scale_shape in enumerate(np.unique(coeffs_shape[1:], axis=0)): + scale_sz = np.prod(scale_shape) + matched_bands = np.all(scale_shape == coeffs_shape[1:], axis=1) + band_per_level = np.sum(matched_bands) + start_hh = start + scale_sz * (band_per_level-1) + stop = start + scale_sz * band_per_level + t = sigma_bands[i+1] * _thresh_select( + wavelet_coeffs[start_hh:stop] / sigma_bands[i+1], + thresh_estimation + ) + ts.append(t) + weights[start:stop] = t + start = stop + return weights + + + class AutoWeightedSparseThreshold(SparseThreshold): """Automatic Weighting of Sparse coefficient. @@ -238,67 +322,23 @@ def __init__(self, coeffs_shape, linear=Identity(), update_period=0, **kwargs) def _auto_thresh(self, input_data): - """Compute the best weights for the input_data.""" - - weights = np.ones(input_data.shape) - weights[:np.prod(self.cf_shape[0])] = 0 - # special case for bayes shrink - if self._thresh_estimation == "bayes": - sigma_noise = _sigma_mad(input_data[-np.prod(self.cf_shape[-1]):]) - start = np.prod(self.cf_shape[0]) - for i in range(1, len(self.cf_shape)): - stop = start + np.prod(self.cf_shape[i]) - band = input_data[start:stop] - sigma_y2 = np.mean(band ** 2) - denom = np.sqrt(np.max(sigma_y2 - sigma_noise, 0)) - if denom == 0: - thr = np.max(abs(band)) - else: - thr = sigma_noise ** 2 / denom - weights[start:stop] = thr - start = stop - return weights - - # Estimate the noise std for each band. + """Compute the best weights for the input_data. - sigma_bands = _wavelet_noise_estimate(input_data, self.cf_shape, self._sigma_range) - - # compute the threshold for each subband - - start = np.prod(self.cf_shape[0]) - stop = start - ts = [] - if self._thresh_range == "global": - weights =sigma_bands[-1] * _thresh_select( - input_data[-np.prod(self.cf_shape[-1]):] / sigma_bands[-1], - self._thresh_estimation - ) - elif self._thresh_range == "band": - for i in range(1, len(self.cf_shape)): - stop = start + np.prod(self.cf_shape[i]) - t = sigma_bands[i] * _thresh_select( - input_data[start:stop] / sigma_bands[i], - self._thresh_estimation - ) - ts.append(t) - weights[start:stop] = t - start = stop - elif self._thresh_range == "level": - start = np.prod(self.cf_shape[0]) - start_hh = start - for i, scale_shape in enumerate(np.unique(self.cf_shape[1:], axis=0)): - scale_sz = np.prod(scale_shape) - matched_bands = np.all(scale_shape == self.cf_shape[1:], axis=1) - band_per_level = np.sum(matched_bands) - start_hh = start + scale_sz * (band_per_level-1) - stop = start + scale_sz * band_per_level - t = sigma_bands[i+1] * _thresh_select( - input_data[start_hh:stop] / sigma_bands[i+1], - self._thresh_estimation - ) - ts.append(t) - weights[start:stop] = t - start = stop + Parameters + ---------- + input_data: numpy.ndarray + Array of sparse coefficient. + See Also + -------- + wavelet_threshold_estimate + """ + weights = wavelet_threshold_estimate( + input_data, + self.cf_shape, + thresh_range=self._thresh_range, + sigma_range=self._sigma_range, + thresh_estimation=self._thresh_estimation, + ) if callable(self._thresh_scale): weights = self._thresh_scale(weights, self._n_op_calls) else: From 13a7ab0f5a63ac07b9dfa7f22044d18eec579bcd Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Thu, 9 Feb 2023 11:24:33 +0100 Subject: [PATCH 11/20] improve docstrings. --- doc/refs.bib | 13 ++++ mri/operators/proximity/weighted.py | 104 ++++++++++++++++++++-------- 2 files changed, 87 insertions(+), 30 deletions(-) diff --git a/doc/refs.bib b/doc/refs.bib index da88c409..711c25e3 100644 --- a/doc/refs.bib +++ b/doc/refs.bib @@ -42,3 +42,16 @@ @article{Pruessmann1999 year={1999}, volume={42}, } +@inproceedings{Donoho1994, + address = {Baltimore, MD, USA}, + title = {Threshold selection for wavelet shrinkage of noisy data}, + ISBN = {978-0-7803-2050-5}, + url = {http://ieeexplore.ieee.org/document/412133/}, + DOI = {10.1109/IEMBS.1994.412133}, + booktitle = {Proceedings of 16th Annual International Conference of the + IEEE Engineering in Medicine and Biology Society}, + publisher = {IEEE}, + author = {Donoho, D.L. and Johnstone, I.M.}, + year = 1994, + pages = {A24–A25} +} diff --git a/mri/operators/proximity/weighted.py b/mri/operators/proximity/weighted.py index 9d0e0560..6fb4583a 100644 --- a/mri/operators/proximity/weighted.py +++ b/mri/operators/proximity/weighted.py @@ -80,16 +80,52 @@ def mu(self, w): self.weights = weights_init -def _sigma_mad(data): - """Return a robust estimation of the variance. +def _sigma_mad(data, centered=True): + """Return a robust estimation of the standard deviation. - It assums that is a sparse vector polluted by gaussian noise. + The standard deviation is computed using the following estimator, based on the + Median Absolute deviation of the data [#]_ + .. math:: + \hat{\sigma} = \frac{MAD}{\sqrt{2}\textrm{erf}^{-1}(1/2)} + + Parameters + ---------- + data: numpy.ndarray + the data on which the standard deviation will be estimated. + centered: bool, default True. + If true the median of the is assummed to be 0. + Returns + ------- + float: + The estimation of the standard deviation. + + References + ---------- + .. [#] https://en.m.wikipedia.org/wiki/Median_absolute_deviation """ - # return np.median(np.abs(data - np.median(data)))/0.6745 - return np.median(np.abs(data))/0.6745 + if centered: + return np.median(np.abs(data[:]))/0.6745 + return np.median(np.abs(data[:] - np.median(data[:])))/0.6745 def _sure_est(data): - """Return an estimation of the threshold computed using the SURE method.""" + """Return an estimation of the threshold computed using the SURE method. + + The computation of the estimator is based on the formulation of `cite:donoho1994` + and the efficient implementation of [#]_ + + Parameters + ---------- + data: numpy.array + Noisy Data with unit standard deviation. + Returns + ------- + float + Value of the threshold minimizing the SURE estimator. + + References + ---------- + .. [#] https://pyyawt.readthedocs.io/_modules/pyyawt/denoising.html#ValSUREThresh + """ dataf = data.flatten() n = dataf.size data_sorted = np.sort(np.abs(dataf))**2 @@ -103,9 +139,19 @@ def _sure_est(data): def _thresh_select(data, thresh_est): """ - Threshold selection for denoising. + Threshold selection for denoising, implementing the methods proposed in `cite:donoho1994` - It assumes that data has a white noise of N(0,1) + Parameters + ---------- + data: numpy.ndarray + Noisy data on which a threshold will be estimated. It should only be corrupted by a + standard gaussian white noise N(0,1). + thresh_est: str + threshold estimation method. Available are "sure", "universal", "hybrid-sure". + Returns + ------- + float: + the threshold for the data provided. """ n = data.size universal_thr = np.sqrt(2*np.log(n)) @@ -124,40 +170,38 @@ def _thresh_select(data, thresh_est): return thr def wavelet_noise_estimate(wavelet_coeffs, coeffs_shape, sigma_est): - r"""Return an estimate of the noise variance in each band. + r"""Return an estimate of the noise standard deviation in each subband. Parameters ---------- wavelet_coeffs: numpy.ndarray - flatten array of wavelet coefficient, typically returned by ``WaveletN.op`` + flatten array of wavelet coefficients, typically returned by ``WaveletN.op`` coeffs_shape: - list of tuple representing the shape of each subbands. + list of tuple representing the shape of each subband. Typically accessible by WaveletN.coeffs_shape sigma_est: str Estimation method, available are "band", "level", and "global" Returns ------- numpy.ndarray - Estimation of the variance for each wavelet bands. + Estimation of the variance for each wavelet subband. Notes ----- This methods makes several assumptions: - - The wavelet coefficient are ordered by scale, and the scale are ordered by size. + - The wavelet coefficients are ordered by scale, and the scales are ordered by size. - At each scale, the subbands should have the same shape. - The variance estimation can be done: - - - On each band - - On each level, using the HH band. - - Only with the largest, most detailled HH band (global) - - For the selected data band(s) the variance is estimated using the MAD estimator: + The variance estimation is either performed: - .. math:: - \hat{\sigma} = \textrm{median}(|x|) / 0.6745 + - On each subband (``sigma_est = "band"``) + - On each level, using the detailled HH subband. (``sigma_est = "level"``) + - Only with the largest, most detailled HH band (``sigma_est = "global"``) + See Also + -------- + _sigma_mad: function estimating the standard deviation. """ sigma_ret = np.ones(len(coeffs_shape)) sigma_ret[0] = np.NaN @@ -195,7 +239,7 @@ def wavelet_threshold_estimate( sigma_range="global", thresh_estimation="hybrid-sure" ): - """Estimate wavelet coefficients thresholds. + """Estimate wavelet coefficient thresholds. Notes that no threshold will be estimate for the coarse scale. Parameters @@ -224,11 +268,11 @@ def wavelet_threshold_estimate( weights = np.ones(wavelet_coeffs.shape) weights[:np.prod(coeffs_shape[0])] = 0 - # Estimate the noise std for each band. + # Estimate the noise std on the specific range. sigma_bands = wavelet_noise_estimate(wavelet_coeffs, coeffs_shape, sigma_range) - # compute the threshold for each subband + # compute the threshold on each specific range. start = np.prod(coeffs_shape[0]) stop = start @@ -269,12 +313,12 @@ def wavelet_threshold_estimate( class AutoWeightedSparseThreshold(SparseThreshold): - """Automatic Weighting of Sparse coefficient. + """Automatic Weighting of sparse coefficients. This proximty automatically determines the threshold for Sparse (e.g. Wavelet based) coefficients. - The weight are computed on first call, and updated every ``update_period`` calls. + The weight are computed on first call, and updated on every ``update_period`` call. Note that the coarse/approximation scale will not be thresholded. Parameters @@ -288,7 +332,7 @@ class AutoWeightedSparseThreshold(SparseThreshold): threshold_estimation: str threshold estimation method. Available are "sure", "hybrid-sure" and "universal" sigma_estimation: str - noise std estimation method. Available are "global", "level" and "level_shared" + noise std estimation method. Available are "global", "level" and "band" thresh_type: str "hard" or "soft" thresholding. """ @@ -304,9 +348,9 @@ def __init__(self, coeffs_shape, linear=Identity(), update_period=0, if thresh_range not in ["bands", "level", "global"]: - raise ValueError("Unsupported threshold range") + raise ValueError("Unsupported threshold range.") if sigma_range not in ["bands", "level", "global"]: - raise ValueError("Unsupported sigma estimation method") + raise ValueError("Unsupported sigma estimation method.") if threshold_estimation not in ["sure", "hybrid-sure", "universal", "bayes"]: raise ValueError("Unsupported threshold estimation method.") From 7c2e8fd48c29542c5941e660182936d196b14ee3 Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Thu, 9 Feb 2023 11:24:50 +0100 Subject: [PATCH 12/20] use elif blocks. --- mri/operators/proximity/weighted.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/mri/operators/proximity/weighted.py b/mri/operators/proximity/weighted.py index 6fb4583a..0381c42d 100644 --- a/mri/operators/proximity/weighted.py +++ b/mri/operators/proximity/weighted.py @@ -158,15 +158,20 @@ def _thresh_select(data, thresh_est): if thresh_est == "sure": thr = _sure_est(data) - if thresh_est == "universal": + elif thresh_est == "universal": thr = universal_thr - if thresh_est == "hybrid-sure": + elif thresh_est == "hybrid-sure": eta = np.sum(data ** 2) /n - 1 if eta < (np.log2(n) ** 1.5) / np.sqrt(n): thr = universal_thr else: test_th = _sure_est(data) thr = min(test_th, universal_thr) + else: + raise ValueError( + "Unsupported threshold method." + "Available are 'sure', 'universal' and 'hybrid-sure'" + ) return thr def wavelet_noise_estimate(wavelet_coeffs, coeffs_shape, sigma_est): @@ -209,13 +214,13 @@ def wavelet_noise_estimate(wavelet_coeffs, coeffs_shape, sigma_est): stop = 0 if sigma_est is None: return sigma_ret - if sigma_est == "band": + elif sigma_est == "band": for i in range(1, len(coeffs_shape)): stop += np.prod(coeffs_shape[i]) sigma_ret[i] = _sigma_mad(wavelet_coeffs[start:stop]) start = stop - if sigma_est == "level": - # use the diagonal coefficient to estimate the variance of the level. + elif sigma_est == "level": + # use the diagonal coefficients subband to estimate the variance of the level. # it assumes that the band of the same level have the same shape. start = np.prod(coeffs_shape[0]) for i, scale_shape in enumerate(np.unique(coeffs_shape[1:], axis=0)): @@ -226,7 +231,7 @@ def wavelet_noise_estimate(wavelet_coeffs, coeffs_shape, sigma_est): stop = start + scale_sz * bpl sigma_ret[1+i*(bpl):1+(i+1)*bpl] = _sigma_mad(wavelet_coeffs[start:stop]) start = stop - if sigma_est == "global": + elif sigma_est == "global": sigma_ret *= _sigma_mad(wavelet_coeffs[-np.prod(coeffs_shape[-1]):]) sigma_ret[0] = np.NaN return sigma_ret From c135de8bcce02724bc26cc42ae226131eabada8b Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Thu, 9 Feb 2023 11:26:05 +0100 Subject: [PATCH 13/20] s/level/scale/g --- mri/operators/proximity/weighted.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/mri/operators/proximity/weighted.py b/mri/operators/proximity/weighted.py index 0381c42d..7b69698e 100644 --- a/mri/operators/proximity/weighted.py +++ b/mri/operators/proximity/weighted.py @@ -185,7 +185,7 @@ def wavelet_noise_estimate(wavelet_coeffs, coeffs_shape, sigma_est): list of tuple representing the shape of each subband. Typically accessible by WaveletN.coeffs_shape sigma_est: str - Estimation method, available are "band", "level", and "global" + Estimation method, available are "band", "scale", and "global" Returns ------- numpy.ndarray @@ -201,7 +201,7 @@ def wavelet_noise_estimate(wavelet_coeffs, coeffs_shape, sigma_est): The variance estimation is either performed: - On each subband (``sigma_est = "band"``) - - On each level, using the detailled HH subband. (``sigma_est = "level"``) + - On each scale, using the detailled HH subband. (``sigma_est = "scale"``) - Only with the largest, most detailled HH band (``sigma_est = "global"``) See Also @@ -219,9 +219,9 @@ def wavelet_noise_estimate(wavelet_coeffs, coeffs_shape, sigma_est): stop += np.prod(coeffs_shape[i]) sigma_ret[i] = _sigma_mad(wavelet_coeffs[start:stop]) start = stop - elif sigma_est == "level": - # use the diagonal coefficients subband to estimate the variance of the level. - # it assumes that the band of the same level have the same shape. + elif sigma_est == "scale": + # use the diagonal coefficients subband to estimate the variance of the scale. + # it assumes that the band of the same scale have the same shape. start = np.prod(coeffs_shape[0]) for i, scale_shape in enumerate(np.unique(coeffs_shape[1:], axis=0)): scale_sz = np.prod(scale_shape) @@ -256,10 +256,10 @@ def wavelet_threshold_estimate( Typically accessible by WaveletN.coeffs_shape thresh_range: str. default "global" Defines on which data range to estimate thresholds. - Either "band", "level", or "global" + Either "band", "scale", or "global" sigma_range: str, default "global" Defines on which data range to estimate thresholds. - Either "band", "level", or "global" + Either "band", "scale", or "global" thresh_estimation: str, default "hybrid-sure" Name of the threshold estimation method. Available are "sure", "hybrid-sure", "universal" @@ -297,15 +297,15 @@ def wavelet_threshold_estimate( ts.append(t) weights[start:stop] = t start = stop - elif thresh_range == "level": + elif thresh_range == "scale": start = np.prod(coeffs_shape[0]) start_hh = start for i, scale_shape in enumerate(np.unique(coeffs_shape[1:], axis=0)): scale_sz = np.prod(scale_shape) matched_bands = np.all(scale_shape == coeffs_shape[1:], axis=1) - band_per_level = np.sum(matched_bands) - start_hh = start + scale_sz * (band_per_level-1) - stop = start + scale_sz * band_per_level + band_per_scale = np.sum(matched_bands) + start_hh = start + scale_sz * (band_per_scale-1) + stop = start + scale_sz * band_per_scale t = sigma_bands[i+1] * _thresh_select( wavelet_coeffs[start_hh:stop] / sigma_bands[i+1], thresh_estimation @@ -337,7 +337,7 @@ class AutoWeightedSparseThreshold(SparseThreshold): threshold_estimation: str threshold estimation method. Available are "sure", "hybrid-sure" and "universal" sigma_estimation: str - noise std estimation method. Available are "global", "level" and "band" + noise std estimation method. Available are "global", "scale" and "band" thresh_type: str "hard" or "soft" thresholding. """ @@ -352,9 +352,9 @@ def __init__(self, coeffs_shape, linear=Identity(), update_period=0, self._update_period = update_period - if thresh_range not in ["bands", "level", "global"]: + if thresh_range not in ["bands", "scale", "global"]: raise ValueError("Unsupported threshold range.") - if sigma_range not in ["bands", "level", "global"]: + if sigma_range not in ["bands", "scale", "global"]: raise ValueError("Unsupported sigma estimation method.") if threshold_estimation not in ["sure", "hybrid-sure", "universal", "bayes"]: raise ValueError("Unsupported threshold estimation method.") From fb6d0ba3034d3b49943c8165cc70bbc698e22a35 Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Thu, 9 Feb 2023 11:35:02 +0100 Subject: [PATCH 14/20] improve auto-threshold example. --- ...cartesian_reconstruction_auto_threshold.py | 37 +++++++++++++------ 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/examples/cartesian_reconstruction_auto_threshold.py b/examples/cartesian_reconstruction_auto_threshold.py index b4a4d77b..fc836a19 100644 --- a/examples/cartesian_reconstruction_auto_threshold.py +++ b/examples/cartesian_reconstruction_auto_threshold.py @@ -1,18 +1,19 @@ """ -Neuroimaging cartesian reconstruction +Neuroimaging Cartesian reconstruction ===================================== Author: Pierre-Antoine Comby / Chaithya G R -In this tutorial we will reconstruct an MRI image from the sparse kspace +In this tutorial we will reconstruct an MR image from the sparse k-space measurements. -Moreover we will see the benefit of automatic tuning of the regularisation parameters. + +Moreover we will see the benefit of automating the tuning of the regularisation parameters. Import neuroimaging data ------------------------ We use the toy datasets available in pysap, more specifically a 2D brain slice -and the cartesian acquisition scheme. +and the Cartesian acquisition scheme. """ # Package import @@ -32,8 +33,8 @@ # Loading input data image = get_sample_data('2d-mri') -# Obtain K-Space Cartesian Mask -mask = get_sample_data("cartesian-mri-mask") +# Obtain k-space Cartesian Mask +mask = get_sample_data("Cartesian-mri-mask") # View Input # image.show() @@ -43,8 +44,8 @@ # Generate the kspace # ------------------- # -# From the 2D brain slice and the acquisition mask, we retrospectively -# undersample the k-space using a cartesian acquisition mask +# From the 2D brain slice and the sampling mask, we retrospectively +# undersample the k-space using a Cartesian acquisition mask # We then reconstruct the zero order solution as a baseline @@ -54,7 +55,7 @@ fourier_op = FFT(samples=kspace_loc, shape=image.shape) kspace_data = fourier_op.op(image) -# Zero order solution +# Zero filled solution image_rec0 = pysap.Image(data=fourier_op.adj_op(kspace_data), metadata=image.metadata) # image_rec0.show() @@ -67,19 +68,29 @@ # FISTA optimization # ------------------ # -# We now want to refine the zero order solution using a FISTA optimization. +# We now want to refine the zero order solution by computing the Compressed sensing one, +# using FISTA optimization. # The cost function is set to Proximity Cost + Gradient Cost # Setup the operators linear_op = WaveletN(wavelet_name="sym8", nb_scales=4) coeffs = linear_op.op(image_rec0) + +#%% +# the auto estimation of the threshold uses the methods of :cite:`donoho1994`. +# The noise standard deviation is estimated on the largest scale using the detail (HH) band. +# A single threshold is then also estimated for each scale. + regularizer_op = AutoWeightedSparseThreshold( coeffs.shape, linear=Identity(), update_period=5, - sigma_estimation="global", + sigma_range="global", + tresh_range="scale", threshold_estimation="sure", thresh_type="soft" ) +#%% The rest of the setup is similar to classical example. + # Setup Reconstructor reconstructor = SingleChannelReconstructor( fourier_op=fourier_op, @@ -88,7 +99,9 @@ gradient_formulation='synthesis', verbose=1, ) -# Start Reconstruction + +#%% +# With everythiing setup we can start Reconstruction x_final, costs, metrics = reconstructor.reconstruct( kspace_data=kspace_data, optimization_alg='fista', From 6b05ca846047db03964dd5775594b7056ba76e72 Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Thu, 9 Feb 2023 11:35:18 +0100 Subject: [PATCH 15/20] update docstring parameters. --- mri/operators/proximity/weighted.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mri/operators/proximity/weighted.py b/mri/operators/proximity/weighted.py index 7b69698e..95d2595c 100644 --- a/mri/operators/proximity/weighted.py +++ b/mri/operators/proximity/weighted.py @@ -336,8 +336,10 @@ class AutoWeightedSparseThreshold(SparseThreshold): Estimation of the weight update period. threshold_estimation: str threshold estimation method. Available are "sure", "hybrid-sure" and "universal" - sigma_estimation: str - noise std estimation method. Available are "global", "scale" and "band" + thresh_range: str + threshold range of estimation. Available are "global", "scale" and "band" + sigma_range: str + noise std range of estimation. Available are "global", "scale" and "band" thresh_type: str "hard" or "soft" thresholding. """ From 9cd6445072b9b120957ed431d91aecf0ce8d25f4 Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Tue, 14 Feb 2023 14:46:06 +0100 Subject: [PATCH 16/20] add import in init modules. --- mri/operators/proximity/__init__.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/mri/operators/proximity/__init__.py b/mri/operators/proximity/__init__.py index 129df1ce..51677bc7 100755 --- a/mri/operators/proximity/__init__.py +++ b/mri/operators/proximity/__init__.py @@ -6,3 +6,9 @@ # http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html # for details. ########################################################################## + +from .weighted import AutoWeightedSparseThreshold, WeightedSparseThreshold +from .ordered_weighted_l1_norm import OWL + + +__all__ = ['AutoWeightedSparseThreshold', 'WeightedSparseThreshold', 'OWL',] From ec28eaae8014b518626e6ff93d847a5297310c58 Mon Sep 17 00:00:00 2001 From: Chaithya G R Date: Tue, 21 Mar 2023 14:26:51 +0100 Subject: [PATCH 17/20] Cartesian-> cartesian --- examples/cartesian_reconstruction_auto_threshold.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/cartesian_reconstruction_auto_threshold.py b/examples/cartesian_reconstruction_auto_threshold.py index fc836a19..4828afa8 100644 --- a/examples/cartesian_reconstruction_auto_threshold.py +++ b/examples/cartesian_reconstruction_auto_threshold.py @@ -34,7 +34,7 @@ image = get_sample_data('2d-mri') # Obtain k-space Cartesian Mask -mask = get_sample_data("Cartesian-mri-mask") +mask = get_sample_data("cartesian-mri-mask") # View Input # image.show() From db11616be95801f4262dae124508defda0ccf68c Mon Sep 17 00:00:00 2001 From: Chaithya G R Date: Tue, 21 Mar 2023 14:38:04 +0100 Subject: [PATCH 18/20] coeffs_shape --- examples/cartesian_reconstruction_auto_threshold.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/cartesian_reconstruction_auto_threshold.py b/examples/cartesian_reconstruction_auto_threshold.py index 4828afa8..d796cdd8 100644 --- a/examples/cartesian_reconstruction_auto_threshold.py +++ b/examples/cartesian_reconstruction_auto_threshold.py @@ -82,7 +82,8 @@ # A single threshold is then also estimated for each scale. regularizer_op = AutoWeightedSparseThreshold( - coeffs.shape, linear=Identity(), + coeffs_shape=coeffs.shape, + linear=Identity(), update_period=5, sigma_range="global", tresh_range="scale", From 55d0d02e038209c6a9d715bce3d8e6e706b33c86 Mon Sep 17 00:00:00 2001 From: Chaithya G R Date: Tue, 21 Mar 2023 14:47:16 +0100 Subject: [PATCH 19/20] Fix typo --- examples/cartesian_reconstruction_auto_threshold.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/cartesian_reconstruction_auto_threshold.py b/examples/cartesian_reconstruction_auto_threshold.py index d796cdd8..febdb49f 100644 --- a/examples/cartesian_reconstruction_auto_threshold.py +++ b/examples/cartesian_reconstruction_auto_threshold.py @@ -86,7 +86,7 @@ linear=Identity(), update_period=5, sigma_range="global", - tresh_range="scale", + thresh_range="scale", threshold_estimation="sure", thresh_type="soft" ) From 53a0db7f211aad437a77371cf26e3fc27250ad8d Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Tue, 21 Mar 2023 17:55:48 +0100 Subject: [PATCH 20/20] feat: update the example for auto_threshold. --- ...cartesian_reconstruction_auto_threshold.py | 254 +++++++++++++----- 1 file changed, 182 insertions(+), 72 deletions(-) diff --git a/examples/cartesian_reconstruction_auto_threshold.py b/examples/cartesian_reconstruction_auto_threshold.py index febdb49f..0962a2c8 100644 --- a/examples/cartesian_reconstruction_auto_threshold.py +++ b/examples/cartesian_reconstruction_auto_threshold.py @@ -1,97 +1,137 @@ -""" -Neuroimaging Cartesian reconstruction -===================================== - -Author: Pierre-Antoine Comby / Chaithya G R - -In this tutorial we will reconstruct an MR image from the sparse k-space -measurements. - -Moreover we will see the benefit of automating the tuning of the regularisation parameters. - -Import neuroimaging data ------------------------- - -We use the toy datasets available in pysap, more specifically a 2D brain slice -and the Cartesian acquisition scheme. -""" - -# Package import +#!/usr/bin/env python +# coding: utf-8 + +# +# Neuroimaging cartesian reconstruction +# ===================================== +# +# Author: Chaithya G R / Pierre-Antoine Comby +# +# In this tutorial we will reconstruct an MRI image from the sparse kspace +# measurements. +# +# Import neuroimaging data +# ------------------------ +# +# We use the toy datasets available in pysap, more specifically a 2D brain slice +# and the cartesian acquisition scheme. +# + +# In[1]: + + +import matplotlib.pyplot as plt +import numpy as np +from modopt.math.metrics import snr, ssim +from modopt.opt.linear import Identity +# Third party import +from modopt.opt.proximity import SparseThreshold from mri.operators import FFT, WaveletN +from mri.operators.proximity.weighted import AutoWeightedSparseThreshold from mri.operators.utils import convert_mask_to_locations from mri.reconstructors import SingleChannelReconstructor -from mri.operators.proximity.weighted import AutoWeightedSparseThreshold -import pysap from pysap.data import get_sample_data -# Third party import -from modopt.opt.proximity import SparseThreshold -from modopt.opt.linear import Identity -from modopt.math.metrics import ssim -import numpy as np - -# Loading input data image = get_sample_data('2d-mri') - -# Obtain k-space Cartesian Mask +print(image.data.min(), image.data.max()) +image = image.data +image /= np.max(image) mask = get_sample_data("cartesian-mri-mask") -# View Input -# image.show() -# mask.show() - -#%% -# Generate the kspace -# ------------------- -# -# From the 2D brain slice and the sampling mask, we retrospectively -# undersample the k-space using a Cartesian acquisition mask -# We then reconstruct the zero order solution as a baseline - # Get the locations of the kspace samples kspace_loc = convert_mask_to_locations(mask.data) # Generate the subsampled kspace -fourier_op = FFT(samples=kspace_loc, shape=image.shape) +fourier_op = FFT(mask=mask, shape=image.shape) kspace_data = fourier_op.op(image) -# Zero filled solution -image_rec0 = pysap.Image(data=fourier_op.adj_op(kspace_data), - metadata=image.metadata) -# image_rec0.show() +# Zero order solution +image_rec0 = np.abs(fourier_op.adj_op(kspace_data)) # Calculate SSIM base_ssim = ssim(image_rec0, image) print(base_ssim) #%% -# FISTA optimization +# POGM optimization # ------------------ -# -# We now want to refine the zero order solution by computing the Compressed sensing one, -# using FISTA optimization. +# We now want to refine the zero order solution using an accelerated Proximal Gradient +# Descent algorithm (FISTA or POGM). # The cost function is set to Proximity Cost + Gradient Cost +# In[4]: + + # Setup the operators -linear_op = WaveletN(wavelet_name="sym8", nb_scales=4) -coeffs = linear_op.op(image_rec0) +linear_op = WaveletN(wavelet_name="sym8", nb_scales=3) + +# Manual tweak of the regularisation parameter +regularizer_op = SparseThreshold(Identity(), 2e-3, thresh_type="soft") +# Setup Reconstructor +reconstructor = SingleChannelReconstructor( + fourier_op=fourier_op, + linear_op=linear_op, + regularizer_op=regularizer_op, + gradient_formulation='synthesis', + verbose=1, +) +# Start Reconstruction +x_final, costs, metrics = reconstructor.reconstruct( + kspace_data=kspace_data, + optimization_alg='pogm', + num_iterations=100, + cost_op_kwargs={"cost_interval":None}, + metric_call_period=1, + metrics = { + "snr":{ + "metric": snr, + "mapping": {"x_new":"test"}, + "cst_kwargs": {"ref": image}, + "early_stopping":False, + }, + "ssim":{ + "metric": ssim, + "mapping": {"x_new":"test"}, + "cst_kwargs": {"ref": image}, + "early_stopping": False, + } + } +) + +image_rec = np.abs(x_final) +# image_rec.show() +# Calculate SSIM +recon_ssim = ssim(image_rec, image) +recon_snr= snr(image_rec, image) + +print('The Reconstruction SSIM is : ' + str(recon_ssim)) +print('The Reconstruction SNR is : ' + str(recon_snr)) #%% -# the auto estimation of the threshold uses the methods of :cite:`donoho1994`. -# The noise standard deviation is estimated on the largest scale using the detail (HH) band. -# A single threshold is then also estimated for each scale. +# Threshold estimation using SURE +# ------------------------------- + +_w = None + +def static_weight(w, idx): + print(np.unique(w)) + return w + +# Setup the operators +linear_op = WaveletN(wavelet_name="sym8", nb_scale=3,padding_mode="periodization") +coeffs = linear_op.op(image_rec0) +print(linear_op.coeffs_shape) + +# Here we don't manually setup the regularisation weights, but use statistics on the wavelet details coefficients regularizer_op = AutoWeightedSparseThreshold( - coeffs_shape=coeffs.shape, - linear=Identity(), - update_period=5, + linear_op.coeffs_shape, linear=Identity(), + update_period=0, # the weight is updated only once. sigma_range="global", - thresh_range="scale", + thresh_range="global", threshold_estimation="sure", - thresh_type="soft" + thresh_type="soft", ) -#%% The rest of the setup is similar to classical example. - # Setup Reconstructor reconstructor = SingleChannelReconstructor( fourier_op=fourier_op, @@ -100,16 +140,86 @@ gradient_formulation='synthesis', verbose=1, ) - -#%% -# With everythiing setup we can start Reconstruction -x_final, costs, metrics = reconstructor.reconstruct( +# Start Reconstruction +x_final, costs, metrics2 = reconstructor.reconstruct( kspace_data=kspace_data, - optimization_alg='fista', - num_iterations=200, + optimization_alg='pogm', + num_iterations=100, + metric_call_period=1, + cost_op_kwargs={"cost_interval":None}, + metrics = { + "snr":{ + "metric": snr, + "mapping": {"x_new":"test"}, + "cst_kwargs": {"ref": image}, + "early_stopping":False, + }, + "ssim":{ + "metric": ssim, + "mapping": {"x_new":"test"}, + "cst_kwargs": {"ref": image}, + "early_stopping": False, + }, + "cost_grad":{ + "metric": lambda x: reconstructor.gradient_op.cost(linear_op.op(x)), + "mapping": {"x_new":"x"}, + "cst_kwargs": {}, + "early_stopping": False, + }, + "cost_prox":{ + "metric": lambda x: reconstructor.prox_op.cost(linear_op.op(x)), + "mapping": {"x_new":"x"}, + "cst_kwargs": {}, + "early_stopping": False, + } + } ) -image_rec = pysap.Image(data=np.abs(x_final)) +image_rec2 = np.abs(x_final) # image_rec.show() # Calculate SSIM -recon_ssim = ssim(image_rec, image) -print('The Reconstruction SSIM is : ' + str(recon_ssim)) +recon_ssim2 = ssim(image_rec2, image) +recon_snr2 = snr(image_rec2, image) + +print('The Reconstruction SSIM is : ' + str(recon_ssim2)) +print('The Reconstruction SNR is : ' + str(recon_snr2)) + +plt.subplot(121) +plt.plot(metrics["snr"]["time"], metrics["snr"]["values"], label="pogm classic") +plt.plot(metrics2["snr"]["time"], metrics2["snr"]["values"], label="pogm sure global") +plt.ylabel("snr") +plt.xlabel("time") +plt.legend() +plt.subplot(122) +plt.plot(metrics["ssim"]["time"], metrics["ssim"]["values"]) +plt.plot(metrics2["ssim"]["time"], metrics2["ssim"]["values"]) +plt.ylabel("ssim") +plt.xlabel("time") +plt.figure() +plt.subplot(121) +plt.plot(metrics["snr"]["index"], metrics["snr"]["values"]) +plt.plot(metrics2["snr"]["index"], metrics2["snr"]["values"]) +plt.ylabel("snr") +plt.subplot(122) +plt.plot(metrics["ssim"]["index"], metrics["ssim"]["values"]) +plt.plot(metrics2["ssim"]["index"], metrics2["ssim"]["values"]) + + +#%% +# Qualitative results +# ------------------- +# +def my_imshow(ax, img, title): + ax.imshow(img, cmap="gray") + ax.set_title(title) + ax.axis("off") + + + +fig, axs = plt.subplots(2,2) + +my_imshow(axs[0,0], image, "Ground Truth") +my_imshow(axs[0,1], abs(image_rec0), f"Zero Order \n SSIM={base_ssim:.4f}") +my_imshow(axs[1,0], abs(image_rec), f"Fista Classic \n SSIM={recon_ssim:.4f}") +my_imshow(axs[1,1], abs(image_rec2), f"Fista Sure \n SSIM={recon_ssim2:.4f}") + +fig.tight_layout()