Skip to content

Commit

Permalink
Merge pull request #11 from ColoredCow/feat/training-refactoring
Browse files Browse the repository at this point in the history
fix(setup bug fix)
  • Loading branch information
pankaj-ag authored Nov 12, 2024
2 parents 65f4bc3 + de45884 commit 0b09a53
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 30 deletions.
69 changes: 42 additions & 27 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from flask import Flask, jsonify, render_template, request, url_for
import whisper
import torch
from transformers import pipeline
from gtts import gTTS
Expand All @@ -8,7 +7,7 @@
from dotenv import load_dotenv
import os
from huggingface_hub import login
# from mlx_lm import load, generate
from transcription import load_asr_model, translate_with_base_whisper, translate_audio

# Load the environment variables from the .env file
load_dotenv()
Expand All @@ -17,6 +16,12 @@
huggingface_token = os.getenv('HUGGINGFACE_TOKEN')
login(token=huggingface_token)

RECORDING_SAVE_PATH = "static/recordings"
OUTPUT_SAVE_PATH = "static/outputs"
ASR_MODEL_NAME = 'pankaj-ag/whisper-small-mr-en-translation'

processor, asr_model = load_asr_model(ASR_MODEL_NAME)

def get_current_time():
"""Get the current time"""
current_time = datetime.now()
Expand All @@ -27,7 +32,6 @@ def get_current_time():
app = Flask(__name__)

# Load Whisper model
whisper_model = whisper.load_model("base")

model_id = "coloredcow/paani-1b-instruct-marathi"

Expand Down Expand Up @@ -77,25 +81,18 @@ def get_chatbot_response(input_text, language):

return None

def timestamped_print(*args, **kwargs):
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
print("\n\n")
print( f"[{timestamp}]", *args, **kwargs)
print("\n\n")

# Route to render the HTML page with the recording UI
@app.route('/')
def index():
return render_template('record.html')

# Flask route to record and transcribe audio
@app.route('/process-audio', methods=['POST'])
def record_audio_endpoint():
# Print current time
print(f"Query start time: {get_current_time()}")

if 'audio_data' not in request.files:
def save_audio(files):
if 'audio_data' not in files:
return jsonify({"error": "No audio file uploaded"}), 400

audio_data = request.files['audio_data']
audio_bytes = audio_data.read()

RECORDING_SAVE_PATH = "static/recordings"
os.makedirs(RECORDING_SAVE_PATH, exist_ok=True)

audio_filename = os.path.join(RECORDING_SAVE_PATH, "recording.wav")
Expand All @@ -104,26 +101,44 @@ def record_audio_endpoint():

file_name = f"{RECORDING_SAVE_PATH}/recording.wav"

# Transcribe using Whisper
result = whisper_model.transcribe(file_name, task="translate")
transcription = result['text']
print(f"Transcription: {transcription}")
return file_name

user_input = transcription
response_text = get_chatbot_response(user_input, selected_language)

OUTPUT_SAVE_PATH = "static/outputs"
def convert_to_audio(text, language):
os.makedirs(OUTPUT_SAVE_PATH, exist_ok=True)
tts = gTTS(text=response_text, lang=selected_language, tld='co.in')
tts = gTTS(text=text, lang=language, tld='co.in')
tts.save(f"{OUTPUT_SAVE_PATH}/final-output.mp3")

audio_file_path = url_for('static', filename='outputs/final-output.mp3')
return audio_file_path

# Route to render the HTML page with the recording UI
@app.route('/')
def index():
return render_template('record.html')

# Flask route to record and transcribe audio
@app.route('/process-audio', methods=['POST'])
def record_audio_endpoint():
# Print current time
print(f"Query start time: {get_current_time()}")

file_name = save_audio(request.files)
timestamped_print("Audio file saved")

transcription = translate_audio(file_name, asr_model, processor, selected_language)
timestamped_print("Audio translate_audio", transcription)

user_input = transcription
response_text = get_chatbot_response(user_input, selected_language)

audio_file_path = convert_to_audio(response_text, selected_language)
timestamped_print("converted in audio", audio_file_path)

# Print current time
print(f"Query end time: {get_current_time()}")

return jsonify({
"user_input": user_input,
"recorded_audio_path": file_name,
"response_text": response_text,
"model_id": model_id,
"audio_file_path": audio_file_path,
Expand Down
4 changes: 4 additions & 0 deletions static/js/recorder.js
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ const modelId = document.getElementById("modelId");
const userInputText = document.getElementById("userInputText");
const modelResponseText = document.getElementById("modelResponseText");
const modelResponsePlayer = document.getElementById("modelResponsePlayer");
const modelRequestPlayer = document.getElementById("modelRequestPlayer");

recordButton.addEventListener("click", async () => {
assistanceResponse.style.display = "none";
Expand Down Expand Up @@ -75,6 +76,9 @@ async function sendAudio(audioBlobOrFile) {
);
modelResponseText.innerHTML = marked.parse(jsonResponse.response_text);

modelRequestPlayer.src = jsonResponse.recorded_audio_path;
modelRequestPlayer.load();

modelResponsePlayer.src = jsonResponse.audio_file_path;
modelResponsePlayer.load();

Expand Down
6 changes: 6 additions & 0 deletions templates/record.html
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,18 @@ <h3 style="margin-bottom: 5px">Model used:</h3>
<div>
<h3 style="margin-bottom: 5px">User Input:</h3>
<div id="userInputText"></div>
<audio controls id="modelRequestPlayer">
Your browser does not support the audio element.
</audio>
</div>
<div>
<h3 style="margin-bottom: 5px">Model Response:</h3>
<div id="modelResponseText"></div>
</div>
<!-- Audio play button -->


<h3> Chatbot response </h3>
<audio controls id="modelResponsePlayer">
Your browser does not support the audio element.
</audio>
Expand Down
5 changes: 2 additions & 3 deletions transcription.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ def load_asr_model(modelName):
model = WhisperForConditionalGeneration.from_pretrained(modelName)
return processor, model


def transcribe_audio(file_path, model, processor, language):
# Load the audio file using librosa
audio_array, sampling_rate = librosa.load(file_path, sr=16000)
Expand Down Expand Up @@ -43,14 +42,14 @@ def translate_audio(file_path, model, processor, language):

return transcription

def translate_with_whisper(file_path, model, processor, language):
def translate_with_base_whisper(file_path, model, processor, language):
# Transcribe using Whisper
result = whisper_model.transcribe(file_path, task="translate", language = language)
transcription = result['text']
print(f"Transcription: {transcription}")
return transcription

def transcribe_with_whisper(file_path, model, processor, language):
def transcribe_with_base_whisper(file_path, model, processor, language):
# Transcribe using Whisper
result = whisper_model.transcribe(file_path, task="transcribe", language = language)
transcription = result['text']
Expand Down

0 comments on commit 0b09a53

Please sign in to comment.