-
Notifications
You must be signed in to change notification settings - Fork 333
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor(serialize): remove redundant code in sql store wrappers
- Loading branch information
1 parent
4745314
commit e572492
Showing
3 changed files
with
17 additions
and
164 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,95 +1,16 @@ | ||
from typing import List, Optional, Dict, Any | ||
from typing import Optional | ||
|
||
from ell.serialize.sql import SQLSerializer, AsyncSQLSerializer | ||
from ell.stores.sql import PostgresStore | ||
from ell.stores.store import BlobStore, AsyncBlobStore | ||
from ell.stores.studio import Invocation, SerializedLMP | ||
from ell.types.serialize import LMP, WriteLMPInput, WriteInvocationInput | ||
from ell.serialize.protocol import EllSerializer, EllAsyncSerializer | ||
|
||
|
||
class PostgresSerializer(EllSerializer): | ||
class PostgresSerializer(SQLSerializer): | ||
def __init__(self, db_uri: str, blob_store: Optional[BlobStore] = None): | ||
self.store = PostgresStore(db_uri, blob_store) | ||
self.supports_blobs = blob_store is not None | ||
|
||
def get_lmp(self, lmp_id: str): | ||
lmp = self.store.get_lmp(lmp_id) | ||
if lmp: | ||
return LMP(**lmp.model_dump()) | ||
return None | ||
|
||
def get_lmp_versions(self, fqn: str) -> List[LMP]: | ||
slmps = self.store.get_versions_by_fqn(fqn) | ||
return [LMP(**slmp.model_dump()) for slmp in slmps] | ||
|
||
def write_lmp(self, lmp: WriteLMPInput, uses: List[str]) -> None: | ||
model = SerializedLMP.from_api(lmp) | ||
self.store.write_lmp(model, uses) | ||
|
||
def write_invocation(self, input: WriteInvocationInput) -> None: | ||
invocation = Invocation.from_api(input.invocation) | ||
self.store.write_invocation(invocation, set(input.consumes)) | ||
return None | ||
|
||
def store_blob(self, blob_id: str, blob: bytes, metadata: Optional[Dict[str, Any]] = None) -> str: | ||
if self.store.blob_store is None: | ||
raise ValueError("Blob store is not enabled") | ||
return self.store.blob_store.store_blob(blob=blob, blob_id=blob_id) | ||
|
||
def retrieve_blob(self, blob_id: str) -> bytes: | ||
if self.store.blob_store is None: | ||
raise ValueError("Blob store is not enabled") | ||
return self.store.blob_store.retrieve_blob(blob_id) | ||
|
||
def close(self): | ||
pass | ||
super().__init__(PostgresStore(db_uri, blob_store)) | ||
|
||
|
||
# todo(async): the underlying store is not async-aware | ||
class AsyncPostgresSerializer(EllAsyncSerializer): | ||
class AsyncPostgresSerializer(AsyncSQLSerializer): | ||
def __init__(self, db_uri: str, blob_store: Optional[AsyncBlobStore] = None): | ||
self.store = PostgresStore(db_uri, blob_store) | ||
self.blob_store = blob_store | ||
self.supports_blobs = blob_store is not None | ||
|
||
async def get_lmp(self, lmp_id: str) -> Optional[LMP]: | ||
lmp = self.store.get_lmp(lmp_id) | ||
if lmp: | ||
return LMP(**lmp.model_dump()) | ||
return None | ||
|
||
async def get_lmp_versions(self, fqn: str) -> List[LMP]: | ||
slmps = self.store.get_versions_by_fqn(fqn) | ||
return [LMP(**slmp.model_dump()) for slmp in slmps] | ||
|
||
async def write_lmp(self, lmp: WriteLMPInput, uses: List[str]) -> None: | ||
model = SerializedLMP.from_api(lmp) | ||
self.store.write_lmp(model, uses) | ||
|
||
async def write_invocation(self, input: WriteInvocationInput) -> None: | ||
invocation = Invocation.from_api(input.invocation) | ||
self.store.write_invocation( | ||
invocation, | ||
set(input.consumes) | ||
) | ||
return None | ||
|
||
async def store_blob(self, blob_id: str, blob: bytes, metadata: Optional[Dict[str, Any]] = None) -> str: | ||
if self.blob_store is None: | ||
raise ValueError("Blob store is not enabled") | ||
return await self.blob_store.store_blob(blob=blob, blob_id=blob_id) | ||
|
||
async def retrieve_blob(self, blob_id: str) -> bytes: | ||
if self.blob_store is None: | ||
raise ValueError("Blob store is not enabled") | ||
return await self.blob_store.retrieve_blob(blob_id) | ||
|
||
async def close(self): | ||
# todo. Do we have a close method? | ||
pass | ||
|
||
async def __aenter__(self): | ||
return self | ||
|
||
async def __aexit__(self): | ||
await self.close() | ||
super().__init__(PostgresStore(db_uri, blob_store)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,86 +1,18 @@ | ||
from typing import List, Optional, Dict, Any | ||
from typing import Optional | ||
|
||
from ell.serialize.protocol import EllSerializer, EllAsyncSerializer | ||
from ell.serialize.sql import SQLSerializer, AsyncSQLSerializer | ||
from ell.stores.sql import SQLiteStore | ||
from ell.stores.store import AsyncBlobStore, BlobStore | ||
from ell.stores.studio import SerializedLMP, Invocation | ||
from ell.types.serialize import WriteLMPInput, WriteInvocationInput, LMP | ||
|
||
|
||
|
||
class SQLiteSerializer(EllSerializer): | ||
class SQLiteSerializer(SQLSerializer): | ||
def __init__(self, storage_dir: str, blob_store: Optional[BlobStore] = None): | ||
self.store = SQLiteStore(storage_dir, blob_store) | ||
self.supports_blobs = True | ||
super().__init__(SQLiteStore(storage_dir, blob_store)) | ||
|
||
def get_lmp(self, lmp_id: str): | ||
lmp = self.store.get_lmp(lmp_id) | ||
if lmp: | ||
return LMP(**lmp.model_dump()) | ||
return None | ||
|
||
def get_lmp_versions(self, fqn: str) -> List[LMP]: | ||
slmps = self.store.get_versions_by_fqn(fqn) | ||
return [LMP(**slmp.model_dump()) for slmp in slmps] | ||
|
||
def write_lmp(self, lmp: WriteLMPInput, uses: List[str]) -> None: | ||
serialized_lmp = SerializedLMP.from_api(lmp) | ||
self.store.write_lmp(serialized_lmp, uses) | ||
|
||
def write_invocation(self, input: WriteInvocationInput) -> None: | ||
invocation = Invocation.from_api(input.invocation) | ||
self.store.write_invocation(invocation, set(input.consumes)) | ||
return None | ||
|
||
def store_blob(self, blob_id: str, blob: bytes, metadata: Optional[Dict[str, Any]] = None) -> str: | ||
return self.store.blob_store.store_blob(blob, blob_id) # type: ignore | ||
|
||
def retrieve_blob(self, blob_id: str) -> bytes: | ||
return self.store.blob_store.retrieve_blob(blob_id) # type: ignore | ||
|
||
def close(self): | ||
pass | ||
|
||
|
||
|
||
# todo(async). underlying store is not async-aware | ||
class AsyncSQLiteSerializer(EllAsyncSerializer): | ||
# todo(async). underlying store is not async | ||
class AsyncSQLiteSerializer(AsyncSQLSerializer): | ||
def __init__(self, storage_dir: str, blob_store: Optional[AsyncBlobStore] = None): | ||
self.store = SQLiteStore(storage_dir, blob_store) | ||
self.blob_store = blob_store | ||
self.supports_blobs = True | ||
|
||
async def get_lmp(self, lmp_id: str): | ||
lmp = self.store.get_lmp(lmp_id) | ||
if lmp: | ||
return LMP(**lmp.model_dump()) | ||
return None | ||
|
||
async def get_lmp_versions(self, fqn: str) -> List[LMP]: | ||
slmps = self.store.get_versions_by_fqn(fqn) | ||
return [LMP(**slmp.model_dump()) for slmp in slmps] | ||
|
||
async def write_lmp(self, lmp: WriteLMPInput, uses: List[str]) -> None: | ||
serialized_lmp = SerializedLMP.from_api(lmp) | ||
self.store.write_lmp(serialized_lmp, uses) | ||
|
||
async def write_invocation(self, input: WriteInvocationInput) -> None: | ||
invocation = Invocation.from_api(input.invocation) | ||
self.store.write_invocation(invocation, set(input.consumes)) | ||
return None | ||
|
||
async def store_blob(self, blob_id: str, blob: bytes, metadata: Optional[Dict[str, Any]] = None) -> str: | ||
return await self.blob_store.store_blob(blob, blob_id) # type: ignore | ||
|
||
async def retrieve_blob(self, blob_id: str) -> bytes: | ||
return await self.blob_store.retrieve_blob(blob_id) # type: ignore | ||
|
||
async def close(self): | ||
pass | ||
|
||
async def __aenter__(self): | ||
return self | ||
|
||
async def __aexit__(self): | ||
await self.close() | ||
super().__init__(SQLiteStore(storage_dir, blob_store)) | ||
|