Skip to content

Commit

Permalink
Fix memory leak in metadata cache (#442)
Browse files Browse the repository at this point in the history
* Fix memory leak in metadata cache
  • Loading branch information
sandorkertesz authored Sep 5, 2024
1 parent 61d0174 commit f41d672
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 17 deletions.
9 changes: 9 additions & 0 deletions docs/release_notes/version_0.10_updates.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,15 @@ Version 0.10 Updates
/////////////////////////


Version 0.10.1
===============

Fixes
++++++

- Fixed memory leak in GRIB field metadata cache


Version 0.10.0
===============

Expand Down
21 changes: 13 additions & 8 deletions src/earthkit/data/core/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,10 @@

from abc import ABCMeta
from abc import abstractmethod
from functools import lru_cache

from earthkit.data.core.constants import DATETIME
from earthkit.data.core.constants import GRIDSPEC

try:
from functools import cache as memoise # noqa
except ImportError:
memoise = lru_cache


class Metadata(metaclass=ABCMeta):
r"""Base class to represent metadata.
Expand Down Expand Up @@ -54,8 +48,10 @@ class Metadata(metaclass=ABCMeta):
def __init__(self, extra=None, cache=False):
if extra is not None:
self.extra = extra
if cache:
self.get = memoise(self.get)
if cache is False:
self._cache = None
else:
self._cache = dict() if cache is True else cache

def __iter__(self):
"""Return an iterator over the metadata keys."""
Expand Down Expand Up @@ -196,12 +192,21 @@ def get(self, key, default=None, *, astype=None, raise_on_missing=False):
a missing value.
"""
if self._cache is not None:
cache_id = (key, default, astype, raise_on_missing)
if cache_id in self._cache:
return self._cache[cache_id]

if self._is_extra_key(key):
v = self._get_extra_key(key, default=default, astype=astype)
elif self._is_custom_key(key):
v = self._get_custom_key(key, default=default, astype=astype, raise_on_missing=raise_on_missing)
else:
v = self._get(key, default=default, astype=astype, raise_on_missing=raise_on_missing)

if self._cache is not None:
self._cache[cache_id] = v

return v

@abstractmethod
Expand Down
6 changes: 4 additions & 2 deletions src/earthkit/data/readers/grib/codes.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,8 @@ def _metadata(self):
cache = False
if self._manager is not None:
cache = self._manager.use_grib_metadata_cache
if cache:
cache = self._manager._make_metadata_cache()
return GribFieldMetadata(self, cache=cache)

def __repr__(self):
Expand Down Expand Up @@ -329,10 +331,10 @@ def message(self):
def _diag(self):
r = r = defaultdict(int)
try:
md_cache = self._metadata.get.cache_info()
md_cache = self._metadata._cache
r["metadata_cache_size"] += len(md_cache)
r["metadata_cache_hits"] += md_cache.hits
r["metadata_cache_misses"] += md_cache.misses
r["metadata_cache_size"] += md_cache.currsize
except Exception:
pass
return r
3 changes: 3 additions & 0 deletions src/earthkit/data/readers/grib/index/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,9 @@ def _field_created(self):
def _handle_created(self):
self.handle_create_count += 1

def _make_metadata_cache(self):
return dict()

def diag(self):
r = defaultdict(int)
r["grib_field_policy"] = self.grib_field_policy
Expand Down
121 changes: 114 additions & 7 deletions tests/grib/test_grib_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,44 @@
from earthkit.data.testing import earthkit_examples_file


class TestMetadataCache:
def __init__(self):
self.hits = 0
self.misses = 0
self.data = {}

def __contains__(self, key):
return key in self.data

def __getitem__(self, key):
self.hits += 1
return self.data[key]

def __setitem__(self, key, value):
self.misses += 1
self.data[key] = value

def __len__(self):
return len(self.data)


@pytest.fixture
def patch_metadata_cache(monkeypatch):
from earthkit.data.readers.grib.index import GribResourceManager

def patched_make_metadata_cache(self):
return TestMetadataCache()

monkeypatch.setattr(GribResourceManager, "_make_metadata_cache", patched_make_metadata_cache)


def _check_diag(diag, ref):
for k, v in ref.items():
assert diag[k] == v, f"{k}={diag[k]} != {v}"


@pytest.mark.parametrize("handle_cache_size", [1, 5])
def test_grib_cache_basic(handle_cache_size):
def test_grib_cache_basic(handle_cache_size, patch_metadata_cache):

with settings.temporary(
{
Expand Down Expand Up @@ -99,7 +130,81 @@ def test_grib_cache_basic(handle_cache_size):
assert ds[0].handle == md._handle


def test_grib_cache_options_1():
def test_grib_cache_basic_non_patched():
"""This test is the same as test_grib_cache_basic but without the patch_metadata_cache fixture.
So metadata cache hits and misses are not counted."""
with settings.temporary(
{
"grib-field-policy": "persistent",
"grib-handle-policy": "cache",
"grib-handle-cache-size": 1,
"use-grib-metadata-cache": True,
}
):
ds = from_source("file", earthkit_examples_file("tuv_pl.grib"))
assert len(ds) == 18

cache = ds._manager
assert cache

# unique values
ref_vals = ds.unique_values("paramId", "levelist", "levtype", "valid_datetime")

ref = {
"field_cache_size": 18,
"field_create_count": 18,
"handle_cache_size": 1,
"handle_create_count": 18,
"current_handle_count": 0,
# "metadata_cache_hits": 0,
# "metadata_cache_misses": 18 * 6,
"metadata_cache_size": 18 * 6,
}
_check_diag(ds._diag(), ref)

for i, f in enumerate(ds):
assert i in cache.field_cache, f"{i} not in cache"
assert id(f) == id(cache.field_cache[i]), f"{i} not the same object"

_check_diag(ds._diag(), ref)

# unique values repeated
vals = ds.unique_values("paramId", "levelist", "levtype", "valid_datetime")

assert vals == ref_vals

ref = {
"field_cache_size": 18,
"field_create_count": 18,
"handle_cache_size": 1,
"handle_create_count": 18,
"current_handle_count": 0,
# "metadata_cache_hits": 18 * 4,
# "metadata_cache_misses": 18 * 6,
"metadata_cache_size": 18 * 6,
}
_check_diag(ds._diag(), ref)

# order by
ds.order_by(["levelist", "valid_datetime", "paramId", "levtype"])
ref = {
"field_cache_size": 18,
"field_create_count": 18,
"handle_cache_size": 1,
"handle_create_count": 18,
"current_handle_count": 0,
# "metadata_cache_misses": 18 * 6,
"metadata_cache_size": 18 * 6,
}
_check_diag(ds._diag(), ref)

# metadata object is not decoupled from the field object
md = ds[0].metadata()
assert hasattr(md, "_field")
assert ds[0].handle == md._handle


def test_grib_cache_options_1(patch_metadata_cache):
with settings.temporary(
{
"grib-field-policy": "persistent",
Expand Down Expand Up @@ -179,7 +284,7 @@ def test_grib_cache_options_1():
_check_diag(ds._diag(), ref)


def test_grib_cache_options_2():
def test_grib_cache_options_2(patch_metadata_cache):
with settings.temporary(
{
"grib-field-policy": "persistent",
Expand Down Expand Up @@ -261,7 +366,7 @@ def test_grib_cache_options_2():
_check_diag(ds._diag(), ref)


def test_grib_cache_options_3():
def test_grib_cache_options_3(patch_metadata_cache):
with settings.temporary(
{
"grib-field-policy": "persistent",
Expand Down Expand Up @@ -341,7 +446,7 @@ def test_grib_cache_options_3():
_check_diag(ds._diag(), ref)


def test_grib_cache_options_4():
def test_grib_cache_options_4(patch_metadata_cache):
with settings.temporary(
{
"grib-field-policy": "temporary",
Expand Down Expand Up @@ -420,6 +525,7 @@ def test_grib_cache_options_4():
_check_diag(
ds[0]._diag(), {"metadata_cache_hits": 0, "metadata_cache_misses": 0, "metadata_cache_size": 0}
)

ref["field_create_count"] += 2
ref["handle_create_count"] += 1
_check_diag(ds._diag(), ref)
Expand All @@ -428,12 +534,13 @@ def test_grib_cache_options_4():
_check_diag(
ds[0]._diag(), {"metadata_cache_hits": 0, "metadata_cache_misses": 0, "metadata_cache_size": 0}
)

ref["field_create_count"] += 2
ref["handle_create_count"] += 1
_check_diag(ds._diag(), ref)


def test_grib_cache_options_5():
def test_grib_cache_options_5(patch_metadata_cache):
with settings.temporary(
{
"grib-field-policy": "temporary",
Expand Down Expand Up @@ -529,7 +636,7 @@ def test_grib_cache_options_5():
_check_diag(ds._diag(), ref)


def test_grib_cache_options_6():
def test_grib_cache_options_6(patch_metadata_cache):
with settings.temporary(
{
"grib-field-policy": "temporary",
Expand Down

0 comments on commit f41d672

Please sign in to comment.