Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not implicitly set generatingProcessIdentifier when writing GRIB data #275

Merged
merged 2 commits into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions earthkit/data/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,12 @@ def metadata(self, *args, **kwargs):

# I/O
@abstractmethod
def save(self, path):
def save(self, path, **kwargs):
"""Writes data into the specified ``path``."""
self._not_implemented()

@abstractmethod
def write(self, f):
def write(self, f, **kwargs):
"""Writes data to the ``f`` file object."""
self._not_implemented()

Expand Down
8 changes: 4 additions & 4 deletions earthkit/data/core/fieldlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -1164,7 +1164,7 @@ def _is_shared_grid(self):
)
return False

def save(self, filename, append=False):
def save(self, filename, append=False, **kwargs):
r"""Write all the fields into a file.

Parameters
Expand All @@ -1177,9 +1177,9 @@ def save(self, filename, append=False):
"""
flag = "wb" if not append else "ab"
with open(filename, flag) as f:
self.write(f)
self.write(f, **kwargs)

def write(self, f):
def write(self, f, **kwargs):
r"""Write all the fields to a file object.

Parameters
Expand All @@ -1188,7 +1188,7 @@ def write(self, f):
The target file object.
"""
for s in self:
s.write(f)
s.write(f, **kwargs)

def to_fieldlist(self, backend, **kwargs):
r"""Convert to a new :class:`FieldList` based on the ``backend``.
Expand Down
6 changes: 3 additions & 3 deletions earthkit/data/readers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,12 @@ def ignore(self):
def cache_file(self, *args, **kwargs):
return self.source.cache_file(*args, **kwargs)

def save(self, path):
def save(self, path, **kwargs):
mode = "wb" if self.binary else "w"
with open(path, mode) as f:
self.write(f)
self.write(f, **kwargs)

def write(self, f):
def write(self, f, **kwargs):
if not self.appendable:
assert f.tell() == 0
mode = "rb" if self.binary else "r"
Expand Down
4 changes: 2 additions & 2 deletions earthkit/data/readers/directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,10 @@ def mutate_source(self):
merger=self.merger,
)

def save(self, path):
def save(self, path, **kwargs):
shutil.copytree(self.path, path)

def write(self, f):
def write(self, f, **kwargs):
raise NotImplementedError()


Expand Down
2 changes: 0 additions & 2 deletions earthkit/data/readers/grib/codes.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,6 @@ def set_values(self, values):
try:
assert self.path is None, "Only cloned handles can have values changed"
eccodes.codes_set_values(self._handle, values.flatten())
# This is writing on the GRIB that something has been modified (255=unknown)
eccodes.codes_set_long(self._handle, "generatingProcessIdentifier", 255)
except Exception as e:
LOG.error("Error setting values")
LOG.exception(e)
Expand Down
3 changes: 3 additions & 0 deletions earthkit/data/readers/grib/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ def write(
k: v for k, v in sorted(metadata.items(), key=lambda x: order(x[0]))
}

if "generatingProcessIdentifier" not in metadata:
metadata["generatingProcessIdentifier"] = 255

LOG.debug("GribOutput.metadata %s", metadata)

for k, v in metadata.items():
Expand Down
8 changes: 4 additions & 4 deletions earthkit/data/sources/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,11 @@ def to_numpy(self, **kwargs):
def values(self):
return self._reader.values

def save(self, path):
return self._reader.save(path)
def save(self, path, **kwargs):
return self._reader.save(path, **kwargs)

def write(self, f):
return self._reader.write(f)
def write(self, f, **kwargs):
return self._reader.write(f, **kwargs)

def scaled(self, *args, **kwargs):
return self._reader.scaled(*args, **kwargs)
Expand Down
8 changes: 4 additions & 4 deletions earthkit/data/sources/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@ def values(self):
def to_fieldlist(self, *args, **kwargs):
return self._reader.to_fieldlist(*args, **kwargs)

def save(self, path):
return self._reader.save(path)
def save(self, path, **kwargs):
return self._reader.save(path, **kwargs)

def write(self, f):
return self._reader.write(f)
def write(self, f, **kwargs):
return self._reader.write(f, **kwargs)

def scaled(self, *args, **kwargs):
return self._reader.scaled(*args, **kwargs)
Expand Down
4 changes: 2 additions & 2 deletions earthkit/data/sources/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,10 @@ def __repr__(self) -> str:
string = ",".join(repr(s) for s in self.sources)
return f"{self.__class__.__name__}({string})"

def save(self, path):
def save(self, path, **kwargs):
with open(path, "wb") as f:
for s in self.sources:
s.write(f)
s.write(f, **kwargs)

def graph(self, depth=0):
print(" " * depth, self.__class__.__name__, self.merger)
Expand Down
17 changes: 13 additions & 4 deletions earthkit/data/sources/numpy_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@


class NumpyField(Field):
r"""Represents a field consisting of an ndarray and metadata object.
r"""Represent a field consisting of an ndarray and metadata object.

Parameters
----------
Expand All @@ -46,10 +46,19 @@ def _values(self, dtype=None):
def __repr__(self):
return f"{self.__class__.__name__}()"

def write(self, f):
def write(self, f, **kwargs):
r"""Write the field to a file object.

Parameters
----------
f: file object
The target file object.
**kwargs: dict, optional
Other keyword arguments passed to :meth:`data.writers.grib.GribWriter.write`.
"""
from earthkit.data.writers import write

write(f, self.values, self._metadata, check_nans=True)
write(f, self.values, self._metadata, **kwargs)


class NumpyFieldListCore(PandasMixIn, XarrayMixIn, FieldList):
Expand Down Expand Up @@ -158,7 +167,7 @@ def to_fieldlist(self):


class NumpyFieldList(NumpyFieldListCore):
r"""Represents a list of :obj:`NumpyField <data.sources.numpy_list.NumpyField>`\ s.
r"""Represent a list of :obj:`NumpyField <data.sources.numpy_list.NumpyField>`\ s.

The preferred way to create a NumpyFieldList is to use either the
static :obj:`from_numpy` method or the :obj:`to_fieldlist` method.
Expand Down
13 changes: 13 additions & 0 deletions earthkit/data/writers/grib.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,19 @@ class GribWriter(Writer):
DATA_FORMAT = "grib"

def write(self, f, values, metadata, check_nans=True):
r"""Write a GRIB field to a file object.

Parameters
----------
f: file object
The target file object.
values: ndarray
Values of the GRIB field/message.
values: :class:`GribMetadata`
Metadata of the GRIB field/message.
check_nans: bool
Replace nans in ``values`` with GRIB missing values when writing to``f``.
"""
handle = metadata._handle
if check_nans:
import numpy as np
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ install_requires =
tqdm
xarray>=0.19.0
earthkit-meteo>=0.0.1
aws-requests-auth
include_package_data = True

[options.packages.find]
Expand Down
6 changes: 6 additions & 0 deletions tests/grib/test_grib_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def test_grib_output_latlon():
assert ds[0].metadata("param") == "2t"
assert ds[0].metadata("levtype") == "sfc"
assert ds[0].metadata("edition") == 2
assert ds[0].metadata("generatingProcessIdentifier") == 255

assert np.allclose(ds[0].to_numpy(), data, rtol=EPSILON, atol=EPSILON)

Expand All @@ -79,6 +80,7 @@ def test_grib_output_o96():
assert ds[0].metadata("param") == "2t"
assert ds[0].metadata("levtype") == "sfc"
assert ds[0].metadata("edition") == 2
assert ds[0].metadata("generatingProcessIdentifier") == 255

assert np.allclose(ds[0].to_numpy(), data, rtol=EPSILON, atol=EPSILON)

Expand All @@ -103,6 +105,7 @@ def test_grib_output_o160():
assert ds[0].metadata("edition") == 2
assert ds[0].metadata("levtype") == "sfc"
assert ds[0].metadata("param") == "2t"
assert ds[0].metadata("generatingProcessIdentifier") == 255

assert np.allclose(ds[0].to_numpy(), data, rtol=EPSILON, atol=EPSILON)

Expand Down Expand Up @@ -130,6 +133,7 @@ def test_grib_output_mars_labeling():
assert ds[0].metadata("levtype") == "sfc"
assert ds[0].metadata("param") == "msl"
assert ds[0].metadata("type") == "fc"
assert ds[0].metadata("generatingProcessIdentifier") == 255

assert np.allclose(ds[0].to_numpy(), data, rtol=EPSILON, atol=EPSILON)

Expand Down Expand Up @@ -158,6 +162,7 @@ def test_grib_output_pl(levtype):
assert ds[0].metadata("level") == 850
assert ds[0].metadata("levtype") == "pl"
assert ds[0].metadata("param") == "t"
assert ds[0].metadata("generatingProcessIdentifier") == 255

assert np.allclose(ds[0].to_numpy(), data, rtol=EPSILON, atol=EPSILON)

Expand Down Expand Up @@ -185,6 +190,7 @@ def test_grib_output_tp():
assert ds[0].metadata("levtype") == "sfc"
assert ds[0].metadata("edition") == 1
assert ds[0].metadata("step") == 48
assert ds[0].metadata("generatingProcessIdentifier") == 255

assert np.allclose(ds[0].to_numpy(), data, rtol=EPSILON, atol=EPSILON)

Expand Down
74 changes: 65 additions & 9 deletions tests/numpy_fs/test_numpy_fs_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import os

import numpy as np
import pytest

from earthkit.data import from_source
from earthkit.data.core.fieldlist import FieldList
Expand All @@ -22,7 +23,8 @@
LOG = logging.getLogger(__name__)


def test_numpy_fs_grib_write_missing():
@pytest.mark.parametrize("_kwargs", [{}, {"check_nans": True}])
def test_numpy_fs_grib_write_missing(_kwargs):
ds = from_source("file", earthkit_examples_file("test.grib"))

assert ds[0].metadata("shortName") == "2t"
Expand All @@ -42,14 +44,40 @@ def test_numpy_fs_grib_write_missing():
assert np.isnan(r[0].values[0])
assert not np.isnan(r[0].values[1])

# save to disk
tmp = temp_file()
r.save(tmp.path)
assert os.path.exists(tmp.path)
r_tmp = from_source("file", tmp.path)
v_tmp = r_tmp[0].values
assert np.isnan(v_tmp[0])
assert not np.isnan(v_tmp[1])
with temp_file() as tmp:
r.save(tmp, **_kwargs)
assert os.path.exists(tmp)
r_tmp = from_source("file", tmp)
v_tmp = r_tmp[0].values
assert np.isnan(v_tmp[0])
assert not np.isnan(v_tmp[1])


def test_numpy_fs_grib_write_check_nans_bad():
ds = from_source("file", earthkit_examples_file("test.grib"))

assert ds[0].metadata("shortName") == "2t"

v = ds[0].values
v1 = v + 1
assert not np.isnan(v1[0])
assert not np.isnan(v1[1])
v1[0] = np.nan
assert np.isnan(v1[0])
assert not np.isnan(v1[1])

md = ds[0].metadata()
md1 = md.override(shortName="msl")
r = FieldList.from_numpy(v1, md1)

assert np.isnan(r[0].values[0])
assert not np.isnan(r[0].values[1])

with temp_file() as tmp:
from eccodes import EncodingError

with pytest.raises(EncodingError):
r.save(tmp, check_nans=False)


def test_numpy_fs_grib_write_append():
Expand Down Expand Up @@ -85,6 +113,34 @@ def test_numpy_fs_grib_write_append():
assert r_tmp.metadata("shortName") == ["msl", "2d"]


def test_numpy_fs_grib_write_generating_proc_id():
ds = from_source("file", earthkit_examples_file("test.grib"))

assert ds[0].metadata("shortName") == "2t"

v = ds[0].values
v1 = v + 1
v2 = v + 2

md = ds[0].metadata()
md1 = md.override(shortName="msl", generatingProcessIdentifier=255)
md2 = md.override(shortName="2d")

r1 = FieldList.from_numpy([v1, v2], [md1, md2])

# save to disk: using generatingProcessIdentifier=255 (default)
with temp_file() as tmp:
r1.save(tmp)
assert os.path.exists(tmp)
r_tmp = from_source("file", tmp)
assert len(r_tmp) == 2
assert r_tmp.metadata("shortName") == ["msl", "2d"]
assert r_tmp.metadata("generatingProcessIdentifier") == [
255,
150,
]


if __name__ == "__main__":
from earthkit.data.testing import main

Expand Down
Loading