Skip to content

Commit

Permalink
Merge pull request #7 from ColoredCow/feat/train-llama
Browse files Browse the repository at this point in the history
Feat/train llama
  • Loading branch information
pankaj-ag authored Nov 12, 2024
2 parents 9c7078a + 1e6b9c4 commit 7eb37d9
Show file tree
Hide file tree
Showing 10 changed files with 318 additions and 108 deletions.
156 changes: 57 additions & 99 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,14 @@
from flask import Flask, jsonify, render_template, request, url_for
# import sounddevice as sd
# import scipy.io.wavfile
# import whisper
import whisper
import torch
from transformers import pipeline
from gtts import gTTS
import os
from datetime import datetime
import librosa

import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from dotenv import load_dotenv # type: ignore
from dotenv import load_dotenv
import os
from huggingface_hub import login
from mlx_lm import load, generate # type: ignore
from transcription import load_asr_model, transcribe_audio, translate_with_whisper, transcribe_with_whisper, translate_audio
# from mlx_lm import load, generate

# Load the environment variables from the .env file
load_dotenv()
Expand All @@ -22,24 +17,35 @@
huggingface_token = os.getenv('HUGGINGFACE_TOKEN')
login(token=huggingface_token)

def get_current_time():
"""Get the current time"""
current_time = datetime.now()
formatted_time = current_time.strftime("%H:%M:%S")
return formatted_time

# Initialize Flask app
app = Flask(__name__)

MODEL_NAME = 'pankaj-ag/whisper-small-mr-en-translation'

# Load Whisper model
whisper_model = whisper.load_model("base")

processor, asr_model = load_asr_model(MODEL_NAME)
model_id = "coloredcow/paani-1b-instruct-marathi"

model, tokenizer = load("mlx-community/Llama-3.2-3B-Instruct-4bit")
pipe = pipeline(
"text-generation",
model=model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
)

selected_language = 'mr'

selected_language = 'en'
language_configs = {
"en": {
"chatbot_instruction": "Please answer the following question in English:\n",
},
"hi": {
"chatbot_instruction": "कृपया निम्नलिखित प्रश्न का उत्तर हिंदी में दें और हाइलाइट्स के लिए विशेष कीवर्ड जैसे * बोल्ड आदि से बचें:\n",
"chatbot_instruction": "कृपया निम्नलिखित प्रश्न का उत्तर हिंदी में दें:\n",
},
"mr": {
"chatbot_instruction": "कृपया पुढील प्रश्नाचे उत्तर मराठीत द्या:\n",
Expand All @@ -49,83 +55,29 @@
},
}

OUTPUT_SAVE_PATH = "static/outputs"


def get_current_time():
"""Get the current time"""
current_time = datetime.now()
formatted_time = current_time.strftime("%H:%M:%S")
return formatted_time

def timestamped_print(*args, **kwargs):
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
print("\n\n")
print( f"[{timestamp}]", *args, **kwargs)
print("\n\n")


def save_audio(files):
if 'audio_data' not in files:
return jsonify({"error": "No audio file uploaded"}), 400

audio_data = request.files['audio_data']
audio_bytes = audio_data.read()

RECORDING_SAVE_PATH = "static/recordings"
os.makedirs(RECORDING_SAVE_PATH, exist_ok=True)

audio_filename = os.path.join(RECORDING_SAVE_PATH, "recording.wav")
with open(audio_filename, "wb") as f:
f.write(audio_bytes)

file_name = f"{RECORDING_SAVE_PATH}/recording.wav"

return file_name

# def transcribe_audio(file_path, language):
# # Load the audio file using librosa
# audio_array, sampling_rate = librosa.load(file_path, sr=16000)

# # Preprocess audio with WhisperProcessor
# inputs = processor(audio_array, sampling_rate=sampling_rate, return_tensors="pt")
# input_features = inputs.input_features

# # Generate transcription using the fine-tuned model
# with torch.no_grad():
# # generated_tokens = whisper_model.generate(input_features)
# generated_tokens = whisper_model.generate(input_features, forced_decoder_ids=processor.get_decoder_prompt_ids(language="Marathi", task="translate"))
# transcription = processor.decode(generated_tokens[0], skip_special_tokens=True)

# return transcription

def get_chatbot_response(input_text, language):
instruction = language_configs[language]['chatbot_instruction']
prompt = instruction + input_text

if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template is not None:
if model_id == 'meta-llama/Llama-3.2-1B-Instruct' or model_id == 'coloredcow/paani-1b-instruct-marathi' or model_id == 'pankaj-ag/fine_tuned_model':
print('inside model id check')
messages = [
# {"role": "system", "content": "You are a chatbot designed to help Indian farmers on any agriculture related questions they have. Be a helpful guide and friend to empower them take best decisions for their crops and growth. Keep your responses brief and short until asked for details."},
{"role": "user", "content": prompt},
]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
outputs = pipe(
messages,
max_new_tokens=256,
do_sample=False,
)

response = generate(model, tokenizer, prompt=prompt, verbose=True)

print("response generated", response)
return {"content": response}

response = outputs[0]["generated_text"][-1]
print("response from model......", response)
return response['content']

return None


def convert_to_audio(text, language):
os.makedirs(OUTPUT_SAVE_PATH, exist_ok=True)
tts = gTTS(text=text, lang=language, tld='co.in')
tts.save(f"{OUTPUT_SAVE_PATH}/final-output.mp3")
audio_file_path = url_for('static', filename='outputs/final-output.mp3')
return audio_file_path

# Route to render the HTML page with the recording UI
@app.route('/')
def index():
Expand All @@ -137,39 +89,45 @@ def record_audio_endpoint():
# Print current time
print(f"Query start time: {get_current_time()}")

file_name = save_audio(request.files)
timestamped_print("Audio file saved")

transcription = translate_with_whisper(file_name, asr_model, processor, selected_language)
timestamped_print("Audio translate_with_whisper", transcription)
if 'audio_data' not in request.files:
return jsonify({"error": "No audio file uploaded"}), 400

transcription = translate_audio(file_name, asr_model, processor, selected_language)
timestamped_print("Audio translate_audio", transcription)
audio_data = request.files['audio_data']
audio_bytes = audio_data.read()

# transcription = transcribe_with_whisper(file_name, asr_model, processor, selected_language)
# timestamped_print("Audio transcribe_with_whisper", transcription)
RECORDING_SAVE_PATH = "static/recordings"
os.makedirs(RECORDING_SAVE_PATH, exist_ok=True)

# transcription = transcribe_audio(file_name, asr_model, processor, selected_language)
# timestamped_print("Audio transcribe", transcription)
audio_filename = os.path.join(RECORDING_SAVE_PATH, "recording.wav")
with open(audio_filename, "wb") as f:
f.write(audio_bytes)

file_name = f"{RECORDING_SAVE_PATH}/recording.wav"

response = get_chatbot_response(transcription, selected_language)
response_text = response['content']
timestamped_print("Answer generated", response_text)
# Transcribe using Whisper
result = whisper_model.transcribe(file_name, task="translate")
transcription = result['text']
print(f"Transcription: {transcription}")

user_input = transcription
response_text = get_chatbot_response(user_input, selected_language)

audio_file_path = convert_to_audio(response_text, selected_language)
timestamped_print("converted in audio", audio_file_path)
OUTPUT_SAVE_PATH = "static/outputs"
os.makedirs(OUTPUT_SAVE_PATH, exist_ok=True)
tts = gTTS(text=response_text, lang=selected_language, tld='co.in')
tts.save(f"{OUTPUT_SAVE_PATH}/final-output.mp3")

audio_file_path = url_for('static', filename='outputs/final-output.mp3')

# Print current time
print(f"Query end time: {get_current_time()}")

return jsonify({
"user_input": transcription,
"user_input": user_input,
"response_text": response_text,
"model_id": model_id,
"audio_file_path": audio_file_path,
})

if __name__ == "__main__":
app.run(debug=True)
app.run(debug=True)
140 changes: 140 additions & 0 deletions fine-tuning/FineTuneLlama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import json
import json
import os
from pathlib import Path
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
from datasets import Dataset
from typing import List, Optional

class FineTuneLlama:
def __init__(self, model_name: str, file_path: str = "data/training_data.json", output_dir: str = "./results", num_epochs: int = 3):
"""
Initializes the FineTuneLlama class.
Args:
model_name (str): The name of the pre-trained model to fine-tune.
file_path (str): Path to the training data file (relative to project).
output_dir (str): Directory to save the fine-tuned model and logs.
num_epochs (int): Number of training epochs.
"""
# Set up paths
base_dir = Path(__file__).resolve().parent
self.file_path = base_dir / file_path
self.output_dir = base_dir / output_dir

# Initialize attributes
self.model_name = model_name
self.num_epochs = num_epochs
print('all attributes initialized....')

# Prepare model, tokenizer, dataset, and training arguments
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
print('tokenizer initialized....')

# # Load the training data
# file_path = self.file_path
# with open(file_path, 'r', encoding='utf-8') as f:
# training_data = json.load(f)


# # Calculate token lengths for inputs and outputs
# input_lengths = [len(self.tokenizer.encode(entry["input"], truncation=False)) for entry in training_data]
# output_lengths = [len(self.tokenizer.encode(entry["output"], truncation=False)) for entry in training_data]

# # Summary statistics
# input_length_summary = {
# "max_input_length": max(input_lengths),
# "average_input_length": sum(input_lengths) / len(input_lengths),
# "min_input_length": min(input_lengths)
# }

# output_length_summary = {
# "max_output_length": max(output_lengths),
# "average_output_length": sum(output_lengths) / len(output_lengths),
# "min_output_length": min(output_lengths)
# }

# print("Input length summary:", input_length_summary)
# print("Output length summary:", output_length_summary)

self.tokenizer.pad_token = self.tokenizer.eos_token
self.model = AutoModelForCausalLM.from_pretrained(self.model_name)
print('model initialized....')
self.train_dataset = self.prepare_dataset(self.file_path)
print('dataset loaded....')
print(self.train_dataset)
self.training_args = TrainingArguments(
output_dir=str(self.output_dir), # Convert Path to string
eval_strategy="no",
learning_rate=1e-5,
per_device_train_batch_size=1,
per_device_eval_batch_size=1,
gradient_accumulation_steps=8,
gradient_checkpointing=True,
num_train_epochs=self.num_epochs,
weight_decay=0.01,
logging_dir=str(base_dir / 'logs'),
logging_steps=10,
fp16=True,
)
print('training args configured....')
self.trainer = Trainer(
model=self.model,
args=self.training_args,
train_dataset=self.train_dataset,
tokenizer=self.tokenizer
)
print('trainer initialized....')

def load_data(self, file_path: Path) -> List[str]:
with open(file_path, 'r', encoding='utf-8') as file:
return json.load(file)

def prepare_dataset(self, file_path: Path) -> Dataset:
print('before train data.....')
train_data = self.load_data(file_path)
print('after train data.....')
train_dataset = Dataset.from_dict(
{"input": [entry["input"] for entry in train_data],
"output": [entry["output"] for entry in train_data]}
)
print('after train dataset.....')

def tokenize_function(examples):
print('inside tokenize_function.....')
max_length=350
input_encodings = self.tokenizer(examples["input"], padding="max_length", truncation=True, max_length=max_length, return_tensors="pt")
output_encodings = self.tokenizer(examples["output"], padding="max_length", truncation=True, max_length=max_length, return_tensors="pt")
# Print shapes for debugging
print("input_ids shape:", input_encodings['input_ids'].shape)
print("labels shape:", output_encodings['input_ids'].shape)
return {'input_ids': input_encodings['input_ids'], 'labels': output_encodings['input_ids']}

print('before train dataset return.....')
try:
return train_dataset.map(tokenize_function, batched=True)
except Exception as e:
print(f"An error occurred: {e}")

def start_training(self):
print("Starting training...")
self.trainer.train()
print("Training complete!")
self.model.save_pretrained(self.output_dir)
self.tokenizer.save_pretrained(self.output_dir)
print(f"Model and tokenizer saved to {self.output_dir}")


if __name__ == "__main__":
# Free up unused GPU memory
torch.cuda.empty_cache()
# model_name = "meta-llama/Llama-3.2-1B-Instruct"
model_name = "./fine_tuned_model"
file_path = "data/training_data_1.json"
output_dir = "./fine_tuned_model"

os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0"

fine_tune_model = FineTuneLlama(model_name, file_path, output_dir)
fine_tune_model.start_training()
Loading

0 comments on commit 7eb37d9

Please sign in to comment.