diff --git a/src/earthkit/data/readers/grib/codes.py b/src/earthkit/data/readers/grib/codes.py index 5fbeb30c..3a9986d3 100644 --- a/src/earthkit/data/readers/grib/codes.py +++ b/src/earthkit/data/readers/grib/codes.py @@ -245,18 +245,19 @@ class GribField(Field): _handle = None - def __init__(self, path, offset, length, backend, manager=None): + def __init__(self, path, offset, length, backend, handle_manager=None, use_metadata_cache=False): super().__init__(backend) self.path = path self._offset = offset self._length = length - self._manager = manager + self._handle_manager = handle_manager + self._use_metadata_cache = use_metadata_cache @property def handle(self): r""":class:`CodesHandle`: Get an object providing access to the low level GRIB message structure.""" - if self._manager is not None: - handle = self._manager.handle(self, self._create_handle) + if self._handle_manager is not None: + handle = self._handle_manager.handle(self, self._create_handle) if handle is None: raise RuntimeError(f"Could not get a handle for offset={self.offset} in {self.path}") return handle @@ -282,13 +283,14 @@ def offset(self): @cached_property def _metadata(self): - cache = False - if self._manager is not None: - cache = self._manager.use_grib_metadata_cache - if cache: - cache = self._manager._make_metadata_cache() + cache = self._use_metadata_cache + if cache: + cache = self._make_metadata_cache() return GribFieldMetadata(self, cache=cache) + def _make_metadata_cache(self): + return dict() + def __repr__(self): return "GribField(%s,%s,%s,%s,%s,%s)" % ( self._metadata.get("shortName", None), diff --git a/src/earthkit/data/readers/grib/index/__init__.py b/src/earthkit/data/readers/grib/index/__init__.py index b8be7648..fe59812d 100644 --- a/src/earthkit/data/readers/grib/index/__init__.py +++ b/src/earthkit/data/readers/grib/index/__init__.py @@ -217,126 +217,106 @@ def __init__(self, *args, **kwargs): FieldList._init_from_multi(self, self) -class GribResourceManager: - def __init__( - self, owner, grib_field_policy, grib_handle_policy, grib_handle_cache_size, use_grib_metadata_cache - ): - from earthkit.data.core.settings import SETTINGS - - def _get(v, name): - return v if v is not None else SETTINGS.get(name) +class GribFieldManager: + def __init__(self, policy, owner): + self.policy = policy + self.cache = None - self.grib_field_policy = _get(grib_field_policy, "grib-field-policy") - self.grib_handle_policy = _get(grib_handle_policy, "grib-handle-policy") - self.grib_handle_cache_size = _get(grib_handle_cache_size, "grib-handle-cache-size") - self.use_grib_metadata_cache = _get(use_grib_metadata_cache, "use-grib-metadata-cache") - - # fields - self.field_cache = None - if self.grib_field_policy == "persistent": + if self.policy == "persistent": from lru import LRU # TODO: the number of fields might only be available only later (e.g. fieldlists with - # an SQL index). Consider making _field_cache a cached property. + # an SQL index). Consider making cache a cached property. n = len(owner) if n > 0: - self.field_cache = LRU(n) - - # handles - self.handle_cache = None - if self.grib_handle_policy == "cache": - if self.grib_handle_cache_size > 0: - from lru import LRU + self.cache = LRU(n) - self.handle_cache = LRU(self.grib_handle_cache_size) - else: - raise ValueError( - 'grib_handle_cache_size must be greater than 0 when grib_handle_policy="cache"' - ) - - self.handle_create_count = 0 self.field_create_count = 0 # check consistency - if self.field_cache is not None: - self.grib_field_policy == "persistent" - else: - self.grib_field_policy == "temporary" - - if self.handle_cache is not None: - self.grib_handle_policy == "cache" - else: - self.grib_handle_policy in ["persistent", "temporary"] + if self.cache is not None: + assert self.policy == "persistent" def field(self, n, create): - if self.grib_field_policy == "persistent": - if n in self.field_cache: - return self.field_cache[n] + if self.cache is not None: + if n in self.cache: + return self.cache[n] else: field = create(n) self._field_created() - self.field_cache[n] = field + self.cache[n] = field return field else: self._field_created() return create(n) + def _field_created(self): + self.field_create_count += 1 + + def diag(self): + r = defaultdict(int) + r["grib_field_policy"] = self.policy + if self.cache is not None: + r["field_cache_size"] = len(self.cache) + + r["field_create_count"] = self.field_create_count + return r + + +class GribHandleManager: + def __init__(self, policy, cache_size): + self.policy = policy + self.max_cache_size = cache_size + self.cache = None + + if self.policy == "cache": + if self.max_cache_size > 0: + from lru import LRU + + self.cache = LRU(self.max_cache_size) + else: + raise ValueError( + 'grib_handle_cache_size must be greater than 0 when grib_handle_policy="cache"' + ) + + self.handle_create_count = 0 + + # check consistency + if self.cache is not None: + self.policy == "cache" + else: + self.policy in ["persistent", "temporary"] + def handle(self, field, create): - if self.grib_handle_policy == "cache": + if self.policy == "cache": key = (field.path, field._offset) - if key in self.handle_cache: - return self.handle_cache[key] + if key in self.cache: + return self.cache[key] else: handle = create() self._handle_created() - self.handle_cache[key] = handle + self.cache[key] = handle return handle - elif self.grib_handle_policy == "persistent": + elif self.policy == "persistent": if field._handle is None: field._handle = create() self._handle_created() return field._handle - elif self.grib_handle_policy == "temporary": + elif self.policy == "temporary": self._handle_created() return create() - def _field_created(self): - self.field_create_count += 1 - def _handle_created(self): self.handle_create_count += 1 - def _make_metadata_cache(self): - return dict() - def diag(self): r = defaultdict(int) - r["grib_field_policy"] = self.grib_field_policy - r["grib_handle_policy"] = self.grib_handle_policy - r["grib_handle_cache_size"] = self.grib_handle_cache_size - - if self.field_cache is not None: - r["field_cache_size"] = len(self.field_cache) - - r["field_create_count"] = self.field_create_count - - if self.handle_cache is not None: - r["handle_cache_size"] = len(self.handle_cache) + r["grib_handle_policy"] = self.policy + r["grib_handle_cache_size"] = self.max_cache_size + if self.cache is not None: + r["handle_cache_size"] = len(self.cache) r["handle_create_count"] = self.handle_create_count - - if self.field_cache is not None: - for f in self.field_cache.values(): - if f._handle is not None: - r["current_handle_count"] += 1 - - try: - md_cache = f._diag() - for k in ["metadata_cache_hits", "metadata_cache_misses", "metadata_cache_size"]: - r[k] += md_cache[k] - except Exception: - pass - return r @@ -351,10 +331,20 @@ def __init__( **kwargs, ): super().__init__(*args, **kwargs) - self._manager = GribResourceManager( - self, grib_field_policy, grib_handle_policy, grib_handle_cache_size, use_grib_metadata_cache + + from earthkit.data.core.settings import SETTINGS + + def _get_opt(v, name): + return v if v is not None else SETTINGS.get(name) + + self._field_manager = GribFieldManager(_get_opt(grib_field_policy, "grib-field-policy"), self) + self._handle_manager = GribHandleManager( + _get_opt(grib_handle_policy, "grib-handle-policy"), + _get_opt(grib_handle_cache_size, "grib-handle-cache-size"), ) + self._use_metadata_cache = _get_opt(use_grib_metadata_cache, "use-grib-metadata-cache") + def _create_field(self, n): part = self.part(n) field = GribField( @@ -362,7 +352,8 @@ def _create_field(self, n): part.offset, part.length, self.array_backend, - manager=self._manager, + handle_manager=self._handle_manager, + use_metadata_cache=self._use_metadata_cache, ) if field is None: raise RuntimeError(f"Could not get a handle for part={part}") @@ -376,13 +367,29 @@ def _getitem(self, n): if n >= len(self): raise IndexError(f"Index {n} out of range") - return self._manager.field(n, self._create_field) + return self._field_manager.field(n, self._create_field) def __len__(self): return self.number_of_parts() def _diag(self): - return self._manager.diag() + r = defaultdict(int) + r.update(self._field_manager.diag()) + r.update(self._handle_manager.diag()) + + if self._field_manager.cache is not None: + for f in self._field_manager.cache.values(): + if f._handle is not None: + r["current_handle_count"] += 1 + + if self._use_metadata_cache: + try: + md_cache = f._diag() + for k in ["metadata_cache_hits", "metadata_cache_misses", "metadata_cache_size"]: + r[k] += md_cache[k] + except Exception: + pass + return r @abstractmethod def part(self, n): diff --git a/tests/grib/test_grib_cache.py b/tests/grib/test_grib_cache.py index 741157d7..93323fca 100644 --- a/tests/grib/test_grib_cache.py +++ b/tests/grib/test_grib_cache.py @@ -39,12 +39,12 @@ def __len__(self): @pytest.fixture def patch_metadata_cache(monkeypatch): - from earthkit.data.readers.grib.index import GribResourceManager + from earthkit.data.readers.grib.codes import GribField def patched_make_metadata_cache(self): return TestMetadataCache() - monkeypatch.setattr(GribResourceManager, "_make_metadata_cache", patched_make_metadata_cache) + monkeypatch.setattr(GribField, "_make_metadata_cache", patched_make_metadata_cache) def _check_diag(diag, ref): @@ -66,9 +66,6 @@ def test_grib_cache_basic(handle_cache_size, patch_metadata_cache): ds = from_source("file", earthkit_examples_file("tuv_pl.grib")) assert len(ds) == 18 - cache = ds._manager - assert cache - # unique values ref_vals = ds.unique_values("paramId", "levelist", "levtype", "valid_datetime") @@ -86,8 +83,8 @@ def test_grib_cache_basic(handle_cache_size, patch_metadata_cache): _check_diag(ds._diag(), ref) for i, f in enumerate(ds): - assert i in cache.field_cache, f"{i} not in cache" - assert id(f) == id(cache.field_cache[i]), f"{i} not the same object" + assert i in ds._field_manager.cache, f"{i} not in cache" + assert id(f) == id(ds._field_manager.cache[i]), f"{i} not the same object" _check_diag(ds._diag(), ref) @@ -133,6 +130,7 @@ def test_grib_cache_basic(handle_cache_size, patch_metadata_cache): def test_grib_cache_basic_non_patched(): """This test is the same as test_grib_cache_basic but without the patch_metadata_cache fixture. So metadata cache hits and misses are not counted.""" + with settings.temporary( { "grib-field-policy": "persistent", @@ -144,9 +142,6 @@ def test_grib_cache_basic_non_patched(): ds = from_source("file", earthkit_examples_file("tuv_pl.grib")) assert len(ds) == 18 - cache = ds._manager - assert cache - # unique values ref_vals = ds.unique_values("paramId", "levelist", "levtype", "valid_datetime") @@ -163,8 +158,8 @@ def test_grib_cache_basic_non_patched(): _check_diag(ds._diag(), ref) for i, f in enumerate(ds): - assert i in cache.field_cache, f"{i} not in cache" - assert id(f) == id(cache.field_cache[i]), f"{i} not the same object" + assert i in ds._field_manager.cache, f"{i} not in cache" + assert id(f) == id(ds._field_manager.cache[i]), f"{i} not the same object" _check_diag(ds._diag(), ref) @@ -216,14 +211,11 @@ def test_grib_cache_options_1(patch_metadata_cache): ds = from_source("file", earthkit_examples_file("tuv_pl.grib")) assert len(ds) == 18 - cache = ds._manager - assert cache - # unique values ds.unique_values("paramId", "levelist", "levtype", "valid_datetime") - assert cache.field_cache is not None - assert cache.handle_cache is None + assert ds._field_manager.cache is not None + assert ds._handle_manager.cache is None ref = { "field_cache_size": 18, @@ -238,8 +230,8 @@ def test_grib_cache_options_1(patch_metadata_cache): _check_diag(ds._diag(), ref) for i, f in enumerate(ds): - assert i in cache.field_cache, f"{i} not in cache" - assert id(f) == id(cache.field_cache[i]), f"{i} not the same object" + assert i in ds._field_manager.cache, f"{i} not in cache" + assert id(f) == id(ds._field_manager.cache[i]), f"{i} not the same object" _check_diag(ds._diag(), ref) @@ -296,14 +288,11 @@ def test_grib_cache_options_2(patch_metadata_cache): ds = from_source("file", earthkit_examples_file("tuv_pl.grib")) assert len(ds) == 18 - cache = ds._manager - assert cache - # unique values ds.unique_values("paramId", "levelist", "levtype", "valid_datetime") - assert cache.field_cache is not None - assert cache.handle_cache is None + assert ds._field_manager.cache is not None + assert ds._handle_manager.cache is None ref = { "field_cache_size": 18, @@ -318,8 +307,8 @@ def test_grib_cache_options_2(patch_metadata_cache): _check_diag(ds._diag(), ref) for i, f in enumerate(ds): - assert i in cache.field_cache, f"{i} not in cache" - assert id(f) == id(cache.field_cache[i]), f"{i} not the same object" + assert i in ds._field_manager.cache, f"{i} not in cache" + assert id(f) == id(ds._field_manager.cache[i]), f"{i} not the same object" _check_diag(ds._diag(), ref) @@ -378,14 +367,11 @@ def test_grib_cache_options_3(patch_metadata_cache): ds = from_source("file", earthkit_examples_file("tuv_pl.grib")) assert len(ds) == 18 - cache = ds._manager - assert cache - # unique values ds.unique_values("paramId", "levelist", "levtype", "valid_datetime") - assert cache.field_cache is not None - assert cache.handle_cache is not None + assert ds._field_manager.cache is not None + assert ds._handle_manager.cache is not None ref = { "field_cache_size": 18, @@ -401,8 +387,8 @@ def test_grib_cache_options_3(patch_metadata_cache): _check_diag(ds._diag(), ref) for i, f in enumerate(ds): - assert i in cache.field_cache, f"{i} not in cache" - assert id(f) == id(cache.field_cache[i]), f"{i} not the same object" + assert i in ds._field_manager.cache, f"{i} not in cache" + assert id(f) == id(ds._field_manager.cache[i]), f"{i} not the same object" _check_diag(ds._diag(), ref) @@ -458,9 +444,6 @@ def test_grib_cache_options_4(patch_metadata_cache): ds = from_source("file", earthkit_examples_file("tuv_pl.grib")) assert len(ds) == 18 - cache = ds._manager - assert cache - # unique values ds.unique_values("paramId", "levelist", "levtype", "valid_datetime") @@ -477,8 +460,8 @@ def test_grib_cache_options_4(patch_metadata_cache): _check_diag(ds._diag(), ref) - assert cache.field_cache is None - assert cache.handle_cache is None + assert ds._field_manager.cache is None + assert ds._handle_manager.cache is None # metadata object is not decoupled from the field object md = ds[0].metadata() @@ -552,9 +535,6 @@ def test_grib_cache_options_5(patch_metadata_cache): ds = from_source("file", earthkit_examples_file("tuv_pl.grib")) assert len(ds) == 18 - cache = ds._manager - assert cache - # unique values ds.unique_values("paramId", "levelist", "levtype", "valid_datetime") @@ -571,8 +551,8 @@ def test_grib_cache_options_5(patch_metadata_cache): _check_diag(ds._diag(), ref) - assert cache.field_cache is None - assert cache.handle_cache is None + assert ds._field_manager.cache is None + assert ds._handle_manager.cache is None # metadata object is not decoupled from the field object md = ds[0].metadata() @@ -648,9 +628,6 @@ def test_grib_cache_options_6(patch_metadata_cache): ds = from_source("file", earthkit_examples_file("tuv_pl.grib")) assert len(ds) == 18 - cache = ds._manager - assert cache - # unique values ds.unique_values("paramId", "levelist", "levtype", "valid_datetime") @@ -667,8 +644,8 @@ def test_grib_cache_options_6(patch_metadata_cache): _check_diag(ds._diag(), ref) - assert cache.field_cache is None - assert cache.handle_cache is not None + assert ds._field_manager.cache is None + assert ds._handle_manager.cache is not None # metadata object is not decoupled from the field object md = ds[0].metadata() @@ -736,9 +713,6 @@ def test_grib_cache_use_kwargs_1(): ds = from_source("file", earthkit_examples_file("tuv_pl.grib"), **_kwargs) assert len(ds) == 18 - cache = ds._manager - assert cache - # unique values ds.unique_values("paramId", "levelist", "levtype", "valid_datetime")