Skip to content

Commit

Permalink
BUG:merge: Ensure dims and coords match input array
Browse files Browse the repository at this point in the history
  • Loading branch information
snowman2 committed Nov 8, 2024
1 parent 87103e1 commit c85613e
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 19 deletions.
1 change: 1 addition & 0 deletions docs/history.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ History
Latest
-------
- DEP: Pin rasterio>=1.3.7 (pull #826)
- BUG:merge: Ensure dims and coords match input array (pull #827)

0.18.0
------
Expand Down
40 changes: 30 additions & 10 deletions rioxarray/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,16 +167,36 @@ def merge_arrays(
memfile.name,
parse_coordinates=parse_coordinates,
mask_and_scale=rioduckarrays[0]._mask_and_scale,
) as xda:
xda = xda.load()
xda.coords.update(
{
coord: value
for coord, value in _get_nonspatial_coords(representative_array).items()
if coord not in xda.coords
}
)
return xda # type: ignore
) as merged_data:
merged_data = merged_data.load()

# make sure old & new coorinate names match & dimensions are correct
rename_map = {}
original_extra_dim = representative_array.rio._check_dimensions()
new_extra_dim = merged_data.rio._check_dimensions()
# make sure the output merged data shape is 2D if the
# original data was 2D. this can happen if the
# xarray datasarray was squeezed.
if len(merged_data.shape) == 3 and len(representative_array.shape) == 2:
merged_data = merged_data.squeeze(
dim=new_extra_dim, drop=original_extra_dim is None
)
new_extra_dim = merged_data.rio._check_dimensions()
if (
original_extra_dim is not None
and new_extra_dim is not None
and original_extra_dim != new_extra_dim
):
rename_map[new_extra_dim] = original_extra_dim
if representative_array.rio.x_dim != merged_data.rio.x_dim:
rename_map[merged_data.rio.x_dim] = representative_array.rio.x_dim
if representative_array.rio.y_dim != merged_data.rio.y_dim:
rename_map[merged_data.rio.y_dim] = representative_array.rio.y_dim
if rename_map:
merged_data = merged_data.rename(rename_map)

merged_data.coords.update(_get_nonspatial_coords(representative_array))
return merged_data # type: ignore


def merge_datasets(
Expand Down
21 changes: 12 additions & 9 deletions test/integration/test_integration_merge.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os

import pytest
import xarray
from numpy import nansum
from numpy.testing import assert_almost_equal

Expand Down Expand Up @@ -48,8 +49,8 @@ def test_merge_arrays(squeeze):
),
)
assert merged.rio._cached_transform() == merged.rio.transform()
assert sorted(merged.coords) == sorted(rds.coords)
assert merged.coords["band"].values == [1]
assert sorted(merged.coords) == ["band", "spatial_ref", "x", "y"]
assert merged.rio.crs == rds.rio.crs
assert merged.attrs == {
"AREA_OR_POINT": "Area",
Expand All @@ -63,10 +64,12 @@ def test_merge_arrays(squeeze):
@pytest.mark.parametrize("dataset", [True, False])
def test_merge__different_crs(dataset):
dem_test = os.path.join(TEST_INPUT_DATA_DIR, "MODIS_ARRAY.nc")
with open_rasterio(dem_test) as rds:
with (
xarray.open_dataset(dem_test, mask_and_scale=False)
if dataset
else xarray.open_dataarray(dem_test, mask_and_scale=False)
) as rds:
crs = rds.rio.crs
if dataset:
rds = rds.to_dataset()
arrays = [
rds.isel(x=slice(100), y=slice(100)).rio.reproject("EPSG:3857"),
rds.isel(x=slice(100, 200), y=slice(100, 200)),
Expand Down Expand Up @@ -106,8 +109,7 @@ def test_merge__different_crs(dataset):
1.0,
),
)
assert merged.coords["band"].values == [1]
assert sorted(merged.coords) == ["band", "spatial_ref", "x", "y"]
assert sorted(merged.coords) == sorted(list(rds.coords) + ["spatial_ref"])
assert merged.rio.crs == rds.rio.crs
if not dataset:
assert merged.attrs == {
Expand Down Expand Up @@ -151,8 +153,8 @@ def test_merge_arrays__res():
),
)
assert merged.rio._cached_transform() == merged.rio.transform()
assert sorted(merged.coords) == sorted(rds.coords)
assert merged.coords["band"].values == [1]
assert sorted(merged.coords) == ["band", "spatial_ref", "x", "y"]
assert merged.rio.crs == rds.rio.crs
assert_almost_equal(merged.rio.nodata, rds.rio.nodata)
assert_almost_equal(merged.rio.encoded_nodata, rds.rio.encoded_nodata)
Expand Down Expand Up @@ -207,8 +209,8 @@ def test_merge_datasets():
1.0,
),
)
assert sorted(merged.coords) == sorted(rds.coords)
assert merged.coords["band"].values == [1]
assert sorted(merged.coords) == ["band", "spatial_ref", "x", "y"]
assert merged.rio.crs == rds.rio.crs
assert merged.attrs == rds.attrs
assert merged.encoding["grid_mapping"] == "spatial_ref"
Expand Down Expand Up @@ -263,7 +265,7 @@ def test_merge_datasets__res():
)
assert merged.rio.shape == (1112, 1112)
assert merged.coords["band"].values == [1]
assert sorted(merged.coords) == ["band", "spatial_ref", "x", "y"]
assert sorted(merged.coords) == sorted(rds.coords)
assert merged.rio.crs == rds.rio.crs
assert merged.attrs == rds.attrs
assert merged.encoding["grid_mapping"] == "spatial_ref"
Expand All @@ -282,6 +284,7 @@ def test_merge_datasets__mask_and_scale(mask_and_scale):
rds.isel(x=slice(100, None), y=slice(100)),
]
merged = merge_datasets(datasets)
assert sorted(merged.coords) == sorted(list(rds.coords) + ["spatial_ref"])
total = merged.air_temperature.sum()
if mask_and_scale:
assert_almost_equal(total, 133376696)
Expand Down

0 comments on commit c85613e

Please sign in to comment.