Skip to content

Commit

Permalink
Merge pull request #337 from Merlinvt/main
Browse files Browse the repository at this point in the history
Fix Round trip Message serialization/deserialization
  • Loading branch information
MadcowD authored Oct 29, 2024
2 parents 01c5611 + 05fd3e8 commit 37fb5d6
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 1 deletion.
40 changes: 40 additions & 0 deletions src/ell/types/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,46 @@ def call_tools_and_collect_as_message(self, parallel=False, max_workers=None):
content = [c.tool_call.call_and_collect_as_content_block() for c in self.content if c.tool_call]
return Message(role="user", content=content)

@field_serializer('content')
def serialize_content(self, content: List[ContentBlock]):
"""Serialize content blocks to a format suitable for JSON"""
return [
{k: v for k, v in {
'text': str(block.text) if block.text is not None else None,
'image': block.image.model_dump() if block.image is not None else None,
'audio': block.audio.tolist() if isinstance(block.audio, np.ndarray) else block.audio,
'tool_call': block.tool_call.model_dump() if block.tool_call is not None else None,
'parsed': block.parsed.model_dump() if block.parsed is not None else None,
'tool_result': block.tool_result.model_dump() if block.tool_result is not None else None,
}.items() if v is not None}
for block in content
]

@classmethod
def model_validate(cls, obj: Any) -> 'Message':
"""Custom validation to handle deserialization"""
if isinstance(obj, dict):
if 'content' in obj and isinstance(obj['content'], list):
content_blocks = []
for block in obj['content']:
if isinstance(block, dict):
if 'text' in block:
block['text'] = str(block['text']) if block['text'] is not None else None
content_blocks.append(ContentBlock.model_validate(block))
else:
content_blocks.append(ContentBlock.coerce(block))
obj['content'] = content_blocks
return super().model_validate(obj)

@classmethod
def model_validate_json(cls, json_str: str) -> 'Message':
"""Custom validation to handle deserialization from JSON string"""
try:
data = json.loads(json_str)
return cls.model_validate(data)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON: {str(e)}")

# HELPERS
def system(content: Union[AnyContent, List[AnyContent]]) -> Message:
"""
Expand Down
12 changes: 11 additions & 1 deletion tests/test_message_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,4 +133,14 @@ def test_content_block_audio_validation():
ContentBlock.model_validate(ContentBlock(audio=valid_audio))

with pytest.raises(ValueError):
ContentBlock.model_validate(ContentBlock(audio=invalid_audio))
ContentBlock.model_validate(ContentBlock(audio=invalid_audio))

def test_message_json_serialization():
original_message = Message(role='assistant', content='Hello, this is a test message.')

message_json = original_message.model_dump_json()
loaded_message = Message.model_validate_json(message_json)

assert loaded_message.role == original_message.role
assert len(loaded_message.content) == len(original_message.content)
assert str(loaded_message.content[0].text) == str(original_message.content[0].text)

0 comments on commit 37fb5d6

Please sign in to comment.