From e2ca5afaeb04f7f3a07a4cd3f70bfc538300936e Mon Sep 17 00:00:00 2001 From: "Alan D. Snow" Date: Fri, 8 Nov 2024 09:00:29 -0600 Subject: [PATCH] BUG:merge: Ensure dims and coords match input array (#828) --- docs/history.rst | 1 + rioxarray/merge.py | 52 +++++++++++++--------- test/integration/test_integration_merge.py | 33 ++++++++++---- 3 files changed, 57 insertions(+), 29 deletions(-) diff --git a/docs/history.rst b/docs/history.rst index de8d460c..1a565702 100644 --- a/docs/history.rst +++ b/docs/history.rst @@ -4,6 +4,7 @@ History Latest ------- - DEP: Pin rasterio>=1.3.7 (pull #826) +- BUG:merge: Ensure dims and coords match input array (pull #828) 0.18.0 ------ diff --git a/rioxarray/merge.py b/rioxarray/merge.py index 25cfbf4e..5020503e 100644 --- a/rioxarray/merge.py +++ b/rioxarray/merge.py @@ -12,7 +12,7 @@ from xarray import DataArray, Dataset from rioxarray._io import open_rasterio -from rioxarray.rioxarray import _get_nonspatial_coords, _make_coords +from rioxarray.rioxarray import _get_nonspatial_coords class RasterioDatasetDuck: @@ -167,16 +167,35 @@ def merge_arrays( memfile.name, parse_coordinates=parse_coordinates, mask_and_scale=rioduckarrays[0]._mask_and_scale, - ) as xda: - xda = xda.load() - xda.coords.update( - { - coord: value - for coord, value in _get_nonspatial_coords(representative_array).items() - if coord not in xda.coords - } - ) - return xda # type: ignore + ) as merged_data: + merged_data = merged_data.load() + + # make sure old & new coorinate names match & dimensions are correct + rename_map = {} + original_extra_dim = representative_array.rio._check_dimensions() + new_extra_dim = merged_data.rio._check_dimensions() + # make sure the output merged data shape is 2D if the + # original data was 2D. this can happen if the + # xarray datasarray was squeezed. + if len(merged_data.shape) == 3 and len(representative_array.shape) == 2: + merged_data = merged_data.squeeze( + dim=new_extra_dim, drop=original_extra_dim is None + ) + new_extra_dim = merged_data.rio._check_dimensions() + if ( + original_extra_dim is not None + and new_extra_dim is not None + and original_extra_dim != new_extra_dim + ): + rename_map[new_extra_dim] = original_extra_dim + if representative_array.rio.x_dim != merged_data.rio.x_dim: + rename_map[merged_data.rio.x_dim] = representative_array.rio.x_dim + if representative_array.rio.y_dim != merged_data.rio.y_dim: + rename_map[merged_data.rio.y_dim] = representative_array.rio.y_dim + if rename_map: + merged_data = merged_data.rename(rename_map) + merged_data.coords.update(_get_nonspatial_coords(representative_array)) + return merged_data # type: ignore def merge_datasets( @@ -227,7 +246,7 @@ def merge_datasets( representative_ds = datasets[0] merged_data = {} - for data_var in representative_ds.data_vars: + for iii, data_var in enumerate(representative_ds.data_vars): merged_data[data_var] = merge_arrays( [dataset[data_var] for dataset in datasets], bounds=bounds, @@ -236,18 +255,11 @@ def merge_datasets( precision=precision, method=method, crs=crs, - parse_coordinates=False, + parse_coordinates=iii == 0, ) data_var = list(representative_ds.data_vars)[0] xds = Dataset( merged_data, - coords=_make_coords( - src_data_array=merged_data[data_var], - dst_affine=merged_data[data_var].rio.transform(), - dst_width=merged_data[data_var].shape[-1], - dst_height=merged_data[data_var].shape[-2], - force_generate=True, - ), attrs=representative_ds.attrs, ) xds.rio.write_crs(merged_data[data_var].rio.crs, inplace=True) diff --git a/test/integration/test_integration_merge.py b/test/integration/test_integration_merge.py index 53100ef1..e135e386 100644 --- a/test/integration/test_integration_merge.py +++ b/test/integration/test_integration_merge.py @@ -1,6 +1,7 @@ import os import pytest +import xarray from numpy import nansum from numpy.testing import assert_almost_equal @@ -48,8 +49,8 @@ def test_merge_arrays(squeeze): ), ) assert merged.rio._cached_transform() == merged.rio.transform() + assert sorted(merged.coords) == sorted(rds.coords) assert merged.coords["band"].values == [1] - assert sorted(merged.coords) == ["band", "spatial_ref", "x", "y"] assert merged.rio.crs == rds.rio.crs assert merged.attrs == { "AREA_OR_POINT": "Area", @@ -63,10 +64,12 @@ def test_merge_arrays(squeeze): @pytest.mark.parametrize("dataset", [True, False]) def test_merge__different_crs(dataset): dem_test = os.path.join(TEST_INPUT_DATA_DIR, "MODIS_ARRAY.nc") - with open_rasterio(dem_test) as rds: + with ( + xarray.open_dataset(dem_test, mask_and_scale=False) + if dataset + else xarray.open_dataarray(dem_test, mask_and_scale=False) + ) as rds: crs = rds.rio.crs - if dataset: - rds = rds.to_dataset() arrays = [ rds.isel(x=slice(100), y=slice(100)).rio.reproject("EPSG:3857"), rds.isel(x=slice(100, 200), y=slice(100, 200)), @@ -106,8 +109,7 @@ def test_merge__different_crs(dataset): 1.0, ), ) - assert merged.coords["band"].values == [1] - assert sorted(merged.coords) == ["band", "spatial_ref", "x", "y"] + assert sorted(merged.coords) == sorted(list(rds.coords) + ["spatial_ref"]) assert merged.rio.crs == rds.rio.crs if not dataset: assert merged.attrs == { @@ -151,8 +153,8 @@ def test_merge_arrays__res(): ), ) assert merged.rio._cached_transform() == merged.rio.transform() + assert sorted(merged.coords) == sorted(rds.coords) assert merged.coords["band"].values == [1] - assert sorted(merged.coords) == ["band", "spatial_ref", "x", "y"] assert merged.rio.crs == rds.rio.crs assert_almost_equal(merged.rio.nodata, rds.rio.nodata) assert_almost_equal(merged.rio.encoded_nodata, rds.rio.encoded_nodata) @@ -207,8 +209,8 @@ def test_merge_datasets(): 1.0, ), ) + assert sorted(merged.coords) == sorted(rds.coords) assert merged.coords["band"].values == [1] - assert sorted(merged.coords) == ["band", "spatial_ref", "x", "y"] assert merged.rio.crs == rds.rio.crs assert merged.attrs == rds.attrs assert merged.encoding["grid_mapping"] == "spatial_ref" @@ -263,7 +265,7 @@ def test_merge_datasets__res(): ) assert merged.rio.shape == (1112, 1112) assert merged.coords["band"].values == [1] - assert sorted(merged.coords) == ["band", "spatial_ref", "x", "y"] + assert sorted(merged.coords) == sorted(rds.coords) assert merged.rio.crs == rds.rio.crs assert merged.attrs == rds.attrs assert merged.encoding["grid_mapping"] == "spatial_ref" @@ -282,8 +284,21 @@ def test_merge_datasets__mask_and_scale(mask_and_scale): rds.isel(x=slice(100, None), y=slice(100)), ] merged = merge_datasets(datasets) + assert sorted(merged.coords) == sorted(list(rds.coords) + ["spatial_ref"]) total = merged.air_temperature.sum() if mask_and_scale: assert_almost_equal(total, 133376696) else: assert_almost_equal(total, 10981781386) + + +def test_merge_datasets__preserve_dimension_names(): + sentinel_2_geographic = os.path.join( + TEST_INPUT_DATA_DIR, "sentinel_2_L1C_geographic.nc" + ) + with xarray.open_dataset(sentinel_2_geographic) as mda: + merged = merge_datasets([mda]) + assert sorted(merged.coords) == sorted(mda.coords) + for data_var in mda.data_vars: + assert_almost_equal(merged[data_var].sum(), mda[data_var].sum()) + assert merged.rio.crs == mda.rio.crs