diff --git a/src/ell/types/message.py b/src/ell/types/message.py index 7bc61e6a..93b8fc5c 100644 --- a/src/ell/types/message.py +++ b/src/ell/types/message.py @@ -423,6 +423,46 @@ def call_tools_and_collect_as_message(self, parallel=False, max_workers=None): content = [c.tool_call.call_and_collect_as_content_block() for c in self.content if c.tool_call] return Message(role="user", content=content) + @field_serializer('content') + def serialize_content(self, content: List[ContentBlock]): + """Serialize content blocks to a format suitable for JSON""" + return [ + {k: v for k, v in { + 'text': str(block.text) if block.text is not None else None, + 'image': block.image.model_dump() if block.image is not None else None, + 'audio': block.audio.tolist() if isinstance(block.audio, np.ndarray) else block.audio, + 'tool_call': block.tool_call.model_dump() if block.tool_call is not None else None, + 'parsed': block.parsed.model_dump() if block.parsed is not None else None, + 'tool_result': block.tool_result.model_dump() if block.tool_result is not None else None, + }.items() if v is not None} + for block in content + ] + + @classmethod + def model_validate(cls, obj: Any) -> 'Message': + """Custom validation to handle deserialization""" + if isinstance(obj, dict): + if 'content' in obj and isinstance(obj['content'], list): + content_blocks = [] + for block in obj['content']: + if isinstance(block, dict): + if 'text' in block: + block['text'] = str(block['text']) if block['text'] is not None else None + content_blocks.append(ContentBlock.model_validate(block)) + else: + content_blocks.append(ContentBlock.coerce(block)) + obj['content'] = content_blocks + return super().model_validate(obj) + + @classmethod + def model_validate_json(cls, json_str: str) -> 'Message': + """Custom validation to handle deserialization from JSON string""" + try: + data = json.loads(json_str) + return cls.model_validate(data) + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON: {str(e)}") + # HELPERS def system(content: Union[AnyContent, List[AnyContent]]) -> Message: """ diff --git a/tests/test_message_type.py b/tests/test_message_type.py index 6b784b6d..8fd476ad 100644 --- a/tests/test_message_type.py +++ b/tests/test_message_type.py @@ -133,4 +133,14 @@ def test_content_block_audio_validation(): ContentBlock.model_validate(ContentBlock(audio=valid_audio)) with pytest.raises(ValueError): - ContentBlock.model_validate(ContentBlock(audio=invalid_audio)) \ No newline at end of file + ContentBlock.model_validate(ContentBlock(audio=invalid_audio)) + +def test_message_json_serialization(): + original_message = Message(role='assistant', content='Hello, this is a test message.') + + message_json = original_message.model_dump_json() + loaded_message = Message.model_validate_json(message_json) + + assert loaded_message.role == original_message.role + assert len(loaded_message.content) == len(original_message.content) + assert str(loaded_message.content[0].text) == str(original_message.content[0].text) \ No newline at end of file