From 3f878eac43e459f608fc6107f641f896766655ff Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Thu, 4 Apr 2024 07:49:37 +0000 Subject: [PATCH 01/25] Remove flash attention --- lmms_eval/models/llava.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lmms_eval/models/llava.py b/lmms_eval/models/llava.py index b3cb8a66d..3edd8d76b 100644 --- a/lmms_eval/models/llava.py +++ b/lmms_eval/models/llava.py @@ -52,7 +52,7 @@ def __init__( batch_size: Optional[Union[int, str]] = 1, trust_remote_code: Optional[bool] = False, revision=None, - use_flash_attention_2=True, + use_flash_attention_2=False, # True, device_map="", conv_template="vicuna_v1", use_cache=True, From c8043455b655a2a169cc0e1f80beaa7758aec38b Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Thu, 4 Apr 2024 13:47:10 +0000 Subject: [PATCH 02/25] Fix LlavaHf integration - need fp16 in accelerate config! --- lmms_eval/models/__init__.py | 1 + lmms_eval/models/llava_hf.py | 215 +++++++++++++++++++++++++++++++++++ 2 files changed, 216 insertions(+) create mode 100644 lmms_eval/models/llava_hf.py diff --git a/lmms_eval/models/__init__.py b/lmms_eval/models/__init__.py index d39ce73cf..9135311cc 100644 --- a/lmms_eval/models/__init__.py +++ b/lmms_eval/models/__init__.py @@ -2,6 +2,7 @@ AVAILABLE_MODELS = { "llava": "Llava", + "llava_hf": "LlavaHf", "qwen_vl": "Qwen_VL", "fuyu": "Fuyu", "gpt4v": "GPT4V", diff --git a/lmms_eval/models/llava_hf.py b/lmms_eval/models/llava_hf.py new file mode 100644 index 000000000..1709c688c --- /dev/null +++ b/lmms_eval/models/llava_hf.py @@ -0,0 +1,215 @@ +import torch +import logging +from tqdm import tqdm +from lmms_eval import utils +from lmms_eval.api.instance import Instance +from lmms_eval.api.model import lmms +from lmms_eval.api.registry import register_model +from accelerate import Accelerator, DistributedType +from accelerate.state import AcceleratorState +from typing import List, Optional, Union, Tuple +from transformers import LlavaForConditionalGeneration, AutoProcessor + +import warnings + +warnings.filterwarnings("ignore") + +eval_logger = logging.getLogger("lmms-eval") + + +@register_model("llava_hf") +class LlavaHf(lmms): + """ + Llava HF Model + """ + + def __init__( + self, + pretrained: str = "llava-hf/llava-1.5-7b-hf", + device: Optional[str] = "cuda", + dtype: Optional[Union[str, torch.dtype]] = torch.float16, + batch_size: Optional[Union[int, str]] = 1, + **kwargs, + ) -> None: + super().__init__() + # Do not use kwargs for now + assert kwargs == {}, f"Unexpected kwargs: {kwargs}" + + accelerator = Accelerator() + if accelerator.num_processes > 1: + self._device = torch.device(f"cuda:{accelerator.local_process_index}") + else: + self._device = device + self._model = LlavaForConditionalGeneration.from_pretrained(pretrained, torch_dtype=dtype) + self._image_processor = AutoProcessor.from_pretrained(pretrained) + self._tokenizer = self._image_processor.tokenizer + self._config = self._model.config + self._model.eval() + self._model.tie_weights() + self.batch_size_per_gpu = int(batch_size) + if accelerator.num_processes > 1: + assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported." + # If you want to use DistributedType.DEEPSPEED, you have to run accelerate config before using the model + # Also, you have to select zero stage 0 (equivalent to DDP) in order to make the prepare model works + # I tried to set different parameters in the kwargs to let default zero 2 stage works, but it didn't work. + if accelerator.distributed_type == DistributedType.DEEPSPEED: + kwargs = { + "train_micro_batch_size_per_gpu": self.batch_size_per_gpu, + "train_batch_size": self.batch_size_per_gpu * accelerator.num_processes, + } + AcceleratorState().deepspeed_plugin.deepspeed_config_process(must_match=True, **kwargs) + eval_logger.info("Detected that you are using DistributedType.DEEPSPEED. Make sure you run `accelerate config` and set zero stage to 0") + if accelerator.distributed_type == DistributedType.FSDP or accelerator.distributed_type == DistributedType.DEEPSPEED: + self._model = accelerator.prepare(self.model) + else: + self._model = accelerator.prepare_model(self.model, evaluation_mode=True) + self.accelerator = accelerator + if self.accelerator.is_local_main_process: + eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism") + self._rank = self.accelerator.local_process_index + self._world_size = self.accelerator.num_processes + else: + self.model.to(self._device) + self._rank = 0 + self._word_size = 1 + + @property + def config(self): + # return the associated transformers.AutoConfig for the given pretrained model. + return self._config + + @property + def tokenizer(self): + return self._tokenizer + + @property + def model(self): + # returns the model, unwrapping it if using Accelerate + if hasattr(self, "accelerator"): + return self.accelerator.unwrap_model(self._model) + else: + return self._model + + @property + def eot_token_id(self): + # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* + return self.tokenizer.eos_token_id + + @property + def max_length(self): + return self._max_length + + @property + def batch_size(self): + return self.batch_size_per_gpu + + @property + def device(self): + return self._device + + @property + def rank(self): + return self._rank + + @property + def world_size(self): + return self._world_size + + def tok_encode(self, string: str, left_truncate_len=None, add_special_tokens=None) -> List[int]: + """ """ + add_special_tokens = False if add_special_tokens is None else add_special_tokens + encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens) + # left-truncate the encoded context to be at most `left_truncate_len` tokens long + if left_truncate_len: + encoding = encoding[-left_truncate_len:] + return encoding + + def tok_decode(self, tokens): + return self.tokenizer.decode(tokens) + + def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: + # TODO + assert False, "We have not implemented this function for LlavaHf yet" + + def flatten(self, input): + new_list = [] + for i in input: + for j in i: + new_list.append(j) + return new_list + + def generate_until(self, requests: List[Instance]) -> List[str]: + res = [] + + def _collate(x): + # the negative sign on len(toks) sorts descending - this has a few advantages: + # - time estimates will always be over not underestimates, which is more useful for planning + # - to know the size of a batch when going through the list, you know the first one is always the batch + # padded context length. this is useful to simplify the batching logic and more importantly to make + # automatic adaptive batches much much easier to implement + # - any OOMs will happen right away rather than near the end + toks = self.tok_encode(x[0]) + return -len(toks), x[0] + + # we group requests by their generation_kwargs, + # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling + # in the same batch. + re_ords = utils.Collator([reg.args for reg in requests], _collate, grouping=True) + chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None) + num_iters = len(requests) // self.batch_size if len(requests) % self.batch_size == 0 else len(requests) // self.batch_size + 1 + pbar = tqdm(total=num_iters, disable=(self.rank != 0), desc="Model Responding") + for chunk in chunks: + contexts, all_gen_kwargs, doc_to_visual, doc_id, task, split = zip(*chunk) + task = task[0] + split = split[0] + visuals = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id] + visuals = self.flatten(visuals) + # we assume all gen kwargs in the batch are the same + # this is safe to assume because the `grouper` object ensures it. + gen_kwargs = all_gen_kwargs[0] + + # Set default values for until and max_new_tokens + until = [self.tok_decode(self.eot_token_id)] + + # Update values from gen_kwargs if present + if "until" in gen_kwargs: + until = gen_kwargs.pop("until") + if isinstance(until, str): + until = [until] + elif not isinstance(until, list): + raise ValueError(f"Expected `gen_kwargs['until']` to be of type Union[str,list] but got {type(until)}") + assert self.batch_size_per_gpu == 1, "Do not support batch_size_per_gpu > 1 for now" + context = contexts[0] + context = f"USER: \n{context}\nASSISTANT:" + inputs = self._image_processor(images=visuals, text=context, return_tensors="pt").to(self._device, torch.float16) + + gen_kwargs["image_sizes"] = [visuals[idx].size for idx in range(len(visuals))] + if "max_new_tokens" not in gen_kwargs: + gen_kwargs["max_new_tokens"] = 1024 + if "temperature" not in gen_kwargs: + gen_kwargs["temperature"] = 0 + if "top_p" not in gen_kwargs: + gen_kwargs["top_p"] = None + if "num_beams" not in gen_kwargs: + gen_kwargs["num_beams"] = 1 + try: + cont = self._model.generate( + **inputs, + do_sample=True if gen_kwargs["temperature"] > 0 else False, + temperature=gen_kwargs["temperature"], + top_p=gen_kwargs["top_p"], + num_beams=gen_kwargs["num_beams"], + max_new_tokens=gen_kwargs["max_new_tokens"], + ) + except Exception as e: + eval_logger.error(f"Error {e} in generating") + cont = "" + text_outputs = self.tokenizer.batch_decode(cont, skip_special_tokens=True)[0].strip() + res.append(text_outputs) + self.cache_hook.add_partial("generate_until", (context, gen_kwargs), text_outputs) + pbar.update(1) + # reorder this group of results back to original unsorted form + res = re_ords.get_original(res) + + pbar.close() + return res From a6e258b87310cbd42e50305eaada3fbdfee66249 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Thu, 4 Apr 2024 13:56:13 +0000 Subject: [PATCH 03/25] Fix parsing --- lmms_eval/models/llava_hf.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/lmms_eval/models/llava_hf.py b/lmms_eval/models/llava_hf.py index 1709c688c..c0231ec11 100644 --- a/lmms_eval/models/llava_hf.py +++ b/lmms_eval/models/llava_hf.py @@ -204,12 +204,14 @@ def _collate(x): except Exception as e: eval_logger.error(f"Error {e} in generating") cont = "" - text_outputs = self.tokenizer.batch_decode(cont, skip_special_tokens=True)[0].strip() + text_outputs = self.tokenizer.batch_decode(cont, skip_special_tokens=True)[0]#.strip() + text_outputs = text_outputs.split("ASSISTANT:")[1].strip() res.append(text_outputs) self.cache_hook.add_partial("generate_until", (context, gen_kwargs), text_outputs) pbar.update(1) # reorder this group of results back to original unsorted form res = re_ords.get_original(res) + res = re_ords.get_foriginal(res) pbar.close() return res From 1ea7a8fd60f39db1b76f665a368a45de6be5b972 Mon Sep 17 00:00:00 2001 From: Li Bo Date: Fri, 5 Apr 2024 00:29:02 +0800 Subject: [PATCH 04/25] [WIP] adding mmbench dev evaluation (#75) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * WIP * Update GPT evaluation model name and sys prompt * 🛠️ Scale accuracy to percentage --- lmms_eval/tasks/mmbench/cc_utils.py | 40 ++- lmms_eval/tasks/mmbench/cn_utils.py | 44 ++- lmms_eval/tasks/mmbench/en_utils.py | 45 ++- lmms_eval/tasks/mmbench/mmbench.yaml | 6 +- lmms_eval/tasks/mmbench/mmbench_cc.yaml | 6 +- lmms_eval/tasks/mmbench/mmbench_cn.yaml | 5 +- lmms_eval/tasks/mmbench/mmbench_cn_dev.yaml | 3 + lmms_eval/tasks/mmbench/mmbench_en.yaml | 1 + lmms_eval/tasks/mmbench/mmbench_en_dev.yaml | 7 +- lmms_eval/tasks/mmbench/mmbench_evals.py | 309 +++++++++++++++++++- 10 files changed, 447 insertions(+), 19 deletions(-) diff --git a/lmms_eval/tasks/mmbench/cc_utils.py b/lmms_eval/tasks/mmbench/cc_utils.py index 7009e012e..abb24ab17 100644 --- a/lmms_eval/tasks/mmbench/cc_utils.py +++ b/lmms_eval/tasks/mmbench/cc_utils.py @@ -9,7 +9,7 @@ from lmms_eval.tasks.mmbench.mmbench_evals import MMBench_Evaluator from lmms_eval.tasks._task_utils.file_utils import generate_submission_file -with open(Path(__file__).parent / "mmbench_cn.yaml", "r") as f: +with open(Path(__file__).parent / "mmbench.yaml", "r") as f: raw_data = f.readlines() safe_data = [] for i, line in enumerate(raw_data): @@ -19,7 +19,18 @@ config = yaml.safe_load("".join(safe_data)) -mmbench_evaluator = MMBench_Evaluator(sys_prompt=config["metadata"]["sys_prompt"]) +GPT_EVAL_MODEL_NAME = config["metadata"]["gpt_eval_model_name"] +API_TYPE = os.getenv("API_TYPE", "openai") + +if API_TYPE == "openai": + API_URL = os.getenv("OPENAI_API_URL", "https://api.openai.com/v1/chat/completions") + API_KEY = os.getenv("OPENAI_API_KEY", "YOUR_API_KEY") +elif API_TYPE == "azure": + API_URL = os.getenv("AZURE_ENDPOINT", "https://api.cognitive.microsoft.com/sts/v1.0/issueToken") + API_KEY = os.getenv("AZURE_API_KEY", "YOUR_API_KEY") + + +mmbench_evaluator = MMBench_Evaluator(sys_prompt=config["metadata"]["sys_prompt"], API_KEY=API_KEY, API_URL=API_URL, model_version=GPT_EVAL_MODEL_NAME) def mmbench_doc_to_visual(doc): @@ -52,6 +63,14 @@ def mmbench_cn_cc_doc_to_text(doc, model_specific_prompt_kwargs=None): def mmbench_cn_cc_process_results(doc, results): model_response = results[0].strip() data = { + "gpt_eval_score": { + "index": doc["index"], + "question": doc["question"], + "answer": doc["answer"], + "prediction": model_response, + "source": doc["source"], + "category": doc["category"], + }, "submission": { "index": doc["index"], "question": doc["question"], @@ -59,14 +78,29 @@ def mmbench_cn_cc_process_results(doc, results): "prediction": model_response, "source": doc["source"], "category": doc["category"], - } + }, } option_candidate = ["A", "B", "C", "D", "E"] for c in option_candidate: data["submission"][c] = doc.get(c, "nan") + data["gpt_eval_score"][c] = doc.get(c, "nan") return data +def mmbench_cn_cc_aggregate_dev_results_eval(results, args): + print(f"============= MMBench-CN(CC) Detailed Results =============") + overall_acc, category_acc, l2_category_acc = mmbench_evaluator.eval_result(results, eval_method="openai") + file = generate_submission_file("mmbench_cn_cc_results.json", args) + details_info = { + "overall_acc": overall_acc, + "category_acc": category_acc, + "l2_category_acc": l2_category_acc, + } + with open(file, "w") as f: + json.dump(details_info, f) + return overall_acc * 100 + + def mmbench_cn_cc_aggregate_results(results, args): df = pd.DataFrame(results) file = generate_submission_file("mmbench_cn_cc_results.xlsx", args) diff --git a/lmms_eval/tasks/mmbench/cn_utils.py b/lmms_eval/tasks/mmbench/cn_utils.py index 812b9aa38..39a55f728 100644 --- a/lmms_eval/tasks/mmbench/cn_utils.py +++ b/lmms_eval/tasks/mmbench/cn_utils.py @@ -8,8 +8,9 @@ eval_logger = logging.getLogger("lmms-eval") from lmms_eval.tasks.mmbench.mmbench_evals import MMBench_Evaluator +from lmms_eval.tasks._task_utils.file_utils import generate_submission_file -with open(Path(__file__).parent / "mmbench_cn.yaml", "r") as f: +with open(Path(__file__).parent / "mmbench.yaml", "r") as f: raw_data = f.readlines() safe_data = [] for i, line in enumerate(raw_data): @@ -19,7 +20,18 @@ config = yaml.safe_load("".join(safe_data)) -mmbench_evaluator = MMBench_Evaluator(sys_prompt=config["metadata"]["sys_prompt"]) +GPT_EVAL_MODEL_NAME = config["metadata"]["gpt_eval_model_name"] +API_TYPE = os.getenv("API_TYPE", "openai") + +if API_TYPE == "openai": + API_URL = os.getenv("OPENAI_API_URL", "https://api.openai.com/v1/chat/completions") + API_KEY = os.getenv("OPENAI_API_KEY", "YOUR_API_KEY") +elif API_TYPE == "azure": + API_URL = os.getenv("AZURE_ENDPOINT", "https://api.cognitive.microsoft.com/sts/v1.0/issueToken") + API_KEY = os.getenv("AZURE_API_KEY", "YOUR_API_KEY") + + +mmbench_evaluator = MMBench_Evaluator(sys_prompt=config["metadata"]["sys_prompt"], API_KEY=API_KEY, API_URL=API_URL, model_version=GPT_EVAL_MODEL_NAME) def mmbench_doc_to_visual(doc): @@ -55,6 +67,17 @@ def mmbench_doc_to_text(doc, model_specific_prompt_kwargs=None): def mmbench_process_results(doc, results): model_response = results[0].strip() data = { + "gpt_eval_score": { + "index": doc["index"], + "question": doc["question"], + "answer": doc["answer"], + "prediction": model_response, + "hint": doc["hint"], + "source": doc["source"], + "split": doc["split"], + "category": doc["category"], + "L2-category": doc["L2-category"], + }, "submission": { "index": doc["index"], "question": doc["question"], @@ -65,14 +88,29 @@ def mmbench_process_results(doc, results): "split": doc["split"], "category": doc["category"], "L2-category": doc["L2-category"], - } + }, } option_candidate = ["A", "B", "C", "D", "E"] for c in option_candidate: data["submission"][c] = doc.get(c, "nan") + data["gpt_eval_score"][c] = doc.get(c, "nan") return data +def mmbench_aggregate_dev_results_eval(results, args): + print(f"============= MMBench-CN(Dev) Detailed Results =============") + overall_acc, category_acc, l2_category_acc = mmbench_evaluator.eval_result(results, eval_method="openai") + file = generate_submission_file("mmbench_cn_dev_results.json", args) + details_info = { + "overall_acc": overall_acc, + "category_acc": category_acc, + "l2_category_acc": l2_category_acc, + } + with open(file, "w") as f: + json.dump(details_info, f) + return overall_acc * 100 + + def mmbench_aggregate_dev_results(results, args): df = pd.DataFrame(results) excel_write_path = generate_submission_file("mmbench_cn_dev_results.xlsx", args) diff --git a/lmms_eval/tasks/mmbench/en_utils.py b/lmms_eval/tasks/mmbench/en_utils.py index 26e260006..1ddccbb68 100644 --- a/lmms_eval/tasks/mmbench/en_utils.py +++ b/lmms_eval/tasks/mmbench/en_utils.py @@ -9,7 +9,7 @@ from lmms_eval.tasks.mmbench.mmbench_evals import MMBench_Evaluator from lmms_eval.tasks._task_utils.file_utils import generate_submission_file -with open(Path(__file__).parent / "mmbench_en.yaml", "r") as f: +with open(Path(__file__).parent / "mmbench.yaml", "r") as f: raw_data = f.readlines() safe_data = [] for i, line in enumerate(raw_data): @@ -19,7 +19,18 @@ config = yaml.safe_load("".join(safe_data)) -mmbench_evaluator = MMBench_Evaluator(sys_prompt=config["metadata"]["sys_prompt"]) +GPT_EVAL_MODEL_NAME = config["metadata"]["gpt_eval_model_name"] +API_TYPE = os.getenv("API_TYPE", "openai") + +if API_TYPE == "openai": + API_URL = os.getenv("OPENAI_API_URL", "https://api.openai.com/v1/chat/completions") + API_KEY = os.getenv("OPENAI_API_KEY", "YOUR_API_KEY") +elif API_TYPE == "azure": + API_URL = os.getenv("AZURE_ENDPOINT", "https://api.cognitive.microsoft.com/sts/v1.0/issueToken") + API_KEY = os.getenv("AZURE_API_KEY", "YOUR_API_KEY") + + +mmbench_evaluator = MMBench_Evaluator(sys_prompt=config["metadata"]["sys_prompt"], API_KEY=API_KEY, API_URL=API_URL, model_version=GPT_EVAL_MODEL_NAME) def mmbench_doc_to_visual(doc): @@ -55,6 +66,17 @@ def mmbench_doc_to_text(doc, model_specific_prompt_kwargs=None): def mmbench_process_results(doc, results): model_response = results[0].strip() data = { + "gpt_eval_score": { + "index": doc["index"], + "question": doc["question"], + "answer": doc["answer"], + "prediction": model_response, + "hint": doc["hint"], + "source": doc["source"], + "split": doc["split"], + "category": doc["category"], + "L2-category": doc["L2-category"], + }, "submission": { "index": doc["index"], "question": doc["question"], @@ -65,15 +87,30 @@ def mmbench_process_results(doc, results): "split": doc["split"], "category": doc["category"], "L2-category": doc["L2-category"], - } + }, } option_candidate = ["A", "B", "C", "D", "E"] for c in option_candidate: data["submission"][c] = doc.get(c, "nan") + data["gpt_eval_score"][c] = doc.get(c, "nan") return data -def mmbench_aggregate_dev_results(results, args): +def mmbench_aggregate_dev_results_eval(results, args): + print(f"============= MMBench-EN(Dev) Detailed Results =============") + overall_acc, category_acc, l2_category_acc = mmbench_evaluator.eval_result(results, eval_method="openai") + file = generate_submission_file("mmbench_en_dev_results.json", args) + details_info = { + "overall_acc": overall_acc, + "category_acc": category_acc, + "l2_category_acc": l2_category_acc, + } + with open(file, "w") as f: + json.dump(details_info, f) + return overall_acc * 100 + + +def mmbench_aggregate_dev_results_submission(results, args): df = pd.DataFrame(results) excel_write_path = generate_submission_file("mmbench_en_dev_results.xlsx", args) with pd.ExcelWriter(excel_write_path) as writer: diff --git a/lmms_eval/tasks/mmbench/mmbench.yaml b/lmms_eval/tasks/mmbench/mmbench.yaml index 45c2ed0db..821065eea 100644 --- a/lmms_eval/tasks/mmbench/mmbench.yaml +++ b/lmms_eval/tasks/mmbench/mmbench.yaml @@ -4,4 +4,8 @@ task: - mmbench_en_test - mmbench_cn_dev - mmbench_cn_test - - mmbench_cn_cc \ No newline at end of file + - mmbench_cn_cc +metadata: + version: 0.0 + sys_prompt: "There are several options:" + gpt_eval_model_name: "gpt-3.5-turbo-0613" \ No newline at end of file diff --git a/lmms_eval/tasks/mmbench/mmbench_cc.yaml b/lmms_eval/tasks/mmbench/mmbench_cc.yaml index 238aa10c9..4a0d58950 100644 --- a/lmms_eval/tasks/mmbench/mmbench_cc.yaml +++ b/lmms_eval/tasks/mmbench/mmbench_cc.yaml @@ -16,12 +16,14 @@ generation_kwargs: do_sample: false process_results: !function cc_utils.mmbench_cn_cc_process_results metric_list: + - metric: gpt_eval_score + aggregation: !function cc_utils.mmbench_cn_cc_aggregate_dev_results_eval + higher_is_better: true - metric: submission aggregation: !function cc_utils.mmbench_cn_cc_aggregate_results metadata: version: 0.0 - gpt_eval_model_name: "gpt-3.5-turbo" - quick_extract: true + gpt_eval_model_name: "gpt-3.5-turbo-0613" model_specific_prompt_kwargs: default: diff --git a/lmms_eval/tasks/mmbench/mmbench_cn.yaml b/lmms_eval/tasks/mmbench/mmbench_cn.yaml index 6232531c4..9da764cc2 100644 --- a/lmms_eval/tasks/mmbench/mmbench_cn.yaml +++ b/lmms_eval/tasks/mmbench/mmbench_cn.yaml @@ -5,6 +5,5 @@ task: - mmbench_cn_cc metadata: version: 0.0 - gpt_eval_model_name: "gpt-3.5-turbo" - quick_extract: true - sys_prompt: "有如下几个选项:" + gpt_eval_model_name: "gpt-3.5-turbo-0613" + sys_prompt: "有如下几个选项:" \ No newline at end of file diff --git a/lmms_eval/tasks/mmbench/mmbench_cn_dev.yaml b/lmms_eval/tasks/mmbench/mmbench_cn_dev.yaml index 3d7b9d98b..b3aaa545e 100644 --- a/lmms_eval/tasks/mmbench/mmbench_cn_dev.yaml +++ b/lmms_eval/tasks/mmbench/mmbench_cn_dev.yaml @@ -1,6 +1,9 @@ task: "mmbench_cn_dev" test_split: "dev" metric_list: + - metric: gpt_eval_score + aggregation: !function cn_utils.mmbench_aggregate_dev_results_eval + higher_is_better: true - metric: submission higher_is_better: true aggregation: !function cn_utils.mmbench_aggregate_dev_results diff --git a/lmms_eval/tasks/mmbench/mmbench_en.yaml b/lmms_eval/tasks/mmbench/mmbench_en.yaml index 9fa757cc3..0a5c55807 100644 --- a/lmms_eval/tasks/mmbench/mmbench_en.yaml +++ b/lmms_eval/tasks/mmbench/mmbench_en.yaml @@ -5,3 +5,4 @@ task: metadata: version: 0.0 sys_prompt: "There are several options:" + gpt_eval_model_name: "gpt-3.5-turbo-0613" diff --git a/lmms_eval/tasks/mmbench/mmbench_en_dev.yaml b/lmms_eval/tasks/mmbench/mmbench_en_dev.yaml index b4f4a2e9f..f9e6babff 100644 --- a/lmms_eval/tasks/mmbench/mmbench_en_dev.yaml +++ b/lmms_eval/tasks/mmbench/mmbench_en_dev.yaml @@ -2,6 +2,9 @@ task: "mmbench_en_dev" test_split: dev include: _default_template_mmbench_en_yaml metric_list: - - metric: submission - aggregation: !function en_utils.mmbench_aggregate_dev_results + - metric: gpt_eval_score + aggregation: !function en_utils.mmbench_aggregate_dev_results_eval higher_is_better: true + - metric: submission + aggregation: !function en_utils.mmbench_aggregate_dev_results_submission + higher_is_better: true \ No newline at end of file diff --git a/lmms_eval/tasks/mmbench/mmbench_evals.py b/lmms_eval/tasks/mmbench/mmbench_evals.py index bdae49d16..7868157ec 100644 --- a/lmms_eval/tasks/mmbench/mmbench_evals.py +++ b/lmms_eval/tasks/mmbench/mmbench_evals.py @@ -1,9 +1,26 @@ +import os.path as osp +import time +import random as rd +import string +from collections import defaultdict +import requests + +import math +import numpy as np import pandas as pd +from tqdm import tqdm + +import logging + +eval_logger = logging.getLogger("lmms-eval") class MMBench_Evaluator: - def __init__(self, sys_prompt="There are several options:"): + def __init__(self, sys_prompt="There are several options:", API_KEY="", API_URL="", model_version="gpt-3.5-turbo-0613"): self.sys_prompt = sys_prompt + self.model_version = model_version + self.API_KEY = API_KEY + self.API_URL = API_URL def create_options_prompt(self, row_data, option_candidate): available_keys = set(row_data.keys()) & set(option_candidate) @@ -14,3 +31,293 @@ def create_options_prompt(self, row_data, option_candidate): if pd.notna(item) and item != "nan": options_prompt += f"{key}. {item}\n" return options_prompt.rstrip("\n"), sorted_options + + # Prompt Building + def build_option_str(self, option_list): + chars = string.ascii_uppercase + s = "There are several options: \n" + for c, opt in zip(chars, option_list): + if not pd.isna(opt): + s += f"{c}. {opt}\n" + else: + return s + return s + + def extract_options(self, item): + options = [] + for c in "ABCD": + if c in item and not pd.isna(item[c]): + options.append(item[c]) + else: + return options + return options + + def build_choices(self, item): + ret = {} + for ch in "ABCD": + if not pd.isna(item[ch]): + ret[ch] = item[ch] + return ret + + def build_prompt(self, question, options, prediction): + tmpl = ( + "You are an AI assistant who will help me to match an answer " + "with several options of a single-choice question. " + "You are provided with a question, several options, and an answer, " + "and you need to find which option is most similar to the answer. " + "If the meaning of all options are significantly different " + "from the answer, output E. " + "Your should output a single uppercase character in A, B, C, D " + "(if they are valid options), and E. \n" + "Example 1: \n" + "Question: What is the main object in image?\nOptions: A. teddy bear " + "B. rabbit C. cat D. dog\nAnswer: a cute teddy bear\nYour output: A\n" + "Example 2: \n" + "Question: What is the main object in image?\nOptions: A. teddy bear " + "B. rabbit C. cat D. dog\nAnswer: Spider\nYour output: E\n" + "Example 3: \n" + "Question: {}?\nOptions: {}\nAnswer: {}\nYour output: " + ) + return tmpl.format(question, options, prediction) + + # Prefetch Answers + def can_infer_option(self, answer, num_choice=5): + choices = string.ascii_uppercase[:num_choice] + if "Failed to obtain answer via API" in answer: + return False + + def count(splits, choices="ABCD", prefix="", suffix=""): + cnt = 0 + for c in choices: + if prefix + c + suffix in splits: + cnt += 1 + return cnt + + splits = [x.strip() for x in answer.split()] + if count(splits, choices) == 1: + for ch in choices: + if "A" in splits and len(splits) > 3: + eval_logger.info(f"A might be a quantifier in the string: {answer}.") + break + if ch in splits: + return ch + tups = [("", "."), ("", ","), ("", ":"), ("", ")"), ("", ")."), ("(", ")"), ("(", ")."), (":", ""), (":", ","), (":", "."), (":", ")"), (":", ").")] + for tup in tups: + if count(splits, choices, prefix=tup[0], suffix=tup[1]) == 1: + for ch in choices: + if tup[0] + ch + tup[1] in splits: + return ch + return False + + def can_infer_text(self, answer, choices): + answer = answer.lower() + assert isinstance(choices, dict) + for k in choices: + assert k in "ABCD" + choices[k] = str(choices[k]).lower() + cands = [] + for k in choices: + if choices[k] in answer: + cands.append(k) + if len(cands) == 1: + return cands[0] + return False + + def can_infer(self, answer, choices): + copt = self.can_infer_option(answer) + return copt if copt else self.can_infer_text(answer, choices) + + def prefetch_answer(self, item): + choices = self.build_choices(item) + return self.can_infer(item["prediction"], choices) + + def _post_request(self, payload): + headers = { + "Authorization": f"Bearer {self.API_KEY}", + "Content-Type": "application/json", + } + response = requests.post(self.API_URL, headers=headers, json=payload, timeout=30) + response.raise_for_status() + return response.json() + + def get_chat_response(self, prompt, temperature=0, max_tokens=256, n=1, patience=5, sleep_time=3): + messages = [ + {"role": "user", "content": prompt}, + ] + payload = {"model": self.model_version, "messages": messages, "temperature": temperature, "max_tokens": max_tokens, "n": n} + + while patience > 0: + patience -= 1 + try: + response = self._post_request(payload) + if n == 1: + prediction = response["choices"][0]["message"]["content"].strip() + if prediction and prediction != "": + return prediction + else: + prediction = [choice["message"]["content"].strip() for choice in response["choices"]] + if prediction and prediction[0] != "": + return prediction + + except Exception as e: + eval_logger.info(f"Attempt {patience + 1} failed with error: {e}") + if sleep_time > 0: + time.sleep(sleep_time) + + return "Failed to obtain answer via API" + + def extract_answer_from_item(self, item): + options = self.extract_options(item) + option_str = self.build_option_str(options) + + prompt = self.build_prompt(item["question"], option_str, item["prediction"]) + retry = 3 + choices = self.build_choices(item) + + ret = self.can_infer(item["prediction"], choices) + if ret: + return ret, item["prediction"] + + while retry: + ans = self.get_chat_response(prompt) + if "Failed to obtain answer via API" in ans: + msg = "GPT API failed to answer. " + eval_logger.info(msg) + retry -= 1 + else: + ret = self.can_infer(ans, choices) + if ret: + return ret, ans + else: + eval_logger.info(f'GPT output includes 0 / >1 letter in "ABCD": {ans}') + retry -= 1 + + if retry == 0: + num_options = sum([ch in item for ch in "ABCD"]) + if num_options >= 2: + chars = string.ascii_uppercase[:num_options] + chars = chars + "E" + num_options += 1 + tmp = rd.randint(0, num_options - 1) + return chars[tmp], "Failed to predict, thus randomly generate one. " + + # Extract answer from multiple rolling records + def eval_sub_data(self, sub_data, answer_map): + lt = len(sub_data) + GT, PRED = [], [] + for i in range(lt): + item = sub_data.iloc[i] + idx = item["index"] + GT.append(answer_map[idx]) + PRED.append(self.prefetch_answer(item)) + if PRED[-1] and (GT[-1] != PRED[-1]): + return 0 + + for i in range(lt): + if PRED[i]: + continue + else: + ret, _ = self.extract_answer_from_item(sub_data.iloc[i]) + PRED[i] = ret + if PRED[i] != GT[i]: + return 0 + return 1 + + def calculate_hit_rates(self, data): + overall_hit_rate = data["hit"].mean() + + category_hit_rate = {} + if "category" in data.columns: + # Category-based hit rate + category_hit_rate = data.groupby("category")["hit"].mean().to_dict() + + # l2-category based hit rate + l2_category_hit_rate = {} + if "l2-category" in data.columns: + l2_category_hit_rate = data.groupby("l2-category")["hit"].mean().to_dict() + + return overall_hit_rate, category_hit_rate, l2_category_hit_rate + + # Evaluate Results + def eval_result(self, results, eval_method): + rd.seed(2680) + assert eval_method == "openai" + # Set a large retry number to avoid failure + # model = OpenAI('gpt-3.5-turbo-0613', retry=99) + + # double_log(f'Evaluating {eval_file}', fout) + + # result_file = eval_file.replace('.xlsx', f'_{eval_method}_result.pkl') + result = {} + # if osp.exists(result_file): + # result = load(result_file) + + # data = load(eval_file) + data = pd.DataFrame(results) + data = data.sort_values(by="index") + data["prediction"] = [str(x) for x in data["prediction"]] + for k in data.keys(): + data[k.lower() if k not in "ABCD" else k] = data.pop(k) + + # meta = load(meta_file) + + data_main = data[data["index"] < int(1e6)] + + data_main["hit"] = 0 + cate_map = {i: c for i, c in zip(data["index"], data["category"])} + answer_map = {i: c for i, c in zip(data["index"], data["answer"])} + if "l2-category" in data.columns: + l2_cate_map = {i: c for i, c in zip(data["index"], data["l2-category"])} + + lt = len(data_main) + hit, tot = 0, 0 + + for i in range(lt): + # Dealing with the normal part + item_main = data_main.iloc[i] + idx = item_main["index"] + + if idx in result: + correct = result[idx] + assert correct in [0, 1] + hit += correct + tot += 1 + continue + + sub_data = data[data["index"] % int(1e6) == idx] + ret = self.eval_sub_data(sub_data, answer_map) + result[idx] = ret + hit += ret + tot += 1 + + data_main.loc[data_main["index"] == idx, "hit"] = ret + # if (i + 1) % 100 == 0: + # eval_logger.info(f"Evaluating: {i + 1}/{lt}, Acc: {hit / tot * 100: .2f}%. ") + + indices = data_main["index"] + data_main = data_main.set_index("index") + data_main["category"] = [cate_map[i] if not math.isnan(i) else "uncategorized" for i in indices] + if "l2-category" in data_main.columns: + data_main["l2-category"] = [l2_cate_map[i] if not math.isnan(i) else "uncategorized" for i in indices] + + overall_hit_rate, category_hit_rate, l2_category_hit_rate = self.calculate_hit_rates(data_main) + + if "category" in data_main.columns: + print(f"Category Acc. (dev):") + for category_key in category_hit_rate: + if category_key == "split": + continue + + category_percentage = category_hit_rate[category_key] * 100 + print(f"\t{category_key}: {category_percentage:.3f}") + + if "l2-category" in data_main.columns: + print(f"L2-category Acc. (dev):") + for l2_category_key in l2_category_hit_rate: + if l2_category_key == "split": + continue + + l2_category_percentage = l2_category_hit_rate[l2_category_key] * 100 + print(f"\t{l2_category_key}: {l2_category_percentage:.3f}") + + return overall_hit_rate, category_hit_rate, l2_category_hit_rate From 86732ebf7a406504ee822fe188329bb1005c7f25 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Fri, 5 Apr 2024 08:34:18 +0000 Subject: [PATCH 05/25] Add chat template --- lmms_eval/models/llava_hf.py | 29 +++++++++++++++++++++++------ 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/lmms_eval/models/llava_hf.py b/lmms_eval/models/llava_hf.py index c0231ec11..4f7d6e075 100644 --- a/lmms_eval/models/llava_hf.py +++ b/lmms_eval/models/llava_hf.py @@ -26,6 +26,7 @@ class LlavaHf(lmms): def __init__( self, pretrained: str = "llava-hf/llava-1.5-7b-hf", + revision=None, device: Optional[str] = "cuda", dtype: Optional[Union[str, torch.dtype]] = torch.float16, batch_size: Optional[Union[int, str]] = 1, @@ -40,8 +41,8 @@ def __init__( self._device = torch.device(f"cuda:{accelerator.local_process_index}") else: self._device = device - self._model = LlavaForConditionalGeneration.from_pretrained(pretrained, torch_dtype=dtype) - self._image_processor = AutoProcessor.from_pretrained(pretrained) + self._model = LlavaForConditionalGeneration.from_pretrained(pretrained, revision=revision, torch_dtype=dtype) + self._image_processor = AutoProcessor.from_pretrained(pretrained, revision=revision) self._tokenizer = self._image_processor.tokenizer self._config = self._model.config self._model.eval() @@ -180,8 +181,20 @@ def _collate(x): raise ValueError(f"Expected `gen_kwargs['until']` to be of type Union[str,list] but got {type(until)}") assert self.batch_size_per_gpu == 1, "Do not support batch_size_per_gpu > 1 for now" context = contexts[0] - context = f"USER: \n{context}\nASSISTANT:" - inputs = self._image_processor(images=visuals, text=context, return_tensors="pt").to(self._device, torch.float16) + + # Some benchmarks like MME do not contain image tokens, so we prepend them to the prompt. + if "" not in context: + context = f"\n{context}" + if self.tokenizer.chat_template is not None: + messages = [{"role": "user", "content": context}] + text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + else: + text = f"USER: {context}\nASSISTANT:" + + if self.accelerator.is_main_process and doc_id[0] % 100 == 0: + eval_logger.info(f"Prompt for doc ID {doc_id[0]}:\n\n{text}\n") + + inputs = self._image_processor(images=visuals, text=text, return_tensors="pt").to(self._device, torch.float16) gen_kwargs["image_sizes"] = [visuals[idx].size for idx in range(len(visuals))] if "max_new_tokens" not in gen_kwargs: @@ -204,14 +217,18 @@ def _collate(x): except Exception as e: eval_logger.error(f"Error {e} in generating") cont = "" - text_outputs = self.tokenizer.batch_decode(cont, skip_special_tokens=True)[0]#.strip() + text_outputs = self.tokenizer.batch_decode(cont, skip_special_tokens=True)[0] text_outputs = text_outputs.split("ASSISTANT:")[1].strip() + + if self.accelerator.is_main_process and doc_id[0] % 100 == 0: + eval_logger.info(f"Generated text for doc ID {doc_id[0]}:\n\n{text_outputs}\n") + res.append(text_outputs) self.cache_hook.add_partial("generate_until", (context, gen_kwargs), text_outputs) pbar.update(1) # reorder this group of results back to original unsorted form res = re_ords.get_original(res) - res = re_ords.get_foriginal(res) + res = re_ords.get_original(res) pbar.close() return res From c4955df8c99a3594983245c3a9823940d090c4b9 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Fri, 5 Apr 2024 13:03:17 +0000 Subject: [PATCH 06/25] Revert --- lmms_eval/models/llava.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lmms_eval/models/llava.py b/lmms_eval/models/llava.py index 3edd8d76b..b3cb8a66d 100644 --- a/lmms_eval/models/llava.py +++ b/lmms_eval/models/llava.py @@ -52,7 +52,7 @@ def __init__( batch_size: Optional[Union[int, str]] = 1, trust_remote_code: Optional[bool] = False, revision=None, - use_flash_attention_2=False, # True, + use_flash_attention_2=True, device_map="", conv_template="vicuna_v1", use_cache=True, From 78fc6175b384797c3728fb8d390bf90e318bfd14 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Mon, 8 Apr 2024 13:19:19 +0000 Subject: [PATCH 07/25] Remove duplicate code --- lmms_eval/models/llava_hf.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lmms_eval/models/llava_hf.py b/lmms_eval/models/llava_hf.py index 4f7d6e075..926168cba 100644 --- a/lmms_eval/models/llava_hf.py +++ b/lmms_eval/models/llava_hf.py @@ -228,7 +228,6 @@ def _collate(x): pbar.update(1) # reorder this group of results back to original unsorted form res = re_ords.get_original(res) - res = re_ords.get_original(res) pbar.close() return res From ca20b8cd8d9eab606ee6474dc384e2c41802d27e Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Mon, 8 Apr 2024 15:44:46 +0000 Subject: [PATCH 08/25] Add chat templates --- lmms_eval/constants.py | 2 + lmms_eval/models/llava_hf.py | 81 +++++++++++++++++++++++++++++------- 2 files changed, 67 insertions(+), 16 deletions(-) create mode 100644 lmms_eval/constants.py diff --git a/lmms_eval/constants.py b/lmms_eval/constants.py new file mode 100644 index 000000000..675c30fbc --- /dev/null +++ b/lmms_eval/constants.py @@ -0,0 +1,2 @@ +DEFAULT_IMAGE_TOKEN = "" +DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}{% if message['role'] == 'user' %}USER: {{ message['content'] }}\n{% else %}ASSISTANT: {{ message['content'] }}\n{% endif %}{% endfor %}" \ No newline at end of file diff --git a/lmms_eval/models/llava_hf.py b/lmms_eval/models/llava_hf.py index 926168cba..fd08c1e44 100644 --- a/lmms_eval/models/llava_hf.py +++ b/lmms_eval/models/llava_hf.py @@ -5,6 +5,7 @@ from lmms_eval.api.instance import Instance from lmms_eval.api.model import lmms from lmms_eval.api.registry import register_model +from lmms_eval.constants import DEFAULT_IMAGE_TOKEN, DEFAULT_CHAT_TEMPLATE from accelerate import Accelerator, DistributedType from accelerate.state import AcceleratorState from typing import List, Optional, Union, Tuple @@ -20,7 +21,8 @@ @register_model("llava_hf") class LlavaHf(lmms): """ - Llava HF Model + Llava Model for Hugging Face Transformers + https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/llava """ def __init__( @@ -28,8 +30,10 @@ def __init__( pretrained: str = "llava-hf/llava-1.5-7b-hf", revision=None, device: Optional[str] = "cuda", - dtype: Optional[Union[str, torch.dtype]] = torch.float16, + dtype: Optional[Union[str, torch.dtype]] = "auto", batch_size: Optional[Union[int, str]] = 1, + device_map="", + chat_template: Optional[str] = None, **kwargs, ) -> None: super().__init__() @@ -37,18 +41,23 @@ def __init__( assert kwargs == {}, f"Unexpected kwargs: {kwargs}" accelerator = Accelerator() - if accelerator.num_processes > 1: + if accelerator.num_processes > 1 and device_map == "": self._device = torch.device(f"cuda:{accelerator.local_process_index}") + self.device_map = f"cuda:{accelerator.local_process_index}" else: - self._device = device - self._model = LlavaForConditionalGeneration.from_pretrained(pretrained, revision=revision, torch_dtype=dtype) + self._device = torch.device(device) + self.device_map = device_map + if isinstance(dtype, str) and dtype != "auto": + dtype = getattr(torch, dtype) + self.dtype = dtype + self._model = LlavaForConditionalGeneration.from_pretrained(pretrained, revision=revision, torch_dtype=self.dtype, device_map=self.device_map) self._image_processor = AutoProcessor.from_pretrained(pretrained, revision=revision) + # Pad from left for batched generation: https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/llava#usage-tips + self._image_processor.tokenizer.padding_side = "left" self._tokenizer = self._image_processor.tokenizer self._config = self._model.config - self._model.eval() - self._model.tie_weights() self.batch_size_per_gpu = int(batch_size) - if accelerator.num_processes > 1: + if accelerator.num_processes > 1 and device_map == "": assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported." # If you want to use DistributedType.DEEPSPEED, you have to run accelerate config before using the model # Also, you have to select zero stage 0 (equivalent to DDP) in order to make the prepare model works @@ -69,7 +78,12 @@ def __init__( eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism") self._rank = self.accelerator.local_process_index self._world_size = self.accelerator.num_processes + elif accelerator.num_processes == 1 and device_map == "auto": + eval_logger.info(f"Using {accelerator.num_processes} devices with pipeline parallelism") + self._rank = 0 + self._word_size = 1 else: + eval_logger.info(f"Using single device: {self._device}") self.model.to(self._device) self._rank = 0 self._word_size = 1 @@ -129,8 +143,42 @@ def tok_decode(self, tokens): return self.tokenizer.decode(tokens) def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: - # TODO - assert False, "We have not implemented this function for LlavaHf yet" + res = [] + pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding") + + for contexts, doc_to_target, doc_to_visual, doc_id, task, split in [reg.args for reg in requests]: + # encode, pad, and truncate contexts for this batch + if type(doc_to_target) == str: + continuation = doc_to_target + else: + continuation = doc_to_target(self.task_dict[task][split][doc_id]) + visuals = [doc_to_visual(self.task_dict[task][split][doc_id])] + visuals = self.flatten(visuals) + formatted_contexts = [f"{contexts}\n"] + formatted_continuation = [f"{contexts}\n{continuation}"] + model_inputs = self.processor(text=formatted_continuation, images=visuals, device=self.device) + for k, v in model_inputs.items(): + model_inputs[k] = v.to(self.device, non_blocking=True) if isinstance(v, torch.Tensor) else [vv.to(self.device, non_blocking=True) for vv in v] + + for index in range(len(model_inputs["image_patches"])): + model_inputs["image_patches"][index] = model_inputs["image_patches"][index].to(dtype=next(self.model.parameters()).dtype) + + labels = model_inputs["input_ids"].clone() + contxt_id = self.processor(text=formatted_contexts, return_tensors="pt")["input_ids"] + labels[: len(contxt_id)] = -100 + with torch.inference_mode(): + outputs = self.model(**model_inputs, labels=labels) + loss = outputs["loss"] + logits = outputs["logits"] + greedy_tokens = logits.argmax(dim=-1) + cont_toks = model_inputs["input_ids"][:, contxt_id.shape[1] :] # [1, seq] + greedy_tokens = greedy_tokens[:, contxt_id.shape[1] : model_inputs["input_ids"].shape[1]] # [1, seq] + max_equal = (greedy_tokens == cont_toks).all() + res.append((float(loss.item()), bool(max_equal))) + pbar.update(1) + + pbar.close() + return res def flatten(self, input): new_list = [] @@ -183,18 +231,19 @@ def _collate(x): context = contexts[0] # Some benchmarks like MME do not contain image tokens, so we prepend them to the prompt. - if "" not in context: - context = f"\n{context}" + if DEFAULT_IMAGE_TOKEN not in context: + context = f"{DEFAULT_IMAGE_TOKEN}\n{context}" + messages = [{"role": "user", "content": context}] if self.tokenizer.chat_template is not None: - messages = [{"role": "user", "content": context}] text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) else: - text = f"USER: {context}\nASSISTANT:" + self.tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE + text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) if self.accelerator.is_main_process and doc_id[0] % 100 == 0: eval_logger.info(f"Prompt for doc ID {doc_id[0]}:\n\n{text}\n") - inputs = self._image_processor(images=visuals, text=text, return_tensors="pt").to(self._device, torch.float16) + inputs = self._image_processor(images=visuals, text=text, return_tensors="pt").to(self._device, self.dtype) gen_kwargs["image_sizes"] = [visuals[idx].size for idx in range(len(visuals))] if "max_new_tokens" not in gen_kwargs: @@ -218,7 +267,7 @@ def _collate(x): eval_logger.error(f"Error {e} in generating") cont = "" text_outputs = self.tokenizer.batch_decode(cont, skip_special_tokens=True)[0] - text_outputs = text_outputs.split("ASSISTANT:")[1].strip() + text_outputs = text_outputs.split("ASSISTANT:")[-1].strip() if self.accelerator.is_main_process and doc_id[0] % 100 == 0: eval_logger.info(f"Generated text for doc ID {doc_id[0]}:\n\n{text_outputs}\n") From 46e1f684c46f8d6bfa038360ffce913c96e51c42 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Mon, 8 Apr 2024 19:41:21 +0000 Subject: [PATCH 09/25] Fix chat templates --- lmms_eval/constants.py | 4 +++- lmms_eval/models/llava_hf.py | 18 ++++++++++++------ 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/lmms_eval/constants.py b/lmms_eval/constants.py index 675c30fbc..6c79b7049 100644 --- a/lmms_eval/constants.py +++ b/lmms_eval/constants.py @@ -1,2 +1,4 @@ DEFAULT_IMAGE_TOKEN = "" -DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}{% if message['role'] == 'user' %}USER: {{ message['content'] }}\n{% else %}ASSISTANT: {{ message['content'] }}\n{% endif %}{% endfor %}" \ No newline at end of file + +# Default chat templates +VICUNA_CHAT_TEMPLATE = "{% for message in messages %}{% if loop.index0 == 0 %}A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {{ message['content'] }} {% elif message['role'] == 'user' %}USER: {{ message['content'] }} {% else %} ASSISTANT: {{ message['content'] }}{{ eos_token }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'ASSISTANT:' }}{% endif %}" \ No newline at end of file diff --git a/lmms_eval/models/llava_hf.py b/lmms_eval/models/llava_hf.py index fd08c1e44..1fab60059 100644 --- a/lmms_eval/models/llava_hf.py +++ b/lmms_eval/models/llava_hf.py @@ -5,7 +5,7 @@ from lmms_eval.api.instance import Instance from lmms_eval.api.model import lmms from lmms_eval.api.registry import register_model -from lmms_eval.constants import DEFAULT_IMAGE_TOKEN, DEFAULT_CHAT_TEMPLATE +from lmms_eval.constants import DEFAULT_IMAGE_TOKEN, VICUNA_CHAT_TEMPLATE from accelerate import Accelerator, DistributedType from accelerate.state import AcceleratorState from typing import List, Optional, Union, Tuple @@ -49,14 +49,14 @@ def __init__( self.device_map = device_map if isinstance(dtype, str) and dtype != "auto": dtype = getattr(torch, dtype) - self.dtype = dtype - self._model = LlavaForConditionalGeneration.from_pretrained(pretrained, revision=revision, torch_dtype=self.dtype, device_map=self.device_map) + self._model = LlavaForConditionalGeneration.from_pretrained(pretrained, revision=revision, torch_dtype=dtype, device_map=self.device_map) self._image_processor = AutoProcessor.from_pretrained(pretrained, revision=revision) # Pad from left for batched generation: https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/llava#usage-tips self._image_processor.tokenizer.padding_side = "left" self._tokenizer = self._image_processor.tokenizer self._config = self._model.config self.batch_size_per_gpu = int(batch_size) + self.chat_template = chat_template if accelerator.num_processes > 1 and device_map == "": assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported." # If you want to use DistributedType.DEEPSPEED, you have to run accelerate config before using the model @@ -234,16 +234,22 @@ def _collate(x): if DEFAULT_IMAGE_TOKEN not in context: context = f"{DEFAULT_IMAGE_TOKEN}\n{context}" messages = [{"role": "user", "content": context}] - if self.tokenizer.chat_template is not None: + # Apply chat template if provided + if self.chat_template is not None: + self.tokenizer.chat_template = self.chat_template text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + # Else default to the tokenizer's template + elif self.tokenizer.chat_template is not None: + text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + # Finally default to the Vicuna chat template else: - self.tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE + self.tokenizer.chat_template = VICUNA_CHAT_TEMPLATE text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) if self.accelerator.is_main_process and doc_id[0] % 100 == 0: eval_logger.info(f"Prompt for doc ID {doc_id[0]}:\n\n{text}\n") - inputs = self._image_processor(images=visuals, text=text, return_tensors="pt").to(self._device, self.dtype) + inputs = self._image_processor(images=visuals, text=text, return_tensors="pt").to(self._device, self._model.dtype) gen_kwargs["image_sizes"] = [visuals[idx].size for idx in range(len(visuals))] if "max_new_tokens" not in gen_kwargs: From 2e96d4ca4b7b2f0a3cee475a1f190262626b1878 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Mon, 8 Apr 2024 20:04:52 +0000 Subject: [PATCH 10/25] Tidy --- lmms_eval/models/llava_hf.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/lmms_eval/models/llava_hf.py b/lmms_eval/models/llava_hf.py index 1fab60059..ce7355d91 100644 --- a/lmms_eval/models/llava_hf.py +++ b/lmms_eval/models/llava_hf.py @@ -233,15 +233,13 @@ def _collate(x): # Some benchmarks like MME do not contain image tokens, so we prepend them to the prompt. if DEFAULT_IMAGE_TOKEN not in context: context = f"{DEFAULT_IMAGE_TOKEN}\n{context}" + # Apply chat template messages = [{"role": "user", "content": context}] - # Apply chat template if provided if self.chat_template is not None: self.tokenizer.chat_template = self.chat_template text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) - # Else default to the tokenizer's template elif self.tokenizer.chat_template is not None: text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) - # Finally default to the Vicuna chat template else: self.tokenizer.chat_template = VICUNA_CHAT_TEMPLATE text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) From bd00350f185f8fa52325f8726e2b2588f25a9d4b Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Tue, 9 Apr 2024 12:36:42 +0000 Subject: [PATCH 11/25] Add log likelihood --- lmms_eval/models/llava_hf.py | 58 +++++++++++++++++++++++++----------- 1 file changed, 40 insertions(+), 18 deletions(-) diff --git a/lmms_eval/models/llava_hf.py b/lmms_eval/models/llava_hf.py index ce7355d91..284f57b4b 100644 --- a/lmms_eval/models/llava_hf.py +++ b/lmms_eval/models/llava_hf.py @@ -21,18 +21,21 @@ @register_model("llava_hf") class LlavaHf(lmms): """ - Llava Model for Hugging Face Transformers - https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/llava + Llava Model for Hugging Face Transformers: https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/llava + + Adapted from the InstructBLIP model in lmms_eval/models/instructblip.py """ def __init__( self, pretrained: str = "llava-hf/llava-1.5-7b-hf", - revision=None, - device: Optional[str] = "cuda", + revision: str = "main", + device: str = "cuda", dtype: Optional[Union[str, torch.dtype]] = "auto", - batch_size: Optional[Union[int, str]] = 1, - device_map="", + batch_size: Union[int, str] = 1, + trust_remote_code: Optional[bool] = False, + attn_implementation: Optional[str] = "flash_attention_2", + device_map: str ="", chat_template: Optional[str] = None, **kwargs, ) -> None: @@ -49,8 +52,8 @@ def __init__( self.device_map = device_map if isinstance(dtype, str) and dtype != "auto": dtype = getattr(torch, dtype) - self._model = LlavaForConditionalGeneration.from_pretrained(pretrained, revision=revision, torch_dtype=dtype, device_map=self.device_map) - self._image_processor = AutoProcessor.from_pretrained(pretrained, revision=revision) + self._model = LlavaForConditionalGeneration.from_pretrained(pretrained, revision=revision, torch_dtype=dtype, device_map=self.device_map, trust_remote_code=trust_remote_code, attn_implementation=attn_implementation) + self._image_processor = AutoProcessor.from_pretrained(pretrained, revision=revision, trust_remote_code=trust_remote_code) # Pad from left for batched generation: https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/llava#usage-tips self._image_processor.tokenizer.padding_side = "left" self._tokenizer = self._image_processor.tokenizer @@ -146,7 +149,7 @@ def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: res = [] pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding") - for contexts, doc_to_target, doc_to_visual, doc_id, task, split in [reg.args for reg in requests]: + for context, doc_to_target, doc_to_visual, doc_id, task, split in [reg.args for reg in requests]: # encode, pad, and truncate contexts for this batch if type(doc_to_target) == str: continuation = doc_to_target @@ -154,18 +157,35 @@ def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: continuation = doc_to_target(self.task_dict[task][split][doc_id]) visuals = [doc_to_visual(self.task_dict[task][split][doc_id])] visuals = self.flatten(visuals) - formatted_contexts = [f"{contexts}\n"] - formatted_continuation = [f"{contexts}\n{continuation}"] - model_inputs = self.processor(text=formatted_continuation, images=visuals, device=self.device) - for k, v in model_inputs.items(): - model_inputs[k] = v.to(self.device, non_blocking=True) if isinstance(v, torch.Tensor) else [vv.to(self.device, non_blocking=True) for vv in v] - for index in range(len(model_inputs["image_patches"])): - model_inputs["image_patches"][index] = model_inputs["image_patches"][index].to(dtype=next(self.model.parameters()).dtype) + image_tokens = [DEFAULT_IMAGE_TOKEN] * len(visuals) + image_tokens = " ".join(image_tokens) + context = f"{image_tokens}\n{context}" + # Apply chat template + messages = [{"role": "user", "content": context}, {"role": "assistant", "content": continuation}] + if self.chat_template is not None: + self.tokenizer.chat_template = self.chat_template + prompt = self.tokenizer.apply_chat_template(messages[:-1], tokenize=False, add_generation_prompt=True) + prompt_and_continuation = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False) + elif self.tokenizer.chat_template is not None: + prompt = self.tokenizer.apply_chat_template(messages[:-1], tokenize=False, add_generation_prompt=True) + prompt_and_continuation = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False) + else: + self.tokenizer.chat_template = VICUNA_CHAT_TEMPLATE + prompt = self.tokenizer.apply_chat_template(messages[:-1], tokenize=False, add_generation_prompt=True) + prompt_and_continuation = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False) + + formatted_contexts = [prompt] + formatted_continuation = [prompt_and_continuation] + model_inputs = self._image_processor(text=formatted_continuation, images=visuals).to(self._device, self._model.dtype) labels = model_inputs["input_ids"].clone() - contxt_id = self.processor(text=formatted_contexts, return_tensors="pt")["input_ids"] + contxt_id = self._image_processor(text=formatted_contexts, return_tensors="pt")["input_ids"] labels[: len(contxt_id)] = -100 + + if self.accelerator.is_main_process and doc_id % 100 == 0: + eval_logger.info(f"Prompt and continuation for doc ID {doc_id}:\n\n{formatted_continuation[0]}\n") + with torch.inference_mode(): outputs = self.model(**model_inputs, labels=labels) loss = outputs["loss"] @@ -240,9 +260,11 @@ def _collate(x): text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) elif self.tokenizer.chat_template is not None: text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + eval_logger.warning("Using the tokenizer's chat template to format the prompt.") else: self.tokenizer.chat_template = VICUNA_CHAT_TEMPLATE text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + eval_logger.warning("No chat template provided or set in the tokenizer. Using the default Vicuna chat template.") if self.accelerator.is_main_process and doc_id[0] % 100 == 0: eval_logger.info(f"Prompt for doc ID {doc_id[0]}:\n\n{text}\n") @@ -279,7 +301,7 @@ def _collate(x): res.append(text_outputs) self.cache_hook.add_partial("generate_until", (context, gen_kwargs), text_outputs) pbar.update(1) - # reorder this group of results back to original unsorted form + # reorder this group of results back to original unsorted form res = re_ords.get_original(res) pbar.close() From d410bea2cd39649cc6cb78f57cc71e12c48a20c8 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Tue, 9 Apr 2024 12:38:04 +0000 Subject: [PATCH 12/25] Style --- lmms_eval/constants.py | 2 +- lmms_eval/models/llava_hf.py | 11 +++++------ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/lmms_eval/constants.py b/lmms_eval/constants.py index 6c79b7049..b7b0283e8 100644 --- a/lmms_eval/constants.py +++ b/lmms_eval/constants.py @@ -1,4 +1,4 @@ DEFAULT_IMAGE_TOKEN = "" # Default chat templates -VICUNA_CHAT_TEMPLATE = "{% for message in messages %}{% if loop.index0 == 0 %}A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {{ message['content'] }} {% elif message['role'] == 'user' %}USER: {{ message['content'] }} {% else %} ASSISTANT: {{ message['content'] }}{{ eos_token }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'ASSISTANT:' }}{% endif %}" \ No newline at end of file +VICUNA_CHAT_TEMPLATE = "{% for message in messages %}{% if loop.index0 == 0 %}A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {{ message['content'] }} {% elif message['role'] == 'user' %}USER: {{ message['content'] }} {% else %} ASSISTANT: {{ message['content'] }}{{ eos_token }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'ASSISTANT:' }}{% endif %}" diff --git a/lmms_eval/models/llava_hf.py b/lmms_eval/models/llava_hf.py index 284f57b4b..e5d580229 100644 --- a/lmms_eval/models/llava_hf.py +++ b/lmms_eval/models/llava_hf.py @@ -35,7 +35,7 @@ def __init__( batch_size: Union[int, str] = 1, trust_remote_code: Optional[bool] = False, attn_implementation: Optional[str] = "flash_attention_2", - device_map: str ="", + device_map: str = "", chat_template: Optional[str] = None, **kwargs, ) -> None: @@ -59,7 +59,7 @@ def __init__( self._tokenizer = self._image_processor.tokenizer self._config = self._model.config self.batch_size_per_gpu = int(batch_size) - self.chat_template = chat_template + self.chat_template = chat_template if accelerator.num_processes > 1 and device_map == "": assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported." # If you want to use DistributedType.DEEPSPEED, you have to run accelerate config before using the model @@ -158,11 +158,10 @@ def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: visuals = [doc_to_visual(self.task_dict[task][split][doc_id])] visuals = self.flatten(visuals) - image_tokens = [DEFAULT_IMAGE_TOKEN] * len(visuals) image_tokens = " ".join(image_tokens) context = f"{image_tokens}\n{context}" - # Apply chat template + # Apply chat template messages = [{"role": "user", "content": context}, {"role": "assistant", "content": continuation}] if self.chat_template is not None: self.tokenizer.chat_template = self.chat_template @@ -253,7 +252,7 @@ def _collate(x): # Some benchmarks like MME do not contain image tokens, so we prepend them to the prompt. if DEFAULT_IMAGE_TOKEN not in context: context = f"{DEFAULT_IMAGE_TOKEN}\n{context}" - # Apply chat template + # Apply chat template messages = [{"role": "user", "content": context}] if self.chat_template is not None: self.tokenizer.chat_template = self.chat_template @@ -268,7 +267,7 @@ def _collate(x): if self.accelerator.is_main_process and doc_id[0] % 100 == 0: eval_logger.info(f"Prompt for doc ID {doc_id[0]}:\n\n{text}\n") - + inputs = self._image_processor(images=visuals, text=text, return_tensors="pt").to(self._device, self._model.dtype) gen_kwargs["image_sizes"] = [visuals[idx].size for idx in range(len(visuals))] From 80a1e7d3270db49b8c378e3e2bbc430bdf6380f3 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Tue, 9 Apr 2024 12:46:15 +0000 Subject: [PATCH 13/25] Remove loggine --- lmms_eval/models/llava_hf.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/lmms_eval/models/llava_hf.py b/lmms_eval/models/llava_hf.py index e5d580229..59968d8e9 100644 --- a/lmms_eval/models/llava_hf.py +++ b/lmms_eval/models/llava_hf.py @@ -259,11 +259,9 @@ def _collate(x): text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) elif self.tokenizer.chat_template is not None: text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) - eval_logger.warning("Using the tokenizer's chat template to format the prompt.") else: self.tokenizer.chat_template = VICUNA_CHAT_TEMPLATE text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) - eval_logger.warning("No chat template provided or set in the tokenizer. Using the default Vicuna chat template.") if self.accelerator.is_main_process and doc_id[0] % 100 == 0: eval_logger.info(f"Prompt for doc ID {doc_id[0]}:\n\n{text}\n") From 7c7b9699af057ca0d750e7cb1bc8e3731c22e852 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Tue, 9 Apr 2024 13:18:40 +0000 Subject: [PATCH 14/25] Fix llava loglikelihood --- lmms_eval/models/llava.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lmms_eval/models/llava.py b/lmms_eval/models/llava.py index b3cb8a66d..2bd6f4414 100644 --- a/lmms_eval/models/llava.py +++ b/lmms_eval/models/llava.py @@ -198,7 +198,7 @@ def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: else: image = None - prompts_input = contexts[0] + prompts_input = contexts[0] if isinstance(contexts, list) else contexts if image is not None and len(image) != 0 and DEFAULT_IMAGE_TOKEN not in prompts_input: """ @@ -209,7 +209,7 @@ def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: """ image_tokens = [DEFAULT_IMAGE_TOKEN] * len(visuals) image_tokens = " ".join(image_tokens) - prompts_input = image_tokens + "\n" + contexts[0] + prompts_input = image_tokens + "\n" + (contexts[0] if isinstance(contexts, list) else contexts) conv = conv_templates[self.conv_template].copy() conv.append_message(conv.roles[0], prompts_input) From 3f0cad909dd7f4ccfe05beaa17782e200e2c5bc6 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Tue, 9 Apr 2024 14:02:53 +0000 Subject: [PATCH 15/25] Tidy up model calls --- lmms_eval/models/llava_hf.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/lmms_eval/models/llava_hf.py b/lmms_eval/models/llava_hf.py index 59968d8e9..cdcfa83d7 100644 --- a/lmms_eval/models/llava_hf.py +++ b/lmms_eval/models/llava_hf.py @@ -37,6 +37,7 @@ def __init__( attn_implementation: Optional[str] = "flash_attention_2", device_map: str = "", chat_template: Optional[str] = None, + use_cache: bool = True, **kwargs, ) -> None: super().__init__() @@ -60,6 +61,7 @@ def __init__( self._config = self._model.config self.batch_size_per_gpu = int(batch_size) self.chat_template = chat_template + self.use_cache = use_cache if accelerator.num_processes > 1 and device_map == "": assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported." # If you want to use DistributedType.DEEPSPEED, you have to run accelerate config before using the model @@ -177,7 +179,7 @@ def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: formatted_contexts = [prompt] formatted_continuation = [prompt_and_continuation] - model_inputs = self._image_processor(text=formatted_continuation, images=visuals).to(self._device, self._model.dtype) + model_inputs = self._image_processor(text=formatted_continuation, images=visuals).to(self._device, self.model.dtype) labels = model_inputs["input_ids"].clone() contxt_id = self._image_processor(text=formatted_contexts, return_tensors="pt")["input_ids"] labels[: len(contxt_id)] = -100 @@ -266,7 +268,7 @@ def _collate(x): if self.accelerator.is_main_process and doc_id[0] % 100 == 0: eval_logger.info(f"Prompt for doc ID {doc_id[0]}:\n\n{text}\n") - inputs = self._image_processor(images=visuals, text=text, return_tensors="pt").to(self._device, self._model.dtype) + inputs = self._image_processor(images=visuals, text=text, return_tensors="pt").to(self._device, self.model.dtype) gen_kwargs["image_sizes"] = [visuals[idx].size for idx in range(len(visuals))] if "max_new_tokens" not in gen_kwargs: @@ -278,13 +280,14 @@ def _collate(x): if "num_beams" not in gen_kwargs: gen_kwargs["num_beams"] = 1 try: - cont = self._model.generate( + cont = self.model.generate( **inputs, do_sample=True if gen_kwargs["temperature"] > 0 else False, temperature=gen_kwargs["temperature"], top_p=gen_kwargs["top_p"], num_beams=gen_kwargs["num_beams"], max_new_tokens=gen_kwargs["max_new_tokens"], + use_cache=self.use_cache, ) except Exception as e: eval_logger.error(f"Error {e} in generating") From 6c2720320f58d2e706ab03bae97a65747fd264f3 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Tue, 9 Apr 2024 14:35:46 +0000 Subject: [PATCH 16/25] Add cmd --- lmms_eval/models/llava_hf.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/lmms_eval/models/llava_hf.py b/lmms_eval/models/llava_hf.py index cdcfa83d7..2b6d88210 100644 --- a/lmms_eval/models/llava_hf.py +++ b/lmms_eval/models/llava_hf.py @@ -24,6 +24,16 @@ class LlavaHf(lmms): Llava Model for Hugging Face Transformers: https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/llava Adapted from the InstructBLIP model in lmms_eval/models/instructblip.py + + Example usage: + + accelerate launch --num_processes=8 -m lmms_eval \ + --model llava_hf \ + --model_args pretrained=llava-hf/llava-1.5-7b-hf \ + --tasks mme \ + --batch_size 1 \ + --output_path ./logs/ \ + --log_samples """ def __init__( From c531bf0302d79389851b2830fbd7621b7dfb2ce7 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Tue, 9 Apr 2024 15:45:48 +0000 Subject: [PATCH 17/25] Split logging --- lmms_eval/models/llava_hf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/lmms_eval/models/llava_hf.py b/lmms_eval/models/llava_hf.py index 2b6d88210..4ec875567 100644 --- a/lmms_eval/models/llava_hf.py +++ b/lmms_eval/models/llava_hf.py @@ -195,6 +195,7 @@ def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: labels[: len(contxt_id)] = -100 if self.accelerator.is_main_process and doc_id % 100 == 0: + eval_logger.info(f"Prompt for doc ID {doc_id}:\n\n{formatted_contexts[0]}\n") eval_logger.info(f"Prompt and continuation for doc ID {doc_id}:\n\n{formatted_continuation[0]}\n") with torch.inference_mode(): From f351728281165d5ab9fb283c182a4cbf764f0d6f Mon Sep 17 00:00:00 2001 From: Hunter Heidenreich Date: Tue, 9 Apr 2024 21:33:20 +0000 Subject: [PATCH 18/25] Draft of localization eval, using RefCOCO as first target --- lmms_eval/api/task.py | 21 +- .../refcoco/_default_template_bbox_rec_yaml | 34 +++ .../tasks/refcoco/refcoco_bbox_rec_test.yaml | 4 + .../tasks/refcoco/refcoco_bbox_rec_testA.yaml | 4 + .../tasks/refcoco/refcoco_bbox_rec_testB.yaml | 4 + .../tasks/refcoco/refcoco_bbox_rec_val.yaml | 4 + lmms_eval/tasks/refcoco/utils_rec.py | 221 ++++++++++++++++++ 7 files changed, 280 insertions(+), 12 deletions(-) create mode 100644 lmms_eval/tasks/refcoco/_default_template_bbox_rec_yaml create mode 100644 lmms_eval/tasks/refcoco/refcoco_bbox_rec_test.yaml create mode 100644 lmms_eval/tasks/refcoco/refcoco_bbox_rec_testA.yaml create mode 100644 lmms_eval/tasks/refcoco/refcoco_bbox_rec_testB.yaml create mode 100644 lmms_eval/tasks/refcoco/refcoco_bbox_rec_val.yaml create mode 100644 lmms_eval/tasks/refcoco/utils_rec.py diff --git a/lmms_eval/api/task.py b/lmms_eval/api/task.py index 3262937c8..57305a531 100644 --- a/lmms_eval/api/task.py +++ b/lmms_eval/api/task.py @@ -687,12 +687,15 @@ def download(self, dataset_kwargs=None) -> None: download_mode=datasets.DownloadMode.REUSE_DATASET_IF_EXISTS, **dataset_kwargs if dataset_kwargs is not None else {}, ) - self.dataset_no_image = datasets.load_dataset( - path=self.DATASET_PATH, - name=self.DATASET_NAME, - download_mode=datasets.DownloadMode.REUSE_DATASET_IF_EXISTS, - **dataset_kwargs if dataset_kwargs is not None else {}, - ) + if self.config.process_docs is not None: + for split in self.dataset: + if split in [ + self.config.training_split, self.config.validation_split, self.config.test_split, self.config.fewshot_split + ]: + self.dataset[split] = self.config.process_docs(self.dataset[split]) + + # copy dataset, remove image features + self.dataset_no_image = self.dataset.copy() for doc_name in self.dataset_no_image: remove_cols = [] features = self.dataset_no_image[doc_name].features @@ -725,20 +728,14 @@ def has_test_docs(self) -> bool: def training_docs(self) -> datasets.Dataset: if self.has_training_docs(): - if self.config.process_docs is not None: - return self.config.process_docs(self.dataset[self.config.training_split]) return self.dataset[self.config.training_split] def validation_docs(self) -> datasets.Dataset: if self.has_validation_docs(): - if self.config.process_docs is not None: - return self.config.process_docs(self.dataset[self.config.validation_split]) return self.dataset[self.config.validation_split] def test_docs(self) -> datasets.Dataset: if self.has_test_docs(): - if self.config.process_docs is not None: - return self.config.process_docs(self.dataset[self.config.test_split]) return self.dataset[self.config.test_split] def fewshot_docs(self): diff --git a/lmms_eval/tasks/refcoco/_default_template_bbox_rec_yaml b/lmms_eval/tasks/refcoco/_default_template_bbox_rec_yaml new file mode 100644 index 000000000..3b5ca100b --- /dev/null +++ b/lmms_eval/tasks/refcoco/_default_template_bbox_rec_yaml @@ -0,0 +1,34 @@ +dataset_path: lmms-lab/RefCOCO +output_type: generate_until +process_docs: !function utils_rec.refcoco_bbox_rec_preprocess_dataset +doc_to_visual: !function utils_rec.refcoco_bbox_rec_doc_to_visual +doc_to_text: !function utils_rec.refcoco_bbox_rec_doc_to_text +doc_to_target: "bbox" +generation_kwargs: + until: + - "ASSISTANT:" +process_results: !function utils_rec.refcoco_bbox_rec_process_result +metric_list: + - metric: refcoco_IoU + aggregation : !function utils_rec.refcoco_bbox_rec_iou + higher_is_better : true + - metric: refcoco_ACC@0.1 + aggregation : !function utils_rec.refcoco_bbox_rec_acc01 + higher_is_better : true + - metric: refcoco_ACC@0.3 + aggregation : !function utils_rec.refcoco_bbox_rec_acc03 + higher_is_better : true + - metric: refcoco_ACC@0.5 + aggregation : !function utils_rec.refcoco_bbox_rec_acc05 + higher_is_better : true + - metric: refcoco_ACC@0.7 + aggregation : !function utils_rec.refcoco_bbox_rec_acc07 + higher_is_better : true + - metric: refcoco_ACC@0.9 + aggregation : !function utils_rec.refcoco_bbox_rec_acc09 + higher_is_better : true + - metric: refcoco_Center_ACC + aggregation : !function utils_rec.refcoco_bbox_rec_center_acc + higher_is_better : true +metadata: + version: '0.0' \ No newline at end of file diff --git a/lmms_eval/tasks/refcoco/refcoco_bbox_rec_test.yaml b/lmms_eval/tasks/refcoco/refcoco_bbox_rec_test.yaml new file mode 100644 index 000000000..896ed4ac8 --- /dev/null +++ b/lmms_eval/tasks/refcoco/refcoco_bbox_rec_test.yaml @@ -0,0 +1,4 @@ +group: refcoco_bbox_rec +task: refcoco_bbox_rec_test +test_split: test +include: _default_template_bbox_rec_yaml diff --git a/lmms_eval/tasks/refcoco/refcoco_bbox_rec_testA.yaml b/lmms_eval/tasks/refcoco/refcoco_bbox_rec_testA.yaml new file mode 100644 index 000000000..191268a6d --- /dev/null +++ b/lmms_eval/tasks/refcoco/refcoco_bbox_rec_testA.yaml @@ -0,0 +1,4 @@ +group: refcoco_bbox_rec +task: refcoco_bbox_rec_testA +test_split: testA +include: _default_template_bbox_rec_yaml diff --git a/lmms_eval/tasks/refcoco/refcoco_bbox_rec_testB.yaml b/lmms_eval/tasks/refcoco/refcoco_bbox_rec_testB.yaml new file mode 100644 index 000000000..39b290713 --- /dev/null +++ b/lmms_eval/tasks/refcoco/refcoco_bbox_rec_testB.yaml @@ -0,0 +1,4 @@ +group: refcoco_bbox_rec +task: refcoco_bbox_rec_testB +test_split: testB +include: _default_template_bbox_rec_yaml diff --git a/lmms_eval/tasks/refcoco/refcoco_bbox_rec_val.yaml b/lmms_eval/tasks/refcoco/refcoco_bbox_rec_val.yaml new file mode 100644 index 000000000..f5da6c5e2 --- /dev/null +++ b/lmms_eval/tasks/refcoco/refcoco_bbox_rec_val.yaml @@ -0,0 +1,4 @@ +group: refcoco_bbox_rec +task: refcoco_bbox_rec_val +test_split: val +include: _default_template_bbox_rec_yaml diff --git a/lmms_eval/tasks/refcoco/utils_rec.py b/lmms_eval/tasks/refcoco/utils_rec.py new file mode 100644 index 000000000..ec91bf9cb --- /dev/null +++ b/lmms_eval/tasks/refcoco/utils_rec.py @@ -0,0 +1,221 @@ +import re +import logging +from datasets import Dataset + +eval_logger = logging.getLogger("lmms-eval") + +COCO_REC_METRICS = ["IoU", "ACC@0.1", "ACC@0.3", "ACC@0.5", "ACC@0.7", "ACC@0.9", "Center_ACC"] + + +def refcoco_bbox_rec_preprocess_dataset(dataset: Dataset): + # PIL image stored in dataset['image'] + # add `image_width` and `image_height` to the dataset + dataset = dataset.map(lambda x: {"image_width": x["image"].width, "image_height": x["image"].height}) + + # Original bbox format (top x, top y, width, height) + # Convert to (top-left x, top-left y, bottom-right x, bottom-right y) + # Normalize the bounding box coordinates to be between 0 and 1 + # using the image width and height + dataset = dataset.map( + lambda x: {"bbox": [x["bbox"][0] / x["image_width"], + x["bbox"][1] / x["image_height"], + (x["bbox"][0] + x["bbox"][2]) / x["image_width"], + (x["bbox"][1] + x["bbox"][3]) / x["image_height"]]} + ) + + # currently, the dataset has `answer` as a list of strings + # each answer should be its own row + # we will explode the dataset to have one row per answer + # duplicate the other columns + def explode_answers(example): + answers = example.pop('answer') + return [{'answer': answer, **example} for answer in answers] + + # Apply the function to each element, collecting the results + exploded_rows = [] + for example in dataset: + exploded_rows.extend(explode_answers(example)) + + # Create a new dataset from the exploded rows + new_dataset = Dataset.from_list(exploded_rows) + print(f"Exploded dataset from {len(dataset)} to {len(new_dataset)} rows") + + return new_dataset + + +def refcoco_bbox_rec_doc_to_visual(doc): + # Image is presented as is + image = doc["image"].convert("RGB") + return [image.convert("RGB")] + + +def refcoco_bbox_rec_doc_to_text(doc): + assert isinstance(doc['answer'], str), "Answer must be a string" + return "Bounding box coordinates are specified in the format (top-left x, top-left y, bottom-right x, bottom-right y). All values are floating point numbers bounded between 0 and 1. Please provide the bounding box coordinate of the region this sentence describes: " + doc['answer'] + + +def parse_float_sequence_within(input_str): + """ + Extract the first sequence of four floating-point numbers within square brackets from a string. + + Args: + input_str (str): A string that may contain a sequence of four floats within square brackets. + + Returns: + list: A list of four floats if the pattern is found, or a list of four zeros if the pattern is not found. + """ + # Define the regex pattern to find the first instance of four floats within square brackets + pattern = r'\[\s*(-?\d+(?:\.\d+)?),\s*(-?\d+(?:\.\d+)?),\s*(-?\d+(?:\.\d+)?),\s*(-?\d+(?:\.\d+)?)\s*\]' + + # Use re.search to find the first match of the pattern in the input string + match = re.search(pattern, input_str) + + # If a match is found, convert the captured groups into a list of floats + if match: + return [float(match.group(i)) for i in range(1, 5)] + + # If the input does not contain the pattern, return the null float sequence + return [0, 0, 0, 0] + + +def refcoco_bbox_rec_process_result(doc, result): + """ + Args: + doc: a instance of the eval dataset + results: [pred] + Returns: + a dictionary with key: metric name, value: metric value + """ + pred = result[0] if len(result) > 0 else "" + pred = parse_float_sequence_within(pred) + ann_id = doc["question_id"] + data_dict = {"answer": doc["answer"], "pred": pred, "ann_id": ann_id, 'bbox': doc['bbox']} + return {f"refcoco_{metric}": data_dict for metric in COCO_REC_METRICS} + + +def compute_iou(box1, box2): + """ + Compute the Intersection over Union (IoU) of two bounding boxes. + + Parameters: + - box1 (list of float): Bounding box [x_min, y_min, x_max, y_max]. + - box2 (list of float): Bounding box [x_min, y_min, x_max, y_max]. + + Returns: + - float: IoU of box1 and box2. + """ + # Determine the coordinates of the intersection rectangle + x_left = max(box1[0], box2[0]) + y_top = max(box1[1], box2[1]) + x_right = min(box1[2], box2[2]) + y_bottom = min(box1[3], box2[3]) + + # Compute the area of intersection + intersection_area = max(0, x_right - x_left) * max(0, y_bottom - y_top) + + # Compute the area of both bounding boxes + box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1]) + box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1]) + + # Compute the area of the union + union_area = box1_area + box2_area - intersection_area + + # Compute the Intersection over Union + iou = intersection_area / union_area + + return iou + + +def compute_accuracy(box1, box2, threshold=0.5): + """ + Compute the accuracy of two bounding boxes based on a specified threshold. + + Parameters: + - box1 (list of float): Bounding box [x_min, y_min, x_max, y_max]. + - box2 (list of float): Bounding box [x_min, y_min, x_max, y_max]. + - threshold (float): Threshold for the IoU to consider the prediction correct. + + Returns: + - float: Accuracy of the prediction based on the IoU threshold. + """ + iou = compute_iou(box1, box2) + return iou >= threshold + + +def compute_center_accuracy(box1, box2): + """ + Compute if the center point of box 2 is within box 1. + + Parameters: + - box1 (list of float): Bounding box [x_min, y_min, x_max, y_max]. + - box2 (list of float): Bounding box [x_min, y_min, x_max, y_max]. + + Returns: + - bool: True if the center point of box 2 is within box 1, False otherwise. + """ + # Compute the center point of box 2 + center_x = (box2[0] + box2[2]) / 2 + center_y = (box2[1] + box2[3]) / 2 + + # Check if the center point is within box 1 + return box1[0] <= center_x <= box1[2] and box1[1] <= center_y <= box1[3] + + +def refcoco_bbox_rec_aggregation_result(results, metric): + """ + Aggregate the results of the RefCOCO evaluation task using the specified metric. + + Args: + - results (list of dict): List of result dictionaries. + - metric (str): Metric to use for aggregation. + + Returns: + - dict: Dictionary containing the aggregated results for the specified metric. + """ + scorers = { + 'IoU': compute_iou, + 'ACC@0.1': lambda x, y: compute_accuracy(x, y, 0.1), + 'ACC@0.3': lambda x, y: compute_accuracy(x, y, 0.3), + 'ACC@0.5': lambda x, y: compute_accuracy(x, y, 0.5), + 'ACC@0.7': lambda x, y: compute_accuracy(x, y, 0.7), + 'ACC@0.9': lambda x, y: compute_accuracy(x, y, 0.9), + 'Center_ACC': compute_center_accuracy + } + results_dict = {metric: []} + for result in results: + # Extract the ground truth and predicted bounding boxes + gt_bbox = result['bbox'] + pred_bbox = result['pred'] + # Compute the specified metric between the ground truth and predicted bounding boxes + score = scorers[metric](gt_bbox, pred_bbox) + results_dict[metric].append(score) + results_dict[metric] = sum(results_dict[metric]) / len(results_dict[metric]) + print(f"Aggregated {metric} score: {results_dict[metric]}") + return results_dict[metric] + + +def refcoco_bbox_rec_iou(results): + return refcoco_bbox_rec_aggregation_result(results, "IoU") + + +def refcoco_bbox_rec_acc01(results): + return refcoco_bbox_rec_aggregation_result(results, "ACC@0.1") + +def refcoco_bbox_rec_acc03(results): + return refcoco_bbox_rec_aggregation_result(results, "ACC@0.3") + + +def refcoco_bbox_rec_acc05(results): + return refcoco_bbox_rec_aggregation_result(results, "ACC@0.5") + + +def refcoco_bbox_rec_acc07(results): + return refcoco_bbox_rec_aggregation_result(results, "ACC@0.7") + + +def refcoco_bbox_rec_acc09(results): + return refcoco_bbox_rec_aggregation_result(results, "ACC@0.9") + + +def refcoco_bbox_rec_center_acc(results): + return refcoco_bbox_rec_aggregation_result(results, "Center_ACC") From 1d0ddc80b396f4792b3745aee509031ced1484b0 Mon Sep 17 00:00:00 2001 From: Fanyi Pu Date: Wed, 10 Apr 2024 15:11:13 +0800 Subject: [PATCH 19/25] Fix percentage calculation in make_table function --- lmms_eval/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lmms_eval/utils.py b/lmms_eval/utils.py index 649241a5c..d6c085536 100644 --- a/lmms_eval/utils.py +++ b/lmms_eval/utils.py @@ -411,6 +411,8 @@ def make_table(result_dict, column: str = "results"): points = "N/A" if v is not None: + if 0 <= v <= 1: + v *= 100 points = "%.4f" % v if m + "_stderr" + "," + f in dic: From 6d08fe8713adefd29e1d248053d4927703ba81d7 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Wed, 10 Apr 2024 13:40:28 +0000 Subject: [PATCH 20/25] Refactor --- lmms_eval/constants.py | 4 ---- lmms_eval/models/llava_hf.py | 6 +++++- 2 files changed, 5 insertions(+), 5 deletions(-) delete mode 100644 lmms_eval/constants.py diff --git a/lmms_eval/constants.py b/lmms_eval/constants.py deleted file mode 100644 index b7b0283e8..000000000 --- a/lmms_eval/constants.py +++ /dev/null @@ -1,4 +0,0 @@ -DEFAULT_IMAGE_TOKEN = "" - -# Default chat templates -VICUNA_CHAT_TEMPLATE = "{% for message in messages %}{% if loop.index0 == 0 %}A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {{ message['content'] }} {% elif message['role'] == 'user' %}USER: {{ message['content'] }} {% else %} ASSISTANT: {{ message['content'] }}{{ eos_token }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'ASSISTANT:' }}{% endif %}" diff --git a/lmms_eval/models/llava_hf.py b/lmms_eval/models/llava_hf.py index 4ec875567..e5a88da60 100644 --- a/lmms_eval/models/llava_hf.py +++ b/lmms_eval/models/llava_hf.py @@ -5,7 +5,6 @@ from lmms_eval.api.instance import Instance from lmms_eval.api.model import lmms from lmms_eval.api.registry import register_model -from lmms_eval.constants import DEFAULT_IMAGE_TOKEN, VICUNA_CHAT_TEMPLATE from accelerate import Accelerator, DistributedType from accelerate.state import AcceleratorState from typing import List, Optional, Union, Tuple @@ -17,6 +16,11 @@ eval_logger = logging.getLogger("lmms-eval") +DEFAULT_IMAGE_TOKEN = "" + +# Default chat for llava-hf/llava-1.5 models: https://huggingface.co/collections/llava-hf/llava-15-65f762d5b6941db5c2ba07e0 +VICUNA_CHAT_TEMPLATE = "{% for message in messages %}{% if loop.index0 == 0 %}A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {{ message['content'] }} {% elif message['role'] == 'user' %}USER: {{ message['content'] }} {% else %} ASSISTANT: {{ message['content'] }}{{ eos_token }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'ASSISTANT:' }}{% endif %}" + @register_model("llava_hf") class LlavaHf(lmms): From 65123db131917b16de51ec7865d31fb2dd1b2478 Mon Sep 17 00:00:00 2001 From: Hunter Heidenreich Date: Wed, 10 Apr 2024 14:08:52 +0000 Subject: [PATCH 21/25] Other RefCOCO instances for grounding/REC --- .../refcoco+/_default_template_bbox_rec_yaml | 34 +++ .../refcoco+/refcoco+_bbox_rec_testA.yaml | 4 + .../refcoco+/refcoco+_bbox_rec_testB.yaml | 4 + .../tasks/refcoco+/refcoco+_bbox_rec_val.yaml | 4 + lmms_eval/tasks/refcoco+/utils_rec.py | 221 ++++++++++++++++++ .../refcocog/_default_template_bbox_rec_yaml | 34 +++ .../refcocog/refcocog_bbox_rec_test.yaml | 4 + .../tasks/refcocog/refcocog_bbox_rec_val.yaml | 4 + lmms_eval/tasks/refcocog/utils_rec.py | 221 ++++++++++++++++++ 9 files changed, 530 insertions(+) create mode 100644 lmms_eval/tasks/refcoco+/_default_template_bbox_rec_yaml create mode 100644 lmms_eval/tasks/refcoco+/refcoco+_bbox_rec_testA.yaml create mode 100644 lmms_eval/tasks/refcoco+/refcoco+_bbox_rec_testB.yaml create mode 100644 lmms_eval/tasks/refcoco+/refcoco+_bbox_rec_val.yaml create mode 100644 lmms_eval/tasks/refcoco+/utils_rec.py create mode 100644 lmms_eval/tasks/refcocog/_default_template_bbox_rec_yaml create mode 100644 lmms_eval/tasks/refcocog/refcocog_bbox_rec_test.yaml create mode 100644 lmms_eval/tasks/refcocog/refcocog_bbox_rec_val.yaml create mode 100644 lmms_eval/tasks/refcocog/utils_rec.py diff --git a/lmms_eval/tasks/refcoco+/_default_template_bbox_rec_yaml b/lmms_eval/tasks/refcoco+/_default_template_bbox_rec_yaml new file mode 100644 index 000000000..c369baee2 --- /dev/null +++ b/lmms_eval/tasks/refcoco+/_default_template_bbox_rec_yaml @@ -0,0 +1,34 @@ +dataset_path: lmms-lab/RefCOCOPlus +output_type: generate_until +process_docs: !function utils_rec.refcoco_bbox_rec_preprocess_dataset +doc_to_visual: !function utils_rec.refcoco_bbox_rec_doc_to_visual +doc_to_text: !function utils_rec.refcoco_bbox_rec_doc_to_text +doc_to_target: "bbox" +generation_kwargs: + until: + - "ASSISTANT:" +process_results: !function utils_rec.refcoco_bbox_rec_process_result +metric_list: + - metric: refcoco_IoU + aggregation : !function utils_rec.refcoco_bbox_rec_iou + higher_is_better : true + - metric: refcoco_ACC@0.1 + aggregation : !function utils_rec.refcoco_bbox_rec_acc01 + higher_is_better : true + - metric: refcoco_ACC@0.3 + aggregation : !function utils_rec.refcoco_bbox_rec_acc03 + higher_is_better : true + - metric: refcoco_ACC@0.5 + aggregation : !function utils_rec.refcoco_bbox_rec_acc05 + higher_is_better : true + - metric: refcoco_ACC@0.7 + aggregation : !function utils_rec.refcoco_bbox_rec_acc07 + higher_is_better : true + - metric: refcoco_ACC@0.9 + aggregation : !function utils_rec.refcoco_bbox_rec_acc09 + higher_is_better : true + - metric: refcoco_Center_ACC + aggregation : !function utils_rec.refcoco_bbox_rec_center_acc + higher_is_better : true +metadata: + version: '0.0' \ No newline at end of file diff --git a/lmms_eval/tasks/refcoco+/refcoco+_bbox_rec_testA.yaml b/lmms_eval/tasks/refcoco+/refcoco+_bbox_rec_testA.yaml new file mode 100644 index 000000000..0ebb6c0c6 --- /dev/null +++ b/lmms_eval/tasks/refcoco+/refcoco+_bbox_rec_testA.yaml @@ -0,0 +1,4 @@ +group: refcoco+_bbox_rec +task: refcoco+_bbox_rec_testA +include: _default_template_bbox_rec_yaml +test_split: testA diff --git a/lmms_eval/tasks/refcoco+/refcoco+_bbox_rec_testB.yaml b/lmms_eval/tasks/refcoco+/refcoco+_bbox_rec_testB.yaml new file mode 100644 index 000000000..b347bce61 --- /dev/null +++ b/lmms_eval/tasks/refcoco+/refcoco+_bbox_rec_testB.yaml @@ -0,0 +1,4 @@ +group: refcoco+_bbox_rec +task: refcoco+_bbox_rec_testB +include: _default_template_bbox_rec_yaml +test_split: testB diff --git a/lmms_eval/tasks/refcoco+/refcoco+_bbox_rec_val.yaml b/lmms_eval/tasks/refcoco+/refcoco+_bbox_rec_val.yaml new file mode 100644 index 000000000..890f588b0 --- /dev/null +++ b/lmms_eval/tasks/refcoco+/refcoco+_bbox_rec_val.yaml @@ -0,0 +1,4 @@ +group: refcoco+_bbox_rec +task: refcoco+_bbox_rec_val +include: _default_template_bbox_rec_yaml +test_split: val diff --git a/lmms_eval/tasks/refcoco+/utils_rec.py b/lmms_eval/tasks/refcoco+/utils_rec.py new file mode 100644 index 000000000..ec91bf9cb --- /dev/null +++ b/lmms_eval/tasks/refcoco+/utils_rec.py @@ -0,0 +1,221 @@ +import re +import logging +from datasets import Dataset + +eval_logger = logging.getLogger("lmms-eval") + +COCO_REC_METRICS = ["IoU", "ACC@0.1", "ACC@0.3", "ACC@0.5", "ACC@0.7", "ACC@0.9", "Center_ACC"] + + +def refcoco_bbox_rec_preprocess_dataset(dataset: Dataset): + # PIL image stored in dataset['image'] + # add `image_width` and `image_height` to the dataset + dataset = dataset.map(lambda x: {"image_width": x["image"].width, "image_height": x["image"].height}) + + # Original bbox format (top x, top y, width, height) + # Convert to (top-left x, top-left y, bottom-right x, bottom-right y) + # Normalize the bounding box coordinates to be between 0 and 1 + # using the image width and height + dataset = dataset.map( + lambda x: {"bbox": [x["bbox"][0] / x["image_width"], + x["bbox"][1] / x["image_height"], + (x["bbox"][0] + x["bbox"][2]) / x["image_width"], + (x["bbox"][1] + x["bbox"][3]) / x["image_height"]]} + ) + + # currently, the dataset has `answer` as a list of strings + # each answer should be its own row + # we will explode the dataset to have one row per answer + # duplicate the other columns + def explode_answers(example): + answers = example.pop('answer') + return [{'answer': answer, **example} for answer in answers] + + # Apply the function to each element, collecting the results + exploded_rows = [] + for example in dataset: + exploded_rows.extend(explode_answers(example)) + + # Create a new dataset from the exploded rows + new_dataset = Dataset.from_list(exploded_rows) + print(f"Exploded dataset from {len(dataset)} to {len(new_dataset)} rows") + + return new_dataset + + +def refcoco_bbox_rec_doc_to_visual(doc): + # Image is presented as is + image = doc["image"].convert("RGB") + return [image.convert("RGB")] + + +def refcoco_bbox_rec_doc_to_text(doc): + assert isinstance(doc['answer'], str), "Answer must be a string" + return "Bounding box coordinates are specified in the format (top-left x, top-left y, bottom-right x, bottom-right y). All values are floating point numbers bounded between 0 and 1. Please provide the bounding box coordinate of the region this sentence describes: " + doc['answer'] + + +def parse_float_sequence_within(input_str): + """ + Extract the first sequence of four floating-point numbers within square brackets from a string. + + Args: + input_str (str): A string that may contain a sequence of four floats within square brackets. + + Returns: + list: A list of four floats if the pattern is found, or a list of four zeros if the pattern is not found. + """ + # Define the regex pattern to find the first instance of four floats within square brackets + pattern = r'\[\s*(-?\d+(?:\.\d+)?),\s*(-?\d+(?:\.\d+)?),\s*(-?\d+(?:\.\d+)?),\s*(-?\d+(?:\.\d+)?)\s*\]' + + # Use re.search to find the first match of the pattern in the input string + match = re.search(pattern, input_str) + + # If a match is found, convert the captured groups into a list of floats + if match: + return [float(match.group(i)) for i in range(1, 5)] + + # If the input does not contain the pattern, return the null float sequence + return [0, 0, 0, 0] + + +def refcoco_bbox_rec_process_result(doc, result): + """ + Args: + doc: a instance of the eval dataset + results: [pred] + Returns: + a dictionary with key: metric name, value: metric value + """ + pred = result[0] if len(result) > 0 else "" + pred = parse_float_sequence_within(pred) + ann_id = doc["question_id"] + data_dict = {"answer": doc["answer"], "pred": pred, "ann_id": ann_id, 'bbox': doc['bbox']} + return {f"refcoco_{metric}": data_dict for metric in COCO_REC_METRICS} + + +def compute_iou(box1, box2): + """ + Compute the Intersection over Union (IoU) of two bounding boxes. + + Parameters: + - box1 (list of float): Bounding box [x_min, y_min, x_max, y_max]. + - box2 (list of float): Bounding box [x_min, y_min, x_max, y_max]. + + Returns: + - float: IoU of box1 and box2. + """ + # Determine the coordinates of the intersection rectangle + x_left = max(box1[0], box2[0]) + y_top = max(box1[1], box2[1]) + x_right = min(box1[2], box2[2]) + y_bottom = min(box1[3], box2[3]) + + # Compute the area of intersection + intersection_area = max(0, x_right - x_left) * max(0, y_bottom - y_top) + + # Compute the area of both bounding boxes + box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1]) + box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1]) + + # Compute the area of the union + union_area = box1_area + box2_area - intersection_area + + # Compute the Intersection over Union + iou = intersection_area / union_area + + return iou + + +def compute_accuracy(box1, box2, threshold=0.5): + """ + Compute the accuracy of two bounding boxes based on a specified threshold. + + Parameters: + - box1 (list of float): Bounding box [x_min, y_min, x_max, y_max]. + - box2 (list of float): Bounding box [x_min, y_min, x_max, y_max]. + - threshold (float): Threshold for the IoU to consider the prediction correct. + + Returns: + - float: Accuracy of the prediction based on the IoU threshold. + """ + iou = compute_iou(box1, box2) + return iou >= threshold + + +def compute_center_accuracy(box1, box2): + """ + Compute if the center point of box 2 is within box 1. + + Parameters: + - box1 (list of float): Bounding box [x_min, y_min, x_max, y_max]. + - box2 (list of float): Bounding box [x_min, y_min, x_max, y_max]. + + Returns: + - bool: True if the center point of box 2 is within box 1, False otherwise. + """ + # Compute the center point of box 2 + center_x = (box2[0] + box2[2]) / 2 + center_y = (box2[1] + box2[3]) / 2 + + # Check if the center point is within box 1 + return box1[0] <= center_x <= box1[2] and box1[1] <= center_y <= box1[3] + + +def refcoco_bbox_rec_aggregation_result(results, metric): + """ + Aggregate the results of the RefCOCO evaluation task using the specified metric. + + Args: + - results (list of dict): List of result dictionaries. + - metric (str): Metric to use for aggregation. + + Returns: + - dict: Dictionary containing the aggregated results for the specified metric. + """ + scorers = { + 'IoU': compute_iou, + 'ACC@0.1': lambda x, y: compute_accuracy(x, y, 0.1), + 'ACC@0.3': lambda x, y: compute_accuracy(x, y, 0.3), + 'ACC@0.5': lambda x, y: compute_accuracy(x, y, 0.5), + 'ACC@0.7': lambda x, y: compute_accuracy(x, y, 0.7), + 'ACC@0.9': lambda x, y: compute_accuracy(x, y, 0.9), + 'Center_ACC': compute_center_accuracy + } + results_dict = {metric: []} + for result in results: + # Extract the ground truth and predicted bounding boxes + gt_bbox = result['bbox'] + pred_bbox = result['pred'] + # Compute the specified metric between the ground truth and predicted bounding boxes + score = scorers[metric](gt_bbox, pred_bbox) + results_dict[metric].append(score) + results_dict[metric] = sum(results_dict[metric]) / len(results_dict[metric]) + print(f"Aggregated {metric} score: {results_dict[metric]}") + return results_dict[metric] + + +def refcoco_bbox_rec_iou(results): + return refcoco_bbox_rec_aggregation_result(results, "IoU") + + +def refcoco_bbox_rec_acc01(results): + return refcoco_bbox_rec_aggregation_result(results, "ACC@0.1") + +def refcoco_bbox_rec_acc03(results): + return refcoco_bbox_rec_aggregation_result(results, "ACC@0.3") + + +def refcoco_bbox_rec_acc05(results): + return refcoco_bbox_rec_aggregation_result(results, "ACC@0.5") + + +def refcoco_bbox_rec_acc07(results): + return refcoco_bbox_rec_aggregation_result(results, "ACC@0.7") + + +def refcoco_bbox_rec_acc09(results): + return refcoco_bbox_rec_aggregation_result(results, "ACC@0.9") + + +def refcoco_bbox_rec_center_acc(results): + return refcoco_bbox_rec_aggregation_result(results, "Center_ACC") diff --git a/lmms_eval/tasks/refcocog/_default_template_bbox_rec_yaml b/lmms_eval/tasks/refcocog/_default_template_bbox_rec_yaml new file mode 100644 index 000000000..ce95812b9 --- /dev/null +++ b/lmms_eval/tasks/refcocog/_default_template_bbox_rec_yaml @@ -0,0 +1,34 @@ +dataset_path: lmms-lab/RefCOCOg +output_type: generate_until +process_docs: !function utils_rec.refcoco_bbox_rec_preprocess_dataset +doc_to_visual: !function utils_rec.refcoco_bbox_rec_doc_to_visual +doc_to_text: !function utils_rec.refcoco_bbox_rec_doc_to_text +doc_to_target: "bbox" +generation_kwargs: + until: + - "ASSISTANT:" +process_results: !function utils_rec.refcoco_bbox_rec_process_result +metric_list: + - metric: refcoco_IoU + aggregation : !function utils_rec.refcoco_bbox_rec_iou + higher_is_better : true + - metric: refcoco_ACC@0.1 + aggregation : !function utils_rec.refcoco_bbox_rec_acc01 + higher_is_better : true + - metric: refcoco_ACC@0.3 + aggregation : !function utils_rec.refcoco_bbox_rec_acc03 + higher_is_better : true + - metric: refcoco_ACC@0.5 + aggregation : !function utils_rec.refcoco_bbox_rec_acc05 + higher_is_better : true + - metric: refcoco_ACC@0.7 + aggregation : !function utils_rec.refcoco_bbox_rec_acc07 + higher_is_better : true + - metric: refcoco_ACC@0.9 + aggregation : !function utils_rec.refcoco_bbox_rec_acc09 + higher_is_better : true + - metric: refcoco_Center_ACC + aggregation : !function utils_rec.refcoco_bbox_rec_center_acc + higher_is_better : true +metadata: + version: '0.0' \ No newline at end of file diff --git a/lmms_eval/tasks/refcocog/refcocog_bbox_rec_test.yaml b/lmms_eval/tasks/refcocog/refcocog_bbox_rec_test.yaml new file mode 100644 index 000000000..2f4359798 --- /dev/null +++ b/lmms_eval/tasks/refcocog/refcocog_bbox_rec_test.yaml @@ -0,0 +1,4 @@ +group: refcocog_bbox_rec +task: refcocog_bbox_rec_test +include: _default_template_bbox_rec_yaml +test_split: test diff --git a/lmms_eval/tasks/refcocog/refcocog_bbox_rec_val.yaml b/lmms_eval/tasks/refcocog/refcocog_bbox_rec_val.yaml new file mode 100644 index 000000000..5e19397ae --- /dev/null +++ b/lmms_eval/tasks/refcocog/refcocog_bbox_rec_val.yaml @@ -0,0 +1,4 @@ +group: refcocog_bbox_rec +task: refcocog_bbox_rec_val +include: _default_template_bbox_rec_yaml +test_split: val diff --git a/lmms_eval/tasks/refcocog/utils_rec.py b/lmms_eval/tasks/refcocog/utils_rec.py new file mode 100644 index 000000000..ec91bf9cb --- /dev/null +++ b/lmms_eval/tasks/refcocog/utils_rec.py @@ -0,0 +1,221 @@ +import re +import logging +from datasets import Dataset + +eval_logger = logging.getLogger("lmms-eval") + +COCO_REC_METRICS = ["IoU", "ACC@0.1", "ACC@0.3", "ACC@0.5", "ACC@0.7", "ACC@0.9", "Center_ACC"] + + +def refcoco_bbox_rec_preprocess_dataset(dataset: Dataset): + # PIL image stored in dataset['image'] + # add `image_width` and `image_height` to the dataset + dataset = dataset.map(lambda x: {"image_width": x["image"].width, "image_height": x["image"].height}) + + # Original bbox format (top x, top y, width, height) + # Convert to (top-left x, top-left y, bottom-right x, bottom-right y) + # Normalize the bounding box coordinates to be between 0 and 1 + # using the image width and height + dataset = dataset.map( + lambda x: {"bbox": [x["bbox"][0] / x["image_width"], + x["bbox"][1] / x["image_height"], + (x["bbox"][0] + x["bbox"][2]) / x["image_width"], + (x["bbox"][1] + x["bbox"][3]) / x["image_height"]]} + ) + + # currently, the dataset has `answer` as a list of strings + # each answer should be its own row + # we will explode the dataset to have one row per answer + # duplicate the other columns + def explode_answers(example): + answers = example.pop('answer') + return [{'answer': answer, **example} for answer in answers] + + # Apply the function to each element, collecting the results + exploded_rows = [] + for example in dataset: + exploded_rows.extend(explode_answers(example)) + + # Create a new dataset from the exploded rows + new_dataset = Dataset.from_list(exploded_rows) + print(f"Exploded dataset from {len(dataset)} to {len(new_dataset)} rows") + + return new_dataset + + +def refcoco_bbox_rec_doc_to_visual(doc): + # Image is presented as is + image = doc["image"].convert("RGB") + return [image.convert("RGB")] + + +def refcoco_bbox_rec_doc_to_text(doc): + assert isinstance(doc['answer'], str), "Answer must be a string" + return "Bounding box coordinates are specified in the format (top-left x, top-left y, bottom-right x, bottom-right y). All values are floating point numbers bounded between 0 and 1. Please provide the bounding box coordinate of the region this sentence describes: " + doc['answer'] + + +def parse_float_sequence_within(input_str): + """ + Extract the first sequence of four floating-point numbers within square brackets from a string. + + Args: + input_str (str): A string that may contain a sequence of four floats within square brackets. + + Returns: + list: A list of four floats if the pattern is found, or a list of four zeros if the pattern is not found. + """ + # Define the regex pattern to find the first instance of four floats within square brackets + pattern = r'\[\s*(-?\d+(?:\.\d+)?),\s*(-?\d+(?:\.\d+)?),\s*(-?\d+(?:\.\d+)?),\s*(-?\d+(?:\.\d+)?)\s*\]' + + # Use re.search to find the first match of the pattern in the input string + match = re.search(pattern, input_str) + + # If a match is found, convert the captured groups into a list of floats + if match: + return [float(match.group(i)) for i in range(1, 5)] + + # If the input does not contain the pattern, return the null float sequence + return [0, 0, 0, 0] + + +def refcoco_bbox_rec_process_result(doc, result): + """ + Args: + doc: a instance of the eval dataset + results: [pred] + Returns: + a dictionary with key: metric name, value: metric value + """ + pred = result[0] if len(result) > 0 else "" + pred = parse_float_sequence_within(pred) + ann_id = doc["question_id"] + data_dict = {"answer": doc["answer"], "pred": pred, "ann_id": ann_id, 'bbox': doc['bbox']} + return {f"refcoco_{metric}": data_dict for metric in COCO_REC_METRICS} + + +def compute_iou(box1, box2): + """ + Compute the Intersection over Union (IoU) of two bounding boxes. + + Parameters: + - box1 (list of float): Bounding box [x_min, y_min, x_max, y_max]. + - box2 (list of float): Bounding box [x_min, y_min, x_max, y_max]. + + Returns: + - float: IoU of box1 and box2. + """ + # Determine the coordinates of the intersection rectangle + x_left = max(box1[0], box2[0]) + y_top = max(box1[1], box2[1]) + x_right = min(box1[2], box2[2]) + y_bottom = min(box1[3], box2[3]) + + # Compute the area of intersection + intersection_area = max(0, x_right - x_left) * max(0, y_bottom - y_top) + + # Compute the area of both bounding boxes + box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1]) + box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1]) + + # Compute the area of the union + union_area = box1_area + box2_area - intersection_area + + # Compute the Intersection over Union + iou = intersection_area / union_area + + return iou + + +def compute_accuracy(box1, box2, threshold=0.5): + """ + Compute the accuracy of two bounding boxes based on a specified threshold. + + Parameters: + - box1 (list of float): Bounding box [x_min, y_min, x_max, y_max]. + - box2 (list of float): Bounding box [x_min, y_min, x_max, y_max]. + - threshold (float): Threshold for the IoU to consider the prediction correct. + + Returns: + - float: Accuracy of the prediction based on the IoU threshold. + """ + iou = compute_iou(box1, box2) + return iou >= threshold + + +def compute_center_accuracy(box1, box2): + """ + Compute if the center point of box 2 is within box 1. + + Parameters: + - box1 (list of float): Bounding box [x_min, y_min, x_max, y_max]. + - box2 (list of float): Bounding box [x_min, y_min, x_max, y_max]. + + Returns: + - bool: True if the center point of box 2 is within box 1, False otherwise. + """ + # Compute the center point of box 2 + center_x = (box2[0] + box2[2]) / 2 + center_y = (box2[1] + box2[3]) / 2 + + # Check if the center point is within box 1 + return box1[0] <= center_x <= box1[2] and box1[1] <= center_y <= box1[3] + + +def refcoco_bbox_rec_aggregation_result(results, metric): + """ + Aggregate the results of the RefCOCO evaluation task using the specified metric. + + Args: + - results (list of dict): List of result dictionaries. + - metric (str): Metric to use for aggregation. + + Returns: + - dict: Dictionary containing the aggregated results for the specified metric. + """ + scorers = { + 'IoU': compute_iou, + 'ACC@0.1': lambda x, y: compute_accuracy(x, y, 0.1), + 'ACC@0.3': lambda x, y: compute_accuracy(x, y, 0.3), + 'ACC@0.5': lambda x, y: compute_accuracy(x, y, 0.5), + 'ACC@0.7': lambda x, y: compute_accuracy(x, y, 0.7), + 'ACC@0.9': lambda x, y: compute_accuracy(x, y, 0.9), + 'Center_ACC': compute_center_accuracy + } + results_dict = {metric: []} + for result in results: + # Extract the ground truth and predicted bounding boxes + gt_bbox = result['bbox'] + pred_bbox = result['pred'] + # Compute the specified metric between the ground truth and predicted bounding boxes + score = scorers[metric](gt_bbox, pred_bbox) + results_dict[metric].append(score) + results_dict[metric] = sum(results_dict[metric]) / len(results_dict[metric]) + print(f"Aggregated {metric} score: {results_dict[metric]}") + return results_dict[metric] + + +def refcoco_bbox_rec_iou(results): + return refcoco_bbox_rec_aggregation_result(results, "IoU") + + +def refcoco_bbox_rec_acc01(results): + return refcoco_bbox_rec_aggregation_result(results, "ACC@0.1") + +def refcoco_bbox_rec_acc03(results): + return refcoco_bbox_rec_aggregation_result(results, "ACC@0.3") + + +def refcoco_bbox_rec_acc05(results): + return refcoco_bbox_rec_aggregation_result(results, "ACC@0.5") + + +def refcoco_bbox_rec_acc07(results): + return refcoco_bbox_rec_aggregation_result(results, "ACC@0.7") + + +def refcoco_bbox_rec_acc09(results): + return refcoco_bbox_rec_aggregation_result(results, "ACC@0.9") + + +def refcoco_bbox_rec_center_acc(results): + return refcoco_bbox_rec_aggregation_result(results, "Center_ACC") From d615bd75b15444860ae54861d83f240af032ee83 Mon Sep 17 00:00:00 2001 From: Hunter Heidenreich Date: Wed, 10 Apr 2024 14:13:24 +0000 Subject: [PATCH 22/25] Update README --- README.md | 40 ++++++++++++++++++++++++++++------------ 1 file changed, 28 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 9c9bafe44..fbf069368 100644 --- a/README.md +++ b/README.md @@ -201,14 +201,21 @@ We also provide the raw data exported from Weights & Biases for the detailed res - OKVQA Validation 2014 (ok_vqa_val2014) - POPE (pope) - RefCOCO (refcoco) - - refcoco_seg_test - - refcoco_seg_val - - refcoco_seg_testA - - refcoco_seg_testB - - refcoco_bbox_test - - refcoco_bbox_val - - refcoco_bbox_testA - - refcoco_bbox_testB + - refcoco_seg + - refcoco_seg_test + - refcoco_seg_val + - refcoco_seg_testA + - refcoco_seg_testB + - refcoco_bbox + - refcoco_bbox_test + - refcoco_bbox_val + - refcoco_bbox_testA + - refcoco_bbox_testB + - refcoco_bbox_rec + - refcoco_bbox_rec_test + - refcoco_bbox_rec_val + - refcoco_bbox_rec_testA + - refcoco_bbox_rec_testB - RefCOCO+ (refcoco+) - refcoco+_seg - refcoco+_seg_val @@ -218,11 +225,20 @@ We also provide the raw data exported from Weights & Biases for the detailed res - refcoco+_bbox_val - refcoco+_bbox_testA - refcoco+_bbox_testB + - refcoco+_bbox_rec + - refcoco+_bbox_rec_val + - refcoco+_bbox_rec_testA + - refcoco+_bbox_rec_testB - RefCOCOg (refcocog) - - refcocog_seg_test - - refcocog_seg_val - - refcocog_bbox_test - - refcocog_bbox_val + - refcocog_seg + - refcocog_seg_test + - refcocog_seg_val + - refcocog_bbox + - refcocog_bbox_test + - refcocog_bbox_val + - refcocog_bbox_rec + - refcocog_bbox_rec_test + - refcocog_bbox_rec_val - ScienceQA (scienceqa_full) - ScienceQA Full (scienceqa) - ScienceQA IMG (scienceqa_img) From f67b26b9e7b549073bfc18103d66c2c03e5da255 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Fri, 12 Apr 2024 09:22:35 +0000 Subject: [PATCH 23/25] Fix typing for nullables in llava-hf --- lmms_eval/models/llava_hf.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lmms_eval/models/llava_hf.py b/lmms_eval/models/llava_hf.py index e5a88da60..89c38a9bc 100644 --- a/lmms_eval/models/llava_hf.py +++ b/lmms_eval/models/llava_hf.py @@ -46,9 +46,9 @@ def __init__( revision: str = "main", device: str = "cuda", dtype: Optional[Union[str, torch.dtype]] = "auto", - batch_size: Union[int, str] = 1, + batch_size: int = 1, trust_remote_code: Optional[bool] = False, - attn_implementation: Optional[str] = "flash_attention_2", + attn_implementation: Optional[str] = None, device_map: str = "", chat_template: Optional[str] = None, use_cache: bool = True, @@ -280,7 +280,7 @@ def _collate(x): self.tokenizer.chat_template = VICUNA_CHAT_TEMPLATE text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) - if self.accelerator.is_main_process and doc_id[0] % 100 == 0: + if self.accelerator.is_main_process and doc_id[0] % 1 == 0: eval_logger.info(f"Prompt for doc ID {doc_id[0]}:\n\n{text}\n") inputs = self._image_processor(images=visuals, text=text, return_tensors="pt").to(self._device, self.model.dtype) From 639f90c13d1376e4d5023a4b7644a87605b5bf85 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Fri, 12 Apr 2024 09:34:38 +0000 Subject: [PATCH 24/25] Revert --- lmms_eval/models/llava_hf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lmms_eval/models/llava_hf.py b/lmms_eval/models/llava_hf.py index 89c38a9bc..effb94fb4 100644 --- a/lmms_eval/models/llava_hf.py +++ b/lmms_eval/models/llava_hf.py @@ -280,7 +280,7 @@ def _collate(x): self.tokenizer.chat_template = VICUNA_CHAT_TEMPLATE text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) - if self.accelerator.is_main_process and doc_id[0] % 1 == 0: + if self.accelerator.is_main_process and doc_id[0] % 100 == 0: eval_logger.info(f"Prompt for doc ID {doc_id[0]}:\n\n{text}\n") inputs = self._image_processor(images=visuals, text=text, return_tensors="pt").to(self._device, self.model.dtype) From 49a23db924e6f1c68b3e5d80d0c1783e54b2e4a1 Mon Sep 17 00:00:00 2001 From: Fanyi Pu Date: Tue, 16 Apr 2024 00:36:25 +0800 Subject: [PATCH 25/25] fix a bug for single gpu --- lmms_eval/models/llava.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lmms_eval/models/llava.py b/lmms_eval/models/llava.py index 2bd6f4414..a735f82c6 100644 --- a/lmms_eval/models/llava.py +++ b/lmms_eval/models/llava.py @@ -53,7 +53,7 @@ def __init__( trust_remote_code: Optional[bool] = False, revision=None, use_flash_attention_2=True, - device_map="", + device_map="auto", conv_template="vicuna_v1", use_cache=True, truncate_context=False, # whether to truncate the context in generation, set it False for LLaVA-1.6