diff --git a/src/ell/providers/anthropic.py b/src/ell/providers/anthropic.py index 4a3d8ba2..ba379b09 100644 --- a/src/ell/providers/anthropic.py +++ b/src/ell/providers/anthropic.py @@ -197,7 +197,7 @@ def _content_block_to_anthropic_format(content_block: ContentBlock): type="tool_use", id=tool_call.tool_call_id, name=tool_call.tool.__name__, - input=tool_call.params.model_dump() if isinstance(tool_call.params, BaseModel) else tool_call.params, + input=tool_call.serialize_params(), ) elif (tool_result := content_block.tool_result): return dict( diff --git a/src/ell/providers/bedrock.py b/src/ell/providers/bedrock.py index 8c69948b..fc0a0cd6 100644 --- a/src/ell/providers/bedrock.py +++ b/src/ell/providers/bedrock.py @@ -202,7 +202,7 @@ def content_block_to_bedrock_format(content_block: ContentBlock) -> Dict[str, An "toolUse": { "toolUseId": content_block.tool_call.tool_call_id, "name": content_block.tool_call.tool.__name__, - "input": content_block.tool_call.params.model_dump() if isinstance(content_block.tool_call.params, BaseModel) else content_block.tool_call.params, + "input": content_block.tool_call.serialize_params(), } } elif content_block.tool_result: diff --git a/src/ell/providers/openai.py b/src/ell/providers/openai.py index b3cca688..ef85ff41 100644 --- a/src/ell/providers/openai.py +++ b/src/ell/providers/openai.py @@ -64,7 +64,7 @@ def translate_to_provider(self, ell_call : EllCallParams) -> Dict[str, Any]: type="function", function=dict( name=tool_call.tool.__name__, - arguments=tool_call.params.model_dump_json() if isinstance(tool_call.params,BaseModel) else json.dumps(tool_call.params, ensure_ascii=False) + arguments=json.dumps(tool_call.serialize_params(), ensure_ascii=False) ) ) for tool_call in tool_calls ], role="assistant", diff --git a/src/ell/types/message.py b/src/ell/types/message.py index 14e39841..852b4d6f 100644 --- a/src/ell/types/message.py +++ b/src/ell/types/message.py @@ -40,10 +40,15 @@ def text_only(self) -> str: def __repr__(self): return f"{self.__class__.__name__}(tool_call_id={self.tool_call_id}, result={_content_to_text(self.result)})" +class ToolReference(BaseModel): + """A reference to an invocable tool""" + fqn: str = Field(description="The fully qualified name of the tool") + hash: str = Field(description="The hash of the tool and its dependencies") + class ToolCall(BaseModel): - tool: Union[InvocableTool, str] = Field(description="The tool function to call or a reference to it when serialized") + tool: Union[InvocableTool, ToolReference] = Field(description="The tool function to call or a reference to it when serialized") tool_call_id: Optional[_lstr_generic] = Field(default=None) - params: Union[Dict[str, Any], BaseModel] + params: Union[Dict[str, Any], BaseModel] = Field(description="Arguments for the tool call provided by the model.") def __init__(self, tool, params: Optional[Union[BaseModel, Dict[str, Any]]], tool_call_id: Optional[_lstr_generic]=None): if (not isinstance(params, BaseModel)) and isinstance(tool, FunctionType) and hasattr(tool, '__ell_params_model__'): @@ -53,18 +58,25 @@ def __init__(self, tool, params: Optional[Union[BaseModel, Dict[str, Any]]], too super().__init__(tool=tool, tool_call_id=tool_call_id, params=params) - # TODO. This should reference a tool fqn + version if possible - # ell should have an InvocableTool with __ properties that have this info at serialization time @field_serializer('tool') - def serialize_tool(self, tool: InvocableTool, _info): - return tool.__name__ if hasattr(tool, '__name__') else str(tool) + def serialize_tool(self, tool: Union[InvocableTool, ToolReference], _info): + if isinstance(tool, ToolReference): + return tool + return ToolReference( + # todo(alex). add the value of fqn we want to standardize on to all lmps so we don't keep using qualname + fqn=tool.__qualname__, + hash=getattr(tool, '__ell_hash__', 'unknown') + ) @field_serializer('params') - def serialize_params(self, params: Union[Dict[str,Any],BaseModel], _info): + def _serialize_params(self, params: Union[Dict[str, Any], BaseModel]) -> Dict[str, Any]: if isinstance(params, dict): return params return params.model_dump(exclude_none=True, exclude_unset=True) + def serialize_params(self) -> Dict[str, Any]: + return self._serialize_params(self.params) + @field_serializer('tool_call_id') def serialize_tool_call_id(self, tool_call_id: _lstr_generic): if tool_call_id is None: @@ -76,19 +88,19 @@ def serialize_tool_call_id(self, tool_call_id: _lstr_generic): def __call__(self, **kwargs): assert not kwargs, "Unexpected arguments provided. Calling a tool uses the params provided in the ToolCall." - assert not isinstance(self.tool, str), "ToolCall.tool is a string. Tools are not invocable once serialized." + assert not isinstance(self.tool, ToolReference), f"Tools are not invocable once serialized. ToolCall.tool is a ToolReference: {self.tool}" # XXX: TODO: MOVE TRACKING CODE TO _TRACK AND OUT OF HERE AND API. - return self.tool(**self.params.model_dump()) + return self.tool(**self.serialize_params()) # XXX: Deprecate in 0.1.0 def call_and_collect_as_message_block(self): raise DeprecationWarning("call_and_collect_as_message_block is deprecated. Use collect_as_content_block instead.") def call_and_collect_as_content_block(self): - if isinstance(self.tool, str): - raise ValueError("Cannot call a tool that is a string reference.") - res = self.tool(**(self.params.model_dump() if isinstance(self.params, BaseModel) else self.params), + if isinstance(self.tool, ToolReference): + raise ValueError(f"Cannot call a tool that is a ToolReference: {self.tool}") + res = self.tool(**self.serialize_params(), _tool_call_id=self.tool_call_id) return ContentBlock(tool_result=res) @@ -203,7 +215,7 @@ def type(self): @property def content(self): - return getattr(self, self.type) + return getattr(self, self.type) # type: ignore @classmethod def coerce(cls, content: AnyContent) -> "ContentBlock": diff --git a/tests/api/test_api.py b/tests/api/test_api.py index 421d6aeb..d478c176 100644 --- a/tests/api/test_api.py +++ b/tests/api/test_api.py @@ -218,6 +218,8 @@ def my_sample_tool(args: MySampleToolInput = Field( def test_invocation_json_round_trip(): + # pretend it's being tracked + my_sample_tool.__ell_hash__ = "lmp-123" invocation_id = "invocation-" + uuid4().hex tool_call = ToolCall( tool=my_sample_tool, @@ -281,6 +283,8 @@ def test_write_invocation_tool_call(async_sqlite_serializer: AsyncSQLiteSerializ print(response.json()) raise e + # pretend it's being tracked + my_sample_tool.__ell_hash__ = "lmp-123" invocation_id = "invocation-" + uuid4().hex tool_call = ToolCall( tool=my_sample_tool,