From 0cd22cda11c44426af5f3d54812654fd418ce4af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alice=20de=20Bardonn=C3=A8che-Richard?= Date: Tue, 4 Mar 2025 15:19:08 +0100 Subject: [PATCH 1/2] feat: moove blockwise coregistration --- doc/source/api.md | 2 +- doc/source/biascorr.md | 2 +- doc/source/coregistration.md | 8 +- examples/advanced/plot_blockwise_coreg.py | 6 +- tests/test_coreg/test_base.py | 274 +-------- tests/test_coreg/test_blockwise.py | 307 ++++++++++ xdem/coreg/__init__.py | 9 +- xdem/coreg/base.py | 624 +------------------- xdem/coreg/blockwise.py | 663 ++++++++++++++++++++++ 9 files changed, 984 insertions(+), 911 deletions(-) create mode 100644 tests/test_coreg/test_blockwise.py create mode 100644 xdem/coreg/blockwise.py diff --git a/doc/source/api.md b/doc/source/api.md index 085378ba..25c4a6fa 100644 --- a/doc/source/api.md +++ b/doc/source/api.md @@ -207,7 +207,7 @@ To build and pass your coregistration pipeline to {func}`~xdem.DEM.coregister_3d coreg.Coreg coreg.CoregPipeline - coreg.BlockwiseCoreg + coreg.blockwise.BlockwiseCoreg ``` #### Fitting and applying transforms diff --git a/doc/source/biascorr.md b/doc/source/biascorr.md index 6d6ff160..3c65753d 100644 --- a/doc/source/biascorr.md +++ b/doc/source/biascorr.md @@ -108,7 +108,7 @@ to {func}`~xdem.coreg.Coreg.fit` and {func}`~xdem.coreg.Coreg.apply`. See {ref} Each bias-correction method in xDEM inherits their interface from the {class}`~xdem.coreg.Coreg` class (see {ref}`coreg_object`). This implies that bias-correction methods can be combined in a {class}`~xdem.coreg.CoregPipeline` with any other methods, or -applied in a block-wise manner through {class}`~xdem.coreg.BlockwiseCoreg`. +applied in a block-wise manner through {class}`~xdem.coreg.blockwise.BlockwiseCoreg`. **Inheritance diagram of co-registration and bias corrections:** diff --git a/doc/source/coregistration.md b/doc/source/coregistration.md index 218fe7f6..b2346065 100644 --- a/doc/source/coregistration.md +++ b/doc/source/coregistration.md @@ -496,18 +496,18 @@ for {class}`xdem.coreg.DirectionalBias`, an input `angle` to define the angle at ## Dividing coregistration in blocks -### The {class}`~xdem.coreg.BlockwiseCoreg` object +### The {class}`~xdem.coreg.blockwise.BlockwiseCoreg` object ```{caution} -The {class}`~xdem.coreg.BlockwiseCoreg` feature is still experimental: it might not support all coregistration +The {class}`~xdem.coreg.blockwise.BlockwiseCoreg` feature is still experimental: it might not support all coregistration methods, and create edge artefacts. ``` Sometimes, we want to split a coregistration across different spatial subsets of an elevation dataset, running that -method independently in each subset. A {class}`~xdem.coreg.BlockwiseCoreg` can be constructed for this: +method independently in each subset. A {class}`~xdem.coreg.blockwise.BlockwiseCoreg` can be constructed for this: ```{code-cell} ipython3 -blockwise = xdem.coreg.BlockwiseCoreg(xdem.coreg.NuthKaab(), subdivision=16) +blockwise = xdem.coreg.blockwise.BlockwiseCoreg(xdem.coreg.NuthKaab(), subdivision=16) ``` The subdivision corresponds to an equal-length block division across the extent of the elevation dataset. It needs diff --git a/examples/advanced/plot_blockwise_coreg.py b/examples/advanced/plot_blockwise_coreg.py index eb4d4192..da036471 100644 --- a/examples/advanced/plot_blockwise_coreg.py +++ b/examples/advanced/plot_blockwise_coreg.py @@ -5,7 +5,7 @@ Often, biases are spatially variable, and a "global" shift may not be enough to coregister a DEM properly. In the :ref:`sphx_glr_basic_examples_plot_nuth_kaab.py` example, we saw that the method improved the alignment significantly, but there were still possibly nonlinear artefacts in the result. Clearly, nonlinear coregistration approaches are needed. -One solution is :class:`xdem.coreg.BlockwiseCoreg`, a helper to run any ``Coreg`` class over an arbitrarily small grid, and then "puppet warp" the DEM to fit the reference best. +One solution is :class:`xdem.coreg.blockwise.BlockwiseCoreg`, a helper to run any ``Coreg`` class over an arbitrarily small grid, and then "puppet warp" the DEM to fit the reference best. The ``BlockwiseCoreg`` class runs in five steps: @@ -56,7 +56,7 @@ # Horizontal and vertical shifts can be estimated using :class:`xdem.coreg.NuthKaab`. # Let's prepare a coregistration class that calculates 64 offsets, evenly spread over the DEM. -blockwise = xdem.coreg.BlockwiseCoreg(xdem.coreg.NuthKaab(), subdivision=64) +blockwise = xdem.coreg.blockwise.BlockwiseCoreg(xdem.coreg.NuthKaab(), subdivision=64) # %% @@ -76,7 +76,7 @@ # %% # The estimated shifts can be visualized by applying the coregistration to a completely flat surface. # This shows the estimated shifts that would be applied in elevation; additional horizontal shifts will also be applied if the method supports it. -# The :func:`xdem.coreg.BlockwiseCoreg.stats` method can be used to annotate each block with its associated Z shift. +# The :func:`xdem.coreg.blockwise.BlockwiseCoreg.stats` method can be used to annotate each block with its associated Z shift. z_correction = blockwise.apply( np.zeros_like(dem_to_be_aligned.data), transform=dem_to_be_aligned.transform, crs=dem_to_be_aligned.crs diff --git a/tests/test_coreg/test_base.py b/tests/test_coreg/test_base.py index 7ed65f07..72a64ce5 100644 --- a/tests/test_coreg/test_base.py +++ b/tests/test_coreg/test_base.py @@ -19,9 +19,8 @@ from scipy.ndimage import binary_dilation import xdem -from xdem import coreg, examples, misc, spatialstats +from xdem import coreg, examples from xdem._typing import NDArrayf -from xdem.coreg import BlockwiseCoreg from xdem.coreg.base import Coreg, apply_matrix, dict_key_to_str @@ -847,143 +846,6 @@ def test_pipeline_consistency(self) -> None: assert np.allclose(nk_vshift.to_matrix(), vshift_nk.to_matrix(), atol=10e-1) -class TestBlockwiseCoreg: - ref, tba, outlines = load_examples() # Load example reference, to-be-aligned and mask. - inlier_mask = ~outlines.create_mask(ref) - - fit_params = dict( - reference_elev=ref.data, - to_be_aligned_elev=tba.data, - inlier_mask=inlier_mask, - transform=ref.transform, - crs=ref.crs, - ) - # Create some 3D coordinates with Z coordinates being 0 to try the apply functions. - points_arr = np.array([[1, 2, 3, 4], [1, 2, 3, 4], [0, 0, 0, 0]], dtype="float64").T - points = gpd.GeoDataFrame( - geometry=gpd.points_from_xy(x=points_arr[:, 0], y=points_arr[:, 1], crs=ref.crs), data={"z": points_arr[:, 2]} - ) - - @pytest.mark.parametrize( - "pipeline", [coreg.VerticalShift(), coreg.VerticalShift() + coreg.NuthKaab()] - ) # type: ignore - @pytest.mark.parametrize("subdivision", [4, 10]) # type: ignore - def test_blockwise_coreg(self, pipeline: Coreg, subdivision: int) -> None: - - blockwise = coreg.BlockwiseCoreg(step=pipeline, subdivision=subdivision) - - # Results can not yet be extracted (since fit has not been called) and should raise an error - with pytest.raises(AssertionError, match="No coreg results exist.*"): - blockwise.to_points() - - blockwise.fit(**self.fit_params) - points = blockwise.to_points() - - # Validate that the number of points is equal to the amount of subdivisions. - assert points.shape[0] == subdivision - - # Validate that the points do not represent only the same location. - assert np.sum(np.linalg.norm(points[:, :, 0] - points[:, :, 1], axis=1)) != 0.0 - - z_diff = points[:, 2, 1] - points[:, 2, 0] - - # Validate that all values are different - assert np.unique(z_diff).size == z_diff.size, "Each coreg cell should have different results." - - # Validate that the BlockwiseCoreg doesn't accept uninstantiated Coreg classes - with pytest.raises(ValueError, match="instantiated Coreg subclass"): - coreg.BlockwiseCoreg(step=coreg.VerticalShift, subdivision=1) # type: ignore - - # Metadata copying has been an issue. Validate that all chunks have unique ids - chunk_numbers = [m["i"] for m in blockwise.meta["step_meta"]] - assert np.unique(chunk_numbers).shape[0] == len(chunk_numbers) - - transformed_dem = blockwise.apply(self.tba) - - ddem_pre = (self.ref - self.tba)[~self.inlier_mask] - ddem_post = (self.ref - transformed_dem)[~self.inlier_mask] - - # Check that the periglacial difference is lower after coregistration. - assert abs(np.ma.median(ddem_post)) < abs(np.ma.median(ddem_pre)) - - stats = blockwise.stats() - - # Check that nans don't exist (if they do, something has gone very wrong) - assert np.all(np.isfinite(stats["nmad"])) - # Check that offsets were actually calculated. - assert np.sum(np.abs(np.linalg.norm(stats[["x_off", "y_off", "z_off"]], axis=0))) > 0 - - def test_blockwise_coreg_large_gaps(self) -> None: - """Test BlockwiseCoreg when large gaps are encountered, e.g. around the frame of a rotated DEM.""" - reference_dem = self.ref.reproject(crs="EPSG:3413", res=self.ref.res, resampling="bilinear") - dem_to_be_aligned = self.tba.reproject(ref=reference_dem, resampling="bilinear") - - blockwise = xdem.coreg.BlockwiseCoreg(xdem.coreg.NuthKaab(), 64, warn_failures=False) - - # This should not fail or trigger warnings as warn_failures is False - blockwise.fit(reference_dem, dem_to_be_aligned) - - stats = blockwise.stats() - - # We expect holes in the blockwise coregistration, but not in stats due to nan padding for failing chunks - assert stats.shape[0] == 64 - - # Copy the TBA DEM and set a square portion to nodata - tba = self.tba.copy() - mask = np.zeros(np.shape(tba.data), dtype=bool) - mask[450:500, 450:500] = True - tba.set_mask(mask=mask) - - blockwise = xdem.coreg.BlockwiseCoreg(xdem.coreg.NuthKaab(), 8, warn_failures=False) - - # Align the DEM and apply blockwise to a zero-array (to get the z_shift) - aligned = blockwise.fit(self.ref, tba).apply(tba) - zshift, _ = blockwise.apply(np.zeros_like(tba.data), transform=tba.transform, crs=tba.crs) - - # Validate that the zshift is not something crazy high and that no negative values exist in the data. - assert np.nanmax(np.abs(zshift)) < 50 - assert np.count_nonzero(aligned.data.compressed() < -50) == 0 - - # Check that coregistration improved the alignment - ddem_post = (aligned - self.ref).data.compressed() - ddem_pre = (tba - self.ref).data.compressed() - assert abs(np.nanmedian(ddem_pre)) > abs(np.nanmedian(ddem_post)) - # assert np.nanstd(ddem_pre) > np.nanstd(ddem_post) - - def test_failed_chunks_return_nan(self) -> None: - blockwise = BlockwiseCoreg(xdem.coreg.NuthKaab(), subdivision=4) - blockwise.fit(**self.fit_params) - # Missing chunk 1 to simulate failure - blockwise._meta["step_meta"] = [meta for meta in blockwise._meta["step_meta"] if meta.get("i") != 1] - - result_df = blockwise.stats() - - # Check that chunk 1 (index 1) has NaN values for the statistics - assert np.isnan(result_df.loc[1, "inlier_count"]) - assert np.isnan(result_df.loc[1, "nmad"]) - assert np.isnan(result_df.loc[1, "median"]) - assert isinstance(result_df.loc[1, "center_x"], float) - assert isinstance(result_df.loc[1, "center_y"], float) - assert np.isnan(result_df.loc[1, "center_z"]) - assert np.isnan(result_df.loc[1, "x_off"]) - assert np.isnan(result_df.loc[1, "y_off"]) - assert np.isnan(result_df.loc[1, "z_off"]) - - def test_successful_chunks_return_values(self) -> None: - blockwise = BlockwiseCoreg(xdem.coreg.NuthKaab(), subdivision=2) - blockwise.fit(**self.fit_params) - result_df = blockwise.stats() - - # Check that the correct statistics are returned for successful chunks - assert result_df.loc[0, "inlier_count"] == blockwise._meta["step_meta"][0]["inlier_count"] - assert result_df.loc[0, "nmad"] == blockwise._meta["step_meta"][0]["nmad"] - assert result_df.loc[0, "median"] == blockwise._meta["step_meta"][0]["median"] - - assert result_df.loc[1, "inlier_count"] == blockwise._meta["step_meta"][1]["inlier_count"] - assert result_df.loc[1, "nmad"] == blockwise._meta["step_meta"][1]["nmad"] - assert result_df.loc[1, "median"] == blockwise._meta["step_meta"][1]["median"] - - class TestAffineManipulation: ref, tba, outlines = load_examples() # Load example reference, to-be-aligned and mask. @@ -1147,137 +1009,3 @@ def test_apply_matrix__raster_realdata(self) -> None: diff_it_gd = z_points_gd[valids] - z_points_it[valids] assert np.percentile(np.abs(diff_it_gd), 99) < 1 # 99% of values are within a meter (instead of 90%) assert np.percentile(np.abs(diff_it_gd), 50) < 0.02 # 10 times more precise than above - - -def test_warp_dem() -> None: - """Test that the warp_dem function works expectedly.""" - - small_dem = np.zeros((5, 10), dtype="float32") - small_transform = rio.transform.from_origin(0, 5, 1, 1) - - source_coords = np.array([[0, 0, 0], [0, 5, 0], [10, 0, 0], [10, 5, 0]]).astype(small_dem.dtype) - - dest_coords = source_coords.copy() - dest_coords[0, 0] = -1e-5 - - warped_dem = coreg.base.warp_dem( - dem=small_dem, - transform=small_transform, - source_coords=source_coords, - destination_coords=dest_coords, - resampling="linear", - trim_border=False, - ) - assert np.nansum(np.abs(warped_dem - small_dem)) < 1e-6 - - elev_shift = 5.0 - dest_coords[1, 2] = elev_shift - warped_dem = coreg.base.warp_dem( - dem=small_dem, - transform=small_transform, - source_coords=source_coords, - destination_coords=dest_coords, - resampling="linear", - ) - - # The warped DEM should have the value 'elev_shift' in the upper left corner. - assert warped_dem[0, 0] == -elev_shift - # The corner should be zero, so the corner pixel (represents the corner minus resolution / 2) should be close. - # We select the pixel before the corner (-2 in X-axis) to avoid the NaN propagation on the bottom row. - assert warped_dem[-2, -1] < 1 - - # Synthesise some X/Y/Z coordinates on the DEM. - source_coords = np.array( - [ - [0, 0, 200], - [480, 20, 200], - [460, 480, 200], - [10, 460, 200], - [250, 250, 200], - ] - ) - - # Copy the source coordinates and apply some shifts - dest_coords = source_coords.copy() - # Apply in the X direction - dest_coords[0, 0] += 20 - dest_coords[1, 0] += 7 - dest_coords[2, 0] += 10 - dest_coords[3, 0] += 5 - - # Apply in the Y direction - dest_coords[4, 1] += 5 - - # Apply in the Z direction - dest_coords[3, 2] += 5 - test_shift = 6 # This shift will be validated below - dest_coords[4, 2] += test_shift - - # Generate a semi-random DEM - transform = rio.transform.from_origin(0, 500, 1, 1) - shape = (500, 550) - dem = misc.generate_random_field(shape, 100) * 200 + misc.generate_random_field(shape, 10) * 50 - - # Warp the DEM using the source-destination coordinates. - transformed_dem = coreg.base.warp_dem( - dem=dem, transform=transform, source_coords=source_coords, destination_coords=dest_coords, resampling="linear" - ) - - # Try to undo the warp by reversing the source-destination coordinates. - untransformed_dem = coreg.base.warp_dem( - dem=transformed_dem, - transform=transform, - source_coords=dest_coords, - destination_coords=source_coords, - resampling="linear", - ) - # Validate that the DEM is now more or less the same as the original. - # Due to the randomness, the threshold is quite high, but would be something like 10+ if it was incorrect. - assert spatialstats.nmad(dem - untransformed_dem) < 0.5 - - # Test with Z-correction disabled - transformed_dem_no_z = coreg.base.warp_dem( - dem=dem, - transform=transform, - source_coords=source_coords, - destination_coords=dest_coords, - resampling="linear", - apply_z_correction=False, - ) - - # Try to undo the warp by reversing the source-destination coordinates with Z-correction disabled - untransformed_dem_no_z = coreg.base.warp_dem( - dem=transformed_dem_no_z, - transform=transform, - source_coords=dest_coords, - destination_coords=source_coords, - resampling="linear", - apply_z_correction=False, - ) - - # Validate that the DEM is now more or less the same as the original, with Z-correction disabled. - # The result should be similar to the original, but with no Z-shift applied. - assert spatialstats.nmad(dem - untransformed_dem_no_z) < 0.5 - - # The difference between the two DEMs should be the vertical shift. - # We expect the difference to be approximately equal to the average vertical shift. - expected_vshift = np.mean(dest_coords[:, 2] - source_coords[:, 2]) - - # Check that the mean difference between the DEMs matches the expected vertical shift. - assert np.nanmean(transformed_dem_no_z - transformed_dem) == pytest.approx(expected_vshift, rel=0.3) - - if False: - import matplotlib.pyplot as plt - - plt.figure(dpi=200) - plt.subplot(141) - - plt.imshow(dem, vmin=0, vmax=300) - plt.subplot(142) - plt.imshow(transformed_dem, vmin=0, vmax=300) - plt.subplot(143) - plt.imshow(untransformed_dem, vmin=0, vmax=300) - - plt.subplot(144) - plt.imshow(dem - untransformed_dem, cmap="coolwarm_r", vmin=-10, vmax=10) - plt.show() diff --git a/tests/test_coreg/test_blockwise.py b/tests/test_coreg/test_blockwise.py new file mode 100644 index 00000000..a8ad94f0 --- /dev/null +++ b/tests/test_coreg/test_blockwise.py @@ -0,0 +1,307 @@ +"""Functions to test the coregistration blockwise classes.""" + +from __future__ import annotations + +import geopandas as gpd +import numpy as np +import pytest +import rasterio as rio +from geoutils import Raster, Vector +from geoutils.raster import RasterType + +import xdem +from xdem import coreg, examples, misc, spatialstats +from xdem.coreg.base import Coreg +from xdem.coreg.blockwise import BlockwiseCoreg + + +def load_examples() -> tuple[RasterType, RasterType, Vector]: + """Load example files to try coregistration methods with.""" + + reference_dem = Raster(examples.get_path("longyearbyen_ref_dem")) + to_be_aligned_dem = Raster(examples.get_path("longyearbyen_tba_dem")) + glacier_mask = Vector(examples.get_path("longyearbyen_glacier_outlines")) + + # Crop to smaller extents for test speed + res = reference_dem.res + crop_geom = ( + reference_dem.bounds.left, + reference_dem.bounds.bottom, + reference_dem.bounds.left + res[0] * 300, + reference_dem.bounds.bottom + res[1] * 300, + ) + reference_dem = reference_dem.crop(crop_geom) + to_be_aligned_dem = to_be_aligned_dem.crop(crop_geom) + + return reference_dem, to_be_aligned_dem, glacier_mask + + +class TestBlockwiseCoreg: + ref, tba, outlines = load_examples() # Load example reference, to-be-aligned and mask. + inlier_mask = ~outlines.create_mask(ref) + + fit_params = dict( + reference_elev=ref.data, + to_be_aligned_elev=tba.data, + inlier_mask=inlier_mask, + transform=ref.transform, + crs=ref.crs, + ) + # Create some 3D coordinates with Z coordinates being 0 to try the apply functions. + points_arr = np.array([[1, 2, 3, 4], [1, 2, 3, 4], [0, 0, 0, 0]], dtype="float64").T + points = gpd.GeoDataFrame( + geometry=gpd.points_from_xy(x=points_arr[:, 0], y=points_arr[:, 1], crs=ref.crs), data={"z": points_arr[:, 2]} + ) + + @pytest.mark.parametrize( + "pipeline", [coreg.VerticalShift(), coreg.VerticalShift() + coreg.NuthKaab()] + ) # type: ignore + @pytest.mark.parametrize("subdivision", [4, 10]) # type: ignore + def test_blockwise_coreg(self, pipeline: Coreg, subdivision: int) -> None: + + blockwise = coreg.BlockwiseCoreg(step=pipeline, subdivision=subdivision) + + # Results can not yet be extracted (since fit has not been called) and should raise an error + with pytest.raises(AssertionError, match="No coreg results exist.*"): + blockwise.to_points() + + blockwise.fit(**self.fit_params) + points = blockwise.to_points() + + # Validate that the number of points is equal to the amount of subdivisions. + assert points.shape[0] == subdivision + + # Validate that the points do not represent only the same location. + assert np.sum(np.linalg.norm(points[:, :, 0] - points[:, :, 1], axis=1)) != 0.0 + + z_diff = points[:, 2, 1] - points[:, 2, 0] + + # Validate that all values are different + assert np.unique(z_diff).size == z_diff.size, "Each coreg cell should have different results." + + # Validate that the BlockwiseCoreg doesn't accept uninstantiated Coreg classes + with pytest.raises(ValueError, match="instantiated Coreg subclass"): + coreg.BlockwiseCoreg(step=coreg.VerticalShift, subdivision=1) # type: ignore + + # Metadata copying has been an issue. Validate that all chunks have unique ids + chunk_numbers = [m["i"] for m in blockwise.meta["step_meta"]] + assert np.unique(chunk_numbers).shape[0] == len(chunk_numbers) + + transformed_dem = blockwise.apply(self.tba) + + ddem_pre = (self.ref - self.tba)[~self.inlier_mask] + ddem_post = (self.ref - transformed_dem)[~self.inlier_mask] + + # Check that the periglacial difference is lower after coregistration. + assert abs(np.ma.median(ddem_post)) < abs(np.ma.median(ddem_pre)) + + stats = blockwise.stats() + + # Check that nans don't exist (if they do, something has gone very wrong) + assert np.all(np.isfinite(stats["nmad"])) + # Check that offsets were actually calculated. + assert np.sum(np.abs(np.linalg.norm(stats[["x_off", "y_off", "z_off"]], axis=0))) > 0 + + def test_blockwise_coreg_large_gaps(self) -> None: + """Test BlockwiseCoreg when large gaps are encountered, e.g. around the frame of a rotated DEM.""" + reference_dem = self.ref.reproject(crs="EPSG:3413", res=self.ref.res, resampling="bilinear") + dem_to_be_aligned = self.tba.reproject(ref=reference_dem, resampling="bilinear") + + blockwise = xdem.coreg.BlockwiseCoreg(xdem.coreg.NuthKaab(), 64, warn_failures=False) + + # This should not fail or trigger warnings as warn_failures is False + blockwise.fit(reference_dem, dem_to_be_aligned) + + stats = blockwise.stats() + + # We expect holes in the blockwise coregistration, but not in stats due to nan padding for failing chunks + assert stats.shape[0] == 64 + + # Copy the TBA DEM and set a square portion to nodata + tba = self.tba.copy() + mask = np.zeros(np.shape(tba.data), dtype=bool) + mask[450:500, 450:500] = True + tba.set_mask(mask=mask) + + blockwise = xdem.coreg.BlockwiseCoreg(xdem.coreg.NuthKaab(), 8, warn_failures=False) + + # Align the DEM and apply blockwise to a zero-array (to get the z_shift) + aligned = blockwise.fit(self.ref, tba).apply(tba) + zshift, _ = blockwise.apply(np.zeros_like(tba.data), transform=tba.transform, crs=tba.crs) + + # Validate that the zshift is not something crazy high and that no negative values exist in the data. + assert np.nanmax(np.abs(zshift)) < 50 + assert np.count_nonzero(aligned.data.compressed() < -50) == 0 + + # Check that coregistration improved the alignment + ddem_post = (aligned - self.ref).data.compressed() + ddem_pre = (tba - self.ref).data.compressed() + assert abs(np.nanmedian(ddem_pre)) > abs(np.nanmedian(ddem_post)) + # assert np.nanstd(ddem_pre) > np.nanstd(ddem_post) + + def test_failed_chunks_return_nan(self) -> None: + blockwise = BlockwiseCoreg(xdem.coreg.NuthKaab(), subdivision=4) + blockwise.fit(**self.fit_params) + # Missing chunk 1 to simulate failure + blockwise._meta["step_meta"] = [meta for meta in blockwise._meta["step_meta"] if meta.get("i") != 1] + + result_df = blockwise.stats() + + # Check that chunk 1 (index 1) has NaN values for the statistics + assert np.isnan(result_df.loc[1, "inlier_count"]) + assert np.isnan(result_df.loc[1, "nmad"]) + assert np.isnan(result_df.loc[1, "median"]) + assert isinstance(result_df.loc[1, "center_x"], float) + assert isinstance(result_df.loc[1, "center_y"], float) + assert np.isnan(result_df.loc[1, "center_z"]) + assert np.isnan(result_df.loc[1, "x_off"]) + assert np.isnan(result_df.loc[1, "y_off"]) + assert np.isnan(result_df.loc[1, "z_off"]) + + def test_successful_chunks_return_values(self) -> None: + blockwise = BlockwiseCoreg(xdem.coreg.NuthKaab(), subdivision=2) + blockwise.fit(**self.fit_params) + result_df = blockwise.stats() + + # Check that the correct statistics are returned for successful chunks + assert result_df.loc[0, "inlier_count"] == blockwise._meta["step_meta"][0]["inlier_count"] + assert result_df.loc[0, "nmad"] == blockwise._meta["step_meta"][0]["nmad"] + assert result_df.loc[0, "median"] == blockwise._meta["step_meta"][0]["median"] + + assert result_df.loc[1, "inlier_count"] == blockwise._meta["step_meta"][1]["inlier_count"] + assert result_df.loc[1, "nmad"] == blockwise._meta["step_meta"][1]["nmad"] + assert result_df.loc[1, "median"] == blockwise._meta["step_meta"][1]["median"] + + +def test_warp_dem() -> None: + """Test that the warp_dem function works expectedly.""" + + small_dem = np.zeros((5, 10), dtype="float32") + small_transform = rio.transform.from_origin(0, 5, 1, 1) + + source_coords = np.array([[0, 0, 0], [0, 5, 0], [10, 0, 0], [10, 5, 0]]).astype(small_dem.dtype) + + dest_coords = source_coords.copy() + dest_coords[0, 0] = -1e-5 + + warped_dem = coreg.base.warp_dem( + dem=small_dem, + transform=small_transform, + source_coords=source_coords, + destination_coords=dest_coords, + resampling="linear", + trim_border=False, + ) + assert np.nansum(np.abs(warped_dem - small_dem)) < 1e-6 + + elev_shift = 5.0 + dest_coords[1, 2] = elev_shift + warped_dem = coreg.base.warp_dem( + dem=small_dem, + transform=small_transform, + source_coords=source_coords, + destination_coords=dest_coords, + resampling="linear", + ) + + # The warped DEM should have the value 'elev_shift' in the upper left corner. + assert warped_dem[0, 0] == -elev_shift + # The corner should be zero, so the corner pixel (represents the corner minus resolution / 2) should be close. + # We select the pixel before the corner (-2 in X-axis) to avoid the NaN propagation on the bottom row. + assert warped_dem[-2, -1] < 1 + + # Synthesise some X/Y/Z coordinates on the DEM. + source_coords = np.array( + [ + [0, 0, 200], + [480, 20, 200], + [460, 480, 200], + [10, 460, 200], + [250, 250, 200], + ] + ) + + # Copy the source coordinates and apply some shifts + dest_coords = source_coords.copy() + # Apply in the X direction + dest_coords[0, 0] += 20 + dest_coords[1, 0] += 7 + dest_coords[2, 0] += 10 + dest_coords[3, 0] += 5 + + # Apply in the Y direction + dest_coords[4, 1] += 5 + + # Apply in the Z direction + dest_coords[3, 2] += 5 + test_shift = 6 # This shift will be validated below + dest_coords[4, 2] += test_shift + + # Generate a semi-random DEM + transform = rio.transform.from_origin(0, 500, 1, 1) + shape = (500, 550) + dem = misc.generate_random_field(shape, 100) * 200 + misc.generate_random_field(shape, 10) * 50 + + # Warp the DEM using the source-destination coordinates. + transformed_dem = coreg.base.warp_dem( + dem=dem, transform=transform, source_coords=source_coords, destination_coords=dest_coords, resampling="linear" + ) + + # Try to undo the warp by reversing the source-destination coordinates. + untransformed_dem = coreg.base.warp_dem( + dem=transformed_dem, + transform=transform, + source_coords=dest_coords, + destination_coords=source_coords, + resampling="linear", + ) + # Validate that the DEM is now more or less the same as the original. + # Due to the randomness, the threshold is quite high, but would be something like 10+ if it was incorrect. + assert spatialstats.nmad(dem - untransformed_dem) < 0.5 + + # Test with Z-correction disabled + transformed_dem_no_z = coreg.base.warp_dem( + dem=dem, + transform=transform, + source_coords=source_coords, + destination_coords=dest_coords, + resampling="linear", + apply_z_correction=False, + ) + + # Try to undo the warp by reversing the source-destination coordinates with Z-correction disabled + untransformed_dem_no_z = coreg.base.warp_dem( + dem=transformed_dem_no_z, + transform=transform, + source_coords=dest_coords, + destination_coords=source_coords, + resampling="linear", + apply_z_correction=False, + ) + + # Validate that the DEM is now more or less the same as the original, with Z-correction disabled. + # The result should be similar to the original, but with no Z-shift applied. + assert spatialstats.nmad(dem - untransformed_dem_no_z) < 0.5 + + # The difference between the two DEMs should be the vertical shift. + # We expect the difference to be approximately equal to the average vertical shift. + expected_vshift = np.mean(dest_coords[:, 2] - source_coords[:, 2]) + + # Check that the mean difference between the DEMs matches the expected vertical shift. + assert np.nanmean(transformed_dem_no_z - transformed_dem) == pytest.approx(expected_vshift, rel=0.3) + + if False: + import matplotlib.pyplot as plt + + plt.figure(dpi=200) + plt.subplot(141) + + plt.imshow(dem, vmin=0, vmax=300) + plt.subplot(142) + plt.imshow(transformed_dem, vmin=0, vmax=300) + plt.subplot(143) + plt.imshow(untransformed_dem, vmin=0, vmax=300) + + plt.subplot(144) + plt.imshow(dem - untransformed_dem, cmap="coolwarm_r", vmin=-10, vmax=10) + plt.show() diff --git a/xdem/coreg/__init__.py b/xdem/coreg/__init__.py index a20d8699..9fa352e3 100644 --- a/xdem/coreg/__init__.py +++ b/xdem/coreg/__init__.py @@ -27,12 +27,7 @@ NuthKaab, VerticalShift, ) -from xdem.coreg.base import ( # noqa - BlockwiseCoreg, - Coreg, - CoregPipeline, - apply_matrix, - invert_matrix, -) +from xdem.coreg.base import Coreg, CoregPipeline, apply_matrix, invert_matrix # noqa from xdem.coreg.biascorr import BiasCorr, Deramp, DirectionalBias, TerrainBias # noqa +from xdem.coreg.blockwise import BlockwiseCoreg # noqa from xdem.coreg.workflows import dem_coregistration # noqa diff --git a/xdem/coreg/base.py b/xdem/coreg/base.py index 79f3c890..41bc9040 100644 --- a/xdem/coreg/base.py +++ b/xdem/coreg/base.py @@ -20,7 +20,6 @@ from __future__ import annotations -import concurrent.futures import copy import inspect import logging @@ -48,20 +47,13 @@ import scipy.interpolate import scipy.ndimage import scipy.optimize -import skimage.transform from geoutils._typing import Number from geoutils.interface.gridding import _grid_pointcloud from geoutils.interface.interpolate import _interp_points -from geoutils.raster import Mask, RasterType, raster, subdivide_array +from geoutils.raster import Mask, RasterType, raster from geoutils.raster.array import get_array_and_mask -from geoutils.raster.georeferencing import ( - _bounds, - _cast_pixel_interpretation, - _coords, - _res, -) +from geoutils.raster.georeferencing import _cast_pixel_interpretation, _coords from geoutils.raster.geotransformations import _resampling_method_from_str, _translate -from tqdm import tqdm from xdem._typing import MArrayf, NDArrayb, NDArrayf from xdem.fit import ( @@ -2995,615 +2987,3 @@ def _to_matrix_func(self) -> NDArrayf: transform_mgr.add_transform(i, i + 1, new_matrix) return transform_mgr.get_transform(0, len(self.pipeline)) - - -class BlockwiseCoreg(Coreg): - """ - Block-wise co-registration processing class to run a step in segmented parts of the grid. - - A processing class of choice is run on an arbitrary subdivision of the raster. When later applying the step - the optimal warping is interpolated based on X/Y/Z shifts from the coreg algorithm at the grid points. - - For instance: a subdivision of 4 triggers a division of the DEM in four equally sized parts. These parts are then - processed separately, with 4 .fit() results. If the subdivision is not divisible by the raster shape, - subdivision is made as good as possible to have approximately equal pixel counts. - """ - - def __init__( - self, - step: Coreg | CoregPipeline, - subdivision: int, - success_threshold: float = 0.8, - n_threads: int | None = None, - warn_failures: bool = False, - apply_z_correction: bool = True, - ) -> None: - """ - Instantiate a blockwise processing object. - - :param step: An instantiated co-registration step object to fit in the subdivided DEMs. - :param subdivision: The number of chunks to divide the DEMs in. E.g. 4 means four different transforms. - :param success_threshold: Raise an error if fewer chunks than the fraction failed for any reason. - :param n_threads: The maximum amount of threads to use. Default=auto - :param warn_failures: Trigger or ignore warnings for each exception/warning in each block. - :param apply_z_correction: Boolean to toggle whether the Z-offset correction is applied or not (default True). - """ - if isinstance(step, type): - raise ValueError( - "The 'step' argument must be an instantiated Coreg subclass. " "Hint: write e.g. ICP() instead of ICP" - ) - self.procstep = step - self.subdivision = subdivision - self.success_threshold = success_threshold - self.n_threads = n_threads - self.warn_failures = warn_failures - self.apply_z_correction = apply_z_correction - - super().__init__() - - self._meta: CoregDict = {"step_meta": []} - self._groups: NDArrayf = np.array([]) - - def fit( - self: CoregType, - reference_elev: NDArrayf | MArrayf | RasterType, - to_be_aligned_elev: NDArrayf | MArrayf | RasterType, - inlier_mask: NDArrayb | Mask | None = None, - bias_vars: dict[str, NDArrayf | MArrayf | RasterType] | None = None, - weights: NDArrayf | None = None, - subsample: float | int | None = None, - transform: rio.transform.Affine | None = None, - crs: rio.crs.CRS | None = None, - area_or_point: Literal["Area", "Point"] | None = None, - z_name: str = "z", - random_state: int | np.random.Generator | None = None, - **kwargs: Any, - ) -> CoregType: - - if isinstance(reference_elev, gpd.GeoDataFrame) and isinstance(to_be_aligned_elev, gpd.GeoDataFrame): - raise NotImplementedError("Blockwise coregistration does not yet support two elevation point cloud inputs.") - - # Check if subsample arguments are different from their default value for any of the coreg steps: - # get default value in argument spec and "subsample" stored in meta, and compare both are consistent - if not isinstance(self.procstep, CoregPipeline): - steps = [self.procstep] - else: - steps = list(self.procstep.pipeline) - argspec = [inspect.getfullargspec(s.__class__) for s in steps] - sub_meta = [s._meta["inputs"]["random"]["subsample"] for s in steps] - sub_is_default = [ - argspec[i].defaults[argspec[i].args.index("subsample") - 1] == sub_meta[i] # type: ignore - for i in range(len(argspec)) - ] - if subsample is not None and not all(sub_is_default): - warnings.warn( - "Subsample argument passed to fit() will override non-default subsample values defined in the" - " step within the blockwise method. To silence this warning: only define 'subsample' in " - "either fit(subsample=...) or instantiation e.g., VerticalShift(subsample=...)." - ) - - # Pre-process the inputs, by reprojecting and subsampling, without any subsampling (done in each step) - ref_dem, tba_dem, inlier_mask, transform, crs, area_or_point = _preprocess_coreg_fit( - reference_elev=reference_elev, - to_be_aligned_elev=to_be_aligned_elev, - inlier_mask=inlier_mask, - transform=transform, - crs=crs, - area_or_point=area_or_point, - ) - - # Define inlier mask if None, before indexing subdivided array in process function below - if inlier_mask is None: - mask = np.ones(tba_dem.shape, dtype=bool) - else: - mask = inlier_mask - - self._groups = self.subdivide_array(tba_dem.shape if isinstance(tba_dem, np.ndarray) else ref_dem.shape) - - indices = np.unique(self._groups) - - progress_bar = tqdm( - total=indices.size, desc="Processing chunks", disable=logging.getLogger().getEffectiveLevel() > logging.INFO - ) - - def process(i: int) -> dict[str, Any] | BaseException | None: - """ - Process a chunk in a thread-safe way. - - :returns: - * If it succeeds: A dictionary of the fitting metadata. - * If it fails: The associated exception. - * If the block is empty: None - """ - group_mask = self._groups == i - - # Find the corresponding slice of the inlier_mask to subset the data - rows, cols = np.where(group_mask) - arrayslice = np.s_[rows.min() : rows.max() + 1, cols.min() : cols.max() + 1] - - # Copy a subset of the two DEMs, the mask, the coreg instance, and make a new subset transform - ref_subset = ref_dem[arrayslice].copy() - tba_subset = tba_dem[arrayslice].copy() - - if any(np.all(~np.isfinite(dem)) for dem in (ref_subset, tba_subset)): - return None - mask_subset = mask[arrayslice].copy() - west, top = rio.transform.xy(transform, min(rows), min(cols), offset="ul") - transform_subset = rio.transform.from_origin(west, top, transform.a, -transform.e) # type: ignore - procstep = self.procstep.copy() - - # Try to run the coregistration. If it fails for any reason, skip it and save the exception. - try: - procstep.fit( - reference_elev=ref_subset, - to_be_aligned_elev=tba_subset, - transform=transform_subset, - inlier_mask=mask_subset, - bias_vars=bias_vars, - weights=weights, - crs=crs, - area_or_point=area_or_point, - z_name=z_name, - subsample=subsample, - random_state=random_state, - ) - nmad, median = procstep.error( - reference_elev=ref_subset, - to_be_aligned_elev=tba_subset, - error_type=["nmad", "median"], - inlier_mask=mask_subset, - transform=transform_subset, - crs=crs, - ) - except Exception as exception: - return exception - - meta: dict[str, Any] = { - "i": i, - "transform": transform_subset, - "inlier_count": np.count_nonzero(mask_subset & np.isfinite(ref_subset) & np.isfinite(tba_subset)), - "nmad": nmad, - "median": median, - } - # Find the center of the inliers. - inlier_positions = np.argwhere(mask_subset) - mid_row = np.mean(inlier_positions[:, 0]).astype(int) - mid_col = np.mean(inlier_positions[:, 1]).astype(int) - - # Find the indices of all finites within the mask - finites = np.argwhere(np.isfinite(tba_subset) & mask_subset) - # Calculate the distance between the approximate center and all finite indices - distances = np.linalg.norm(finites - np.array([mid_row, mid_col]), axis=1) - # Find the index representing the closest finite value to the center. - closest = np.argwhere(distances == distances.min()) - - # Assign the closest finite value as the representative point - representative_row, representative_col = finites[closest][0][0] - meta["representative_x"], meta["representative_y"] = rio.transform.xy( - transform_subset, representative_row, representative_col - ) - - repr_val = ref_subset[representative_row, representative_col] - if ~np.isfinite(repr_val): - repr_val = 0 - meta["representative_val"] = repr_val - - # If the coreg is a pipeline, copy its metadatas to the output meta - if hasattr(procstep, "pipeline"): - meta["pipeline"] = [step.meta.copy() for step in procstep.pipeline] - - # Copy all current metadata (except for the already existing keys like "i", "min_row", etc, and the - # "coreg_meta" key) - # This can then be iteratively restored when the apply function should be called. - meta.update( - {key: value for key, value in procstep.meta.items() if key not in ["step_meta"] + list(meta.keys())} - ) - - progress_bar.update() - - return meta.copy() - - # Catch warnings; only show them if - exceptions: list[BaseException | warnings.WarningMessage] = [] - with warnings.catch_warnings(record=True) as caught_warnings: - warnings.simplefilter("default") - with concurrent.futures.ThreadPoolExecutor(max_workers=None) as executor: - results = executor.map(process, indices) - - exceptions += list(caught_warnings) - - empty_blocks = 0 - for result in results: - if isinstance(result, BaseException): - exceptions.append(result) - elif result is None: - empty_blocks += 1 - continue - else: - self._meta["step_meta"].append(result) - - progress_bar.close() - - # Stop if the success rate was below the threshold - if ((len(self._meta["step_meta"]) + empty_blocks) / self.subdivision) <= self.success_threshold: - raise ValueError( - f"Fitting failed for {len(exceptions)} chunks:\n" - + "\n".join(map(str, exceptions[:5])) - + f"\n... and {len(exceptions) - 5} more" - if len(exceptions) > 5 - else "" - ) - - if self.warn_failures: - for exception in exceptions: - warnings.warn(str(exception)) - - # Set the _fit_called parameters (only identical copies of self.coreg have actually been called) - self.procstep._fit_called = True - if isinstance(self.procstep, CoregPipeline): - for step in self.procstep.pipeline: - step._fit_called = True - - # Flag that the fitting function has been called. - self._fit_called = True - - return self - - def _restore_metadata(self, meta: CoregDict) -> None: - """ - Given some metadata, set it in the right place. - - :param meta: A metadata file to update self._meta - """ - self.procstep._meta.update(meta) - - if isinstance(self.procstep, CoregPipeline) and "pipeline" in meta: - for i, step in enumerate(self.procstep.pipeline): - step._meta.update(meta["pipeline"][i]) - - def to_points(self) -> NDArrayf: - """ - Convert the blockwise coregistration matrices to 3D (source -> destination) points. - - The returned shape is (N, 3, 2) where the dimensions represent: - 0. The point index where N is equal to the amount of subdivisions. - 1. The X/Y/Z coordinate of the point. - 2. The old/new position of the point. - - To acquire the first point's original position: points[0, :, 0] - To acquire the first point's new position: points[0, :, 1] - To acquire the first point's Z difference: points[0, 2, 1] - points[0, 2, 0] - - :returns: An array of 3D source -> destination points. - """ - if len(self._meta["step_meta"]) == 0: - raise AssertionError("No coreg results exist. Has '.fit()' been called?") - points = np.empty(shape=(0, 3, 2)) - - for i in range(self.subdivision): - # Try to restore the metadata for this chunk (if it succeeded) - chunk_meta = next((meta for meta in self._meta["step_meta"] if meta["i"] == i), None) - - if chunk_meta is not None: - # Successful chunk: Retrieve the representative X, Y, Z coordinates - self._restore_metadata(chunk_meta) - x_coord, y_coord = chunk_meta["representative_x"], chunk_meta["representative_y"] - repr_val = chunk_meta["representative_val"] - else: - # Failed chunk: Calculate the approximate center using the group's bounds - rows, cols = np.where(self._groups == i) - center_row = (rows.min() + rows.max()) // 2 - center_col = (cols.min() + cols.max()) // 2 - - transform = self._meta["step_meta"][0]["transform"] # Assuming all chunks share a transform - x_coord, y_coord = rio.transform.xy(transform, center_row, center_col) - repr_val = np.nan # No valid Z value for failed chunks - - # Old position based on the calculated or retrieved coordinates - old_pos_arr = np.reshape([x_coord, y_coord, repr_val], (1, 3)) - old_position = gpd.GeoDataFrame( - geometry=gpd.points_from_xy(x=old_pos_arr[:, 0], y=old_pos_arr[:, 1], crs=None), - data={"z": old_pos_arr[:, 2]}, - ) - - if chunk_meta is not None: - # Successful chunk: Apply the transformation - new_position = self.procstep.apply(old_position) - new_pos_arr = np.reshape( - [new_position.geometry.x.values, new_position.geometry.y.values, new_position["z"].values], (1, 3) - ) - else: - # Failed chunk: Keep the new position the same as the old position (no transformation) - new_pos_arr = old_pos_arr.copy() - - # Append the result - points = np.append(points, np.dstack((old_pos_arr, new_pos_arr)), axis=0) - - return points - - def stats(self) -> pd.DataFrame: - """ - Return statistics for each chunk in the blockwise coregistration. - - * center_{x,y,z}: The center coordinate of the chunk in georeferenced units. - * {x,y,z}_off: The calculated offset in georeferenced units. - * inlier_count: The number of pixels that were inliers in the chunk. - * nmad: The NMAD of elevation differences (robust dispersion) after coregistration. - * median: The median of elevation differences (vertical shift) after coregistration. - - :raises ValueError: If no coregistration results exist yet. - - :returns: A dataframe of statistics for each chunk. - If a chunk fails (not present in `chunk_meta`), the statistics will be returned as `NaN`. - """ - points = self.to_points() - - chunk_meta = {meta["i"]: meta for meta in self.meta["step_meta"]} - - statistics: list[dict[str, Any]] = [] - for i in range(points.shape[0]): - if i not in chunk_meta: - # For missing chunks, return NaN for all stats - statistics.append( - { - "center_x": points[i, 0, 0], - "center_y": points[i, 1, 0], - "center_z": points[i, 2, 0], - "x_off": np.nan, - "y_off": np.nan, - "z_off": np.nan, - "inlier_count": np.nan, - "nmad": np.nan, - "median": np.nan, - } - ) - else: - statistics.append( - { - "center_x": points[i, 0, 0], - "center_y": points[i, 1, 0], - "center_z": points[i, 2, 0], - "x_off": points[i, 0, 1] - points[i, 0, 0], - "y_off": points[i, 1, 1] - points[i, 1, 0], - "z_off": points[i, 2, 1] - points[i, 2, 0], - "inlier_count": chunk_meta[i]["inlier_count"], - "nmad": chunk_meta[i]["nmad"], - "median": chunk_meta[i]["median"], - } - ) - - stats_df = pd.DataFrame(statistics) - stats_df.index.name = "chunk" - - return stats_df - - def subdivide_array(self, shape: tuple[int, ...]) -> NDArrayf: - """ - Return the grid subdivision for a given DEM shape. - - :param shape: The shape of the input DEM. - - :returns: An array of shape 'shape' with 'self.subdivision' unique indices. - """ - if len(shape) == 3 and shape[0] == 1: # Account for (1, row, col) shapes - shape = (shape[1], shape[2]) - return subdivide_array(shape, count=self.subdivision) - - def _apply_rst( - self, - elev: NDArrayf, - transform: rio.transform.Affine, - crs: rio.crs.CRS, - bias_vars: dict[str, NDArrayf] | None = None, - **kwargs: Any, - ) -> tuple[NDArrayf, rio.transform.Affine]: - - if np.count_nonzero(np.isfinite(elev)) == 0: - return elev, transform - - # Other option than resample=True is not implemented for this case - if "resample" in kwargs and kwargs["resample"] is not True: - raise NotImplementedError("Option `resample=False` not supported for coreg method BlockwiseCoreg.") - - points = self.to_points() - # Check for NaN values across both the old and new positions for each point - mask = ~np.isnan(points).any(axis=(1, 2)) - - # Filter out points where there are no NaN values - points = points[mask] - - bounds = _bounds(transform=transform, shape=elev.shape) - resolution = _res(transform) - - representative_height = np.nanmean(elev) - edges_source_arr = np.array( - [ - [bounds.left + resolution[0] / 2, bounds.top - resolution[1] / 2, representative_height], - [bounds.right - resolution[0] / 2, bounds.top - resolution[1] / 2, representative_height], - [bounds.left + resolution[0] / 2, bounds.bottom + resolution[1] / 2, representative_height], - [bounds.right - resolution[0] / 2, bounds.bottom + resolution[1] / 2, representative_height], - ] - ) - edges_source = gpd.GeoDataFrame( - geometry=gpd.points_from_xy(x=edges_source_arr[:, 0], y=edges_source_arr[:, 1], crs=None), - data={"z": edges_source_arr[:, 2]}, - ) - - edges_dest = self.apply(edges_source) - edges_dest_arr = np.array( - [edges_dest.geometry.x.values, edges_dest.geometry.y.values, edges_dest["z"].values] - ).T - edges = np.dstack((edges_source_arr, edges_dest_arr)) - - all_points = np.append(points, edges, axis=0) - - warped_dem = warp_dem( - dem=elev, - transform=transform, - source_coords=all_points[:, :, 1], - destination_coords=all_points[:, :, 0], - resampling="linear", - apply_z_correction=self.apply_z_correction, - ) - - return warped_dem, transform - - def _apply_pts( - self, elev: gpd.GeoDataFrame, z_name: str = "z", bias_vars: dict[str, NDArrayf] | None = None, **kwargs: Any - ) -> gpd.GeoDataFrame: - """Apply the scaling model to a set of points.""" - points = self.to_points() - - # Check for NaN values across both the old and new positions for each point - mask = ~np.isnan(points).any(axis=(1, 2)) - - # Filter out points where there are no NaN values - points = points[mask] - - new_coords = np.array([elev.geometry.x.values, elev.geometry.y.values, elev["z"].values]).T - - for dim in range(0, 3): - with warnings.catch_warnings(): - # ZeroDivisionErrors may happen when the transformation is empty (which is fine) - warnings.filterwarnings("ignore", message="ZeroDivisionError") - model = scipy.interpolate.Rbf( - points[:, 0, 0], - points[:, 1, 0], - points[:, dim, 1] - points[:, dim, 0], - function="linear", - ) - - new_coords[:, dim] += model(elev.geometry.x.values, elev.geometry.y.values) - - gdf_new_coords = gpd.GeoDataFrame( - geometry=gpd.points_from_xy(x=new_coords[:, 0], y=new_coords[:, 1], crs=None), data={"z": new_coords[:, 2]} - ) - - return gdf_new_coords - - -def warp_dem( - dem: NDArrayf, - transform: rio.transform.Affine, - source_coords: NDArrayf, - destination_coords: NDArrayf, - resampling: str = "cubic", - trim_border: bool = True, - dilate_mask: bool = True, - apply_z_correction: bool = True, -) -> NDArrayf: - """ - (22/08/24: Method currently used only for blockwise coregistration) - Warp a DEM using a set of source-destination 2D or 3D coordinates. - - :param dem: The DEM to warp. Allowed shapes are (1, row, col) or (row, col) - :param transform: The Affine transform of the DEM. - :param source_coords: The source 2D or 3D points. must be X/Y/(Z) coords of shape (N, 2) or (N, 3). - :param destination_coords: The destination 2D or 3D points. Must have the exact same shape as 'source_coords' - :param resampling: The resampling order to use. Choices: ['nearest', 'linear', 'cubic']. - :param trim_border: Remove values outside of the interpolation regime (True) or leave them unmodified (False). - :param dilate_mask: Dilate the nan mask to exclude edge pixels that could be wrong. - :param apply_z_correction: Boolean to toggle whether the Z-offset correction is applied or not (default True). - - :raises ValueError: If the inputs are poorly formatted. - :raises AssertionError: For unexpected outputs. - - :returns: A warped DEM with the same shape as the input. - """ - if source_coords.shape != destination_coords.shape: - raise ValueError( - f"Incompatible shapes: source_coords '({source_coords.shape})' and " - f"destination_coords '({destination_coords.shape})' shapes must be the same" - ) - if (len(source_coords.shape) > 2) or (source_coords.shape[1] < 2) or (source_coords.shape[1] > 3): - raise ValueError( - "Invalid coordinate shape. Expected 2D or 3D coordinates of shape (N, 2) or (N, 3). " - f"Got '{source_coords.shape}'" - ) - allowed_resampling_strs = ["nearest", "linear", "cubic"] - if resampling not in allowed_resampling_strs: - raise ValueError(f"Resampling type '{resampling}' not understood. Choices: {allowed_resampling_strs}") - - dem_arr, dem_mask = get_array_and_mask(dem) - - bounds = _bounds(transform=transform, shape=dem_arr.shape) - - no_horizontal = np.sum(np.linalg.norm(destination_coords[:, :2] - source_coords[:, :2], axis=1)) < 1e-6 - no_vertical = source_coords.shape[1] > 2 and np.sum(np.abs(destination_coords[:, 2] - source_coords[:, 2])) < 1e-6 - - if no_horizontal and no_vertical: - warnings.warn("No difference between source and destination coordinates. Returning self.") - return dem - - source_coords_scaled = source_coords.copy() - destination_coords_scaled = destination_coords.copy() - # Scale the coordinates to index-space - for coords in (source_coords_scaled, destination_coords_scaled): - coords[:, 0] = dem_arr.shape[1] * (coords[:, 0] - bounds.left) / (bounds.right - bounds.left) - coords[:, 1] = dem_arr.shape[0] * (1 - (coords[:, 1] - bounds.bottom) / (bounds.top - bounds.bottom)) - - # Generate a grid of x and y index coordinates. - grid_y, grid_x = np.mgrid[0 : dem_arr.shape[0], 0 : dem_arr.shape[1]] - - if no_horizontal: - warped = dem_arr.copy() - else: - # Interpolate the sparse source-destination points to a grid. - # (row, col, 0) represents the destination y-coordinates of the pixels. - # (row, col, 1) represents the destination x-coordinates of the pixels. - new_indices = scipy.interpolate.griddata( - source_coords_scaled[:, [1, 0]], - destination_coords_scaled[:, [1, 0]], # Coordinates should be in y/x (not x/y) for some reason.. - (grid_y, grid_x), - method="linear", - ) - - # If the border should not be trimmed, just assign the original indices to the missing values. - if not trim_border: - missing_ys = np.isnan(new_indices[:, :, 0]) - missing_xs = np.isnan(new_indices[:, :, 1]) - new_indices[:, :, 0][missing_ys] = grid_y[missing_ys] - new_indices[:, :, 1][missing_xs] = grid_x[missing_xs] - - order = {"nearest": 0, "linear": 1, "cubic": 3} - - with warnings.catch_warnings(): - # A skimage warning that will hopefully be fixed soon. (2021-06-08) - warnings.filterwarnings("ignore", message="Passing `np.nan` to mean no clipping in np.clip") - warped = skimage.transform.warp( - image=np.where(dem_mask, np.nan, dem_arr), - inverse_map=np.moveaxis(new_indices, 2, 0), - output_shape=dem_arr.shape, - preserve_range=True, - order=order[resampling], - cval=np.nan, - ) - new_mask = ( - skimage.transform.warp( - image=dem_mask, inverse_map=np.moveaxis(new_indices, 2, 0), output_shape=dem_arr.shape, cval=False - ) - > 0 - ) - - if dilate_mask: - new_mask = scipy.ndimage.binary_dilation(new_mask, iterations=order[resampling]).astype(new_mask.dtype) - - warped[new_mask] = np.nan - - # Apply the Z-correction if apply_z_correction is True and if the coordinates are 3D (N, 3) - if not no_vertical and apply_z_correction: - grid_offsets = scipy.interpolate.griddata( - points=destination_coords_scaled[:, :2], - values=source_coords_scaled[:, 2] - destination_coords_scaled[:, 2], - xi=(grid_x, grid_y), - method=resampling, - fill_value=np.nan, - ) - if not trim_border: - grid_offsets[np.isnan(grid_offsets)] = np.nanmean(grid_offsets) - - warped += grid_offsets - - assert not np.all(np.isnan(warped)), "All-NaN output." - - return warped.reshape(dem.shape) diff --git a/xdem/coreg/blockwise.py b/xdem/coreg/blockwise.py new file mode 100644 index 00000000..8aac5f38 --- /dev/null +++ b/xdem/coreg/blockwise.py @@ -0,0 +1,663 @@ +# Copyright (c) 2024 xDEM developers +# Copyright (c) 2025 Centre National d'Etudes Spatiales (CNES). +# +# This file is part of the xDEM project: +# https://github.com/glaciohack/xdem +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Base coregistration classes to define generic methods and pre/post-processing of input data.""" + +from __future__ import annotations + +import concurrent.futures +import inspect +import logging +import warnings +from typing import Any, Literal + +import geopandas as gpd +import numpy as np +import pandas as pd +import rasterio as rio +import scipy +import scipy.interpolate +import scipy.ndimage +import scipy.optimize +import skimage +from geoutils.raster import Mask, RasterType, subdivide_array +from geoutils.raster.array import get_array_and_mask +from geoutils.raster.georeferencing import _bounds, _res +from tqdm import tqdm + +from xdem._typing import MArrayf, NDArrayb, NDArrayf +from xdem.coreg.base import ( + Coreg, + CoregDict, + CoregPipeline, + CoregType, + _preprocess_coreg_fit, +) + + +class BlockwiseCoreg(Coreg): + """ + Block-wise co-registration processing class to run a step in segmented parts of the grid. + + A processing class of choice is run on an arbitrary subdivision of the raster. When later applying the step + the optimal warping is interpolated based on X/Y/Z shifts from the coreg algorithm at the grid points. + + For instance: a subdivision of 4 triggers a division of the DEM in four equally sized parts. These parts are then + processed separately, with 4 .fit() results. If the subdivision is not divisible by the raster shape, + subdivision is made as good as possible to have approximately equal pixel counts. + """ + + def __init__( + self, + step: Coreg | CoregPipeline, + subdivision: int, + success_threshold: float = 0.8, + n_threads: int | None = None, + warn_failures: bool = False, + apply_z_correction: bool = True, + ) -> None: + """ + Instantiate a blockwise processing object. + + :param step: An instantiated co-registration step object to fit in the subdivided DEMs. + :param subdivision: The number of chunks to divide the DEMs in. E.g. 4 means four different transforms. + :param success_threshold: Raise an error if fewer chunks than the fraction failed for any reason. + :param n_threads: The maximum amount of threads to use. Default=auto + :param warn_failures: Trigger or ignore warnings for each exception/warning in each block. + :param apply_z_correction: Boolean to toggle whether the Z-offset correction is applied or not (default True). + """ + if isinstance(step, type): + raise ValueError( + "The 'step' argument must be an instantiated Coreg subclass. " "Hint: write e.g. ICP() instead of ICP" + ) + self.procstep = step + self.subdivision = subdivision + self.success_threshold = success_threshold + self.n_threads = n_threads + self.warn_failures = warn_failures + self.apply_z_correction = apply_z_correction + + super().__init__() + + self._meta: CoregDict = {"step_meta": []} + self._groups: NDArrayf = np.array([]) + + def fit( + self: CoregType, + reference_elev: NDArrayf | MArrayf | RasterType, + to_be_aligned_elev: NDArrayf | MArrayf | RasterType, + inlier_mask: NDArrayb | Mask | None = None, + bias_vars: dict[str, NDArrayf | MArrayf | RasterType] | None = None, + weights: NDArrayf | None = None, + subsample: float | int | None = None, + transform: rio.transform.Affine | None = None, + crs: rio.crs.CRS | None = None, + area_or_point: Literal["Area", "Point"] | None = None, + z_name: str = "z", + random_state: int | np.random.Generator | None = None, + **kwargs: Any, + ) -> CoregType: + + if isinstance(reference_elev, gpd.GeoDataFrame) and isinstance(to_be_aligned_elev, gpd.GeoDataFrame): + raise NotImplementedError("Blockwise coregistration does not yet support two elevation point cloud inputs.") + + # Check if subsample arguments are different from their default value for any of the coreg steps: + # get default value in argument spec and "subsample" stored in meta, and compare both are consistent + if not isinstance(self.procstep, CoregPipeline): + steps = [self.procstep] + else: + steps = list(self.procstep.pipeline) + argspec = [inspect.getfullargspec(s.__class__) for s in steps] + sub_meta = [s._meta["inputs"]["random"]["subsample"] for s in steps] + sub_is_default = [ + argspec[i].defaults[argspec[i].args.index("subsample") - 1] == sub_meta[i] # type: ignore + for i in range(len(argspec)) + ] + if subsample is not None and not all(sub_is_default): + warnings.warn( + "Subsample argument passed to fit() will override non-default subsample values defined in the" + " step within the blockwise method. To silence this warning: only define 'subsample' in " + "either fit(subsample=...) or instantiation e.g., VerticalShift(subsample=...)." + ) + + # Pre-process the inputs, by reprojecting and subsampling, without any subsampling (done in each step) + ref_dem, tba_dem, inlier_mask, transform, crs, area_or_point = _preprocess_coreg_fit( + reference_elev=reference_elev, + to_be_aligned_elev=to_be_aligned_elev, + inlier_mask=inlier_mask, + transform=transform, + crs=crs, + area_or_point=area_or_point, + ) + + # Define inlier mask if None, before indexing subdivided array in process function below + if inlier_mask is None: + mask = np.ones(tba_dem.shape, dtype=bool) + else: + mask = inlier_mask + + self._groups = self.subdivide_array(tba_dem.shape if isinstance(tba_dem, np.ndarray) else ref_dem.shape) + + indices = np.unique(self._groups) + + progress_bar = tqdm( + total=indices.size, desc="Processing chunks", disable=logging.getLogger().getEffectiveLevel() > logging.INFO + ) + + def process(i: int) -> dict[str, Any] | BaseException | None: + """ + Process a chunk in a thread-safe way. + + :returns: + * If it succeeds: A dictionary of the fitting metadata. + * If it fails: The associated exception. + * If the block is empty: None + """ + group_mask = self._groups == i + + # Find the corresponding slice of the inlier_mask to subset the data + rows, cols = np.where(group_mask) + arrayslice = np.s_[rows.min() : rows.max() + 1, cols.min() : cols.max() + 1] + + # Copy a subset of the two DEMs, the mask, the coreg instance, and make a new subset transform + ref_subset = ref_dem[arrayslice].copy() + tba_subset = tba_dem[arrayslice].copy() + + if any(np.all(~np.isfinite(dem)) for dem in (ref_subset, tba_subset)): + return None + mask_subset = mask[arrayslice].copy() + west, top = rio.transform.xy(transform, min(rows), min(cols), offset="ul") + transform_subset = rio.transform.from_origin(west, top, transform.a, -transform.e) # type: ignore + procstep = self.procstep.copy() + + # Try to run the coregistration. If it fails for any reason, skip it and save the exception. + try: + procstep.fit( + reference_elev=ref_subset, + to_be_aligned_elev=tba_subset, + transform=transform_subset, + inlier_mask=mask_subset, + bias_vars=bias_vars, + weights=weights, + crs=crs, + area_or_point=area_or_point, + z_name=z_name, + subsample=subsample, + random_state=random_state, + ) + nmad, median = procstep.error( + reference_elev=ref_subset, + to_be_aligned_elev=tba_subset, + error_type=["nmad", "median"], + inlier_mask=mask_subset, + transform=transform_subset, + crs=crs, + ) + except Exception as exception: + return exception + + meta: dict[str, Any] = { + "i": i, + "transform": transform_subset, + "inlier_count": np.count_nonzero(mask_subset & np.isfinite(ref_subset) & np.isfinite(tba_subset)), + "nmad": nmad, + "median": median, + } + # Find the center of the inliers. + inlier_positions = np.argwhere(mask_subset) + mid_row = np.mean(inlier_positions[:, 0]).astype(int) + mid_col = np.mean(inlier_positions[:, 1]).astype(int) + + # Find the indices of all finites within the mask + finites = np.argwhere(np.isfinite(tba_subset) & mask_subset) + # Calculate the distance between the approximate center and all finite indices + distances = np.linalg.norm(finites - np.array([mid_row, mid_col]), axis=1) + # Find the index representing the closest finite value to the center. + closest = np.argwhere(distances == distances.min()) + + # Assign the closest finite value as the representative point + representative_row, representative_col = finites[closest][0][0] + meta["representative_x"], meta["representative_y"] = rio.transform.xy( + transform_subset, representative_row, representative_col + ) + + repr_val = ref_subset[representative_row, representative_col] + if ~np.isfinite(repr_val): + repr_val = 0 + meta["representative_val"] = repr_val + + # If the coreg is a pipeline, copy its metadatas to the output meta + if hasattr(procstep, "pipeline"): + meta["pipeline"] = [step.meta.copy() for step in procstep.pipeline] + + # Copy all current metadata (except for the already existing keys like "i", "min_row", etc, and the + # "coreg_meta" key) + # This can then be iteratively restored when the apply function should be called. + meta.update( + {key: value for key, value in procstep.meta.items() if key not in ["step_meta"] + list(meta.keys())} + ) + + progress_bar.update() + + return meta.copy() + + # Catch warnings; only show them if + exceptions: list[BaseException | warnings.WarningMessage] = [] + with warnings.catch_warnings(record=True) as caught_warnings: + warnings.simplefilter("default") + with concurrent.futures.ThreadPoolExecutor(max_workers=None) as executor: + results = executor.map(process, indices) + + exceptions += list(caught_warnings) + + empty_blocks = 0 + for result in results: + if isinstance(result, BaseException): + exceptions.append(result) + elif result is None: + empty_blocks += 1 + continue + else: + self._meta["step_meta"].append(result) + + progress_bar.close() + + # Stop if the success rate was below the threshold + if ((len(self._meta["step_meta"]) + empty_blocks) / self.subdivision) <= self.success_threshold: + raise ValueError( + f"Fitting failed for {len(exceptions)} chunks:\n" + + "\n".join(map(str, exceptions[:5])) + + f"\n... and {len(exceptions) - 5} more" + if len(exceptions) > 5 + else "" + ) + + if self.warn_failures: + for exception in exceptions: + warnings.warn(str(exception)) + + # Set the _fit_called parameters (only identical copies of self.coreg have actually been called) + self.procstep._fit_called = True + if isinstance(self.procstep, CoregPipeline): + for step in self.procstep.pipeline: + step._fit_called = True + + # Flag that the fitting function has been called. + self._fit_called = True + + return self + + def _restore_metadata(self, meta: CoregDict) -> None: + """ + Given some metadata, set it in the right place. + + :param meta: A metadata file to update self._meta + """ + self.procstep._meta.update(meta) + + if isinstance(self.procstep, CoregPipeline) and "pipeline" in meta: + for i, step in enumerate(self.procstep.pipeline): + step._meta.update(meta["pipeline"][i]) + + def to_points(self) -> NDArrayf: + """ + Convert the blockwise coregistration matrices to 3D (source -> destination) points. + + The returned shape is (N, 3, 2) where the dimensions represent: + 0. The point index where N is equal to the amount of subdivisions. + 1. The X/Y/Z coordinate of the point. + 2. The old/new position of the point. + + To acquire the first point's original position: points[0, :, 0] + To acquire the first point's new position: points[0, :, 1] + To acquire the first point's Z difference: points[0, 2, 1] - points[0, 2, 0] + + :returns: An array of 3D source -> destination points. + """ + if len(self._meta["step_meta"]) == 0: + raise AssertionError("No coreg results exist. Has '.fit()' been called?") + points = np.empty(shape=(0, 3, 2)) + + for i in range(self.subdivision): + # Try to restore the metadata for this chunk (if it succeeded) + chunk_meta = next((meta for meta in self._meta["step_meta"] if meta["i"] == i), None) + + if chunk_meta is not None: + # Successful chunk: Retrieve the representative X, Y, Z coordinates + self._restore_metadata(chunk_meta) + x_coord, y_coord = chunk_meta["representative_x"], chunk_meta["representative_y"] + repr_val = chunk_meta["representative_val"] + else: + # Failed chunk: Calculate the approximate center using the group's bounds + rows, cols = np.where(self._groups == i) + center_row = (rows.min() + rows.max()) // 2 + center_col = (cols.min() + cols.max()) // 2 + + transform = self._meta["step_meta"][0]["transform"] # Assuming all chunks share a transform + x_coord, y_coord = rio.transform.xy(transform, center_row, center_col) + repr_val = np.nan # No valid Z value for failed chunks + + # Old position based on the calculated or retrieved coordinates + old_pos_arr = np.reshape([x_coord, y_coord, repr_val], (1, 3)) + old_position = gpd.GeoDataFrame( + geometry=gpd.points_from_xy(x=old_pos_arr[:, 0], y=old_pos_arr[:, 1], crs=None), + data={"z": old_pos_arr[:, 2]}, + ) + + if chunk_meta is not None: + # Successful chunk: Apply the transformation + new_position = self.procstep.apply(old_position) + new_pos_arr = np.reshape( + [new_position.geometry.x.values, new_position.geometry.y.values, new_position["z"].values], (1, 3) + ) + else: + # Failed chunk: Keep the new position the same as the old position (no transformation) + new_pos_arr = old_pos_arr.copy() + + # Append the result + points = np.append(points, np.dstack((old_pos_arr, new_pos_arr)), axis=0) + + return points + + def stats(self) -> pd.DataFrame: + """ + Return statistics for each chunk in the blockwise coregistration. + + * center_{x,y,z}: The center coordinate of the chunk in georeferenced units. + * {x,y,z}_off: The calculated offset in georeferenced units. + * inlier_count: The number of pixels that were inliers in the chunk. + * nmad: The NMAD of elevation differences (robust dispersion) after coregistration. + * median: The median of elevation differences (vertical shift) after coregistration. + + :raises ValueError: If no coregistration results exist yet. + + :returns: A dataframe of statistics for each chunk. + If a chunk fails (not present in `chunk_meta`), the statistics will be returned as `NaN`. + """ + points = self.to_points() + + chunk_meta = {meta["i"]: meta for meta in self.meta["step_meta"]} + + statistics: list[dict[str, Any]] = [] + for i in range(points.shape[0]): + if i not in chunk_meta: + # For missing chunks, return NaN for all stats + statistics.append( + { + "center_x": points[i, 0, 0], + "center_y": points[i, 1, 0], + "center_z": points[i, 2, 0], + "x_off": np.nan, + "y_off": np.nan, + "z_off": np.nan, + "inlier_count": np.nan, + "nmad": np.nan, + "median": np.nan, + } + ) + else: + statistics.append( + { + "center_x": points[i, 0, 0], + "center_y": points[i, 1, 0], + "center_z": points[i, 2, 0], + "x_off": points[i, 0, 1] - points[i, 0, 0], + "y_off": points[i, 1, 1] - points[i, 1, 0], + "z_off": points[i, 2, 1] - points[i, 2, 0], + "inlier_count": chunk_meta[i]["inlier_count"], + "nmad": chunk_meta[i]["nmad"], + "median": chunk_meta[i]["median"], + } + ) + + stats_df = pd.DataFrame(statistics) + stats_df.index.name = "chunk" + + return stats_df + + def subdivide_array(self, shape: tuple[int, ...]) -> NDArrayf: + """ + Return the grid subdivision for a given DEM shape. + + :param shape: The shape of the input DEM. + + :returns: An array of shape 'shape' with 'self.subdivision' unique indices. + """ + if len(shape) == 3 and shape[0] == 1: # Account for (1, row, col) shapes + shape = (shape[1], shape[2]) + return subdivide_array(shape, count=self.subdivision) + + def _apply_rst( + self, + elev: NDArrayf, + transform: rio.transform.Affine, + crs: rio.crs.CRS, + bias_vars: dict[str, NDArrayf] | None = None, + **kwargs: Any, + ) -> tuple[NDArrayf, rio.transform.Affine]: + + if np.count_nonzero(np.isfinite(elev)) == 0: + return elev, transform + + # Other option than resample=True is not implemented for this case + if "resample" in kwargs and kwargs["resample"] is not True: + raise NotImplementedError("Option `resample=False` not supported for coreg method BlockwiseCoreg.") + + points = self.to_points() + # Check for NaN values across both the old and new positions for each point + mask = ~np.isnan(points).any(axis=(1, 2)) + + # Filter out points where there are no NaN values + points = points[mask] + + bounds = _bounds(transform=transform, shape=elev.shape) + resolution = _res(transform) + + representative_height = np.nanmean(elev) + edges_source_arr = np.array( + [ + [bounds.left + resolution[0] / 2, bounds.top - resolution[1] / 2, representative_height], + [bounds.right - resolution[0] / 2, bounds.top - resolution[1] / 2, representative_height], + [bounds.left + resolution[0] / 2, bounds.bottom + resolution[1] / 2, representative_height], + [bounds.right - resolution[0] / 2, bounds.bottom + resolution[1] / 2, representative_height], + ] + ) + edges_source = gpd.GeoDataFrame( + geometry=gpd.points_from_xy(x=edges_source_arr[:, 0], y=edges_source_arr[:, 1], crs=None), + data={"z": edges_source_arr[:, 2]}, + ) + + edges_dest = self.apply(edges_source) + edges_dest_arr = np.array( + [edges_dest.geometry.x.values, edges_dest.geometry.y.values, edges_dest["z"].values] + ).T + edges = np.dstack((edges_source_arr, edges_dest_arr)) + + all_points = np.append(points, edges, axis=0) + + warped_dem = warp_dem( + dem=elev, + transform=transform, + source_coords=all_points[:, :, 1], + destination_coords=all_points[:, :, 0], + resampling="linear", + apply_z_correction=self.apply_z_correction, + ) + + return warped_dem, transform + + def _apply_pts( + self, elev: gpd.GeoDataFrame, z_name: str = "z", bias_vars: dict[str, NDArrayf] | None = None, **kwargs: Any + ) -> gpd.GeoDataFrame: + """Apply the scaling model to a set of points.""" + points = self.to_points() + + # Check for NaN values across both the old and new positions for each point + mask = ~np.isnan(points).any(axis=(1, 2)) + + # Filter out points where there are no NaN values + points = points[mask] + + new_coords = np.array([elev.geometry.x.values, elev.geometry.y.values, elev["z"].values]).T + + for dim in range(0, 3): + with warnings.catch_warnings(): + # ZeroDivisionErrors may happen when the transformation is empty (which is fine) + warnings.filterwarnings("ignore", message="ZeroDivisionError") + model = scipy.interpolate.Rbf( + points[:, 0, 0], + points[:, 1, 0], + points[:, dim, 1] - points[:, dim, 0], + function="linear", + ) + + new_coords[:, dim] += model(elev.geometry.x.values, elev.geometry.y.values) + + gdf_new_coords = gpd.GeoDataFrame( + geometry=gpd.points_from_xy(x=new_coords[:, 0], y=new_coords[:, 1], crs=None), data={"z": new_coords[:, 2]} + ) + + return gdf_new_coords + + +def warp_dem( + dem: NDArrayf, + transform: rio.transform.Affine, + source_coords: NDArrayf, + destination_coords: NDArrayf, + resampling: str = "cubic", + trim_border: bool = True, + dilate_mask: bool = True, + apply_z_correction: bool = True, +) -> NDArrayf: + """ + (22/08/24: Method currently used only for blockwise coregistration) + Warp a DEM using a set of source-destination 2D or 3D coordinates. + + :param dem: The DEM to warp. Allowed shapes are (1, row, col) or (row, col) + :param transform: The Affine transform of the DEM. + :param source_coords: The source 2D or 3D points. must be X/Y/(Z) coords of shape (N, 2) or (N, 3). + :param destination_coords: The destination 2D or 3D points. Must have the exact same shape as 'source_coords' + :param resampling: The resampling order to use. Choices: ['nearest', 'linear', 'cubic']. + :param trim_border: Remove values outside of the interpolation regime (True) or leave them unmodified (False). + :param dilate_mask: Dilate the nan mask to exclude edge pixels that could be wrong. + :param apply_z_correction: Boolean to toggle whether the Z-offset correction is applied or not (default True). + + :raises ValueError: If the inputs are poorly formatted. + :raises AssertionError: For unexpected outputs. + + :returns: A warped DEM with the same shape as the input. + """ + if source_coords.shape != destination_coords.shape: + raise ValueError( + f"Incompatible shapes: source_coords '({source_coords.shape})' and " + f"destination_coords '({destination_coords.shape})' shapes must be the same" + ) + if (len(source_coords.shape) > 2) or (source_coords.shape[1] < 2) or (source_coords.shape[1] > 3): + raise ValueError( + "Invalid coordinate shape. Expected 2D or 3D coordinates of shape (N, 2) or (N, 3). " + f"Got '{source_coords.shape}'" + ) + allowed_resampling_strs = ["nearest", "linear", "cubic"] + if resampling not in allowed_resampling_strs: + raise ValueError(f"Resampling type '{resampling}' not understood. Choices: {allowed_resampling_strs}") + + dem_arr, dem_mask = get_array_and_mask(dem) + + bounds = _bounds(transform=transform, shape=dem_arr.shape) + + no_horizontal = np.sum(np.linalg.norm(destination_coords[:, :2] - source_coords[:, :2], axis=1)) < 1e-6 + no_vertical = source_coords.shape[1] > 2 and np.sum(np.abs(destination_coords[:, 2] - source_coords[:, 2])) < 1e-6 + + if no_horizontal and no_vertical: + warnings.warn("No difference between source and destination coordinates. Returning self.") + return dem + + source_coords_scaled = source_coords.copy() + destination_coords_scaled = destination_coords.copy() + # Scale the coordinates to index-space + for coords in (source_coords_scaled, destination_coords_scaled): + coords[:, 0] = dem_arr.shape[1] * (coords[:, 0] - bounds.left) / (bounds.right - bounds.left) + coords[:, 1] = dem_arr.shape[0] * (1 - (coords[:, 1] - bounds.bottom) / (bounds.top - bounds.bottom)) + + # Generate a grid of x and y index coordinates. + grid_y, grid_x = np.mgrid[0 : dem_arr.shape[0], 0 : dem_arr.shape[1]] + + if no_horizontal: + warped = dem_arr.copy() + else: + # Interpolate the sparse source-destination points to a grid. + # (row, col, 0) represents the destination y-coordinates of the pixels. + # (row, col, 1) represents the destination x-coordinates of the pixels. + new_indices = scipy.interpolate.griddata( + source_coords_scaled[:, [1, 0]], + destination_coords_scaled[:, [1, 0]], # Coordinates should be in y/x (not x/y) for some reason.. + (grid_y, grid_x), + method="linear", + ) + + # If the border should not be trimmed, just assign the original indices to the missing values. + if not trim_border: + missing_ys = np.isnan(new_indices[:, :, 0]) + missing_xs = np.isnan(new_indices[:, :, 1]) + new_indices[:, :, 0][missing_ys] = grid_y[missing_ys] + new_indices[:, :, 1][missing_xs] = grid_x[missing_xs] + + order = {"nearest": 0, "linear": 1, "cubic": 3} + + with warnings.catch_warnings(): + # A skimage warning that will hopefully be fixed soon. (2021-06-08) + warnings.filterwarnings("ignore", message="Passing `np.nan` to mean no clipping in np.clip") + warped = skimage.transform.warp( + image=np.where(dem_mask, np.nan, dem_arr), + inverse_map=np.moveaxis(new_indices, 2, 0), + output_shape=dem_arr.shape, + preserve_range=True, + order=order[resampling], + cval=np.nan, + ) + new_mask = ( + skimage.transform.warp( + image=dem_mask, inverse_map=np.moveaxis(new_indices, 2, 0), output_shape=dem_arr.shape, cval=False + ) + > 0 + ) + + if dilate_mask: + new_mask = scipy.ndimage.binary_dilation(new_mask, iterations=order[resampling]).astype(new_mask.dtype) + + warped[new_mask] = np.nan + + # Apply the Z-correction if apply_z_correction is True and if the coordinates are 3D (N, 3) + if not no_vertical and apply_z_correction: + grid_offsets = scipy.interpolate.griddata( + points=destination_coords_scaled[:, :2], + values=source_coords_scaled[:, 2] - destination_coords_scaled[:, 2], + xi=(grid_x, grid_y), + method=resampling, + fill_value=np.nan, + ) + if not trim_border: + grid_offsets[np.isnan(grid_offsets)] = np.nanmean(grid_offsets) + + warped += grid_offsets + + assert not np.all(np.isnan(warped)), "All-NaN output." + + return warped.reshape(dem.shape) From e698260a5e54cfbb0a2201eece7e880fa53675f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alice=20de=20Bardonn=C3=A8che-Richard?= Date: Wed, 5 Mar 2025 09:15:04 +0100 Subject: [PATCH 2/2] test: forget warp_dem --- tests/test_coreg/test_blockwise.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_coreg/test_blockwise.py b/tests/test_coreg/test_blockwise.py index a8ad94f0..116cd813 100644 --- a/tests/test_coreg/test_blockwise.py +++ b/tests/test_coreg/test_blockwise.py @@ -184,7 +184,7 @@ def test_warp_dem() -> None: dest_coords = source_coords.copy() dest_coords[0, 0] = -1e-5 - warped_dem = coreg.base.warp_dem( + warped_dem = coreg.blockwise.warp_dem( dem=small_dem, transform=small_transform, source_coords=source_coords, @@ -196,7 +196,7 @@ def test_warp_dem() -> None: elev_shift = 5.0 dest_coords[1, 2] = elev_shift - warped_dem = coreg.base.warp_dem( + warped_dem = coreg.blockwise.warp_dem( dem=small_dem, transform=small_transform, source_coords=source_coords, @@ -243,12 +243,12 @@ def test_warp_dem() -> None: dem = misc.generate_random_field(shape, 100) * 200 + misc.generate_random_field(shape, 10) * 50 # Warp the DEM using the source-destination coordinates. - transformed_dem = coreg.base.warp_dem( + transformed_dem = coreg.blockwise.warp_dem( dem=dem, transform=transform, source_coords=source_coords, destination_coords=dest_coords, resampling="linear" ) # Try to undo the warp by reversing the source-destination coordinates. - untransformed_dem = coreg.base.warp_dem( + untransformed_dem = coreg.blockwise.warp_dem( dem=transformed_dem, transform=transform, source_coords=dest_coords, @@ -260,7 +260,7 @@ def test_warp_dem() -> None: assert spatialstats.nmad(dem - untransformed_dem) < 0.5 # Test with Z-correction disabled - transformed_dem_no_z = coreg.base.warp_dem( + transformed_dem_no_z = coreg.blockwise.warp_dem( dem=dem, transform=transform, source_coords=source_coords, @@ -270,7 +270,7 @@ def test_warp_dem() -> None: ) # Try to undo the warp by reversing the source-destination coordinates with Z-correction disabled - untransformed_dem_no_z = coreg.base.warp_dem( + untransformed_dem_no_z = coreg.blockwise.warp_dem( dem=transformed_dem_no_z, transform=transform, source_coords=dest_coords,