diff --git a/.gitignore b/.gitignore index 84565514..fe6a0172 100644 --- a/.gitignore +++ b/.gitignore @@ -359,3 +359,5 @@ notebooks/data/*/ # local code _dev + +test.db diff --git a/earthkit/data/core/fieldlist.py b/earthkit/data/core/fieldlist.py index b44c1dba..03ed3226 100644 --- a/earthkit/data/core/fieldlist.py +++ b/earthkit/data/core/fieldlist.py @@ -537,7 +537,7 @@ def _attributes(self, names): class FieldList(Index): r"""Represents a list of :obj:`Field` \s.""" - _indices = {} + _md_indices = {} def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -622,12 +622,12 @@ def indices(self, squeeze=False): used in :obj:`indices`. """ - if not self._indices: - self._indices = self._find_all_index_dict() + if not self._md_indices: + self._md_indices = self._find_all_index_dict() if squeeze: - return {k: v for k, v in self._indices.items() if len(v) > 1} + return {k: v for k, v in self._md_indices.items() if len(v) > 1} else: - return self._indices + return self._md_indices def index(self, key): r"""Return the unique, sorted values of the specified metadata ``key`` from all the fields. @@ -659,8 +659,8 @@ def index(self, key): if key in self.indices(): return self.indices()[key] - self._indices[key] = self._find_index_values(key) - return self._indices[key] + self._md_indices[key] = self._find_index_values(key) + return self._md_indices[key] def to_numpy(self, **kwargs): r"""Return the field values as an ndarray. It is formed as the array of the diff --git a/earthkit/data/core/index.py b/earthkit/data/core/index.py index 455d9097..76c8d8e1 100644 --- a/earthkit/data/core/index.py +++ b/earthkit/data/core/index.py @@ -550,8 +550,8 @@ def full(self, *coords): class MaskIndex(Index): def __init__(self, index, indices): - self.index = index - self.indices = list(indices) + self._index = index + self._indices = list(indices) # super().__init__( # *self.index._init_args, # order_by=self.index._init_order_by, @@ -559,19 +559,19 @@ def __init__(self, index, indices): # ) def _getitem(self, n): - n = self.indices[n] - return self.index[n] + n = self._indices[n] + return self._index[n] def __len__(self): - return len(self.indices) + return len(self._indices) def __repr__(self): - return "MaskIndex(%r,%s)" % (self.index, self.indices) + return "MaskIndex(%r,%s)" % (self._index, self._indices) class MultiIndex(Index): def __init__(self, indexes, *args, **kwargs): - self.indexes = list(indexes) + self._indexes = list(indexes) super().__init__(*args, **kwargs) # self.indexes = list(i for i in indexes if len(i)) # TODO: propagate index._init_args, index._init_order_by, index._init_kwargs, for each i in indexes? @@ -579,36 +579,36 @@ def __init__(self, indexes, *args, **kwargs): def sel(self, *args, **kwargs): if not args and not kwargs: return self - return self.__class__(i.sel(*args, **kwargs) for i in self.indexes) + return self.__class__(i.sel(*args, **kwargs) for i in self._indexes) def _getitem(self, n): k = 0 - while n >= len(self.indexes[k]): - n -= len(self.indexes[k]) + while n >= len(self._indexes[k]): + n -= len(self._indexes[k]) k += 1 - return self.indexes[k][n] + return self._indexes[k][n] def __len__(self): - return sum(len(i) for i in self.indexes) + return sum(len(i) for i in self._indexes) def graph(self, depth=0): print(" " * depth, self.__class__.__name__) - for s in self.indexes: + for s in self._indexes: s.graph(depth + 3) def __repr__(self): return "%s(%s)" % ( self.__class__.__name__, - ",".join(repr(i) for i in self.indexes), + ",".join(repr(i) for i in self._indexes), ) class ForwardingIndex(Index): def __init__(self, index): - self.index = index + self._index = index def __len__(self): - return len(self.index) + return len(self._index) class ScaledField: @@ -635,7 +635,7 @@ class FullIndex(Index): def __init__(self, index, *coords): import numpy as np - self.index = index + self._index = index # Pass1, unique values unique = index.unique_values(*coords) @@ -663,4 +663,4 @@ def __len__(self): def _getitem(self, n): assert self.holes[n], f"Attempting to access hole {n}" - return self.index[sum(self.holes[:n])] + return self._index[sum(self.holes[:n])] diff --git a/earthkit/data/readers/netcdf.py b/earthkit/data/readers/netcdf.py index 4cd0062e..acd1ba97 100644 --- a/earthkit/data/readers/netcdf.py +++ b/earthkit/data/readers/netcdf.py @@ -518,7 +518,7 @@ def __init__(self, *args, **kwargs): def to_xarray(self, **kwargs): import xarray as xr - return xr.merge([x.ds for x in self.indexes], **kwargs) + return xr.merge([x.ds for x in self._indexes], **kwargs) class NetCDFMetadata(XArrayMetadata): @@ -635,7 +635,7 @@ def __init__(self, *args, **kwargs): def to_xarray(self, **kwargs): try: return NetCDFFieldList.to_xarray_multi_from_paths( - [x.path for x in self.indexes], **kwargs + [x.path for x in self._indexes], **kwargs ) except AttributeError: # TODO: Implement this, but discussion required diff --git a/tests/grib/test_grib_inidces.py b/tests/grib/test_grib_inidces.py new file mode 100644 index 00000000..9d4f52b8 --- /dev/null +++ b/tests/grib/test_grib_inidces.py @@ -0,0 +1,195 @@ +#!/usr/bin/env python3 + +# (C) Copyright 2020 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import os +import sys + +import pytest + +here = os.path.dirname(__file__) +sys.path.insert(0, here) +from grib_fixtures import load_file_or_numpy_fs # noqa: E402 + + +@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) +def test_grib_indices_base(mode): + ds = load_file_or_numpy_fs("tuv_pl.grib", mode) + + ref = { + "class": ["od"], + "stream": ["oper"], + "levtype": ["pl"], + "type": ["an"], + "expver": ["0001"], + "date": [20180801], + "time": [1200], + "domain": ["g"], + "number": [0], + "levelist": [300, 400, 500, 700, 850, 1000], + "param": ["t", "u", "v"], + } + + r = ds.indices() + assert r == ref + + ref = { + "levelist": [300, 400, 500, 700, 850, 1000], + "param": ["t", "u", "v"], + } + r = ds.indices(squeeze=True) + assert r == ref + + ref = ["t", "u", "v"] + r = ds.index("param") + assert r == ref + + +@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) +def test_grib_indices_sel(mode): + ds = load_file_or_numpy_fs("tuv_pl.grib", mode) + + ref = { + "class": ["od"], + "stream": ["oper"], + "levtype": ["pl"], + "type": ["an"], + "expver": ["0001"], + "date": [20180801], + "time": [1200], + "domain": ["g"], + "number": [0], + "levelist": [300, 400, 500, 700, 850, 1000], + "param": ["t"], + } + + ds1 = ds.sel(param="t") + r = ds1.indices() + assert r == ref + + ref = { + "levelist": [300, 400, 500, 700, 850, 1000], + } + r = ds1.indices(squeeze=True) + assert r == ref + + +@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) +def test_grib_indices_multi(mode): + f1 = load_file_or_numpy_fs("tuv_pl.grib", mode) + f2 = load_file_or_numpy_fs("ml_data.grib", mode, folder="data") + ds = f1 + f2 + + ref = { + "class": ["od"], + "stream": ["oper"], + "levtype": ["ml", "pl"], + "type": ["an", "fc"], + "expver": ["0001"], + "date": [20180111, 20180801], + "time": [1200], + "domain": ["g"], + "number": [0], + "levelist": [ + 1, + 5, + 9, + 13, + 17, + 21, + 25, + 29, + 33, + 37, + 41, + 45, + 49, + 53, + 57, + 61, + 65, + 69, + 73, + 77, + 81, + 85, + 89, + 93, + 97, + 101, + 105, + 109, + 113, + 117, + 121, + 125, + 129, + 133, + 137, + 300, + 400, + 500, + 700, + 850, + 1000, + ], + "param": ["lnsp", "t", "u", "v"], + } + + r = ds.indices() + assert r == ref + + +@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) +def test_grib_indices_multi_Del(mode): + f1 = load_file_or_numpy_fs("tuv_pl.grib", mode) + f2 = load_file_or_numpy_fs("ml_data.grib", mode, folder="data") + ds = f1 + f2 + + ref = { + "class": ["od"], + "stream": ["oper"], + "levtype": ["ml", "pl"], + "type": ["an", "fc"], + "expver": ["0001"], + "date": [20180111, 20180801], + "time": [1200], + "domain": ["g"], + "number": [0], + "levelist": [93, 500], + "param": ["t"], + } + + ds1 = ds.sel(param="t", level=[93, 500]) + r = ds1.indices() + assert r == ref + + +@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) +def test_grib_indices_order_by(mode): + ds = load_file_or_numpy_fs("tuv_pl.grib", mode) + + ref = { + "class": ["od"], + "stream": ["oper"], + "levtype": ["pl"], + "type": ["an"], + "expver": ["0001"], + "date": [20180801], + "time": [1200], + "domain": ["g"], + "number": [0], + "levelist": [300, 400, 500, 700, 850, 1000], + "param": ["t", "u", "v"], + } + + ds1 = ds.order_by(levelist="descending") + r = ds1.indices() + assert r == ref diff --git a/tests/indexing/test_indexing_db.py b/tests/indexing/test_indexing_db.py index 84b3b931..77f30871 100644 --- a/tests/indexing/test_indexing_db.py +++ b/tests/indexing/test_indexing_db.py @@ -35,8 +35,8 @@ def test_indexing_db_file_multi(): ds = from_source("file", path, indexing=True) counts = [6, 6, 6] - assert len(counts) == len(ds.indexes) - for i, d in enumerate(ds.indexes): + assert len(counts) == len(ds._indexes) + for i, d in enumerate(ds._indexes): assert hasattr(d, "db"), f"db,{i}" assert d.db.count() == counts[i] diff --git a/tests/sources/test_cds.py b/tests/sources/test_cds.py index a2a359c8..bbdd08a4 100644 --- a/tests/sources/test_cds.py +++ b/tests/sources/test_cds.py @@ -86,7 +86,7 @@ def test_cds_grib_split_on_var(): assert len(s) == 2 assert s.metadata("param") == ["2t", "msl"] assert not hasattr(s, "path") - assert len(s.indexes) == 2 + assert len(s._indexes) == 2 @pytest.mark.parametrize( @@ -152,10 +152,10 @@ def test_cds_split_on(split_on, expected_file_num, expected_param, expected_time if expected_file_num == 1: assert hasattr(s, "path") - assert not hasattr(s, "indexes") + assert not hasattr(s, "_indexes") else: assert not hasattr(s, "path") - assert len(s.indexes) == expected_file_num + assert len(s._indexes) == expected_file_num assert len(s) == 4 assert s.metadata("param") == expected_param @@ -188,7 +188,7 @@ def test_cds_multiple_requests( base_request | {"variable": "2t", "split_on": split_on1}, base_request | {"variable": "msl", "split_on": split_on2}, ) - assert len(s.indexes) == expected_file_num + assert len(s._indexes) == expected_file_num assert len(s) == 4 assert s.metadata("param") == expected_param assert s.metadata("time") == expected_time