diff --git a/doc/refs.bib b/doc/refs.bib index da88c409..711c25e3 100644 --- a/doc/refs.bib +++ b/doc/refs.bib @@ -42,3 +42,16 @@ @article{Pruessmann1999 year={1999}, volume={42}, } +@inproceedings{Donoho1994, + address = {Baltimore, MD, USA}, + title = {Threshold selection for wavelet shrinkage of noisy data}, + ISBN = {978-0-7803-2050-5}, + url = {http://ieeexplore.ieee.org/document/412133/}, + DOI = {10.1109/IEMBS.1994.412133}, + booktitle = {Proceedings of 16th Annual International Conference of the + IEEE Engineering in Medicine and Biology Society}, + publisher = {IEEE}, + author = {Donoho, D.L. and Johnstone, I.M.}, + year = 1994, + pages = {A24–A25} +} diff --git a/examples/cartesian_reconstruction_auto_threshold.py b/examples/cartesian_reconstruction_auto_threshold.py new file mode 100644 index 00000000..0962a2c8 --- /dev/null +++ b/examples/cartesian_reconstruction_auto_threshold.py @@ -0,0 +1,225 @@ +#!/usr/bin/env python +# coding: utf-8 + +# +# Neuroimaging cartesian reconstruction +# ===================================== +# +# Author: Chaithya G R / Pierre-Antoine Comby +# +# In this tutorial we will reconstruct an MRI image from the sparse kspace +# measurements. +# +# Import neuroimaging data +# ------------------------ +# +# We use the toy datasets available in pysap, more specifically a 2D brain slice +# and the cartesian acquisition scheme. +# + +# In[1]: + + +import matplotlib.pyplot as plt +import numpy as np +from modopt.math.metrics import snr, ssim +from modopt.opt.linear import Identity +# Third party import +from modopt.opt.proximity import SparseThreshold +from mri.operators import FFT, WaveletN +from mri.operators.proximity.weighted import AutoWeightedSparseThreshold +from mri.operators.utils import convert_mask_to_locations +from mri.reconstructors import SingleChannelReconstructor +from pysap.data import get_sample_data + +image = get_sample_data('2d-mri') +print(image.data.min(), image.data.max()) +image = image.data +image /= np.max(image) +mask = get_sample_data("cartesian-mri-mask") + + +# Get the locations of the kspace samples +kspace_loc = convert_mask_to_locations(mask.data) +# Generate the subsampled kspace +fourier_op = FFT(mask=mask, shape=image.shape) +kspace_data = fourier_op.op(image) + +# Zero order solution +image_rec0 = np.abs(fourier_op.adj_op(kspace_data)) + +# Calculate SSIM +base_ssim = ssim(image_rec0, image) +print(base_ssim) + +#%% +# POGM optimization +# ------------------ +# We now want to refine the zero order solution using an accelerated Proximal Gradient +# Descent algorithm (FISTA or POGM). +# The cost function is set to Proximity Cost + Gradient Cost + +# In[4]: + + +# Setup the operators +linear_op = WaveletN(wavelet_name="sym8", nb_scales=3) + +# Manual tweak of the regularisation parameter +regularizer_op = SparseThreshold(Identity(), 2e-3, thresh_type="soft") +# Setup Reconstructor +reconstructor = SingleChannelReconstructor( + fourier_op=fourier_op, + linear_op=linear_op, + regularizer_op=regularizer_op, + gradient_formulation='synthesis', + verbose=1, +) +# Start Reconstruction +x_final, costs, metrics = reconstructor.reconstruct( + kspace_data=kspace_data, + optimization_alg='pogm', + num_iterations=100, + cost_op_kwargs={"cost_interval":None}, + metric_call_period=1, + metrics = { + "snr":{ + "metric": snr, + "mapping": {"x_new":"test"}, + "cst_kwargs": {"ref": image}, + "early_stopping":False, + }, + "ssim":{ + "metric": ssim, + "mapping": {"x_new":"test"}, + "cst_kwargs": {"ref": image}, + "early_stopping": False, + } + } +) + +image_rec = np.abs(x_final) +# image_rec.show() +# Calculate SSIM +recon_ssim = ssim(image_rec, image) +recon_snr= snr(image_rec, image) + +print('The Reconstruction SSIM is : ' + str(recon_ssim)) +print('The Reconstruction SNR is : ' + str(recon_snr)) + +#%% +# Threshold estimation using SURE +# ------------------------------- + +_w = None + +def static_weight(w, idx): + print(np.unique(w)) + return w + +# Setup the operators +linear_op = WaveletN(wavelet_name="sym8", nb_scale=3,padding_mode="periodization") +coeffs = linear_op.op(image_rec0) +print(linear_op.coeffs_shape) + +# Here we don't manually setup the regularisation weights, but use statistics on the wavelet details coefficients + +regularizer_op = AutoWeightedSparseThreshold( + linear_op.coeffs_shape, linear=Identity(), + update_period=0, # the weight is updated only once. + sigma_range="global", + thresh_range="global", + threshold_estimation="sure", + thresh_type="soft", +) +# Setup Reconstructor +reconstructor = SingleChannelReconstructor( + fourier_op=fourier_op, + linear_op=linear_op, + regularizer_op=regularizer_op, + gradient_formulation='synthesis', + verbose=1, +) +# Start Reconstruction +x_final, costs, metrics2 = reconstructor.reconstruct( + kspace_data=kspace_data, + optimization_alg='pogm', + num_iterations=100, + metric_call_period=1, + cost_op_kwargs={"cost_interval":None}, + metrics = { + "snr":{ + "metric": snr, + "mapping": {"x_new":"test"}, + "cst_kwargs": {"ref": image}, + "early_stopping":False, + }, + "ssim":{ + "metric": ssim, + "mapping": {"x_new":"test"}, + "cst_kwargs": {"ref": image}, + "early_stopping": False, + }, + "cost_grad":{ + "metric": lambda x: reconstructor.gradient_op.cost(linear_op.op(x)), + "mapping": {"x_new":"x"}, + "cst_kwargs": {}, + "early_stopping": False, + }, + "cost_prox":{ + "metric": lambda x: reconstructor.prox_op.cost(linear_op.op(x)), + "mapping": {"x_new":"x"}, + "cst_kwargs": {}, + "early_stopping": False, + } + } +) +image_rec2 = np.abs(x_final) +# image_rec.show() +# Calculate SSIM +recon_ssim2 = ssim(image_rec2, image) +recon_snr2 = snr(image_rec2, image) + +print('The Reconstruction SSIM is : ' + str(recon_ssim2)) +print('The Reconstruction SNR is : ' + str(recon_snr2)) + +plt.subplot(121) +plt.plot(metrics["snr"]["time"], metrics["snr"]["values"], label="pogm classic") +plt.plot(metrics2["snr"]["time"], metrics2["snr"]["values"], label="pogm sure global") +plt.ylabel("snr") +plt.xlabel("time") +plt.legend() +plt.subplot(122) +plt.plot(metrics["ssim"]["time"], metrics["ssim"]["values"]) +plt.plot(metrics2["ssim"]["time"], metrics2["ssim"]["values"]) +plt.ylabel("ssim") +plt.xlabel("time") +plt.figure() +plt.subplot(121) +plt.plot(metrics["snr"]["index"], metrics["snr"]["values"]) +plt.plot(metrics2["snr"]["index"], metrics2["snr"]["values"]) +plt.ylabel("snr") +plt.subplot(122) +plt.plot(metrics["ssim"]["index"], metrics["ssim"]["values"]) +plt.plot(metrics2["ssim"]["index"], metrics2["ssim"]["values"]) + + +#%% +# Qualitative results +# ------------------- +# +def my_imshow(ax, img, title): + ax.imshow(img, cmap="gray") + ax.set_title(title) + ax.axis("off") + + + +fig, axs = plt.subplots(2,2) + +my_imshow(axs[0,0], image, "Ground Truth") +my_imshow(axs[0,1], abs(image_rec0), f"Zero Order \n SSIM={base_ssim:.4f}") +my_imshow(axs[1,0], abs(image_rec), f"Fista Classic \n SSIM={recon_ssim:.4f}") +my_imshow(axs[1,1], abs(image_rec2), f"Fista Sure \n SSIM={recon_ssim2:.4f}") + +fig.tight_layout() diff --git a/mri/operators/proximity/__init__.py b/mri/operators/proximity/__init__.py index 129df1ce..51677bc7 100755 --- a/mri/operators/proximity/__init__.py +++ b/mri/operators/proximity/__init__.py @@ -6,3 +6,9 @@ # http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html # for details. ########################################################################## + +from .weighted import AutoWeightedSparseThreshold, WeightedSparseThreshold +from .ordered_weighted_l1_norm import OWL + + +__all__ = ['AutoWeightedSparseThreshold', 'WeightedSparseThreshold', 'OWL',] diff --git a/mri/operators/proximity/weighted.py b/mri/operators/proximity/weighted.py index a84cc5d2..95d2595c 100644 --- a/mri/operators/proximity/weighted.py +++ b/mri/operators/proximity/weighted.py @@ -78,3 +78,347 @@ def mu(self, w): if self.zero_weight_coarse: weights_init[:np.prod(self.cf_shape[0])] = 0 self.weights = weights_init + + +def _sigma_mad(data, centered=True): + """Return a robust estimation of the standard deviation. + + The standard deviation is computed using the following estimator, based on the + Median Absolute deviation of the data [#]_ + .. math:: + \hat{\sigma} = \frac{MAD}{\sqrt{2}\textrm{erf}^{-1}(1/2)} + + Parameters + ---------- + data: numpy.ndarray + the data on which the standard deviation will be estimated. + centered: bool, default True. + If true the median of the is assummed to be 0. + Returns + ------- + float: + The estimation of the standard deviation. + + References + ---------- + .. [#] https://en.m.wikipedia.org/wiki/Median_absolute_deviation + """ + if centered: + return np.median(np.abs(data[:]))/0.6745 + return np.median(np.abs(data[:] - np.median(data[:])))/0.6745 + +def _sure_est(data): + """Return an estimation of the threshold computed using the SURE method. + + The computation of the estimator is based on the formulation of `cite:donoho1994` + and the efficient implementation of [#]_ + + Parameters + ---------- + data: numpy.array + Noisy Data with unit standard deviation. + Returns + ------- + float + Value of the threshold minimizing the SURE estimator. + + References + ---------- + .. [#] https://pyyawt.readthedocs.io/_modules/pyyawt/denoising.html#ValSUREThresh + """ + dataf = data.flatten() + n = dataf.size + data_sorted = np.sort(np.abs(dataf))**2 + idx = np.arange(n-1, -1, -1) + tmp = np.cumsum(data_sorted) + idx * data_sorted + + risk = (n - (2 * np.arange(n)) + tmp) / n + ibest = np.argmin(risk) + + return np.sqrt(data_sorted[ibest]) + +def _thresh_select(data, thresh_est): + """ + Threshold selection for denoising, implementing the methods proposed in `cite:donoho1994` + + Parameters + ---------- + data: numpy.ndarray + Noisy data on which a threshold will be estimated. It should only be corrupted by a + standard gaussian white noise N(0,1). + thresh_est: str + threshold estimation method. Available are "sure", "universal", "hybrid-sure". + Returns + ------- + float: + the threshold for the data provided. + """ + n = data.size + universal_thr = np.sqrt(2*np.log(n)) + + if thresh_est == "sure": + thr = _sure_est(data) + elif thresh_est == "universal": + thr = universal_thr + elif thresh_est == "hybrid-sure": + eta = np.sum(data ** 2) /n - 1 + if eta < (np.log2(n) ** 1.5) / np.sqrt(n): + thr = universal_thr + else: + test_th = _sure_est(data) + thr = min(test_th, universal_thr) + else: + raise ValueError( + "Unsupported threshold method." + "Available are 'sure', 'universal' and 'hybrid-sure'" + ) + return thr + +def wavelet_noise_estimate(wavelet_coeffs, coeffs_shape, sigma_est): + r"""Return an estimate of the noise standard deviation in each subband. + + Parameters + ---------- + wavelet_coeffs: numpy.ndarray + flatten array of wavelet coefficients, typically returned by ``WaveletN.op`` + coeffs_shape: + list of tuple representing the shape of each subband. + Typically accessible by WaveletN.coeffs_shape + sigma_est: str + Estimation method, available are "band", "scale", and "global" + Returns + ------- + numpy.ndarray + Estimation of the variance for each wavelet subband. + + Notes + ----- + This methods makes several assumptions: + + - The wavelet coefficients are ordered by scale, and the scales are ordered by size. + - At each scale, the subbands should have the same shape. + + The variance estimation is either performed: + + - On each subband (``sigma_est = "band"``) + - On each scale, using the detailled HH subband. (``sigma_est = "scale"``) + - Only with the largest, most detailled HH band (``sigma_est = "global"``) + + See Also + -------- + _sigma_mad: function estimating the standard deviation. + """ + sigma_ret = np.ones(len(coeffs_shape)) + sigma_ret[0] = np.NaN + start = 0 + stop = 0 + if sigma_est is None: + return sigma_ret + elif sigma_est == "band": + for i in range(1, len(coeffs_shape)): + stop += np.prod(coeffs_shape[i]) + sigma_ret[i] = _sigma_mad(wavelet_coeffs[start:stop]) + start = stop + elif sigma_est == "scale": + # use the diagonal coefficients subband to estimate the variance of the scale. + # it assumes that the band of the same scale have the same shape. + start = np.prod(coeffs_shape[0]) + for i, scale_shape in enumerate(np.unique(coeffs_shape[1:], axis=0)): + scale_sz = np.prod(scale_shape) + matched_bands = np.all(scale_shape == coeffs_shape[1:], axis=1) + bpl = np.sum(matched_bands) + start = start + scale_sz * (bpl-1) + stop = start + scale_sz * bpl + sigma_ret[1+i*(bpl):1+(i+1)*bpl] = _sigma_mad(wavelet_coeffs[start:stop]) + start = stop + elif sigma_est == "global": + sigma_ret *= _sigma_mad(wavelet_coeffs[-np.prod(coeffs_shape[-1]):]) + sigma_ret[0] = np.NaN + return sigma_ret + + +def wavelet_threshold_estimate( + wavelet_coeffs, + coeffs_shape, + thresh_range="global", + sigma_range="global", + thresh_estimation="hybrid-sure" +): + """Estimate wavelet coefficient thresholds. + + Notes that no threshold will be estimate for the coarse scale. + Parameters + ---------- + wavelet_coeffs: numpy.ndarray + flatten array of wavelet coefficient, typically returned by ``WaveletN.op`` + coeffs_shape: list + List of tuple representing the shape of each subbands. + Typically accessible by WaveletN.coeffs_shape + thresh_range: str. default "global" + Defines on which data range to estimate thresholds. + Either "band", "scale", or "global" + sigma_range: str, default "global" + Defines on which data range to estimate thresholds. + Either "band", "scale", or "global" + thresh_estimation: str, default "hybrid-sure" + Name of the threshold estimation method. + Available are "sure", "hybrid-sure", "universal" + + Returns + ------- + numpy.ndarray + array of threshold for each wavelet coefficient. + """ + + weights = np.ones(wavelet_coeffs.shape) + weights[:np.prod(coeffs_shape[0])] = 0 + + # Estimate the noise std on the specific range. + + sigma_bands = wavelet_noise_estimate(wavelet_coeffs, coeffs_shape, sigma_range) + + # compute the threshold on each specific range. + + start = np.prod(coeffs_shape[0]) + stop = start + ts = [] + if thresh_range == "global": + weights =sigma_bands[-1] * _thresh_select( + wavelet_coeffs[-np.prod(coeffs_shape[-1]):] / sigma_bands[-1], + thresh_estimation + ) + elif thresh_range == "band": + for i in range(1, len(coeffs_shape)): + stop = start + np.prod(coeffs_shape[i]) + t = sigma_bands[i] * _thresh_select( + wavelet_coeffs[start:stop] / sigma_bands[i], + thresh_estimation + ) + ts.append(t) + weights[start:stop] = t + start = stop + elif thresh_range == "scale": + start = np.prod(coeffs_shape[0]) + start_hh = start + for i, scale_shape in enumerate(np.unique(coeffs_shape[1:], axis=0)): + scale_sz = np.prod(scale_shape) + matched_bands = np.all(scale_shape == coeffs_shape[1:], axis=1) + band_per_scale = np.sum(matched_bands) + start_hh = start + scale_sz * (band_per_scale-1) + stop = start + scale_sz * band_per_scale + t = sigma_bands[i+1] * _thresh_select( + wavelet_coeffs[start_hh:stop] / sigma_bands[i+1], + thresh_estimation + ) + ts.append(t) + weights[start:stop] = t + start = stop + return weights + + + +class AutoWeightedSparseThreshold(SparseThreshold): + """Automatic Weighting of sparse coefficients. + + This proximty automatically determines the threshold for Sparse (e.g. Wavelet based) + coefficients. + + The weight are computed on first call, and updated on every ``update_period`` call. + Note that the coarse/approximation scale will not be thresholded. + + Parameters + ---------- + coeffs_shape: list of tuple + list of shape for the subbands. + linear: LinearOperator + Required for cost estimation. + update_period: int + Estimation of the weight update period. + threshold_estimation: str + threshold estimation method. Available are "sure", "hybrid-sure" and "universal" + thresh_range: str + threshold range of estimation. Available are "global", "scale" and "band" + sigma_range: str + noise std range of estimation. Available are "global", "scale" and "band" + thresh_type: str + "hard" or "soft" thresholding. + """ + def __init__(self, coeffs_shape, linear=Identity(), update_period=0, + sigma_range="global", + thresh_range="global", + threshold_estimation="sure", + threshold_scaler=1.0, + **kwargs): + self._n_op_calls = 0 + self.cf_shape = coeffs_shape + self._update_period = update_period + + + if thresh_range not in ["bands", "scale", "global"]: + raise ValueError("Unsupported threshold range.") + if sigma_range not in ["bands", "scale", "global"]: + raise ValueError("Unsupported sigma estimation method.") + if threshold_estimation not in ["sure", "hybrid-sure", "universal", "bayes"]: + raise ValueError("Unsupported threshold estimation method.") + + self._sigma_range = sigma_range + self._thresh_range = thresh_range + self._thresh_estimation = threshold_estimation + self._thresh_scale = threshold_scaler + + + weights_init = np.zeros(np.sum(np.prod(coeffs_shape, axis=-1))) + super().__init__(weights=weights_init, + linear=linear, + **kwargs) + + def _auto_thresh(self, input_data): + """Compute the best weights for the input_data. + + Parameters + ---------- + input_data: numpy.ndarray + Array of sparse coefficient. + See Also + -------- + wavelet_threshold_estimate + """ + weights = wavelet_threshold_estimate( + input_data, + self.cf_shape, + thresh_range=self._thresh_range, + sigma_range=self._sigma_range, + thresh_estimation=self._thresh_estimation, + ) + if callable(self._thresh_scale): + weights = self._thresh_scale(weights, self._n_op_calls) + else: + weights *= self._thresh_scale + return weights + + def _op_method(self, input_data, extra_factor=1.0): + """Operator. + + This method returns the input data thresholded by the weights. + The weights are computed using the universal threshold rule. + + Parameters + ---------- + input_data : numpy.ndarray + Input data array + extra_factor : float + Additional multiplication factor (default is ``1.0``) + + Returns + ------- + numpy.ndarray + Thresholded data + + """ + if self._update_period == 0 and self._n_op_calls == 0: + self.weights = self._auto_thresh(input_data) + if self._update_period != 0 and self._n_op_calls % self._update_period == 0: + self.weights = self._auto_thresh(input_data) + + self._n_op_calls += 1 + return super()._op_method(input_data, extra_factor=extra_factor) diff --git a/mri/optimizers/forward_backward.py b/mri/optimizers/forward_backward.py index 9f9df45b..9b002201 100644 --- a/mri/optimizers/forward_backward.py +++ b/mri/optimizers/forward_backward.py @@ -69,7 +69,7 @@ def fista(gradient_op, linear_op, prox_op, cost_op, kspace_generator=None, estim if x_init is None: x_init = np.squeeze(np.zeros((gradient_op.linear_op.n_coils, *gradient_op.fourier_op.shape), - dtype=np.complex)) + dtype=np.complex64)) alpha_init = linear_op.op(x_init) # Welcome message diff --git a/mri/optimizers/primal_dual.py b/mri/optimizers/primal_dual.py index cd7ec2f9..d980a439 100644 --- a/mri/optimizers/primal_dual.py +++ b/mri/optimizers/primal_dual.py @@ -101,7 +101,7 @@ def condatvu(gradient_op, linear_op, dual_regularizer, cost_op, kspace_generator if x_init is None: x_init = np.squeeze(np.zeros((linear_op.n_coils, *gradient_op.fourier_op.shape), - dtype=np.complex)) + dtype=np.complex64)) primal = x_init dual = linear_op.op(primal) weights = dual diff --git a/mri/optimizers/utils/cost.py b/mri/optimizers/utils/cost.py index a69a8253..814458ac 100644 --- a/mri/optimizers/utils/cost.py +++ b/mri/optimizers/utils/cost.py @@ -152,7 +152,7 @@ def _calc_cost(self, x_new, *args, **kwargs): the cost function defined by the operators (gradient + prox_op). """ if self.optimizer_type == 'forward_backward': - if not hasattr(self.grad_op, 'linear_op') and self.linear_op is not None: + if not hasattr(self.gradient_op, 'linear_op') and self.linear_op is not None: y_new = self.linear_op.op(x_new) else: # synthesis case, y_new is already in the linear_op (sparse) domain.