Skip to content

Commit

Permalink
[Reka and Fix] move gpt eval to process_results stage. (EvolvingLMMs-…
Browse files Browse the repository at this point in the history
…Lab#108)

* Refactor activitynetqa_generation.yaml for improved metric aggregation

* Refactor JSON dumping in cli_evaluate_single for non-ASCII characters

* Refactor video_detail_description/utils.py for improved cache directory handling

* Refactor gpt_eval function to handle non-ASCII characters in data_dict

* Refactor pywsd import error message for lmms-eval nextqa module

* Refactor lmms_eval/models/llava_vid.py for consistency in variable naming

* Refactor activitynetqa default template YAML for improved prompt formatting

* Refactor lmms_eval/tasks/activitynetqa/utils.py for consistent capitalization of questions

* Refactor llava_vid.py for improved scaling factor calculation
  • Loading branch information
Luodian authored May 31, 2024
1 parent 243ba5f commit ba3de94
Show file tree
Hide file tree
Showing 10 changed files with 256 additions and 180 deletions.
4 changes: 2 additions & 2 deletions lmms_eval/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,8 +323,8 @@ def cli_evaluate_single(args: Union[argparse.Namespace, None] = None) -> None:
filename = args.output_path.joinpath(f"{task_name}.json")
# Structure the data with 'args' and 'logs' keys
data_to_dump = {"args": vars(args), "model_configs": config, "logs": sorted(samples[task_name], key=lambda x: x["doc_id"])} # Convert Namespace to dict
samples_dumped = json.dumps(data_to_dump, indent=4, default=_handle_non_serializable)
filename.open("w").write(samples_dumped)
samples_dumped = json.dumps(data_to_dump, indent=4, default=_handle_non_serializable, ensure_ascii=False)
filename.open("w", encoding="utf-8").write(samples_dumped)
eval_logger.info(f"Saved samples to {filename}")

return results, samples
Expand Down
32 changes: 13 additions & 19 deletions lmms_eval/models/llava_vid.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from lmms_eval.api.model import lmms
from lmms_eval.api.registry import register_model
from lmms_eval.utils import stop_sequences_criteria
from lmms_eval.models.model_utils.load_video import read_video_pyav

eval_logger = logging.getLogger("lmms-eval")

Expand Down Expand Up @@ -67,7 +66,6 @@ def __init__(
mm_spatial_pool_out_channels: int = 1024,
mm_spatial_pool_mode: str = "average",
overwrite: bool = True,
video_decode_backend: str = "decord",
**kwargs,
) -> None:
super().__init__()
Expand All @@ -87,7 +85,6 @@ def __init__(

self.pretrained = pretrained
self.model_name = get_model_name_from_path(pretrained)
self.video_decode_backend = video_decode_backend
# self._config = AutoConfig.from_pretrained(self.pretrained)
self.overwrite = overwrite
self.mm_resampler_type = mm_resampler_type
Expand All @@ -107,18 +104,18 @@ def __init__(

cfg_pretrained = AutoConfig.from_pretrained(self.pretrained)

if "224" in cfg_pretrained.mm_vision_tower:
# suppose the length of text tokens is around 1000, from bo's report
least_token_number = self.max_frames_num * (16 // self.mm_spatial_pool_stride) ** 2 + 1000
else:
least_token_number = self.max_frames_num * (24 // self.mm_spatial_pool_stride) ** 2 + 1000

scaling_factor = math.ceil(least_token_number / 4096)
if scaling_factor >= 2:
print(float(scaling_factor))
overwrite_config["rope_scaling"] = {"factor": float(scaling_factor), "type": "linear"}
overwrite_config["max_sequence_length"] = 4096 * scaling_factor
overwrite_config["tokenizer_model_max_length"] = 4096 * scaling_factor
if cfg_pretrained.architectures[0] == "LlavaLlamaForCausalLM": # Ugly code, only used in vicuna that needs ROPE
if "224" in cfg_pretrained.mm_vision_tower:
least_token_number = self.max_frames_num * (16 // self.mm_spatial_pool_stride) ** 2 + 1000
else:
least_token_number = self.max_frames_num * (24 // self.mm_spatial_pool_stride) ** 2 + 1000

scaling_factor = math.ceil(least_token_number / 4096)
if scaling_factor >= 2:
overwrite_config["rope_scaling"] = {"factor": float(scaling_factor), "type": "linear"}
overwrite_config["max_sequence_length"] = 4096 * scaling_factor
overwrite_config["tokenizer_model_max_length"] = 4096 * scaling_factor

self._tokenizer, self._model, self._image_processor, self._max_length = load_pretrained_model(pretrained, None, self.model_name, device_map=self.device_map, overwrite_config=overwrite_config)
else:
self._tokenizer, self._model, self._image_processor, self._max_length = load_pretrained_model(
Expand Down Expand Up @@ -316,10 +313,7 @@ def generate_until(self, requests) -> List[str]:
videos = []
try:
for visual in visuals:
if self.video_decode_backend == "decord":
video = self.load_video(visual, self.max_frames_num)
elif self.video_decode_backend == "pyav":
video = read_video_pyav(visual, num_frm=self.max_frames_num)
video = self.load_video(visual, self.max_frames_num)
video = self._image_processor.preprocess(video, return_tensors="pt")["pixel_values"].half().cuda()
videos.append(video)
except Exception as e:
Expand Down
127 changes: 127 additions & 0 deletions lmms_eval/models/reka.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
from PIL import Image
from io import BytesIO
from copy import deepcopy
import numpy as np
import os
import base64
from typing import List, Tuple
from tqdm import tqdm
import requests as url_requests
import time
import logging

from lmms_eval.api.instance import Instance
from lmms_eval.api.model import lmms
from lmms_eval.api.registry import register_model
from lmms_eval import utils

from accelerate import Accelerator, DistributedType, InitProcessGroupKwargs
from accelerate.state import AcceleratorState

NUM_SECONDS_TO_SLEEP = 30
eval_logger = logging.getLogger("lmms-eval")

try:
import reka
from decord import VideoReader, cpu

reka.API_KEY = os.getenv("REKA_API_KEY", "YOUR_API_KEY")
except Exception as e:
eval_logger.error(f"{e}")
pass


@register_model("reka")
class Reka(lmms):
def __init__(
self,
model_version: str = "reka-edge",
modality: str = "image",
max_frames_for_video: int = 10,
timeout: int = 120,
**kwargs,
) -> None:
super().__init__()
self.model_version = model_version
self.modality = modality
self.max_frames_for_video = max_frames_for_video
self.timeout = timeout

accelerator = Accelerator()
if accelerator.num_processes > 1:
assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported."
self.accelerator = accelerator
if self.accelerator.is_local_main_process:
eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism")
self._rank = self.accelerator.local_process_index
self._world_size = self.accelerator.num_processes
else:
self.accelerator = accelerator
self._rank = self.accelerator.local_process_index
self._world_size = self.accelerator.num_processes

self.device = self.accelerator.device

def encode_media(self, media_path):
img = Image.open(media_path)
output_buffer = BytesIO()
img.save(output_buffer, format="PNG")
byte_data = output_buffer.getvalue()
base64_str = base64.b64encode(byte_data).decode("utf-8")

return f"data:image/jpeg;base64,{base64_str}"

def generate_until(self, requests) -> List[str]:
res = []
pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding")

for contexts, gen_kwargs, doc_to_visual, doc_id, task, split in [reg.args for reg in requests]:
visuals = [doc_to_visual(self.task_dict[task][split][doc_id])]

conversations_history = []
media_urls = []
for visual in visuals:
if self.modality == "image":
media_url = self.encode_media(visual)
else:
raise NotImplementedError

conversations_history.append({"type": "human", "text": contexts, "media_url": media_url})

if "max_new_tokens" not in gen_kwargs:
gen_kwargs["max_new_tokens"] = 1024
if "temperature" not in gen_kwargs:
gen_kwargs["temperature"] = 0
if "top_p" not in gen_kwargs:
gen_kwargs["top_p"] = None
if "num_beams" not in gen_kwargs:
gen_kwargs["num_beams"] = 1

for attempt in range(5):
try:
response = reka.chat(
conversations_history=conversations_history,
model=self.model_version,
request_output_len=gen_kwargs["max_new_tokens"],
temperature=gen_kwargs["temperature"],
)
content = response["text"].strip()
break # If successful, break out of the loop

except Exception as e:
eval_logger.info(f"Attempt {attempt + 1} failed with error: {str(e)}")
if attempt < 5 - 1: # If we have retries left, sleep and then continue to next attempt
time.sleep(NUM_SECONDS_TO_SLEEP)
else: # If this was the last attempt, log and return empty
eval_logger.error(f"All 5 attempts failed. Last error message: {str(e)}")
eval_logger.error(f"Response: {response}")
content = ""

res.append(content)
pbar.update(1)
pbar.close()
return res

def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
# TODO
assert False, "Reka not support loglikelihood"
2 changes: 1 addition & 1 deletion lmms_eval/tasks/activitynetqa/_default_template_yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ dataset_kwargs:
model_specific_prompt_kwargs:
default:
pre_prompt: ""
post_prompt: ""
post_prompt: " Answer the question using a single word or phrase."

metadata:
version: 0.0
Expand Down
10 changes: 7 additions & 3 deletions lmms_eval/tasks/activitynetqa/activitynetqa_generation.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@ output_type: generate_until
doc_to_visual: !function utils.activitynetqa_doc_to_visual
doc_to_text: !function utils.activitynetqa_doc_to_text
doc_to_target: !function utils.activitynetqa_doc_to_answer
process_results: !function utils.activitynetqa_process_results
process_results: !function utils.activitynetqa_process_results # gpt eval here for each QA pairs
metric_list:
- metric: submission
aggregation: !function utils.activitynetqa_aggregate
- metric: gpt_eval_score
aggregation: !function utils.activitynetqa_aggregate_score # parse scores from each QA pairs
higher_is_better: true
- metric: gpt_eval_accuracy
aggregation: !function utils.activitynetqa_aggregate_accuracy # parse accuracy from each QA pairs
higher_is_better: true

include: _default_template_yaml

generation_kwargs:
Expand Down
Loading

0 comments on commit ba3de94

Please sign in to comment.