From f41d67291a6698466a8cacc44bbf821efc7b741f Mon Sep 17 00:00:00 2001 From: Sandor Kertesz Date: Thu, 5 Sep 2024 16:42:46 +0100 Subject: [PATCH] Fix memory leak in metadata cache (#442) * Fix memory leak in metadata cache --- docs/release_notes/version_0.10_updates.rst | 9 ++ src/earthkit/data/core/metadata.py | 21 +-- src/earthkit/data/readers/grib/codes.py | 6 +- .../data/readers/grib/index/__init__.py | 3 + tests/grib/test_grib_cache.py | 121 +++++++++++++++++- 5 files changed, 143 insertions(+), 17 deletions(-) diff --git a/docs/release_notes/version_0.10_updates.rst b/docs/release_notes/version_0.10_updates.rst index 66522bdd..a85937ef 100644 --- a/docs/release_notes/version_0.10_updates.rst +++ b/docs/release_notes/version_0.10_updates.rst @@ -2,6 +2,15 @@ Version 0.10 Updates ///////////////////////// +Version 0.10.1 +=============== + +Fixes +++++++ + +- Fixed memory leak in GRIB field metadata cache + + Version 0.10.0 =============== diff --git a/src/earthkit/data/core/metadata.py b/src/earthkit/data/core/metadata.py index 851b5ad4..61ab9b8b 100644 --- a/src/earthkit/data/core/metadata.py +++ b/src/earthkit/data/core/metadata.py @@ -9,16 +9,10 @@ from abc import ABCMeta from abc import abstractmethod -from functools import lru_cache from earthkit.data.core.constants import DATETIME from earthkit.data.core.constants import GRIDSPEC -try: - from functools import cache as memoise # noqa -except ImportError: - memoise = lru_cache - class Metadata(metaclass=ABCMeta): r"""Base class to represent metadata. @@ -54,8 +48,10 @@ class Metadata(metaclass=ABCMeta): def __init__(self, extra=None, cache=False): if extra is not None: self.extra = extra - if cache: - self.get = memoise(self.get) + if cache is False: + self._cache = None + else: + self._cache = dict() if cache is True else cache def __iter__(self): """Return an iterator over the metadata keys.""" @@ -196,12 +192,21 @@ def get(self, key, default=None, *, astype=None, raise_on_missing=False): a missing value. """ + if self._cache is not None: + cache_id = (key, default, astype, raise_on_missing) + if cache_id in self._cache: + return self._cache[cache_id] + if self._is_extra_key(key): v = self._get_extra_key(key, default=default, astype=astype) elif self._is_custom_key(key): v = self._get_custom_key(key, default=default, astype=astype, raise_on_missing=raise_on_missing) else: v = self._get(key, default=default, astype=astype, raise_on_missing=raise_on_missing) + + if self._cache is not None: + self._cache[cache_id] = v + return v @abstractmethod diff --git a/src/earthkit/data/readers/grib/codes.py b/src/earthkit/data/readers/grib/codes.py index c61a6e71..5fbeb30c 100644 --- a/src/earthkit/data/readers/grib/codes.py +++ b/src/earthkit/data/readers/grib/codes.py @@ -285,6 +285,8 @@ def _metadata(self): cache = False if self._manager is not None: cache = self._manager.use_grib_metadata_cache + if cache: + cache = self._manager._make_metadata_cache() return GribFieldMetadata(self, cache=cache) def __repr__(self): @@ -329,10 +331,10 @@ def message(self): def _diag(self): r = r = defaultdict(int) try: - md_cache = self._metadata.get.cache_info() + md_cache = self._metadata._cache + r["metadata_cache_size"] += len(md_cache) r["metadata_cache_hits"] += md_cache.hits r["metadata_cache_misses"] += md_cache.misses - r["metadata_cache_size"] += md_cache.currsize except Exception: pass return r diff --git a/src/earthkit/data/readers/grib/index/__init__.py b/src/earthkit/data/readers/grib/index/__init__.py index 0ffaa85d..b8be7648 100644 --- a/src/earthkit/data/readers/grib/index/__init__.py +++ b/src/earthkit/data/readers/grib/index/__init__.py @@ -306,6 +306,9 @@ def _field_created(self): def _handle_created(self): self.handle_create_count += 1 + def _make_metadata_cache(self): + return dict() + def diag(self): r = defaultdict(int) r["grib_field_policy"] = self.grib_field_policy diff --git a/tests/grib/test_grib_cache.py b/tests/grib/test_grib_cache.py index fd9c255f..741157d7 100644 --- a/tests/grib/test_grib_cache.py +++ b/tests/grib/test_grib_cache.py @@ -16,13 +16,44 @@ from earthkit.data.testing import earthkit_examples_file +class TestMetadataCache: + def __init__(self): + self.hits = 0 + self.misses = 0 + self.data = {} + + def __contains__(self, key): + return key in self.data + + def __getitem__(self, key): + self.hits += 1 + return self.data[key] + + def __setitem__(self, key, value): + self.misses += 1 + self.data[key] = value + + def __len__(self): + return len(self.data) + + +@pytest.fixture +def patch_metadata_cache(monkeypatch): + from earthkit.data.readers.grib.index import GribResourceManager + + def patched_make_metadata_cache(self): + return TestMetadataCache() + + monkeypatch.setattr(GribResourceManager, "_make_metadata_cache", patched_make_metadata_cache) + + def _check_diag(diag, ref): for k, v in ref.items(): assert diag[k] == v, f"{k}={diag[k]} != {v}" @pytest.mark.parametrize("handle_cache_size", [1, 5]) -def test_grib_cache_basic(handle_cache_size): +def test_grib_cache_basic(handle_cache_size, patch_metadata_cache): with settings.temporary( { @@ -99,7 +130,81 @@ def test_grib_cache_basic(handle_cache_size): assert ds[0].handle == md._handle -def test_grib_cache_options_1(): +def test_grib_cache_basic_non_patched(): + """This test is the same as test_grib_cache_basic but without the patch_metadata_cache fixture. + So metadata cache hits and misses are not counted.""" + with settings.temporary( + { + "grib-field-policy": "persistent", + "grib-handle-policy": "cache", + "grib-handle-cache-size": 1, + "use-grib-metadata-cache": True, + } + ): + ds = from_source("file", earthkit_examples_file("tuv_pl.grib")) + assert len(ds) == 18 + + cache = ds._manager + assert cache + + # unique values + ref_vals = ds.unique_values("paramId", "levelist", "levtype", "valid_datetime") + + ref = { + "field_cache_size": 18, + "field_create_count": 18, + "handle_cache_size": 1, + "handle_create_count": 18, + "current_handle_count": 0, + # "metadata_cache_hits": 0, + # "metadata_cache_misses": 18 * 6, + "metadata_cache_size": 18 * 6, + } + _check_diag(ds._diag(), ref) + + for i, f in enumerate(ds): + assert i in cache.field_cache, f"{i} not in cache" + assert id(f) == id(cache.field_cache[i]), f"{i} not the same object" + + _check_diag(ds._diag(), ref) + + # unique values repeated + vals = ds.unique_values("paramId", "levelist", "levtype", "valid_datetime") + + assert vals == ref_vals + + ref = { + "field_cache_size": 18, + "field_create_count": 18, + "handle_cache_size": 1, + "handle_create_count": 18, + "current_handle_count": 0, + # "metadata_cache_hits": 18 * 4, + # "metadata_cache_misses": 18 * 6, + "metadata_cache_size": 18 * 6, + } + _check_diag(ds._diag(), ref) + + # order by + ds.order_by(["levelist", "valid_datetime", "paramId", "levtype"]) + ref = { + "field_cache_size": 18, + "field_create_count": 18, + "handle_cache_size": 1, + "handle_create_count": 18, + "current_handle_count": 0, + # "metadata_cache_misses": 18 * 6, + "metadata_cache_size": 18 * 6, + } + _check_diag(ds._diag(), ref) + + # metadata object is not decoupled from the field object + md = ds[0].metadata() + assert hasattr(md, "_field") + assert ds[0].handle == md._handle + + +def test_grib_cache_options_1(patch_metadata_cache): with settings.temporary( { "grib-field-policy": "persistent", @@ -179,7 +284,7 @@ def test_grib_cache_options_1(): _check_diag(ds._diag(), ref) -def test_grib_cache_options_2(): +def test_grib_cache_options_2(patch_metadata_cache): with settings.temporary( { "grib-field-policy": "persistent", @@ -261,7 +366,7 @@ def test_grib_cache_options_2(): _check_diag(ds._diag(), ref) -def test_grib_cache_options_3(): +def test_grib_cache_options_3(patch_metadata_cache): with settings.temporary( { "grib-field-policy": "persistent", @@ -341,7 +446,7 @@ def test_grib_cache_options_3(): _check_diag(ds._diag(), ref) -def test_grib_cache_options_4(): +def test_grib_cache_options_4(patch_metadata_cache): with settings.temporary( { "grib-field-policy": "temporary", @@ -420,6 +525,7 @@ def test_grib_cache_options_4(): _check_diag( ds[0]._diag(), {"metadata_cache_hits": 0, "metadata_cache_misses": 0, "metadata_cache_size": 0} ) + ref["field_create_count"] += 2 ref["handle_create_count"] += 1 _check_diag(ds._diag(), ref) @@ -428,12 +534,13 @@ def test_grib_cache_options_4(): _check_diag( ds[0]._diag(), {"metadata_cache_hits": 0, "metadata_cache_misses": 0, "metadata_cache_size": 0} ) + ref["field_create_count"] += 2 ref["handle_create_count"] += 1 _check_diag(ds._diag(), ref) -def test_grib_cache_options_5(): +def test_grib_cache_options_5(patch_metadata_cache): with settings.temporary( { "grib-field-policy": "temporary", @@ -529,7 +636,7 @@ def test_grib_cache_options_5(): _check_diag(ds._diag(), ref) -def test_grib_cache_options_6(): +def test_grib_cache_options_6(patch_metadata_cache): with settings.temporary( { "grib-field-policy": "temporary",