Skip to content

Commit

Permalink
example
Browse files Browse the repository at this point in the history
  • Loading branch information
MadcowD committed Oct 1, 2024
1 parent 07de122 commit 8b46cb5
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 8 deletions.
60 changes: 60 additions & 0 deletions x/openai_realtime/examples/audio_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import asyncio
import os
from pydub import AudioSegment
import numpy as np
from openai_realtime import RealtimeClient, RealtimeUtils
import logging

# Configure logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)

# Helper function to load and convert audio files
def load_audio_sample(file_path):
audio = AudioSegment.from_file(file_path)
samples = np.array(audio.get_array_of_samples())
return RealtimeUtils.array_buffer_to_base64(samples)

async def main():
# Initialize the RealtimeClient
client = RealtimeClient(
api_key=os.getenv("OPENAI_API_KEY"),
debug=True
)

# Update session with instructions
client.update_session(
instructions=(
"Please describe the content of the audio you receive.\n"
"Be concise in your responses. Speak quickly and answer shortly."
)
)

# Connect to the RealtimeClient
await client.connect()
print("Connected to RealtimeClient")

# Wait for session creation
await client.wait_for_session_created()
print("Session created")

# Load audio sample
audio_file_path = './tests/samples/toronto.mp3'
audio_sample = load_audio_sample(audio_file_path)

# Send audio content
content = [{'type': 'input_audio', 'audio': audio_sample}]
client.send_user_message_content(content)
print("Audio sent")

# Wait for and print the assistant's response
assistant_item = await client.wait_for_next_completed_item()
print("Assistant's response:")
print(assistant_item['item']['formatted']['transcript'])

# Disconnect from the client
client.disconnect()
print("Disconnected from RealtimeClient")

if __name__ == "__main__":
asyncio.run(main())
5 changes: 4 additions & 1 deletion x/openai_realtime/src/openai_realtime/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,10 @@ def disconnect(self):
self.realtime.disconnect()

def get_turn_detection_type(self):
return self.session_config.get('turn_detection', {}).get('type')
turn_detection = self.session_config.get('turn_detection')
if isinstance(turn_detection, dict):
return turn_detection.get('type')
return None

def add_tool(self, definition, handler):
if not definition.get('name'):
Expand Down
50 changes: 43 additions & 7 deletions x/openai_realtime/tests/test_mock.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# x/openai_realtime/tests/test_mock.py
import pytest
from unittest.mock import Mock, patch, AsyncMock
import numpy as np
Expand Down Expand Up @@ -26,6 +27,9 @@ def client():
# Initialize other necessary attributes
client.input_audio_buffer = np.array([], dtype=np.int16)

# Ensure session_config is properly initialized
client._reset_config()

return client

def test_init(client):
Expand All @@ -47,7 +51,9 @@ def test_reset(client):
async def test_connect(client):
await client.connect()
client.realtime.connect.assert_awaited_once()
client.realtime.send.assert_called_once_with('session.update', {'session': client.session_config})

expected_session = client.session_config.copy()
client.realtime.send.assert_called_once_with('session.update', {'session': expected_session})

def test_add_tool(client):
tool_definition = {'name': 'test_tool', 'description': 'A test tool'}
Expand All @@ -58,13 +64,36 @@ def test_add_tool(client):
assert 'test_tool' in client.tools
assert client.tools['test_tool']['definition'] == tool_definition
assert client.tools['test_tool']['handler'] == tool_handler
client.realtime.send.assert_called_once_with('session.update', {'session': client.session_config})

expected_session = client.session_config.copy()
expected_session['tools'] = [{
'name': 'test_tool',
'description': 'A test tool',
'type': 'function'
}]

client.realtime.send.assert_called_once_with('session.update', {'session': expected_session})

def test_remove_tool(client):
client.tools = {'test_tool': {}}
# Setup: Add a tool first
client.tools = {'test_tool': {'definition': {'name': 'test_tool', 'description': 'A test tool'}}}

# Remove the tool
client.remove_tool('test_tool')

# Assertions
assert 'test_tool' not in client.tools
client.realtime.send.assert_called_once_with('session.update', {'session': client.session_config})

# Ensure 'session.update' was NOT called automatically
client.realtime.send.assert_not_called()

# If session synchronization is needed, it should be done explicitly
# For example:
client.update_session()
expected_session = client.session_config.copy()
expected_session['tools'] = []

client.realtime.send.assert_called_once_with('session.update', {'session': expected_session})

def test_delete_item(client):
client.delete_item('item_id')
Expand All @@ -73,11 +102,15 @@ def test_delete_item(client):
def test_update_session(client):
client.update_session(modalities=['text'])
assert client.session_config['modalities'] == ['text']
client.realtime.send.assert_called_once_with('session.update', {'session': client.session_config})

expected_session = client.session_config.copy()

client.realtime.send.assert_called_once_with('session.update', {'session': expected_session})

def test_send_user_message_content(client):
content = [{'type': 'text', 'text': 'Hello'}]
client.send_user_message_content(content)

expected_calls = [
('conversation.item.create', {
'item': {
Expand All @@ -88,6 +121,7 @@ def test_send_user_message_content(client):
}),
('response.create',)
]

assert client.realtime.send.call_count == 2
client.realtime.send.assert_any_call('conversation.item.create', {
'item': {
Expand Down Expand Up @@ -154,7 +188,8 @@ async def test_call_tool(client):
tool_handler_mock = AsyncMock(return_value=tool_result)
client.tools = {
tool_name: {
'handler': tool_handler_mock
'handler': tool_handler_mock,
'definition': {'name': tool_name, 'description': 'A test tool'}
}
}

Expand All @@ -181,7 +216,8 @@ async def test_call_tool_error(client):
tool_handler_mock = AsyncMock(side_effect=Exception(error_message))
client.tools = {
tool_name: {
'handler': tool_handler_mock
'handler': tool_handler_mock,
'definition': {'name': tool_name, 'description': 'A test tool'}
}
}

Expand Down

0 comments on commit 8b46cb5

Please sign in to comment.