Skip to content

Commit

Permalink
Add skipna arg to all temporal APIs
Browse files Browse the repository at this point in the history
- Add unit tests
  • Loading branch information
tomvothecoder committed Nov 6, 2024
1 parent f74699c commit 9c4aa48
Show file tree
Hide file tree
Showing 3 changed files with 220 additions and 12 deletions.
22 changes: 22 additions & 0 deletions tests/test_spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,28 @@ def test_spatial_average_for_lat_region(self):

assert result.identical(expected)

def test_spatial_average_for_lat_region_and_skipna(self):
ds = self.ds.copy(deep=True)
ds.ts[0] = np.nan

# Specifying axis as a str instead of list of str.
result = ds.spatial.average("ts", axis=["Y"], lat_bounds=(-5.0, 5), skipna=True)

expected = self.ds.copy()
expected["ts"] = xr.DataArray(
data=np.array(
[
[np.nan, np.nan, np.nan, np.nan],
[1.0, 1.0, 1.0, 1.0],
[1.0, 1.0, 1.0, 1.0],
]
),
coords={"time": expected.time, "lon": expected.lon},
dims=["time", "lon"],
)

assert result.identical(expected)

def test_spatial_average_for_domain_wrapping_p_meridian_non_cf_conventions(
self,
):
Expand Down
169 changes: 169 additions & 0 deletions tests/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,57 @@ def test_weighted_annual_averages(self):
assert result.ts.attrs == expected.ts.attrs
assert result.time.attrs == expected.time.attrs

def test_weighted_annual_averages_and_skipna(self):
ds = self.ds.copy(deep=True)
ds.ts[0] = np.nan

result = ds.temporal.group_average("ts", "year", skipna=True)
expected = ds.copy()
expected = expected.drop_dims("time")
expected["ts"] = xr.DataArray(
name="ts",
data=np.array([[[1]], [[2.0]]]),
coords={
"lat": expected.lat,
"lon": expected.lon,
"time": xr.DataArray(
data=np.array(
[
cftime.DatetimeGregorian(2000, 1, 1),
cftime.DatetimeGregorian(2001, 1, 1),
],
),
coords={
"time": np.array(
[
cftime.DatetimeGregorian(2000, 1, 1),
cftime.DatetimeGregorian(2001, 1, 1),
],
)
},
dims=["time"],
attrs={
"axis": "T",
"long_name": "time",
"standard_name": "time",
"bounds": "time_bnds",
},
),
},
dims=["time", "lat", "lon"],
attrs={
"test_attr": "test",
"operation": "temporal_avg",
"mode": "group_average",
"freq": "year",
"weighted": "True",
},
)

xr.testing.assert_allclose(result, expected)
assert result.ts.attrs == expected.ts.attrs
assert result.time.attrs == expected.time.attrs

@requires_dask
def test_weighted_annual_averages_with_chunking(self):
ds = self.ds.copy().chunk({"time": 2})
Expand Down Expand Up @@ -1161,6 +1212,68 @@ def test_weighted_seasonal_climatology_with_DJF(self):

xr.testing.assert_identical(result, expected)

def test_weighted_seasonal_climatology_with_DJF_and_skipna(self):
ds = self.ds.copy(deep=True)

# Replace all MAM values with np.nan.
djf_months = [3, 4, 5]
for mon in djf_months:
ds["ts"] = ds.ts.where(ds.ts.time.dt.month != mon, np.nan)

result = ds.temporal.climatology(
"ts",
"season",
season_config={"dec_mode": "DJF", "drop_incomplete_djf": True},
skipna=True,
)

expected = ds.copy()
expected = expected.drop_dims("time")
expected_time = xr.DataArray(
data=np.array(
[
cftime.DatetimeGregorian(1, 1, 1),
cftime.DatetimeGregorian(1, 4, 1),
cftime.DatetimeGregorian(1, 7, 1),
cftime.DatetimeGregorian(1, 10, 1),
],
),
coords={
"time": np.array(
[
cftime.DatetimeGregorian(1, 1, 1),
cftime.DatetimeGregorian(1, 4, 1),
cftime.DatetimeGregorian(1, 7, 1),
cftime.DatetimeGregorian(1, 10, 1),
],
),
},
attrs={
"axis": "T",
"long_name": "time",
"standard_name": "time",
"bounds": "time_bnds",
},
)
expected["ts"] = xr.DataArray(
name="ts",
data=np.ones((4, 4, 4)),
coords={"lat": expected.lat, "lon": expected.lon, "time": expected_time},
dims=["time", "lat", "lon"],
attrs={
"operation": "temporal_avg",
"mode": "climatology",
"freq": "season",
"weighted": "True",
"dec_mode": "DJF",
"drop_incomplete_djf": "True",
},
)
expected.ts[1] = np.nan

# MAM should be np.nan
assert result.identical(expected)

@requires_dask
def test_chunked_weighted_seasonal_climatology_with_DJF(self):
ds = self.ds.copy().chunk({"time": 2})
Expand Down Expand Up @@ -1947,6 +2060,62 @@ def test_weighted_seasonal_departures_with_DJF(self):

xr.testing.assert_identical(result, expected)

def test_weighted_seasonal_departures_with_DJF_and_skipna(self):
ds = self.ds.copy(deep=True)

# Replace all MAM values with np.nan.
djf_months = [3, 4, 5]
for mon in djf_months:
ds["ts"] = ds.ts.where(ds.ts.time.dt.month != mon, np.nan)

result = ds.temporal.departures(
"ts",
"season",
weighted=True,
season_config={"dec_mode": "DJF", "drop_incomplete_djf": True},
skipna=True,
)

expected = ds.copy()
expected = expected.drop_dims("time")
expected["ts"] = xr.DataArray(
name="ts",
data=np.array([[[np.nan]], [[0.0]], [[0.0]], [[0.0]]]),
coords={
"lat": expected.lat,
"lon": expected.lon,
"time": xr.DataArray(
data=np.array(
[
cftime.DatetimeGregorian(2000, 4, 1),
cftime.DatetimeGregorian(2000, 7, 1),
cftime.DatetimeGregorian(2000, 10, 1),
cftime.DatetimeGregorian(2001, 1, 1),
],
),
dims=["time"],
attrs={
"axis": "T",
"long_name": "time",
"standard_name": "time",
"bounds": "time_bnds",
},
),
},
dims=["time", "lat", "lon"],
attrs={
"test_attr": "test",
"operation": "temporal_avg",
"mode": "departures",
"freq": "season",
"weighted": "True",
"dec_mode": "DJF",
"drop_incomplete_djf": "True",
},
)

assert result.identical(expected)

def test_weighted_seasonal_departures_with_DJF_and_keep_weights(self):
ds = self.ds.copy()

Expand Down
41 changes: 29 additions & 12 deletions xcdat/temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def average(
data_var: str,
weighted: bool = True,
keep_weights: bool = False,
skipna: Union[bool, None] = None,
skipna: bool | None = None,
):
"""
Returns a Dataset with the average of a data variable and the time
Expand Down Expand Up @@ -202,7 +202,7 @@ def average(
keep_weights : bool, optional
If calculating averages using weights, keep the weights in the
final dataset output, by default False.
skipna : bool or None, optional
skipna : bool | None, optional
If True, skip missing values (as marked by NaN). By default, only
skips missing values for float dtypes; other dtypes either do not
have a sentinel missing value (int) or ``skipna=True`` has not been
Expand Down Expand Up @@ -257,6 +257,7 @@ def group_average(
weighted: bool = True,
keep_weights: bool = False,
season_config: SeasonConfigInput = DEFAULT_SEASON_CONFIG,
skipna: bool | None = None,
):
"""Returns a Dataset with average of a data variable by time group.
Expand Down Expand Up @@ -335,6 +336,11 @@ def group_average(
>>> ["Jul", "Aug", "Sep"], # "JulAugSep"
>>> ["Oct", "Nov", "Dec"], # "OctNovDec"
>>> ]
skipna : bool | None, optional
If True, skip missing values (as marked by NaN). By default, only
skips missing values for float dtypes; other dtypes either do not
have a sentinel missing value (int) or ``skipna=True`` has not been
implemented (object, datetime64 or timedelta64).
Returns
-------
Expand Down Expand Up @@ -413,6 +419,7 @@ def group_average(
weighted=weighted,
keep_weights=keep_weights,
season_config=season_config,
skipna=skipna,
)

def climatology(
Expand All @@ -423,6 +430,7 @@ def climatology(
keep_weights: bool = False,
reference_period: Optional[Tuple[str, str]] = None,
season_config: SeasonConfigInput = DEFAULT_SEASON_CONFIG,
skipna: bool | None = None,
):
"""Returns a Dataset with the climatology of a data variable.
Expand Down Expand Up @@ -510,6 +518,11 @@ def climatology(
>>> ["Jul", "Aug", "Sep"], # "JulAugSep"
>>> ["Oct", "Nov", "Dec"], # "OctNovDec"
>>> ]
skipna : bool | None, optional
If True, skip missing values (as marked by NaN). By default, only
skips missing values for float dtypes; other dtypes either do not
have a sentinel missing value (int) or ``skipna=True`` has not been
implemented (object, datetime64 or timedelta64).
Returns
-------
Expand Down Expand Up @@ -593,6 +606,7 @@ def climatology(
keep_weights,
reference_period,
season_config,
skipna,
)

def departures(
Expand All @@ -603,6 +617,7 @@ def departures(
keep_weights: bool = False,
reference_period: Optional[Tuple[str, str]] = None,
season_config: SeasonConfigInput = DEFAULT_SEASON_CONFIG,
skipna: bool | None = None,
) -> xr.Dataset:
"""
Returns a Dataset with the climatological departures (anomalies) for a
Expand Down Expand Up @@ -697,6 +712,11 @@ def departures(
>>> ["Jul", "Aug", "Sep"], # "JulAugSep"
>>> ["Oct", "Nov", "Dec"], # "OctNovDec"
>>> ]
skipna : bool | None, optional
If True, skip missing values (as marked by NaN). By default, only
skips missing values for float dtypes; other dtypes either do not
have a sentinel missing value (int) or ``skipna=True`` has not been
implemented (object, datetime64 or timedelta64).
Returns
-------
Expand Down Expand Up @@ -777,11 +797,7 @@ def departures(
inferred_freq = _infer_freq(ds[self.dim])
if inferred_freq != freq:
ds_obs = ds_obs.temporal.group_average(
data_var,
freq,
weighted,
keep_weights,
season_config,
data_var, freq, weighted, keep_weights, season_config, skipna
)

# 4. Calculate the climatology of the data variable.
Expand All @@ -794,6 +810,7 @@ def departures(
keep_weights,
reference_period,
season_config,
skipna,
)

# 5. Calculate the departures for the data variable.
Expand All @@ -815,7 +832,7 @@ def _averager(
keep_weights: bool = False,
reference_period: Optional[Tuple[str, str]] = None,
season_config: SeasonConfigInput = DEFAULT_SEASON_CONFIG,
skipna: Union[bool, None] = None,
skipna: bool | None = None,
) -> xr.Dataset:
"""Averages a data variable based on the averaging mode and frequency."""
ds = self._dataset.copy()
Expand Down Expand Up @@ -1141,7 +1158,7 @@ def _drop_leap_days(self, ds: xr.Dataset):
return ds

def _average(
self, ds: xr.Dataset, data_var: str, skipna: Union[bool, None] = None
self, ds: xr.Dataset, data_var: str, skipna: bool | None = None
) -> xr.DataArray:
"""Averages a data variable with the time dimension removed.
Expand All @@ -1151,7 +1168,7 @@ def _average(
The dataset.
data_var : str
The key of the data variable.
skipna : bool or None, optional
skipna : bool | None, optional
If True, skip missing values (as marked by NaN). By default, only
skips missing values for float dtypes; other dtypes either do not
have a sentinel missing value (int) or ``skipna=True`` has not been
Expand All @@ -1178,7 +1195,7 @@ def _average(
return dv

def _group_average(
self, ds: xr.Dataset, data_var: str, skipna: Union[bool, None] = None
self, ds: xr.Dataset, data_var: str, skipna: bool | None = None
) -> xr.DataArray:
"""Averages a data variable by time group.
Expand All @@ -1188,7 +1205,7 @@ def _group_average(
The dataset.
data_var : str
The key of the data variable.
skipna : bool or None, optional
skipna : bool | None, optional
If True, skip missing values (as marked by NaN). By default, only
skips missing values for float dtypes; other dtypes either do not
have a sentinel missing value (int) or ``skipna=True`` has not been
Expand Down

0 comments on commit 9c4aa48

Please sign in to comment.