Skip to content

Commit

Permalink
Feature/fix memory leak in metadata cache (#443)
Browse files Browse the repository at this point in the history
* Fix memory leak in metadata cache
  • Loading branch information
sandorkertesz authored Sep 7, 2024
1 parent f41d672 commit bf1136a
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 145 deletions.
20 changes: 11 additions & 9 deletions src/earthkit/data/readers/grib/codes.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,18 +245,19 @@ class GribField(Field):

_handle = None

def __init__(self, path, offset, length, backend, manager=None):
def __init__(self, path, offset, length, backend, handle_manager=None, use_metadata_cache=False):
super().__init__(backend)
self.path = path
self._offset = offset
self._length = length
self._manager = manager
self._handle_manager = handle_manager
self._use_metadata_cache = use_metadata_cache

@property
def handle(self):
r""":class:`CodesHandle`: Get an object providing access to the low level GRIB message structure."""
if self._manager is not None:
handle = self._manager.handle(self, self._create_handle)
if self._handle_manager is not None:
handle = self._handle_manager.handle(self, self._create_handle)
if handle is None:
raise RuntimeError(f"Could not get a handle for offset={self.offset} in {self.path}")
return handle
Expand All @@ -282,13 +283,14 @@ def offset(self):

@cached_property
def _metadata(self):
cache = False
if self._manager is not None:
cache = self._manager.use_grib_metadata_cache
if cache:
cache = self._manager._make_metadata_cache()
cache = self._use_metadata_cache
if cache:
cache = self._make_metadata_cache()
return GribFieldMetadata(self, cache=cache)

def _make_metadata_cache(self):
return dict()

def __repr__(self):
return "GribField(%s,%s,%s,%s,%s,%s)" % (
self._metadata.get("shortName", None),
Expand Down
177 changes: 92 additions & 85 deletions src/earthkit/data/readers/grib/index/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,126 +217,106 @@ def __init__(self, *args, **kwargs):
FieldList._init_from_multi(self, self)


class GribResourceManager:
def __init__(
self, owner, grib_field_policy, grib_handle_policy, grib_handle_cache_size, use_grib_metadata_cache
):
from earthkit.data.core.settings import SETTINGS

def _get(v, name):
return v if v is not None else SETTINGS.get(name)
class GribFieldManager:
def __init__(self, policy, owner):
self.policy = policy
self.cache = None

self.grib_field_policy = _get(grib_field_policy, "grib-field-policy")
self.grib_handle_policy = _get(grib_handle_policy, "grib-handle-policy")
self.grib_handle_cache_size = _get(grib_handle_cache_size, "grib-handle-cache-size")
self.use_grib_metadata_cache = _get(use_grib_metadata_cache, "use-grib-metadata-cache")

# fields
self.field_cache = None
if self.grib_field_policy == "persistent":
if self.policy == "persistent":
from lru import LRU

# TODO: the number of fields might only be available only later (e.g. fieldlists with
# an SQL index). Consider making _field_cache a cached property.
# an SQL index). Consider making cache a cached property.
n = len(owner)
if n > 0:
self.field_cache = LRU(n)

# handles
self.handle_cache = None
if self.grib_handle_policy == "cache":
if self.grib_handle_cache_size > 0:
from lru import LRU
self.cache = LRU(n)

self.handle_cache = LRU(self.grib_handle_cache_size)
else:
raise ValueError(
'grib_handle_cache_size must be greater than 0 when grib_handle_policy="cache"'
)

self.handle_create_count = 0
self.field_create_count = 0

# check consistency
if self.field_cache is not None:
self.grib_field_policy == "persistent"
else:
self.grib_field_policy == "temporary"

if self.handle_cache is not None:
self.grib_handle_policy == "cache"
else:
self.grib_handle_policy in ["persistent", "temporary"]
if self.cache is not None:
assert self.policy == "persistent"

def field(self, n, create):
if self.grib_field_policy == "persistent":
if n in self.field_cache:
return self.field_cache[n]
if self.cache is not None:
if n in self.cache:
return self.cache[n]
else:
field = create(n)
self._field_created()
self.field_cache[n] = field
self.cache[n] = field
return field
else:
self._field_created()
return create(n)

def _field_created(self):
self.field_create_count += 1

def diag(self):
r = defaultdict(int)
r["grib_field_policy"] = self.policy
if self.cache is not None:
r["field_cache_size"] = len(self.cache)

r["field_create_count"] = self.field_create_count
return r


class GribHandleManager:
def __init__(self, policy, cache_size):
self.policy = policy
self.max_cache_size = cache_size
self.cache = None

if self.policy == "cache":
if self.max_cache_size > 0:
from lru import LRU

self.cache = LRU(self.max_cache_size)
else:
raise ValueError(
'grib_handle_cache_size must be greater than 0 when grib_handle_policy="cache"'
)

self.handle_create_count = 0

# check consistency
if self.cache is not None:
self.policy == "cache"
else:
self.policy in ["persistent", "temporary"]

def handle(self, field, create):
if self.grib_handle_policy == "cache":
if self.policy == "cache":
key = (field.path, field._offset)
if key in self.handle_cache:
return self.handle_cache[key]
if key in self.cache:
return self.cache[key]
else:
handle = create()
self._handle_created()
self.handle_cache[key] = handle
self.cache[key] = handle
return handle
elif self.grib_handle_policy == "persistent":
elif self.policy == "persistent":
if field._handle is None:
field._handle = create()
self._handle_created()
return field._handle
elif self.grib_handle_policy == "temporary":
elif self.policy == "temporary":
self._handle_created()
return create()

def _field_created(self):
self.field_create_count += 1

def _handle_created(self):
self.handle_create_count += 1

def _make_metadata_cache(self):
return dict()

def diag(self):
r = defaultdict(int)
r["grib_field_policy"] = self.grib_field_policy
r["grib_handle_policy"] = self.grib_handle_policy
r["grib_handle_cache_size"] = self.grib_handle_cache_size

if self.field_cache is not None:
r["field_cache_size"] = len(self.field_cache)

r["field_create_count"] = self.field_create_count

if self.handle_cache is not None:
r["handle_cache_size"] = len(self.handle_cache)
r["grib_handle_policy"] = self.policy
r["grib_handle_cache_size"] = self.max_cache_size
if self.cache is not None:
r["handle_cache_size"] = len(self.cache)

r["handle_create_count"] = self.handle_create_count

if self.field_cache is not None:
for f in self.field_cache.values():
if f._handle is not None:
r["current_handle_count"] += 1

try:
md_cache = f._diag()
for k in ["metadata_cache_hits", "metadata_cache_misses", "metadata_cache_size"]:
r[k] += md_cache[k]
except Exception:
pass

return r


Expand All @@ -351,18 +331,29 @@ def __init__(
**kwargs,
):
super().__init__(*args, **kwargs)
self._manager = GribResourceManager(
self, grib_field_policy, grib_handle_policy, grib_handle_cache_size, use_grib_metadata_cache

from earthkit.data.core.settings import SETTINGS

def _get_opt(v, name):
return v if v is not None else SETTINGS.get(name)

self._field_manager = GribFieldManager(_get_opt(grib_field_policy, "grib-field-policy"), self)
self._handle_manager = GribHandleManager(
_get_opt(grib_handle_policy, "grib-handle-policy"),
_get_opt(grib_handle_cache_size, "grib-handle-cache-size"),
)

self._use_metadata_cache = _get_opt(use_grib_metadata_cache, "use-grib-metadata-cache")

def _create_field(self, n):
part = self.part(n)
field = GribField(
part.path,
part.offset,
part.length,
self.array_backend,
manager=self._manager,
handle_manager=self._handle_manager,
use_metadata_cache=self._use_metadata_cache,
)
if field is None:
raise RuntimeError(f"Could not get a handle for part={part}")
Expand All @@ -376,13 +367,29 @@ def _getitem(self, n):
if n >= len(self):
raise IndexError(f"Index {n} out of range")

return self._manager.field(n, self._create_field)
return self._field_manager.field(n, self._create_field)

def __len__(self):
return self.number_of_parts()

def _diag(self):
return self._manager.diag()
r = defaultdict(int)
r.update(self._field_manager.diag())
r.update(self._handle_manager.diag())

if self._field_manager.cache is not None:
for f in self._field_manager.cache.values():
if f._handle is not None:
r["current_handle_count"] += 1

if self._use_metadata_cache:
try:
md_cache = f._diag()
for k in ["metadata_cache_hits", "metadata_cache_misses", "metadata_cache_size"]:
r[k] += md_cache[k]
except Exception:
pass
return r

@abstractmethod
def part(self, n):
Expand Down
Loading

0 comments on commit bf1136a

Please sign in to comment.