Skip to content

Commit

Permalink
Review netcdf to_xarray args (#369)
Browse files Browse the repository at this point in the history
* Review netcdf to_xarray args
  • Loading branch information
sandorkertesz authored Apr 18, 2024
1 parent 0212f7f commit 99eace2
Show file tree
Hide file tree
Showing 7 changed files with 174 additions and 20 deletions.
23 changes: 23 additions & 0 deletions docs/guide/sources.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ We can get data from a given source by using :func:`from_source`:
- read data from a stream
* - :ref:`data-sources-memory`
- read data from a memory buffer
* - :ref:`data-sources-multi`
- read data from multiple sources
* - :ref:`data-sources-ads`
- retrieve data from the `Copernicus Atmosphere Data Store <https://ads.atmosphere.copernicus.eu/>`_ (ADS)
* - :ref:`data-sources-cds`
Expand Down Expand Up @@ -418,6 +420,27 @@ memory
print(f.metadata("param"))
.. _data-sources-multi:

multi
--------------

.. py:function:: from_source("multi", *sources, merger=None, **kwargs)
:noindex:

The ``multi`` source reads multiple sources.

:param tuple *sources: the sources
:param merger: if it is None an attempt is made to merge/concatenate the sources by their classes (using the nearest common class). Otherwise the sources are merged/concatenated using the merger in a lazy way. The merger can one of the following:
- class/object implementing the :func:`to_xarray` or :func:`to_pandas` methods
- callable
- str, describing a call either to "concat" or "merge". E.g.: "concat(concat_dim=time)"
- tuple with 2 elements. The fist element is a str, either "concat" or "merge", and the second element is a dict with the keyword arguments for the call. E.g.: ("concat", {"concat_dim": "time"})
:param dict **kwargs: other keyword arguments


.. _data-sources-ads:

ads
Expand Down
28 changes: 28 additions & 0 deletions docs/howtos.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,31 @@ save the results of a :ref:`MARS retrieval <data-sources-mars>` into a file:
)
ds.save("my_data.grib")
How to call to_xarray() with arguments for NetCDF data?
---------------------------------------------------------

When calling :func:`to_xarray` for NetCDF data it calls ``xarray.open_mfdataset`` internally. You can pass arguments to this xarray function by using the ``xarray_open_mfdataset_kwargs`` option. For example:


.. code-block:: python
import earthkit.data
req = {
"format": "zip",
"origin": "c3s",
"sensor": "olci",
"version": "1_1",
"year": "2022",
"month": "04",
"nominal_day": "01",
"variable": "pixel_variables",
"region": "europe",
}
ds = earthkit.data.from_source("cds", "satellite-fire-burned-area", req)
r = ds.to_xarray(
xarray_open_mfdataset_kwargs=dict(decode_cf=False, decode_times=False)
)
16 changes: 1 addition & 15 deletions earthkit/data/mergers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

LOG = logging.getLogger(__name__)

FORWARDS = ("to_xarray", "to_pandas", "to_tfdataset")
FORWARDS = ("to_xarray", "to_pandas")


def _nearest_common_class(objects):
Expand Down Expand Up @@ -86,16 +86,6 @@ def to_pandas(self, **kwargs):
**kwargs,
)

def to_tfdataset(self, **kwargs):
from .tfdataset import merge

return merge(
sources=self.sources,
paths=self.paths,
reader_class=self.reader_class,
**kwargs,
)

def to_xarray(self, **kwargs):
from .xarray import merge

Expand All @@ -118,9 +108,6 @@ def to_xarray(self, *args, **kwargs):
def to_pandas(self, *args, **kwargs):
return self.obj.to_pandas(self.paths_or_sources, **kwargs)

def to_tfdataset(self, *args, **kwargs):
return self.obj.to_tfdataset(self.paths_or_sources, **kwargs)


class CallableMerger(Merger):
def __init__(self, func, sources, *args, **kwargs):
Expand All @@ -132,7 +119,6 @@ def _call_func(self, *args, **kwargs):

to_xarray = _call_func
to_pandas = _call_func
to_tfdataset = _call_func


class XarrayGenericMerger(Merger):
Expand Down
2 changes: 2 additions & 0 deletions earthkit/data/readers/netcdf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ def to_xarray_multi_from_paths(cls, paths, **kwargs):

options = dict()
options.update(kwargs.get("xarray_open_mfdataset_kwargs", {}))
if not options:
options = dict(**kwargs)

return xr.open_mfdataset(
paths,
Expand Down
10 changes: 5 additions & 5 deletions earthkit/data/readers/netcdf/fieldlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,8 @@ def to_xarray_multi_from_paths(cls, paths, **kwargs):

options = dict()
options.update(kwargs.get("xarray_open_mfdataset_kwargs", {}))
if not options:
options = dict(**kwargs)

return xr.open_mfdataset(
paths,
Expand Down Expand Up @@ -248,7 +250,7 @@ def __init__(self, *args, **kwargs):
def to_xarray(self, **kwargs):
import xarray as xr

return xr.merge([x.ds for x in self._indexes], **kwargs)
return xr.merge([x._ds for x in self._indexes], **kwargs)


class NetCDFFieldList(XArrayFieldListCore):
Expand All @@ -268,10 +270,8 @@ def new_mask_index(cls, *args, **kwargs):
return NetCDFMaskFieldList(*args, **kwargs)

def to_xarray(self, **kwargs):
import xarray as xr

if self.path.startswith("http"):
return xr.open_dataset(self.path, **kwargs)
# if self.path.startswith("http"):
# return xr.open_dataset(self.path, **kwargs)
return type(self).to_xarray_multi_from_paths([self.path], **kwargs)

def write(self, *args, **kwargs):
Expand Down
49 changes: 49 additions & 0 deletions tests/netcdf/test_netcdf_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,55 @@ def test_netcdf_concat(mode):
],
}

import xarray as xr

target = xr.merge([ds1.to_xarray(), ds2.to_xarray()])
merged = ds.to_xarray()
assert target.identical(merged)


def test_netcdf_read_multiple_files():
ds = from_source(
"file",
[
earthkit_test_data_file("era5_2t_1.nc"),
earthkit_test_data_file("era5_2t_2.nc"),
],
)

assert len(ds) == 2
assert ds.metadata("variable") == ["t2m", "t2m"]

assert ds[0].datetime() == {
"base_time": datetime.datetime(2021, 3, 1, 12, 0),
"valid_time": datetime.datetime(2021, 3, 1, 12, 0),
}
assert ds[1].datetime() == {
"base_time": datetime.datetime(2021, 3, 2, 12, 0),
"valid_time": datetime.datetime(2021, 3, 2, 12, 0),
}
assert ds.datetime() == {
"base_time": [
datetime.datetime(2021, 3, 1, 12, 0),
datetime.datetime(2021, 3, 2, 12, 0),
],
"valid_time": [
datetime.datetime(2021, 3, 1, 12, 0),
datetime.datetime(2021, 3, 2, 12, 0),
],
}

import xarray as xr

target = xr.merge(
[
xr.open_dataset(earthkit_test_data_file("era5_2t_1.nc")),
xr.open_dataset(earthkit_test_data_file("era5_2t_2.nc")),
]
)
merged = ds.to_xarray()
assert target.identical(merged)


@pytest.mark.parametrize("custom_merger", (merger_func, Merger_obj()))
def test_netdcf_merge_custom(custom_merger):
Expand Down
66 changes: 66 additions & 0 deletions tests/netcdf/test_netcdf_convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#!/usr/bin/env python3

# (C) Copyright 2020 ECMWF.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
#

import numpy as np
import pytest

from earthkit.data import from_source
from earthkit.data.testing import earthkit_remote_test_data_file


@pytest.mark.long_test
@pytest.mark.download
def test_netcdf_to_xarray_args():
# The JD variable in the NetCDF is defined as follows:
#
# short JD(time, lat, lon) ;
# string JD:long_name = "Date of the first detection" ;
# string JD:units = "days since 2022-01-01" ;
# string JD:comment = "Possible values: 0 when the pixel is not burned; 1 to 366 day of
# the first detection when the pixel is burned; -1 when the pixel is not observed
# in the month; -2 when pixel is not burnable: water bodies, bare areas, urban areas,
# and permanent snow and ice.
#
# when loaded with xarray.open_dataset/xarray.open_mdataset without any kwargs the
# type of the JD variable is datetime64[ns], which is wrong. The correct type should
# be int16.

ds = from_source(
"url",
earthkit_remote_test_data_file(
"test-data", "20220401-C3S-L3S_FIRE-BA-OLCI-AREA_3-fv1.1.nc"
),
)

r = ds.to_xarray(
xarray_open_mfdataset_kwargs=dict(decode_cf=False, decode_times=False)
)
assert r["JD"].dtype == "int16"
r["JD"].shape == (1, 20880, 28440)
assert np.isclose(r["JD"].values.min(), -2)
assert np.isclose(r["JD"].values.max(), 120)

r = ds.to_xarray(decode_cf=False, decode_times=False)
assert r["JD"].dtype == "int16"
r["JD"].shape == (1, 20880, 28440)
assert np.isclose(r["JD"].values.min(), -2)
assert np.isclose(r["JD"].values.max(), 120)

r = ds.to_xarray()
assert r["JD"].dtype == "<M8[ns]"
r["JD"].shape == (1, 20880, 28440)


if __name__ == "__main__":
from earthkit.data.testing import main

# test_datetime()
main(__file__)

0 comments on commit 99eace2

Please sign in to comment.