Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/fix memory leak in metadata cache #443

Merged
merged 9 commits into from
Sep 7, 2024
20 changes: 11 additions & 9 deletions src/earthkit/data/readers/grib/codes.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,18 +245,19 @@ class GribField(Field):

_handle = None

def __init__(self, path, offset, length, backend, manager=None):
def __init__(self, path, offset, length, backend, handle_manager=None, use_metadata_cache=False):
super().__init__(backend)
self.path = path
self._offset = offset
self._length = length
self._manager = manager
self._handle_manager = handle_manager
self._use_metadata_cache = use_metadata_cache

@property
def handle(self):
r""":class:`CodesHandle`: Get an object providing access to the low level GRIB message structure."""
if self._manager is not None:
handle = self._manager.handle(self, self._create_handle)
if self._handle_manager is not None:
handle = self._handle_manager.handle(self, self._create_handle)
if handle is None:
raise RuntimeError(f"Could not get a handle for offset={self.offset} in {self.path}")
return handle
Expand All @@ -282,13 +283,14 @@ def offset(self):

@cached_property
def _metadata(self):
cache = False
if self._manager is not None:
cache = self._manager.use_grib_metadata_cache
if cache:
cache = self._manager._make_metadata_cache()
cache = self._use_metadata_cache
if cache:
cache = self._make_metadata_cache()
return GribFieldMetadata(self, cache=cache)

def _make_metadata_cache(self):
return dict()

def __repr__(self):
return "GribField(%s,%s,%s,%s,%s,%s)" % (
self._metadata.get("shortName", None),
Expand Down
177 changes: 92 additions & 85 deletions src/earthkit/data/readers/grib/index/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,126 +217,106 @@ def __init__(self, *args, **kwargs):
FieldList._init_from_multi(self, self)


class GribResourceManager:
def __init__(
self, owner, grib_field_policy, grib_handle_policy, grib_handle_cache_size, use_grib_metadata_cache
):
from earthkit.data.core.settings import SETTINGS

def _get(v, name):
return v if v is not None else SETTINGS.get(name)
class GribFieldManager:
def __init__(self, policy, owner):
self.policy = policy
self.cache = None

self.grib_field_policy = _get(grib_field_policy, "grib-field-policy")
self.grib_handle_policy = _get(grib_handle_policy, "grib-handle-policy")
self.grib_handle_cache_size = _get(grib_handle_cache_size, "grib-handle-cache-size")
self.use_grib_metadata_cache = _get(use_grib_metadata_cache, "use-grib-metadata-cache")

# fields
self.field_cache = None
if self.grib_field_policy == "persistent":
if self.policy == "persistent":
from lru import LRU

# TODO: the number of fields might only be available only later (e.g. fieldlists with
# an SQL index). Consider making _field_cache a cached property.
# an SQL index). Consider making cache a cached property.
n = len(owner)
if n > 0:
self.field_cache = LRU(n)

# handles
self.handle_cache = None
if self.grib_handle_policy == "cache":
if self.grib_handle_cache_size > 0:
from lru import LRU
self.cache = LRU(n)

self.handle_cache = LRU(self.grib_handle_cache_size)
else:
raise ValueError(
'grib_handle_cache_size must be greater than 0 when grib_handle_policy="cache"'
)

self.handle_create_count = 0
self.field_create_count = 0

# check consistency
if self.field_cache is not None:
self.grib_field_policy == "persistent"
else:
self.grib_field_policy == "temporary"

if self.handle_cache is not None:
self.grib_handle_policy == "cache"
else:
self.grib_handle_policy in ["persistent", "temporary"]
if self.cache is not None:
assert self.policy == "persistent"

def field(self, n, create):
if self.grib_field_policy == "persistent":
if n in self.field_cache:
return self.field_cache[n]
if self.cache is not None:
if n in self.cache:
return self.cache[n]
else:
field = create(n)
self._field_created()
self.field_cache[n] = field
self.cache[n] = field
return field
else:
self._field_created()
return create(n)

def _field_created(self):
self.field_create_count += 1

def diag(self):
r = defaultdict(int)
r["grib_field_policy"] = self.policy
if self.cache is not None:
r["field_cache_size"] = len(self.cache)

r["field_create_count"] = self.field_create_count
return r


class GribHandleManager:
def __init__(self, policy, cache_size):
self.policy = policy
self.max_cache_size = cache_size
self.cache = None

if self.policy == "cache":
if self.max_cache_size > 0:
from lru import LRU

self.cache = LRU(self.max_cache_size)
else:
raise ValueError(
'grib_handle_cache_size must be greater than 0 when grib_handle_policy="cache"'
)

self.handle_create_count = 0

# check consistency
if self.cache is not None:
self.policy == "cache"
else:
self.policy in ["persistent", "temporary"]

def handle(self, field, create):
if self.grib_handle_policy == "cache":
if self.policy == "cache":
key = (field.path, field._offset)
if key in self.handle_cache:
return self.handle_cache[key]
if key in self.cache:
return self.cache[key]
else:
handle = create()
self._handle_created()
self.handle_cache[key] = handle
self.cache[key] = handle
return handle
elif self.grib_handle_policy == "persistent":
elif self.policy == "persistent":
if field._handle is None:
field._handle = create()
self._handle_created()
return field._handle
elif self.grib_handle_policy == "temporary":
elif self.policy == "temporary":
self._handle_created()
return create()

def _field_created(self):
self.field_create_count += 1

def _handle_created(self):
self.handle_create_count += 1

def _make_metadata_cache(self):
return dict()

def diag(self):
r = defaultdict(int)
r["grib_field_policy"] = self.grib_field_policy
r["grib_handle_policy"] = self.grib_handle_policy
r["grib_handle_cache_size"] = self.grib_handle_cache_size

if self.field_cache is not None:
r["field_cache_size"] = len(self.field_cache)

r["field_create_count"] = self.field_create_count

if self.handle_cache is not None:
r["handle_cache_size"] = len(self.handle_cache)
r["grib_handle_policy"] = self.policy
r["grib_handle_cache_size"] = self.max_cache_size
if self.cache is not None:
r["handle_cache_size"] = len(self.cache)

r["handle_create_count"] = self.handle_create_count

if self.field_cache is not None:
for f in self.field_cache.values():
if f._handle is not None:
r["current_handle_count"] += 1

try:
md_cache = f._diag()
for k in ["metadata_cache_hits", "metadata_cache_misses", "metadata_cache_size"]:
r[k] += md_cache[k]
except Exception:
pass

return r


Expand All @@ -351,18 +331,29 @@ def __init__(
**kwargs,
):
super().__init__(*args, **kwargs)
self._manager = GribResourceManager(
self, grib_field_policy, grib_handle_policy, grib_handle_cache_size, use_grib_metadata_cache

from earthkit.data.core.settings import SETTINGS

def _get_opt(v, name):
return v if v is not None else SETTINGS.get(name)

self._field_manager = GribFieldManager(_get_opt(grib_field_policy, "grib-field-policy"), self)
self._handle_manager = GribHandleManager(
_get_opt(grib_handle_policy, "grib-handle-policy"),
_get_opt(grib_handle_cache_size, "grib-handle-cache-size"),
)

self._use_metadata_cache = _get_opt(use_grib_metadata_cache, "use-grib-metadata-cache")

def _create_field(self, n):
part = self.part(n)
field = GribField(
part.path,
part.offset,
part.length,
self.array_backend,
manager=self._manager,
handle_manager=self._handle_manager,
use_metadata_cache=self._use_metadata_cache,
)
if field is None:
raise RuntimeError(f"Could not get a handle for part={part}")
Expand All @@ -376,13 +367,29 @@ def _getitem(self, n):
if n >= len(self):
raise IndexError(f"Index {n} out of range")

return self._manager.field(n, self._create_field)
return self._field_manager.field(n, self._create_field)

def __len__(self):
return self.number_of_parts()

def _diag(self):
return self._manager.diag()
r = defaultdict(int)
r.update(self._field_manager.diag())
r.update(self._handle_manager.diag())

if self._field_manager.cache is not None:
for f in self._field_manager.cache.values():
if f._handle is not None:
r["current_handle_count"] += 1

if self._use_metadata_cache:
try:
md_cache = f._diag()
for k in ["metadata_cache_hits", "metadata_cache_misses", "metadata_cache_size"]:
r[k] += md_cache[k]
except Exception:
pass
return r

@abstractmethod
def part(self, n):
Expand Down
Loading
Loading