diff --git a/x/openai_realtime/examples/audio_example.py b/x/openai_realtime/examples/audio_example.py new file mode 100644 index 00000000..014220c7 --- /dev/null +++ b/x/openai_realtime/examples/audio_example.py @@ -0,0 +1,60 @@ +import asyncio +import os +from pydub import AudioSegment +import numpy as np +from openai_realtime import RealtimeClient, RealtimeUtils +import logging + +# Configure logging +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + +# Helper function to load and convert audio files +def load_audio_sample(file_path): + audio = AudioSegment.from_file(file_path) + samples = np.array(audio.get_array_of_samples()) + return RealtimeUtils.array_buffer_to_base64(samples) + +async def main(): + # Initialize the RealtimeClient + client = RealtimeClient( + api_key=os.getenv("OPENAI_API_KEY"), + debug=True + ) + + # Update session with instructions + client.update_session( + instructions=( + "Please describe the content of the audio you receive.\n" + "Be concise in your responses. Speak quickly and answer shortly." + ) + ) + + # Connect to the RealtimeClient + await client.connect() + print("Connected to RealtimeClient") + + # Wait for session creation + await client.wait_for_session_created() + print("Session created") + + # Load audio sample + audio_file_path = './tests/samples/toronto.mp3' + audio_sample = load_audio_sample(audio_file_path) + + # Send audio content + content = [{'type': 'input_audio', 'audio': audio_sample}] + client.send_user_message_content(content) + print("Audio sent") + + # Wait for and print the assistant's response + assistant_item = await client.wait_for_next_completed_item() + print("Assistant's response:") + print(assistant_item['item']['formatted']['transcript']) + + # Disconnect from the client + client.disconnect() + print("Disconnected from RealtimeClient") + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/x/openai_realtime/src/openai_realtime/client.py b/x/openai_realtime/src/openai_realtime/client.py index 4cdead3a..a43911c8 100644 --- a/x/openai_realtime/src/openai_realtime/client.py +++ b/x/openai_realtime/src/openai_realtime/client.py @@ -120,7 +120,10 @@ def disconnect(self): self.realtime.disconnect() def get_turn_detection_type(self): - return self.session_config.get('turn_detection', {}).get('type') + turn_detection = self.session_config.get('turn_detection') + if isinstance(turn_detection, dict): + return turn_detection.get('type') + return None def add_tool(self, definition, handler): if not definition.get('name'): diff --git a/x/openai_realtime/tests/test_mock.py b/x/openai_realtime/tests/test_mock.py index d64afed3..3bb9437c 100644 --- a/x/openai_realtime/tests/test_mock.py +++ b/x/openai_realtime/tests/test_mock.py @@ -1,3 +1,4 @@ +# x/openai_realtime/tests/test_mock.py import pytest from unittest.mock import Mock, patch, AsyncMock import numpy as np @@ -26,6 +27,9 @@ def client(): # Initialize other necessary attributes client.input_audio_buffer = np.array([], dtype=np.int16) + # Ensure session_config is properly initialized + client._reset_config() + return client def test_init(client): @@ -47,7 +51,9 @@ def test_reset(client): async def test_connect(client): await client.connect() client.realtime.connect.assert_awaited_once() - client.realtime.send.assert_called_once_with('session.update', {'session': client.session_config}) + + expected_session = client.session_config.copy() + client.realtime.send.assert_called_once_with('session.update', {'session': expected_session}) def test_add_tool(client): tool_definition = {'name': 'test_tool', 'description': 'A test tool'} @@ -58,13 +64,36 @@ def test_add_tool(client): assert 'test_tool' in client.tools assert client.tools['test_tool']['definition'] == tool_definition assert client.tools['test_tool']['handler'] == tool_handler - client.realtime.send.assert_called_once_with('session.update', {'session': client.session_config}) + + expected_session = client.session_config.copy() + expected_session['tools'] = [{ + 'name': 'test_tool', + 'description': 'A test tool', + 'type': 'function' + }] + + client.realtime.send.assert_called_once_with('session.update', {'session': expected_session}) def test_remove_tool(client): - client.tools = {'test_tool': {}} + # Setup: Add a tool first + client.tools = {'test_tool': {'definition': {'name': 'test_tool', 'description': 'A test tool'}}} + + # Remove the tool client.remove_tool('test_tool') + + # Assertions assert 'test_tool' not in client.tools - client.realtime.send.assert_called_once_with('session.update', {'session': client.session_config}) + + # Ensure 'session.update' was NOT called automatically + client.realtime.send.assert_not_called() + + # If session synchronization is needed, it should be done explicitly + # For example: + client.update_session() + expected_session = client.session_config.copy() + expected_session['tools'] = [] + + client.realtime.send.assert_called_once_with('session.update', {'session': expected_session}) def test_delete_item(client): client.delete_item('item_id') @@ -73,11 +102,15 @@ def test_delete_item(client): def test_update_session(client): client.update_session(modalities=['text']) assert client.session_config['modalities'] == ['text'] - client.realtime.send.assert_called_once_with('session.update', {'session': client.session_config}) + + expected_session = client.session_config.copy() + + client.realtime.send.assert_called_once_with('session.update', {'session': expected_session}) def test_send_user_message_content(client): content = [{'type': 'text', 'text': 'Hello'}] client.send_user_message_content(content) + expected_calls = [ ('conversation.item.create', { 'item': { @@ -88,6 +121,7 @@ def test_send_user_message_content(client): }), ('response.create',) ] + assert client.realtime.send.call_count == 2 client.realtime.send.assert_any_call('conversation.item.create', { 'item': { @@ -154,7 +188,8 @@ async def test_call_tool(client): tool_handler_mock = AsyncMock(return_value=tool_result) client.tools = { tool_name: { - 'handler': tool_handler_mock + 'handler': tool_handler_mock, + 'definition': {'name': tool_name, 'description': 'A test tool'} } } @@ -181,7 +216,8 @@ async def test_call_tool_error(client): tool_handler_mock = AsyncMock(side_effect=Exception(error_message)) client.tools = { tool_name: { - 'handler': tool_handler_mock + 'handler': tool_handler_mock, + 'definition': {'name': tool_name, 'description': 'A test tool'} } }