diff --git a/src/ell/py.typed b/src/ell/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/src/ell/types/message.py b/src/ell/types/message.py index d55302ed..7bc61e6a 100644 --- a/src/ell/types/message.py +++ b/src/ell/types/message.py @@ -7,18 +7,19 @@ from io import BytesIO from PIL import Image as PILImage -from pydantic import BaseModel, ConfigDict, Field, model_validator, field_validator, field_serializer +from pydantic import BaseModel, ConfigDict, model_validator, field_serializer from sqlmodel import Field from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import Any, Callable, Dict, List, Literal, Optional, Type, Union +from typing import Any, Callable, Dict, List, Optional, Union from ell.util.serialization import serialize_image _lstr_generic = Union[_lstr, str] InvocableTool = Callable[..., Union["ToolResult", _lstr_generic, List["ContentBlock"], ]] - +# AnyContent represents any type that can be passed to Message. +AnyContent = Union["ContentBlock", str, "ToolCall", "ToolResult", "ImageContent", np.ndarray, PILImage.Image, BaseModel] class ToolResult(BaseModel): @@ -178,7 +179,7 @@ def content(self): return getattr(self, self.type) @classmethod - def coerce(cls, content: Union["ContentBlock", str, ToolCall, ToolResult, ImageContent, np.ndarray, PILImage.Image, BaseModel]) -> "ContentBlock": + def coerce(cls, content: AnyContent) -> "ContentBlock": """ Coerce various types of content into a ContentBlock. @@ -266,7 +267,7 @@ def serialize_parsed(self, value: Optional[BaseModel], _info): def to_content_blocks( - content: Optional[Union[str, List[ContentBlock], List[Union[ContentBlock, str, ToolCall, ToolResult, ImageContent, np.ndarray, PILImage.Image, BaseModel]]]] = None, + content: Optional[Union[AnyContent, List[AnyContent]]] = None, **content_block_kwargs ) -> List[ContentBlock]: """ @@ -313,10 +314,10 @@ class Message(BaseModel): content: List[ContentBlock] - def __init__(self, role, content: Union[str, List[ContentBlock], List[Union[ContentBlock, str, ToolCall, ToolResult, ImageContent, np.ndarray, PILImage.Image, BaseModel]]] = None, **content_block_kwargs): - content = to_content_blocks(content, **content_block_kwargs) + def __init__(self, role: str, content: Union[AnyContent, List[AnyContent], None] = None, **content_block_kwargs): + content_blocks = to_content_blocks(content, **content_block_kwargs) - super().__init__(content=content, role=role) + super().__init__(role=role, content=content_blocks) # XXX: This choice of naming is unfortunate, but it is what it is. @property @@ -423,7 +424,7 @@ def call_tools_and_collect_as_message(self, parallel=False, max_workers=None): return Message(role="user", content=content) # HELPERS -def system(content: Union[str, List[ContentBlock]]) -> Message: +def system(content: Union[AnyContent, List[AnyContent]]) -> Message: """ Create a system message with the given content. @@ -436,7 +437,7 @@ def system(content: Union[str, List[ContentBlock]]) -> Message: return Message(role="system", content=content) -def user(content: Union[str, List[ContentBlock]]) -> Message: +def user(content: Union[AnyContent, List[AnyContent]]) -> Message: """ Create a user message with the given content. @@ -449,7 +450,7 @@ def user(content: Union[str, List[ContentBlock]]) -> Message: return Message(role="user", content=content) -def assistant(content: Union[str, List[ContentBlock]]) -> Message: +def assistant(content: Union[AnyContent, List[AnyContent]]) -> Message: """ Create an assistant message with the given content.