diff --git a/python/egglog/declarations.py b/python/egglog/declarations.py index 73c6aaf..deeeafd 100644 --- a/python/egglog/declarations.py +++ b/python/egglog/declarations.py @@ -7,7 +7,9 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Literal, Protocol, TypeAlias, TypeVar, Union, runtime_checkable +from functools import cached_property +from typing import TYPE_CHECKING, ClassVar, Literal, Protocol, TypeAlias, TypeVar, Union, cast, runtime_checkable +from weakref import WeakValueDictionary from typing_extensions import Self, assert_never @@ -569,29 +571,46 @@ class CallDecl: # Used for pretty printing classmethod calls with type parameters bound_tp_params: tuple[JustTypeRef, ...] | None = None + # pool objects for faster __eq__ + _args_to_value: ClassVar[WeakValueDictionary[tuple[object, ...], CallDecl]] = WeakValueDictionary({}) + + def __new__(cls, *args: object, **kwargs: object) -> Self: + """ + Pool CallDecls so that they can be compared by identity more quickly. + + Neccessary bc we search for common parents when serializing CallDecl trees to egglog to + only serialize each sub-tree once. + """ + # normalize the args/kwargs to a tuple so that they can be compared + callable = args[0] if args else kwargs["callable"] + args_ = args[1] if len(args) > 1 else kwargs.get("args", ()) + bound_tp_params = args[2] if len(args) > 2 else kwargs.get("bound_tp_params") + + normalized_args = (callable, args_, bound_tp_params) + try: + return cast(Self, cls._args_to_value[normalized_args]) + except KeyError: + res = super().__new__(cls) + cls._args_to_value[normalized_args] = res + return res + def __post_init__(self) -> None: if self.bound_tp_params and not isinstance(self.callable, ClassMethodRef | InitRef): msg = "Cannot bind type parameters to a non-class method callable." raise ValueError(msg) - # def __hash__(self) -> int: - # return self._cached_hash - - # @cached_property - # def _cached_hash(self) -> int: - # return hash((self.callable, self.args, self.bound_tp_params)) - - # def __eq__(self, other: object) -> bool: - # # Override eq to use cached hash for perf - # if not isinstance(other, CallDecl): - # return False - # if hash(self) != hash(other): - # return False - # return ( - # self.callable == other.callable - # and self.args == other.args - # and self.bound_tp_params == other.bound_tp_params - # ) + def __hash__(self) -> int: + return self._cached_hash + + @cached_property + def _cached_hash(self) -> int: + return hash((self.callable, self.args, self.bound_tp_params)) + + def __eq__(self, other: object) -> bool: + return self is other + + def __ne__(self, other: object) -> bool: + return self is not other @dataclass(frozen=True)