Skip to content

Commit

Permalink
BUG:merge: Ensure dims and coords match input array
Browse files Browse the repository at this point in the history
  • Loading branch information
snowman2 committed Nov 8, 2024
1 parent 87103e1 commit 2c8ed76
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 29 deletions.
1 change: 1 addition & 0 deletions docs/history.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ History
Latest
-------
- DEP: Pin rasterio>=1.3.7 (pull #826)
- BUG:merge: Ensure dims and coords match input array (pull #828)

0.18.0
------
Expand Down
52 changes: 32 additions & 20 deletions rioxarray/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from xarray import DataArray, Dataset

from rioxarray._io import open_rasterio
from rioxarray.rioxarray import _get_nonspatial_coords, _make_coords
from rioxarray.rioxarray import _get_nonspatial_coords


class RasterioDatasetDuck:
Expand Down Expand Up @@ -167,16 +167,35 @@ def merge_arrays(
memfile.name,
parse_coordinates=parse_coordinates,
mask_and_scale=rioduckarrays[0]._mask_and_scale,
) as xda:
xda = xda.load()
xda.coords.update(
{
coord: value
for coord, value in _get_nonspatial_coords(representative_array).items()
if coord not in xda.coords
}
)
return xda # type: ignore
) as merged_data:
merged_data = merged_data.load()

# make sure old & new coorinate names match & dimensions are correct
rename_map = {}
original_extra_dim = representative_array.rio._check_dimensions()
new_extra_dim = merged_data.rio._check_dimensions()
# make sure the output merged data shape is 2D if the
# original data was 2D. this can happen if the
# xarray datasarray was squeezed.
if len(merged_data.shape) == 3 and len(representative_array.shape) == 2:
merged_data = merged_data.squeeze(
dim=new_extra_dim, drop=original_extra_dim is None
)
new_extra_dim = merged_data.rio._check_dimensions()
if (
original_extra_dim is not None
and new_extra_dim is not None
and original_extra_dim != new_extra_dim
):
rename_map[new_extra_dim] = original_extra_dim
if representative_array.rio.x_dim != merged_data.rio.x_dim:
rename_map[merged_data.rio.x_dim] = representative_array.rio.x_dim
if representative_array.rio.y_dim != merged_data.rio.y_dim:
rename_map[merged_data.rio.y_dim] = representative_array.rio.y_dim
if rename_map:
merged_data = merged_data.rename(rename_map)
merged_data.coords.update(_get_nonspatial_coords(representative_array))
return merged_data # type: ignore


def merge_datasets(
Expand Down Expand Up @@ -227,7 +246,7 @@ def merge_datasets(

representative_ds = datasets[0]
merged_data = {}
for data_var in representative_ds.data_vars:
for iii, data_var in enumerate(representative_ds.data_vars):
merged_data[data_var] = merge_arrays(
[dataset[data_var] for dataset in datasets],
bounds=bounds,
Expand All @@ -236,18 +255,11 @@ def merge_datasets(
precision=precision,
method=method,
crs=crs,
parse_coordinates=False,
parse_coordinates=iii == 0,
)
data_var = list(representative_ds.data_vars)[0]
xds = Dataset(
merged_data,
coords=_make_coords(
src_data_array=merged_data[data_var],
dst_affine=merged_data[data_var].rio.transform(),
dst_width=merged_data[data_var].shape[-1],
dst_height=merged_data[data_var].shape[-2],
force_generate=True,
),
attrs=representative_ds.attrs,
)
xds.rio.write_crs(merged_data[data_var].rio.crs, inplace=True)
Expand Down
33 changes: 24 additions & 9 deletions test/integration/test_integration_merge.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os

import pytest
import xarray
from numpy import nansum
from numpy.testing import assert_almost_equal

Expand Down Expand Up @@ -48,8 +49,8 @@ def test_merge_arrays(squeeze):
),
)
assert merged.rio._cached_transform() == merged.rio.transform()
assert sorted(merged.coords) == sorted(rds.coords)
assert merged.coords["band"].values == [1]
assert sorted(merged.coords) == ["band", "spatial_ref", "x", "y"]
assert merged.rio.crs == rds.rio.crs
assert merged.attrs == {
"AREA_OR_POINT": "Area",
Expand All @@ -63,10 +64,12 @@ def test_merge_arrays(squeeze):
@pytest.mark.parametrize("dataset", [True, False])
def test_merge__different_crs(dataset):
dem_test = os.path.join(TEST_INPUT_DATA_DIR, "MODIS_ARRAY.nc")
with open_rasterio(dem_test) as rds:
with (
xarray.open_dataset(dem_test, mask_and_scale=False)
if dataset
else xarray.open_dataarray(dem_test, mask_and_scale=False)
) as rds:
crs = rds.rio.crs
if dataset:
rds = rds.to_dataset()
arrays = [
rds.isel(x=slice(100), y=slice(100)).rio.reproject("EPSG:3857"),
rds.isel(x=slice(100, 200), y=slice(100, 200)),
Expand Down Expand Up @@ -106,8 +109,7 @@ def test_merge__different_crs(dataset):
1.0,
),
)
assert merged.coords["band"].values == [1]
assert sorted(merged.coords) == ["band", "spatial_ref", "x", "y"]
assert sorted(merged.coords) == sorted(list(rds.coords) + ["spatial_ref"])
assert merged.rio.crs == rds.rio.crs
if not dataset:
assert merged.attrs == {
Expand Down Expand Up @@ -151,8 +153,8 @@ def test_merge_arrays__res():
),
)
assert merged.rio._cached_transform() == merged.rio.transform()
assert sorted(merged.coords) == sorted(rds.coords)
assert merged.coords["band"].values == [1]
assert sorted(merged.coords) == ["band", "spatial_ref", "x", "y"]
assert merged.rio.crs == rds.rio.crs
assert_almost_equal(merged.rio.nodata, rds.rio.nodata)
assert_almost_equal(merged.rio.encoded_nodata, rds.rio.encoded_nodata)
Expand Down Expand Up @@ -207,8 +209,8 @@ def test_merge_datasets():
1.0,
),
)
assert sorted(merged.coords) == sorted(rds.coords)
assert merged.coords["band"].values == [1]
assert sorted(merged.coords) == ["band", "spatial_ref", "x", "y"]
assert merged.rio.crs == rds.rio.crs
assert merged.attrs == rds.attrs
assert merged.encoding["grid_mapping"] == "spatial_ref"
Expand Down Expand Up @@ -263,7 +265,7 @@ def test_merge_datasets__res():
)
assert merged.rio.shape == (1112, 1112)
assert merged.coords["band"].values == [1]
assert sorted(merged.coords) == ["band", "spatial_ref", "x", "y"]
assert sorted(merged.coords) == sorted(rds.coords)
assert merged.rio.crs == rds.rio.crs
assert merged.attrs == rds.attrs
assert merged.encoding["grid_mapping"] == "spatial_ref"
Expand All @@ -282,8 +284,21 @@ def test_merge_datasets__mask_and_scale(mask_and_scale):
rds.isel(x=slice(100, None), y=slice(100)),
]
merged = merge_datasets(datasets)
assert sorted(merged.coords) == sorted(list(rds.coords) + ["spatial_ref"])
total = merged.air_temperature.sum()
if mask_and_scale:
assert_almost_equal(total, 133376696)
else:
assert_almost_equal(total, 10981781386)


def test_merge_datasets__preserve_dimension_names():
sentinel_2_geographic = os.path.join(
TEST_INPUT_DATA_DIR, "sentinel_2_L1C_geographic.nc"
)
with xarray.open_dataset(sentinel_2_geographic) as mda:
merged = merge_datasets([mda])
assert sorted(merged.coords) == sorted(mda.coords)
for data_var in mda.data_vars:
assert_almost_equal(merged[data_var].sum(), mda[data_var].sum())
assert merged.rio.crs == mda.rio.crs

0 comments on commit 2c8ed76

Please sign in to comment.