Skip to content

Commit

Permalink
Fix indices() on fieldlist filtered with sel() (#265)
Browse files Browse the repository at this point in the history
* Fix indices on fieldlist filtered with sel
  • Loading branch information
sandorkertesz authored Nov 29, 2023
1 parent 8854b9f commit 2d7af06
Show file tree
Hide file tree
Showing 7 changed files with 230 additions and 33 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -359,3 +359,5 @@ notebooks/data/*/

# local code
_dev

test.db
14 changes: 7 additions & 7 deletions earthkit/data/core/fieldlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ def _attributes(self, names):
class FieldList(Index):
r"""Represents a list of :obj:`Field` \s."""

_indices = {}
_md_indices = {}

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -622,12 +622,12 @@ def indices(self, squeeze=False):
used in :obj:`indices`.
"""
if not self._indices:
self._indices = self._find_all_index_dict()
if not self._md_indices:
self._md_indices = self._find_all_index_dict()
if squeeze:
return {k: v for k, v in self._indices.items() if len(v) > 1}
return {k: v for k, v in self._md_indices.items() if len(v) > 1}
else:
return self._indices
return self._md_indices

def index(self, key):
r"""Return the unique, sorted values of the specified metadata ``key`` from all the fields.
Expand Down Expand Up @@ -659,8 +659,8 @@ def index(self, key):
if key in self.indices():
return self.indices()[key]

self._indices[key] = self._find_index_values(key)
return self._indices[key]
self._md_indices[key] = self._find_index_values(key)
return self._md_indices[key]

def to_numpy(self, **kwargs):
r"""Return the field values as an ndarray. It is formed as the array of the
Expand Down
36 changes: 18 additions & 18 deletions earthkit/data/core/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,65 +550,65 @@ def full(self, *coords):

class MaskIndex(Index):
def __init__(self, index, indices):
self.index = index
self.indices = list(indices)
self._index = index
self._indices = list(indices)
# super().__init__(
# *self.index._init_args,
# order_by=self.index._init_order_by,
# **self.index._init_kwargs,
# )

def _getitem(self, n):
n = self.indices[n]
return self.index[n]
n = self._indices[n]
return self._index[n]

def __len__(self):
return len(self.indices)
return len(self._indices)

def __repr__(self):
return "MaskIndex(%r,%s)" % (self.index, self.indices)
return "MaskIndex(%r,%s)" % (self._index, self._indices)


class MultiIndex(Index):
def __init__(self, indexes, *args, **kwargs):
self.indexes = list(indexes)
self._indexes = list(indexes)
super().__init__(*args, **kwargs)
# self.indexes = list(i for i in indexes if len(i))
# TODO: propagate index._init_args, index._init_order_by, index._init_kwargs, for each i in indexes?

def sel(self, *args, **kwargs):
if not args and not kwargs:
return self
return self.__class__(i.sel(*args, **kwargs) for i in self.indexes)
return self.__class__(i.sel(*args, **kwargs) for i in self._indexes)

def _getitem(self, n):
k = 0
while n >= len(self.indexes[k]):
n -= len(self.indexes[k])
while n >= len(self._indexes[k]):
n -= len(self._indexes[k])
k += 1
return self.indexes[k][n]
return self._indexes[k][n]

def __len__(self):
return sum(len(i) for i in self.indexes)
return sum(len(i) for i in self._indexes)

def graph(self, depth=0):
print(" " * depth, self.__class__.__name__)
for s in self.indexes:
for s in self._indexes:
s.graph(depth + 3)

def __repr__(self):
return "%s(%s)" % (
self.__class__.__name__,
",".join(repr(i) for i in self.indexes),
",".join(repr(i) for i in self._indexes),
)


class ForwardingIndex(Index):
def __init__(self, index):
self.index = index
self._index = index

def __len__(self):
return len(self.index)
return len(self._index)


class ScaledField:
Expand All @@ -635,7 +635,7 @@ class FullIndex(Index):
def __init__(self, index, *coords):
import numpy as np

self.index = index
self._index = index

# Pass1, unique values
unique = index.unique_values(*coords)
Expand Down Expand Up @@ -663,4 +663,4 @@ def __len__(self):

def _getitem(self, n):
assert self.holes[n], f"Attempting to access hole {n}"
return self.index[sum(self.holes[:n])]
return self._index[sum(self.holes[:n])]
4 changes: 2 additions & 2 deletions earthkit/data/readers/netcdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ def __init__(self, *args, **kwargs):
def to_xarray(self, **kwargs):
import xarray as xr

return xr.merge([x.ds for x in self.indexes], **kwargs)
return xr.merge([x.ds for x in self._indexes], **kwargs)


class NetCDFMetadata(XArrayMetadata):
Expand Down Expand Up @@ -635,7 +635,7 @@ def __init__(self, *args, **kwargs):
def to_xarray(self, **kwargs):
try:
return NetCDFFieldList.to_xarray_multi_from_paths(
[x.path for x in self.indexes], **kwargs
[x.path for x in self._indexes], **kwargs
)
except AttributeError:
# TODO: Implement this, but discussion required
Expand Down
195 changes: 195 additions & 0 deletions tests/grib/test_grib_inidces.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
#!/usr/bin/env python3

# (C) Copyright 2020 ECMWF.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
#

import os
import sys

import pytest

here = os.path.dirname(__file__)
sys.path.insert(0, here)
from grib_fixtures import load_file_or_numpy_fs # noqa: E402


@pytest.mark.parametrize("mode", ["file", "numpy_fs"])
def test_grib_indices_base(mode):
ds = load_file_or_numpy_fs("tuv_pl.grib", mode)

ref = {
"class": ["od"],
"stream": ["oper"],
"levtype": ["pl"],
"type": ["an"],
"expver": ["0001"],
"date": [20180801],
"time": [1200],
"domain": ["g"],
"number": [0],
"levelist": [300, 400, 500, 700, 850, 1000],
"param": ["t", "u", "v"],
}

r = ds.indices()
assert r == ref

ref = {
"levelist": [300, 400, 500, 700, 850, 1000],
"param": ["t", "u", "v"],
}
r = ds.indices(squeeze=True)
assert r == ref

ref = ["t", "u", "v"]
r = ds.index("param")
assert r == ref


@pytest.mark.parametrize("mode", ["file", "numpy_fs"])
def test_grib_indices_sel(mode):
ds = load_file_or_numpy_fs("tuv_pl.grib", mode)

ref = {
"class": ["od"],
"stream": ["oper"],
"levtype": ["pl"],
"type": ["an"],
"expver": ["0001"],
"date": [20180801],
"time": [1200],
"domain": ["g"],
"number": [0],
"levelist": [300, 400, 500, 700, 850, 1000],
"param": ["t"],
}

ds1 = ds.sel(param="t")
r = ds1.indices()
assert r == ref

ref = {
"levelist": [300, 400, 500, 700, 850, 1000],
}
r = ds1.indices(squeeze=True)
assert r == ref


@pytest.mark.parametrize("mode", ["file", "numpy_fs"])
def test_grib_indices_multi(mode):
f1 = load_file_or_numpy_fs("tuv_pl.grib", mode)
f2 = load_file_or_numpy_fs("ml_data.grib", mode, folder="data")
ds = f1 + f2

ref = {
"class": ["od"],
"stream": ["oper"],
"levtype": ["ml", "pl"],
"type": ["an", "fc"],
"expver": ["0001"],
"date": [20180111, 20180801],
"time": [1200],
"domain": ["g"],
"number": [0],
"levelist": [
1,
5,
9,
13,
17,
21,
25,
29,
33,
37,
41,
45,
49,
53,
57,
61,
65,
69,
73,
77,
81,
85,
89,
93,
97,
101,
105,
109,
113,
117,
121,
125,
129,
133,
137,
300,
400,
500,
700,
850,
1000,
],
"param": ["lnsp", "t", "u", "v"],
}

r = ds.indices()
assert r == ref


@pytest.mark.parametrize("mode", ["file", "numpy_fs"])
def test_grib_indices_multi_Del(mode):
f1 = load_file_or_numpy_fs("tuv_pl.grib", mode)
f2 = load_file_or_numpy_fs("ml_data.grib", mode, folder="data")
ds = f1 + f2

ref = {
"class": ["od"],
"stream": ["oper"],
"levtype": ["ml", "pl"],
"type": ["an", "fc"],
"expver": ["0001"],
"date": [20180111, 20180801],
"time": [1200],
"domain": ["g"],
"number": [0],
"levelist": [93, 500],
"param": ["t"],
}

ds1 = ds.sel(param="t", level=[93, 500])
r = ds1.indices()
assert r == ref


@pytest.mark.parametrize("mode", ["file", "numpy_fs"])
def test_grib_indices_order_by(mode):
ds = load_file_or_numpy_fs("tuv_pl.grib", mode)

ref = {
"class": ["od"],
"stream": ["oper"],
"levtype": ["pl"],
"type": ["an"],
"expver": ["0001"],
"date": [20180801],
"time": [1200],
"domain": ["g"],
"number": [0],
"levelist": [300, 400, 500, 700, 850, 1000],
"param": ["t", "u", "v"],
}

ds1 = ds.order_by(levelist="descending")
r = ds1.indices()
assert r == ref
4 changes: 2 additions & 2 deletions tests/indexing/test_indexing_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def test_indexing_db_file_multi():
ds = from_source("file", path, indexing=True)

counts = [6, 6, 6]
assert len(counts) == len(ds.indexes)
for i, d in enumerate(ds.indexes):
assert len(counts) == len(ds._indexes)
for i, d in enumerate(ds._indexes):
assert hasattr(d, "db"), f"db,{i}"
assert d.db.count() == counts[i]

Expand Down
8 changes: 4 additions & 4 deletions tests/sources/test_cds.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def test_cds_grib_split_on_var():
assert len(s) == 2
assert s.metadata("param") == ["2t", "msl"]
assert not hasattr(s, "path")
assert len(s.indexes) == 2
assert len(s._indexes) == 2


@pytest.mark.parametrize(
Expand Down Expand Up @@ -152,10 +152,10 @@ def test_cds_split_on(split_on, expected_file_num, expected_param, expected_time

if expected_file_num == 1:
assert hasattr(s, "path")
assert not hasattr(s, "indexes")
assert not hasattr(s, "_indexes")
else:
assert not hasattr(s, "path")
assert len(s.indexes) == expected_file_num
assert len(s._indexes) == expected_file_num

assert len(s) == 4
assert s.metadata("param") == expected_param
Expand Down Expand Up @@ -188,7 +188,7 @@ def test_cds_multiple_requests(
base_request | {"variable": "2t", "split_on": split_on1},
base_request | {"variable": "msl", "split_on": split_on2},
)
assert len(s.indexes) == expected_file_num
assert len(s._indexes) == expected_file_num
assert len(s) == 4
assert s.metadata("param") == expected_param
assert s.metadata("time") == expected_time
Expand Down

0 comments on commit 2d7af06

Please sign in to comment.