diff --git a/tests/test_coreg/test_affine.py b/tests/test_coreg/test_affine.py index 185398c1..8843f516 100644 --- a/tests/test_coreg/test_affine.py +++ b/tests/test_coreg/test_affine.py @@ -17,7 +17,7 @@ from scipy.ndimage import binary_dilation from xdem import coreg, examples -from xdem.coreg.affine import AffineCoreg, _reproject_horizontal_shift_samecrs +from xdem.coreg.affine import AffineCoreg, NuthKaab, _reproject_horizontal_shift_samecrs def load_examples(crop: bool = True) -> tuple[RasterType, RasterType, Vector]: @@ -444,3 +444,24 @@ def test_coreg_rigid__example( fit_shifts_rotations = tuple(np.concatenate((fit_shifts, fit_rotations))) assert fit_shifts_rotations == pytest.approx(expected_shifts_rots, abs=10e-6) + + def test_nuthkaab_no_vertical_shift(self) -> None: + ref, tba = load_examples(crop=False)[0:2] + + # Compare Nuth and Kaab method with and without applying vertical shift + coreg_method1 = NuthKaab(vertical_shift=True) + coreg_method2 = NuthKaab(vertical_shift=False) + + coreg_method1.fit(ref, tba, random_state=42) + coreg_method2.fit(ref, tba, random_state=42) + + # Recover the shifts computed by coregistration in matrix form + matrix1 = coreg_method1.to_matrix() + matrix2 = coreg_method2.to_matrix() + + # Assert vertical shift is 0 for the 2nd coreg method + assert matrix2[2, 3] == 0 + + # Assert horizontal shifts are the same + matrix2[2, 3] = matrix1[2, 3] + assert np.array_equal(matrix1, matrix2) diff --git a/xdem/coreg/affine.py b/xdem/coreg/affine.py index fd69eecb..4b50a5ec 100644 --- a/xdem/coreg/affine.py +++ b/xdem/coreg/affine.py @@ -1287,6 +1287,7 @@ def __init__( bin_sizes: int | dict[str, int | Iterable[float]] = 72, bin_statistic: Callable[[NDArrayf], np.floating[Any]] = np.nanmedian, subsample: int | float = 5e5, + vertical_shift: bool = True, ) -> None: """ Instantiate a new Nuth and Kääb (2011) coregistration object. @@ -1299,15 +1300,22 @@ def __init__( :param bin_sizes: Size (if integer) or edges (if iterable) for binning variables later passed in .fit(). :param bin_statistic: Statistic of central tendency (e.g., mean) to apply during the binning. :param subsample: Subsample the input for speed-up. <1 is parsed as a fraction. >1 is a pixel count. + :param vertical_shift: Whether to apply the vertical shift or not (default is True). """ + self.vertical_shift = vertical_shift + # Input checks _check_inputs_bin_before_fit( bin_before_fit=bin_before_fit, fit_optimizer=fit_optimizer, bin_sizes=bin_sizes, bin_statistic=bin_statistic ) - # Define iterative parameters - meta_input_iterative = {"max_iterations": max_iterations, "tolerance": offset_threshold} + # Define iterative parameters and vertical shift + meta_input_iterative = { + "max_iterations": max_iterations, + "tolerance": offset_threshold, + "apply_vshift": vertical_shift, + } # Define parameters exactly as in BiasCorr, but with only "fit" or "bin_and_fit" as option, so a bin_before_fit # boolean, no bin apply option, and fit_func is predefined @@ -1394,7 +1402,9 @@ def _fit_rst_pts( # Write output to class # (Mypy does not pass with normal dict, requires "OutAffineDict" here for some reason...) - output_affine = OutAffineDict(shift_x=-easting_offset, shift_y=-northing_offset, shift_z=vertical_offset) + output_affine = OutAffineDict( + shift_x=-easting_offset, shift_y=-northing_offset, shift_z=vertical_offset * self.vertical_shift + ) self._meta["outputs"]["affine"] = output_affine self._meta["outputs"]["random"] = {"subsample_final": subsample_final} diff --git a/xdem/coreg/base.py b/xdem/coreg/base.py index 8e502de3..79f3c890 100644 --- a/xdem/coreg/base.py +++ b/xdem/coreg/base.py @@ -116,6 +116,7 @@ "best_poly_order": "Best polynomial order", "best_nb_sin_freq": "Best number of sinusoid frequencies", "vshift_reduc_func": "Reduction function used to remove vertical shift", + "apply_vshift": "Vertical shift activated", "centroid": "Centroid found for affine rotation", "shift_x": "Eastward shift estimated (georeferenced unit)", "shift_y": "Northward shift estimated (georeferenced unit)", @@ -1578,6 +1579,8 @@ class InAffineDict(TypedDict, total=False): # Vertical shift reduction function for methods focusing on translation coregistration vshift_reduc_func: Callable[[NDArrayf], np.floating[Any]] + # Vertical shift activated + apply_vshift: bool class OutAffineDict(TypedDict, total=False):