Skip to content

Commit

Permalink
refactor: dedupe http error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-dixon committed Nov 19, 2024
1 parent 3440233 commit 3db907c
Showing 1 changed file with 36 additions and 61 deletions.
97 changes: 36 additions & 61 deletions src/ell/serialize/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,8 @@
# pydantic_ltype_aware_cattr.unstructure(obj),
# sort_keys=True, default=repr, ensure_ascii=False)


class EllHTTPSerializer(EllSerializer):
def __init__(self, base_url: str):
self.base_url = base_url
self.client = httpx.Client(base_url=base_url)
self.supports_blobs = True # we assume the server does, if not will find out later
self.logger = logging.getLogger(
__name__).getChild(self.__class__.__name__)

def _handle_http_error(
self,
def make_handle_http_error(logger: logging.Logger):
def handle_http_error(
error: HTTPStatusError,
span: str,
message: Optional[str] = None,
Expand All @@ -36,7 +27,7 @@ def _handle_http_error(
if error.response.status_code == 422:
error_detail = error.response.json().get(
"detail", "No detailed error message provided")
self.logger.error(
logger.error(
message or f"HTTP {error.response.status_code} Error in {span}",
extra={
**(extra or {}),
Expand All @@ -50,14 +41,25 @@ def _handle_http_error(
raise ValueError(f"Invalid input: {error_detail}") from error
raise

return handle_http_error


class EllHTTPSerializer(EllSerializer):
def __init__(self, base_url: str):
self.base_url = base_url
self.client = httpx.Client(base_url=base_url)
self.supports_blobs = True # we assume the server does, if not will find out later
self.logger = logging.getLogger(__name__).getChild(self.__class__.__name__)
self._handle_http_error = make_handle_http_error(self.logger)

def get_lmp(self, lmp_id: str) -> GetLMPOutput:
try:
response = self.client.get(f"/lmp/{lmp_id}")
response.raise_for_status()
data = response.json()
return None if data is None else LMP(**data)
except HTTPStatusError as e:
self._handle_http_error(e, "get_lmp")
self._handle_http_error(error=e, span="get_lmp", message="Failed to get LMP", extra={lmp_id: lmp_id})
raise

def write_lmp(self, lmp: WriteLMPInput, uses: List[str]) -> None:
Expand All @@ -71,9 +73,9 @@ def write_lmp(self, lmp: WriteLMPInput, uses: List[str]) -> None:
response.raise_for_status()
except HTTPStatusError as e:
self._handle_http_error(
error=e,
span="write_lmp",
message="Failed to write LMP",
span="write_lmp",
error=e,
extra={'lmp_id': lmp.lmp_id, 'lmp_version': lmp.version_number}
)
raise
Expand Down Expand Up @@ -154,30 +156,7 @@ def __init__(self, base_url: str):
self.supports_blobs = True # we assume the server does, if not will find out later
self.logger = logging.getLogger(
__name__).getChild(self.__class__.__name__)

def _handle_http_error(
self,
error: HTTPStatusError,
span: str,
message: Optional[str] = None,
extra: Optional[Dict[str, Any]] = None
) -> None:
if error.response.status_code == 422:
error_detail = error.response.json().get(
"detail", "No detailed error message provided")
self.logger.error(
message or f"HTTP {error.response.status_code} Error in {span}",
extra={
**(extra or {}),
"status_code": error.response.status_code,
"error_detail": error_detail,
"span": span,
"url": str(error.response.url),
"request_id": error.response.headers.get("x-request-id"),
}
)
raise ValueError(f"Invalid input: {error_detail}") from error
raise
self._handle_http_error = make_handle_http_error(self.logger)

async def get_lmp(self, lmp_id: str) -> GetLMPOutput:
try:
Expand All @@ -196,6 +175,21 @@ async def get_lmp(self, lmp_id: str) -> GetLMPOutput:
)
raise

async def get_lmp_versions(self, fqn: str) -> List[LMP]:
try:
response = await self.client.get("/lmp/versions", params={"fqn": fqn})
response.raise_for_status()
data = response.json()
return [LMP(**lmp_data) for lmp_data in data]
except HTTPStatusError as e:
self._handle_http_error(
error=e,
span="get_lmp_versions",
message="Failed to get LMP versions",
extra={'fqn': fqn}
)
raise

async def write_lmp(self, lmp: WriteLMPInput, uses: List[str]) -> None:
try:
response = await self.client.post("/lmp", json={
Expand All @@ -217,17 +211,13 @@ async def write_invocation(self, input: WriteInvocationInput) -> None:
response = await self.client.post(
"/invocation",
headers={"Content-Type": "application/json"},
content=input.model_dump_json(exclude_none=True, exclude_unset=True),
content=input.model_dump_json(exclude_none=True, exclude_unset=True)
)
response.raise_for_status()
return None
except HTTPStatusError as e:
self._handle_http_error(
error=e,
span="write_invocation",
message="Failed to write invocation",
extra={'invocation_id': input.invocation.id}
)
self._handle_http_error(message="Failed to write invocation", span="write_invocation", error=e,
extra={'invocation_id': input.invocation.id})
raise

async def store_blob(self, blob_id: str, blob: bytes, metadata: Optional[Dict[str, Any]] = None) -> str:
Expand Down Expand Up @@ -270,18 +260,3 @@ async def __aenter__(self):

async def __aexit__(self):
await self.close()

async def get_lmp_versions(self, fqn: str) -> List[LMP]:
try:
response = await self.client.get("/lmp/versions", params={"fqn": fqn})
response.raise_for_status()
data = response.json()
return [LMP(**lmp_data) for lmp_data in data]
except HTTPStatusError as e:
self._handle_http_error(
error=e,
span="get_lmp_versions",
message="Failed to get LMP versions",
extra={'fqn': fqn}
)
raise

0 comments on commit 3db907c

Please sign in to comment.