Skip to content

Commit

Permalink
Switch to pooled objects
Browse files Browse the repository at this point in the history
  • Loading branch information
saulshanabrook committed Feb 27, 2025
1 parent 8f948d1 commit eb118bb
Showing 1 changed file with 38 additions and 19 deletions.
57 changes: 38 additions & 19 deletions python/egglog/declarations.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Literal, Protocol, TypeAlias, TypeVar, Union, runtime_checkable
from functools import cached_property
from typing import TYPE_CHECKING, ClassVar, Literal, Protocol, TypeAlias, TypeVar, Union, cast, runtime_checkable
from weakref import WeakValueDictionary

from typing_extensions import Self, assert_never

Expand Down Expand Up @@ -569,29 +571,46 @@ class CallDecl:
# Used for pretty printing classmethod calls with type parameters
bound_tp_params: tuple[JustTypeRef, ...] | None = None

# pool objects for faster __eq__
_args_to_value: ClassVar[WeakValueDictionary[tuple[object, ...], CallDecl]] = WeakValueDictionary({})

def __new__(cls, *args: object, **kwargs: object) -> Self:
"""
Pool CallDecls so that they can be compared by identity more quickly.
Neccessary bc we search for common parents when serializing CallDecl trees to egglog to
only serialize each sub-tree once.
"""
# normalize the args/kwargs to a tuple so that they can be compared
callable = args[0] if args else kwargs["callable"]
args_ = args[1] if len(args) > 1 else kwargs.get("args", ())
bound_tp_params = args[2] if len(args) > 2 else kwargs.get("bound_tp_params")

normalized_args = (callable, args_, bound_tp_params)
try:
return cast(Self, cls._args_to_value[normalized_args])
except KeyError:
res = super().__new__(cls)
cls._args_to_value[normalized_args] = res
return res

def __post_init__(self) -> None:
if self.bound_tp_params and not isinstance(self.callable, ClassMethodRef | InitRef):
msg = "Cannot bind type parameters to a non-class method callable."
raise ValueError(msg)

# def __hash__(self) -> int:
# return self._cached_hash

# @cached_property
# def _cached_hash(self) -> int:
# return hash((self.callable, self.args, self.bound_tp_params))

# def __eq__(self, other: object) -> bool:
# # Override eq to use cached hash for perf
# if not isinstance(other, CallDecl):
# return False
# if hash(self) != hash(other):
# return False
# return (
# self.callable == other.callable
# and self.args == other.args
# and self.bound_tp_params == other.bound_tp_params
# )
def __hash__(self) -> int:
return self._cached_hash

@cached_property
def _cached_hash(self) -> int:
return hash((self.callable, self.args, self.bound_tp_params))

def __eq__(self, other: object) -> bool:
return self is other

def __ne__(self, other: object) -> bool:
return self is not other


@dataclass(frozen=True)
Expand Down

0 comments on commit eb118bb

Please sign in to comment.