Skip to content

Commit

Permalink
refactor(serialize): remove redundant code in sql store wrappers
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-dixon committed Nov 19, 2024
1 parent 4745314 commit e572492
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 164 deletions.
91 changes: 6 additions & 85 deletions src/ell/serialize/postgres.py
Original file line number Diff line number Diff line change
@@ -1,95 +1,16 @@
from typing import List, Optional, Dict, Any
from typing import Optional

from ell.serialize.sql import SQLSerializer, AsyncSQLSerializer
from ell.stores.sql import PostgresStore
from ell.stores.store import BlobStore, AsyncBlobStore
from ell.stores.studio import Invocation, SerializedLMP
from ell.types.serialize import LMP, WriteLMPInput, WriteInvocationInput
from ell.serialize.protocol import EllSerializer, EllAsyncSerializer


class PostgresSerializer(EllSerializer):
class PostgresSerializer(SQLSerializer):
def __init__(self, db_uri: str, blob_store: Optional[BlobStore] = None):
self.store = PostgresStore(db_uri, blob_store)
self.supports_blobs = blob_store is not None

def get_lmp(self, lmp_id: str):
lmp = self.store.get_lmp(lmp_id)
if lmp:
return LMP(**lmp.model_dump())
return None

def get_lmp_versions(self, fqn: str) -> List[LMP]:
slmps = self.store.get_versions_by_fqn(fqn)
return [LMP(**slmp.model_dump()) for slmp in slmps]

def write_lmp(self, lmp: WriteLMPInput, uses: List[str]) -> None:
model = SerializedLMP.from_api(lmp)
self.store.write_lmp(model, uses)

def write_invocation(self, input: WriteInvocationInput) -> None:
invocation = Invocation.from_api(input.invocation)
self.store.write_invocation(invocation, set(input.consumes))
return None

def store_blob(self, blob_id: str, blob: bytes, metadata: Optional[Dict[str, Any]] = None) -> str:
if self.store.blob_store is None:
raise ValueError("Blob store is not enabled")
return self.store.blob_store.store_blob(blob=blob, blob_id=blob_id)

def retrieve_blob(self, blob_id: str) -> bytes:
if self.store.blob_store is None:
raise ValueError("Blob store is not enabled")
return self.store.blob_store.retrieve_blob(blob_id)

def close(self):
pass
super().__init__(PostgresStore(db_uri, blob_store))


# todo(async): the underlying store is not async-aware
class AsyncPostgresSerializer(EllAsyncSerializer):
class AsyncPostgresSerializer(AsyncSQLSerializer):
def __init__(self, db_uri: str, blob_store: Optional[AsyncBlobStore] = None):
self.store = PostgresStore(db_uri, blob_store)
self.blob_store = blob_store
self.supports_blobs = blob_store is not None

async def get_lmp(self, lmp_id: str) -> Optional[LMP]:
lmp = self.store.get_lmp(lmp_id)
if lmp:
return LMP(**lmp.model_dump())
return None

async def get_lmp_versions(self, fqn: str) -> List[LMP]:
slmps = self.store.get_versions_by_fqn(fqn)
return [LMP(**slmp.model_dump()) for slmp in slmps]

async def write_lmp(self, lmp: WriteLMPInput, uses: List[str]) -> None:
model = SerializedLMP.from_api(lmp)
self.store.write_lmp(model, uses)

async def write_invocation(self, input: WriteInvocationInput) -> None:
invocation = Invocation.from_api(input.invocation)
self.store.write_invocation(
invocation,
set(input.consumes)
)
return None

async def store_blob(self, blob_id: str, blob: bytes, metadata: Optional[Dict[str, Any]] = None) -> str:
if self.blob_store is None:
raise ValueError("Blob store is not enabled")
return await self.blob_store.store_blob(blob=blob, blob_id=blob_id)

async def retrieve_blob(self, blob_id: str) -> bytes:
if self.blob_store is None:
raise ValueError("Blob store is not enabled")
return await self.blob_store.retrieve_blob(blob_id)

async def close(self):
# todo. Do we have a close method?
pass

async def __aenter__(self):
return self

async def __aexit__(self):
await self.close()
super().__init__(PostgresStore(db_uri, blob_store))
8 changes: 4 additions & 4 deletions src/ell/serialize/sql.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from typing import List, Optional, Dict, Any

import ell.stores.store
from ell.stores.store import Store
from ell.stores.studio import Invocation, SerializedLMP
from ell.types.serialize import LMP, WriteLMPInput, WriteInvocationInput
from ell.serialize.protocol import EllSerializer, EllAsyncSerializer


class SQLSerializer(EllSerializer):
def __init__(self, store: ell.stores.store.Store ):
def __init__(self, store: Store):
self.store = store
self.supports_blobs = store.has_blob_storage

Expand Down Expand Up @@ -46,7 +46,7 @@ def close(self):

# todo(async): the underlying store and blob store is not async-aware
class AsyncSQLSerializer(EllAsyncSerializer):
def __init__(self, store: ell.stores.store.Store):
def __init__(self, store: Store):
self.store = store
self.supports_blobs = store.has_blob_storage

Expand Down Expand Up @@ -78,7 +78,7 @@ async def store_blob(self, blob_id: str, blob: bytes, metadata: Optional[Dict[st
return self.store.blob_store.store_blob(blob=blob, blob_id=blob_id)

async def retrieve_blob(self, blob_id: str) -> bytes:
if self.blob_store is None:
if self.store.blob_store is None:
raise ValueError("Blob store is not enabled")
return self.store.blob_store.retrieve_blob(blob_id)

Expand Down
82 changes: 7 additions & 75 deletions src/ell/serialize/sqlite.py
Original file line number Diff line number Diff line change
@@ -1,86 +1,18 @@
from typing import List, Optional, Dict, Any
from typing import Optional

from ell.serialize.protocol import EllSerializer, EllAsyncSerializer
from ell.serialize.sql import SQLSerializer, AsyncSQLSerializer
from ell.stores.sql import SQLiteStore
from ell.stores.store import AsyncBlobStore, BlobStore
from ell.stores.studio import SerializedLMP, Invocation
from ell.types.serialize import WriteLMPInput, WriteInvocationInput, LMP



class SQLiteSerializer(EllSerializer):
class SQLiteSerializer(SQLSerializer):
def __init__(self, storage_dir: str, blob_store: Optional[BlobStore] = None):
self.store = SQLiteStore(storage_dir, blob_store)
self.supports_blobs = True
super().__init__(SQLiteStore(storage_dir, blob_store))

def get_lmp(self, lmp_id: str):
lmp = self.store.get_lmp(lmp_id)
if lmp:
return LMP(**lmp.model_dump())
return None

def get_lmp_versions(self, fqn: str) -> List[LMP]:
slmps = self.store.get_versions_by_fqn(fqn)
return [LMP(**slmp.model_dump()) for slmp in slmps]

def write_lmp(self, lmp: WriteLMPInput, uses: List[str]) -> None:
serialized_lmp = SerializedLMP.from_api(lmp)
self.store.write_lmp(serialized_lmp, uses)

def write_invocation(self, input: WriteInvocationInput) -> None:
invocation = Invocation.from_api(input.invocation)
self.store.write_invocation(invocation, set(input.consumes))
return None

def store_blob(self, blob_id: str, blob: bytes, metadata: Optional[Dict[str, Any]] = None) -> str:
return self.store.blob_store.store_blob(blob, blob_id) # type: ignore

def retrieve_blob(self, blob_id: str) -> bytes:
return self.store.blob_store.retrieve_blob(blob_id) # type: ignore

def close(self):
pass



# todo(async). underlying store is not async-aware
class AsyncSQLiteSerializer(EllAsyncSerializer):
# todo(async). underlying store is not async
class AsyncSQLiteSerializer(AsyncSQLSerializer):
def __init__(self, storage_dir: str, blob_store: Optional[AsyncBlobStore] = None):
self.store = SQLiteStore(storage_dir, blob_store)
self.blob_store = blob_store
self.supports_blobs = True

async def get_lmp(self, lmp_id: str):
lmp = self.store.get_lmp(lmp_id)
if lmp:
return LMP(**lmp.model_dump())
return None

async def get_lmp_versions(self, fqn: str) -> List[LMP]:
slmps = self.store.get_versions_by_fqn(fqn)
return [LMP(**slmp.model_dump()) for slmp in slmps]

async def write_lmp(self, lmp: WriteLMPInput, uses: List[str]) -> None:
serialized_lmp = SerializedLMP.from_api(lmp)
self.store.write_lmp(serialized_lmp, uses)

async def write_invocation(self, input: WriteInvocationInput) -> None:
invocation = Invocation.from_api(input.invocation)
self.store.write_invocation(invocation, set(input.consumes))
return None

async def store_blob(self, blob_id: str, blob: bytes, metadata: Optional[Dict[str, Any]] = None) -> str:
return await self.blob_store.store_blob(blob, blob_id) # type: ignore

async def retrieve_blob(self, blob_id: str) -> bytes:
return await self.blob_store.retrieve_blob(blob_id) # type: ignore

async def close(self):
pass

async def __aenter__(self):
return self

async def __aexit__(self):
await self.close()
super().__init__(SQLiteStore(storage_dir, blob_store))

0 comments on commit e572492

Please sign in to comment.