Skip to content

Commit

Permalink
fix Ola Audio path for multi GPUs
Browse files Browse the repository at this point in the history
  • Loading branch information
Devininthelab committed Feb 28, 2025
1 parent eba3bcf commit 86120b9
Showing 1 changed file with 27 additions and 27 deletions.
54 changes: 27 additions & 27 deletions lmms_eval/models/ola.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def __init__(
)

self._config = self._model.config
self.model.to("cuda").eval().bfloat16()
self.model.to(self.device).eval().bfloat16()
self.model.tie_weights()
self.truncation = truncation
self.batch_size_per_gpu = int(batch_size)
Expand Down Expand Up @@ -207,10 +207,9 @@ def __init__(
elif accelerator.num_processes == 1 and device_map == "auto":
eval_logger.info(f"Using {accelerator.num_processes} devices with tensor parallelism")
self._rank = 0
self._word_size = 1
self._world_size = 1
else:
eval_logger.info(f"Using single device: {self._device}")
self.model.to(self._device)
self._rank = 0
self._world_size = 1
self.accelerator = accelerator
Expand Down Expand Up @@ -407,6 +406,7 @@ def process_audio(self, audio_array, sampling_rate):
CHUNK_LIM = 480000
import librosa

speech_wav = audio_array
if sampling_rate != target_sr:
speech_wav = librosa.resample(audio_array, orig_sr=sampling_rate, target_sr=target_sr).astype(np.float32)
speechs = []
Expand Down Expand Up @@ -485,13 +485,13 @@ def _collate(x):
eval_logger.info(f"Video {visuals} can not load, check the source")
continue
audio = self.extract_audio(visual)
audio.write_audiofile("./video_audio.wav")
video_audio_path = "./video_audio.wav"
audio.write_audiofile(f"./video_audio_{self.rank}.wav")
video_audio_path = f"./video_audio_{self.rank}.wav"
speech, speech_length, speech_chunk, speech_wav = self.load_audio(video_audio_path)
speechs.append(speech.bfloat16().to("cuda"))
speech_lengths.append(speech_length.to("cuda"))
speech_chunks.append(speech_chunk.to("cuda"))
speech_wavs.append(speech_wav.to("cuda"))
speechs.append(speech.bfloat16().to(self.device))
speech_lengths.append(speech_length.to(self.device))
speech_chunks.append(speech_chunk.to(self.device))
speech_wavs.append(speech_wav.to(self.device))
os.remove(video_audio_path)

# Process images of video
Expand All @@ -508,7 +508,7 @@ def _collate(x):
if frame_idx is None:
frame_idx = np.arange(0, len(video_processed), dtype=int).tolist()

video_processed = torch.cat(video_processed, dim=0).bfloat16().to("cuda")
video_processed = torch.cat(video_processed, dim=0).bfloat16().to(self.device)
video_processed = (video_processed, video_processed)

video_data = (video_processed, (384, 384), "video")
Expand All @@ -527,32 +527,32 @@ def _collate(x):
if all(x.shape == image_highres_tensor[0].shape for x in image_highres_tensor):
image_highres_tensor = torch.stack(image_highres_tensor, dim=0)
if type(image_tensor) is list:
image_tensor = [_image.bfloat16().to("cuda") for _image in image_tensor]
image_tensor = [_image.bfloat16().to(self.device) for _image in image_tensor]
else:
image_tensor = image_tensor.bfloat16().to("cuda")
image_tensor = image_tensor.bfloat16().to(self.device)
if type(image_highres_tensor) is list:
image_highres_tensor = [_image.bfloat16().to("cuda") for _image in image_highres_tensor]
image_highres_tensor = [_image.bfloat16().to(self.device) for _image in image_highres_tensor]
else:
image_highres_tensor = image_highres_tensor.bfloat16().to("cuda")
image_highres_tensor = image_highres_tensor.bfloat16().to(self.device)

# Processing dummy audio, as required by model
speechs.append(torch.zeros(1, 3000, 128).bfloat16().to("cuda"))
speech_lengths.append(torch.LongTensor([3000]).to("cuda"))
speech_wavs.append(torch.zeros([1, 480000]).to("cuda"))
speech_chunks.append(torch.LongTensor([1]).to("cuda"))
speechs.append(torch.zeros(1, 3000, 128).bfloat16().to(self.device))
speech_lengths.append(torch.LongTensor([3000]).to(self.device))
speech_wavs.append(torch.zeros([1, 480000]).to(self.device))
speech_chunks.append(torch.LongTensor([1]).to(self.device))

elif isinstance(visual, dict) and "array" in visual: # For Audio
if MODALITY is None:
MODALITY = "AUDIO"
mels, speech_length, speech_chunk, speech_wav = self.process_audio(visual["array"], visual["sampling_rate"])
speechs.append(mels.bfloat16().to("cuda"))
speech_lengths.append(speech_length.to("cuda"))
speech_chunks.append(speech_chunk.to("cuda"))
speech_wavs.append(speech_wav.to("cuda"))
speechs.append(mels.bfloat16().to(self.device))
speech_lengths.append(speech_length.to(self.device))
speech_chunks.append(speech_chunk.to(self.device))
speech_wavs.append(speech_wav.to(self.device))

# Processing dummy images, as required by model
images.append(torch.zeros(1, 3, 224, 224).to(dtype=torch.bfloat16, device="cuda", non_blocking=True))
images_highres.append(torch.zeros(1, 3, 224, 224).to(dtype=torch.bfloat16, device="cuda", non_blocking=True))
images.append(torch.zeros(1, 3, 224, 224).to(dtype=torch.bfloat16, device=self.device, non_blocking=True))
images_highres.append(torch.zeros(1, 3, 224, 224).to(dtype=torch.bfloat16, device=self.device, non_blocking=True))
image_sizes.append((224, 224))

if not video_processed and MODALITY == "VIDEO":
Expand Down Expand Up @@ -601,11 +601,11 @@ def _collate(x):
eval_logger.debug(f"Prompt for doc ID {doc_id[0]}:\n\n{prompt}\n")

if MODALITY == "AUDIO":
input_ids = tokenizer_speech_token(prompt, self.tokenizer, SPEECH_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(self._device)
input_ids = tokenizer_speech_token(prompt, self.tokenizer, SPEECH_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(self.device)
elif MODALITY == "IMAGE":
input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(self._device)
input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(self.device)
elif MODALITY == "VIDEO":
input_ids = tokenizer_speech_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to("cuda")
input_ids = tokenizer_speech_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(self.device)
pad_token_ids = 151643
attention_masks = input_ids.ne(pad_token_ids).long().to(self.device)

Expand Down

0 comments on commit 86120b9

Please sign in to comment.