Skip to content

Commit 911c3ca

Browse files
authored
Fixes #1240 (#1241)
* Fixes #1240 The cache store assumed that every persister took a `path` argument. That is not the case because the savers / loaders wrap external APIs and we decided to not try to create our own abstraction layer around them, and instead mirror them. E.g. polars takes `file`, but pandas takes `path`. This means future changes could need to change things here. * Adds tests To catch case with `file` and without `path` or `file`.
1 parent 47a5146 commit 911c3ca

File tree

2 files changed

+126
-2
lines changed

2 files changed

+126
-2
lines changed

hamilton/caching/stores/file.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import inspect
12
import shutil
23
from pathlib import Path
34
from typing import Any, Optional
@@ -68,8 +69,24 @@ def set(
6869
if saver_cls is not None:
6970
# materialized_path
7071
materialized_path = self._materialized_path(data_version, saver_cls)
71-
saver = saver_cls(path=str(materialized_path.absolute()))
72-
loader = loader_cls(path=str(materialized_path.absolute()))
72+
saver_argspec = inspect.getfullargspec(saver_cls.__init__)
73+
loader_argspec = inspect.getfullargspec(loader_cls.__init__)
74+
if "file" in saver_argspec.args:
75+
saver = saver_cls(file=str(materialized_path.absolute()))
76+
elif "path" in saver_argspec.args:
77+
saver = saver_cls(path=str(materialized_path.absolute()))
78+
else:
79+
raise ValueError(
80+
f"Saver [{saver_cls.name()}] must have either `file` or `path` as an argument."
81+
)
82+
if "file" in loader_argspec.args:
83+
loader = loader_cls(file=str(materialized_path.absolute()))
84+
elif "path" in loader_argspec.args:
85+
loader = loader_cls(path=str(materialized_path.absolute()))
86+
else:
87+
raise ValueError(
88+
f"Loader [{loader_cls.name()}] must have either `file` or `path` as an argument."
89+
)
7390
else:
7491
saver = None
7592
loader = None

tests/caching/test_result_store.py

+107
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import pathlib
22
import pickle
3+
from typing import Any, Collection, Dict, Tuple, Type
34

45
import pytest
56

67
from hamilton.caching import fingerprinting
78
from hamilton.caching.stores.base import search_data_adapter_registry
89
from hamilton.caching.stores.file import FileResultStore
10+
from hamilton.io.data_adapters import DataLoader, DataSaver
911

1012

1113
@pytest.fixture
@@ -114,3 +116,108 @@ def test_save_and_load_materializer(format, value, result_store):
114116

115117
assert materialized_path.exists()
116118
assert fingerprinting.hash_value(value) == fingerprinting.hash_value(retrieved_value)
119+
120+
121+
class FakeParquetSaver(DataSaver):
122+
def __init__(self, file):
123+
self.file = file
124+
125+
def save_data(self, data: Any) -> Dict[str, Any]:
126+
with open(self.file, "w") as f:
127+
f.write(str(data))
128+
return {"meta": "data"}
129+
130+
@classmethod
131+
def applicable_types(cls) -> Collection[Type]:
132+
pass
133+
134+
@classmethod
135+
def name(cls) -> str:
136+
return "fake_parquet"
137+
138+
139+
class FakeParquetLoader(DataLoader):
140+
def __init__(self, file):
141+
self.file = file
142+
143+
def load_data(self, type_: Type[Type]) -> Tuple[Type, Dict[str, Any]]:
144+
with open(self.file, "r") as f:
145+
data = eval(f.read())
146+
return data, {"meta": data}
147+
148+
@classmethod
149+
def applicable_types(cls) -> Collection[Type]:
150+
pass
151+
152+
@classmethod
153+
def name(cls) -> str:
154+
return "fake_parquet"
155+
156+
157+
def test_save_and_load_file_in_init(result_store):
158+
value = {"a": 1}
159+
saver_cls, loader_cls = (FakeParquetSaver, FakeParquetLoader)
160+
data_version = "foo"
161+
materialized_path = result_store._materialized_path(data_version, saver_cls)
162+
163+
result_store.set(
164+
data_version=data_version, result=value, saver_cls=saver_cls, loader_cls=loader_cls
165+
)
166+
retrieved_value = result_store.get(data_version)
167+
168+
assert materialized_path.exists()
169+
assert fingerprinting.hash_value(value) == fingerprinting.hash_value(retrieved_value)
170+
171+
172+
class BadSaver(DataSaver):
173+
def __init__(self, file123):
174+
self.file = file123
175+
176+
def save_data(self, data: Any) -> Dict[str, Any]:
177+
with open(self.file, "w") as f:
178+
f.write(str(data))
179+
return {"meta": "data"}
180+
181+
@classmethod
182+
def applicable_types(cls) -> Collection[Type]:
183+
pass
184+
185+
@classmethod
186+
def name(cls) -> str:
187+
return "fake_parquet"
188+
189+
190+
class BadLoader(DataLoader):
191+
def __init__(self, file123):
192+
self.file = file123
193+
194+
def load_data(self, type_: Type[Type]) -> Tuple[Type, Dict[str, Any]]:
195+
with open(self.file, "r") as f:
196+
data = eval(f.read())
197+
return data, {"meta": data}
198+
199+
@classmethod
200+
def applicable_types(cls) -> Collection[Type]:
201+
pass
202+
203+
@classmethod
204+
def name(cls) -> str:
205+
return "fake_parquet"
206+
207+
208+
def test_save_and_load_not_path_not_file_init_error(result_store):
209+
value = {"a": 1}
210+
saver_cls, loader_cls = (BadSaver, BadLoader)
211+
data_version = "foo"
212+
with pytest.raises(ValueError):
213+
result_store.set(
214+
data_version=data_version, result=value, saver_cls=saver_cls, loader_cls=loader_cls
215+
)
216+
with pytest.raises(ValueError):
217+
result_store.set( # make something store it in the result store
218+
data_version=data_version,
219+
result=value,
220+
saver_cls=FakeParquetSaver,
221+
loader_cls=loader_cls,
222+
)
223+
result_store.get(data_version)

0 commit comments

Comments
 (0)