Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: speech to text #26

Merged
merged 5 commits into from
Feb 22, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
refactor: better TTS struct design
cs50victor committed Jan 15, 2024
commit 9ca980e1cdbfb21c01597fb78e1a565b8d8c7c9a
21 changes: 11 additions & 10 deletions lkgpt/src/main.rs
Original file line number Diff line number Diff line change
@@ -51,7 +51,7 @@ use serde::{Deserialize, Serialize};

use crate::{
controls::WorldControlChannel, llm::LLMChannel, server::RoomData, stt::AudioInputChannel,
tts::create_tts, video::VideoChannel,
tts::TTS, video::VideoChannel,
};

pub const LIVEKIT_API_SECRET: &str = "LIVEKIT_API_SECRET";
@@ -275,8 +275,8 @@ pub async fn publish_tracks(
) -> Result<TracksPublicationData, RoomError> {
let audio_src = NativeAudioSource::new(
AudioSourceOptions::default(),
STT::SAMPLE_RATE,
STT::NUM_OF_CHANNELS,
TTS::SAMPLE_RATE,
TTS::NUM_OF_CHANNELS,
);

let audio_track =
@@ -385,7 +385,8 @@ pub fn sync_bevy_and_server_resources(
} = room_data;

info!("initializing required bevy resources");
let tts = async_runtime.rt.block_on(create_tts(audio_src)).unwrap();

let tts = async_runtime.rt.block_on(TTS::new(audio_src)).unwrap();

commands.init_resource::<LLMChannel>();
commands.init_resource::<WorldControlChannel>();
@@ -484,12 +485,12 @@ fn main() {
.run_if(resource_exists::<LivekitRoom>()),
);

app.add_systems(
Update,
receive_audio_input
.run_if(resource_exists::<stt::AudioInputChannel>())
.run_if(resource_exists::<STT>()),
);
// app.add_systems(
// Update,
// receive_audio_input
// .run_if(resource_exists::<stt::AudioInputChannel>())
// .run_if(resource_exists::<STT>()),
// );

app.add_systems(
Update,
5 changes: 0 additions & 5 deletions lkgpt/src/stt.rs
Original file line number Diff line number Diff line change
@@ -39,11 +39,6 @@ pub struct STT {
rx: crossbeam_channel::Receiver<String>,
}

impl STT {
pub const NUM_OF_CHANNELS: u32 = 1;
pub const SAMPLE_RATE: u32 = 44100;
}

impl FromWorld for STT {
fn from_world(world: &mut World) -> Self {
let async_rt = world.get_resource::<AsyncRuntime>().unwrap();
75 changes: 31 additions & 44 deletions lkgpt/src/tts.rs
Original file line number Diff line number Diff line change
@@ -16,7 +16,13 @@ use log::{error, info};
use parking_lot::Mutex;
use serde::Serialize;
use serde_json::Value;
use std::{sync::Arc, time::Duration};
use std::{
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
time::Duration,
};

use std::io::Cursor;

@@ -55,14 +61,19 @@ struct RegularMessage {
#[allow(clippy::upper_case_acronyms)]
#[derive(Clone, Resource)]
pub struct TTS {
ws_client: Option<Client<WSClient>>,
pub started: bool,
ws_client: Client<WSClient>,
pub started: Arc<AtomicBool>,
eleven_labs_api_key: String,
}

impl TTS {
pub const NUM_OF_CHANNELS: u32 = 1;
pub const SAMPLE_RATE: u32 = 44100;
}

struct WSClient {
audio_src: NativeAudioSource,
tts_client_ref: Arc<Mutex<TTS>>,
tts_ws_started: Arc<AtomicBool>,
}

fn decode_base64_audio(base64_audio: &str) -> anyhow::Result<Vec<i16>> {
@@ -83,22 +94,22 @@ impl ezsockets::ClientExt for WSClient {

info!("incoming speech from eleven labs");
if base64_audio != Value::Null {
let data = decode_base64_audio(base64_audio.as_str().unwrap())?;
let data = std::borrow::Cow::from(decode_base64_audio(base64_audio.as_str().unwrap())?);

const FRAME_DURATION: Duration = Duration::from_millis(500); // Write 0.5s of audio at a time
let ms = FRAME_DURATION.as_millis() as u32;

let num_channels = self.audio_src.num_channels();
let sample_rate = self.audio_src.sample_rate();
let samples_per_channel = 1_u32;

let num_samples = (sample_rate / 1000 * ms) as usize;
let samples_per_channel = num_samples as u32;

let audio_frame =
AudioFrame { data: data.into(), num_channels, sample_rate, samples_per_channel };
let audio_frame = AudioFrame { data, num_channels, sample_rate, samples_per_channel };

self.audio_src.capture_frame(&audio_frame).await?;
} else {
error!("received null message from eleven labs: {text:?}");
error!("received null audio from eleven labs: {text:?}");
}

Ok(())
@@ -116,8 +127,7 @@ impl ezsockets::ClientExt for WSClient {
}

async fn on_connect(&mut self) -> Result<(), ezsockets::Error> {
let mut tts = self.tts_client_ref.lock();
tts.started = true;
self.tts_ws_started.store(true, Ordering::Relaxed);
info!("ELEVEN LABS CONNECTED 🎉");
Ok(())
}
@@ -135,8 +145,7 @@ impl ezsockets::ClientExt for WSClient {
_frame: Option<CloseFrame>,
) -> Result<ClientCloseMode, ezsockets::Error> {
info!("ELEVEN LABS connection CLOSE");
let mut tts = self.tts_client_ref.lock();
tts.started = false;
self.tts_ws_started.store(false, Ordering::Relaxed);
Ok(ClientCloseMode::Reconnect)
}

@@ -147,23 +156,10 @@ impl ezsockets::ClientExt for WSClient {
}

impl TTS {
pub fn new() -> anyhow::Result<Self> {
pub async fn new(audio_src: NativeAudioSource) -> anyhow::Result<Self> {
let eleven_labs_api_key = std::env::var(ELEVENLABS_API_KEY).unwrap();
let started = Arc::new(AtomicBool::new(true));

Ok(Self { ws_client: None, started: false, eleven_labs_api_key })
}

pub async fn setup_ws_client(&mut self, audio_src: NativeAudioSource) -> anyhow::Result<()> {
let ws_client = self.connect_ws_client(audio_src).await?;
self.started = true;
self.ws_client = Some(ws_client);
Ok(())
}

async fn connect_ws_client(
&self,
audio_src: NativeAudioSource,
) -> anyhow::Result<Client<WSClient>> {
let voice_id = "21m00Tcm4TlvDq8ikWAM";
let model = "eleven_turbo_v2";

@@ -189,10 +185,10 @@ impl TTS {
)
}),
})
.header("xi-api-key", &self.eleven_labs_api_key);
.header("xi-api-key", eleven_labs_api_key.clone());

let (ws_client, _) = ezsockets::connect(
|_client| WSClient { audio_src, tts_client_ref: Arc::new(Mutex::new(self.clone())) },
|_client| WSClient { audio_src, tts_ws_started: started.clone() },
config,
)
.await;
@@ -203,11 +199,12 @@ impl TTS {
voice_settings: VoiceSettings { stability: 0.8, similarity_boost: true },
generation_config: GenerationConfig { chunk_length_schedule: [50] },
})?)?;
Ok(ws_client)

Ok(Self { ws_client, started, eleven_labs_api_key })
}

pub fn start(&mut self) -> anyhow::Result<()> {
self.started = true;
self.started.store(true, Ordering::Relaxed);
self.send(" ".to_string())?;
Ok(())
}
@@ -228,22 +225,12 @@ impl TTS {
};
let msg = msg?;

if !self.started {
if !self.started.load(Ordering::Relaxed) {
self.start()?;
}

if self.ws_client.as_ref().is_none() {
bail!("ws_client is none");
}

info!("sending to eleven labs {msg}");

Ok(self.ws_client.as_ref().unwrap().text(msg)?.status())
Ok(self.ws_client.text(msg)?.status())
}
}

pub async fn create_tts(audio_src: NativeAudioSource) -> anyhow::Result<TTS> {
let mut tts = TTS::new()?;
tts.setup_ws_client(audio_src).await?;
Ok(tts)
}