From 3f878eac43e459f608fc6107f641f896766655ff Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Thu, 4 Apr 2024 07:49:37 +0000 Subject: [PATCH 01/17] Remove flash attention --- lmms_eval/models/llava.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lmms_eval/models/llava.py b/lmms_eval/models/llava.py index b3cb8a66d..3edd8d76b 100644 --- a/lmms_eval/models/llava.py +++ b/lmms_eval/models/llava.py @@ -52,7 +52,7 @@ def __init__( batch_size: Optional[Union[int, str]] = 1, trust_remote_code: Optional[bool] = False, revision=None, - use_flash_attention_2=True, + use_flash_attention_2=False, # True, device_map="", conv_template="vicuna_v1", use_cache=True, From c8043455b655a2a169cc0e1f80beaa7758aec38b Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Thu, 4 Apr 2024 13:47:10 +0000 Subject: [PATCH 02/17] Fix LlavaHf integration - need fp16 in accelerate config! --- lmms_eval/models/__init__.py | 1 + lmms_eval/models/llava_hf.py | 215 +++++++++++++++++++++++++++++++++++ 2 files changed, 216 insertions(+) create mode 100644 lmms_eval/models/llava_hf.py diff --git a/lmms_eval/models/__init__.py b/lmms_eval/models/__init__.py index d39ce73cf..9135311cc 100644 --- a/lmms_eval/models/__init__.py +++ b/lmms_eval/models/__init__.py @@ -2,6 +2,7 @@ AVAILABLE_MODELS = { "llava": "Llava", + "llava_hf": "LlavaHf", "qwen_vl": "Qwen_VL", "fuyu": "Fuyu", "gpt4v": "GPT4V", diff --git a/lmms_eval/models/llava_hf.py b/lmms_eval/models/llava_hf.py new file mode 100644 index 000000000..1709c688c --- /dev/null +++ b/lmms_eval/models/llava_hf.py @@ -0,0 +1,215 @@ +import torch +import logging +from tqdm import tqdm +from lmms_eval import utils +from lmms_eval.api.instance import Instance +from lmms_eval.api.model import lmms +from lmms_eval.api.registry import register_model +from accelerate import Accelerator, DistributedType +from accelerate.state import AcceleratorState +from typing import List, Optional, Union, Tuple +from transformers import LlavaForConditionalGeneration, AutoProcessor + +import warnings + +warnings.filterwarnings("ignore") + +eval_logger = logging.getLogger("lmms-eval") + + +@register_model("llava_hf") +class LlavaHf(lmms): + """ + Llava HF Model + """ + + def __init__( + self, + pretrained: str = "llava-hf/llava-1.5-7b-hf", + device: Optional[str] = "cuda", + dtype: Optional[Union[str, torch.dtype]] = torch.float16, + batch_size: Optional[Union[int, str]] = 1, + **kwargs, + ) -> None: + super().__init__() + # Do not use kwargs for now + assert kwargs == {}, f"Unexpected kwargs: {kwargs}" + + accelerator = Accelerator() + if accelerator.num_processes > 1: + self._device = torch.device(f"cuda:{accelerator.local_process_index}") + else: + self._device = device + self._model = LlavaForConditionalGeneration.from_pretrained(pretrained, torch_dtype=dtype) + self._image_processor = AutoProcessor.from_pretrained(pretrained) + self._tokenizer = self._image_processor.tokenizer + self._config = self._model.config + self._model.eval() + self._model.tie_weights() + self.batch_size_per_gpu = int(batch_size) + if accelerator.num_processes > 1: + assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported." + # If you want to use DistributedType.DEEPSPEED, you have to run accelerate config before using the model + # Also, you have to select zero stage 0 (equivalent to DDP) in order to make the prepare model works + # I tried to set different parameters in the kwargs to let default zero 2 stage works, but it didn't work. + if accelerator.distributed_type == DistributedType.DEEPSPEED: + kwargs = { + "train_micro_batch_size_per_gpu": self.batch_size_per_gpu, + "train_batch_size": self.batch_size_per_gpu * accelerator.num_processes, + } + AcceleratorState().deepspeed_plugin.deepspeed_config_process(must_match=True, **kwargs) + eval_logger.info("Detected that you are using DistributedType.DEEPSPEED. Make sure you run `accelerate config` and set zero stage to 0") + if accelerator.distributed_type == DistributedType.FSDP or accelerator.distributed_type == DistributedType.DEEPSPEED: + self._model = accelerator.prepare(self.model) + else: + self._model = accelerator.prepare_model(self.model, evaluation_mode=True) + self.accelerator = accelerator + if self.accelerator.is_local_main_process: + eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism") + self._rank = self.accelerator.local_process_index + self._world_size = self.accelerator.num_processes + else: + self.model.to(self._device) + self._rank = 0 + self._word_size = 1 + + @property + def config(self): + # return the associated transformers.AutoConfig for the given pretrained model. + return self._config + + @property + def tokenizer(self): + return self._tokenizer + + @property + def model(self): + # returns the model, unwrapping it if using Accelerate + if hasattr(self, "accelerator"): + return self.accelerator.unwrap_model(self._model) + else: + return self._model + + @property + def eot_token_id(self): + # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* + return self.tokenizer.eos_token_id + + @property + def max_length(self): + return self._max_length + + @property + def batch_size(self): + return self.batch_size_per_gpu + + @property + def device(self): + return self._device + + @property + def rank(self): + return self._rank + + @property + def world_size(self): + return self._world_size + + def tok_encode(self, string: str, left_truncate_len=None, add_special_tokens=None) -> List[int]: + """ """ + add_special_tokens = False if add_special_tokens is None else add_special_tokens + encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens) + # left-truncate the encoded context to be at most `left_truncate_len` tokens long + if left_truncate_len: + encoding = encoding[-left_truncate_len:] + return encoding + + def tok_decode(self, tokens): + return self.tokenizer.decode(tokens) + + def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: + # TODO + assert False, "We have not implemented this function for LlavaHf yet" + + def flatten(self, input): + new_list = [] + for i in input: + for j in i: + new_list.append(j) + return new_list + + def generate_until(self, requests: List[Instance]) -> List[str]: + res = [] + + def _collate(x): + # the negative sign on len(toks) sorts descending - this has a few advantages: + # - time estimates will always be over not underestimates, which is more useful for planning + # - to know the size of a batch when going through the list, you know the first one is always the batch + # padded context length. this is useful to simplify the batching logic and more importantly to make + # automatic adaptive batches much much easier to implement + # - any OOMs will happen right away rather than near the end + toks = self.tok_encode(x[0]) + return -len(toks), x[0] + + # we group requests by their generation_kwargs, + # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling + # in the same batch. + re_ords = utils.Collator([reg.args for reg in requests], _collate, grouping=True) + chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None) + num_iters = len(requests) // self.batch_size if len(requests) % self.batch_size == 0 else len(requests) // self.batch_size + 1 + pbar = tqdm(total=num_iters, disable=(self.rank != 0), desc="Model Responding") + for chunk in chunks: + contexts, all_gen_kwargs, doc_to_visual, doc_id, task, split = zip(*chunk) + task = task[0] + split = split[0] + visuals = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id] + visuals = self.flatten(visuals) + # we assume all gen kwargs in the batch are the same + # this is safe to assume because the `grouper` object ensures it. + gen_kwargs = all_gen_kwargs[0] + + # Set default values for until and max_new_tokens + until = [self.tok_decode(self.eot_token_id)] + + # Update values from gen_kwargs if present + if "until" in gen_kwargs: + until = gen_kwargs.pop("until") + if isinstance(until, str): + until = [until] + elif not isinstance(until, list): + raise ValueError(f"Expected `gen_kwargs['until']` to be of type Union[str,list] but got {type(until)}") + assert self.batch_size_per_gpu == 1, "Do not support batch_size_per_gpu > 1 for now" + context = contexts[0] + context = f"USER: \n{context}\nASSISTANT:" + inputs = self._image_processor(images=visuals, text=context, return_tensors="pt").to(self._device, torch.float16) + + gen_kwargs["image_sizes"] = [visuals[idx].size for idx in range(len(visuals))] + if "max_new_tokens" not in gen_kwargs: + gen_kwargs["max_new_tokens"] = 1024 + if "temperature" not in gen_kwargs: + gen_kwargs["temperature"] = 0 + if "top_p" not in gen_kwargs: + gen_kwargs["top_p"] = None + if "num_beams" not in gen_kwargs: + gen_kwargs["num_beams"] = 1 + try: + cont = self._model.generate( + **inputs, + do_sample=True if gen_kwargs["temperature"] > 0 else False, + temperature=gen_kwargs["temperature"], + top_p=gen_kwargs["top_p"], + num_beams=gen_kwargs["num_beams"], + max_new_tokens=gen_kwargs["max_new_tokens"], + ) + except Exception as e: + eval_logger.error(f"Error {e} in generating") + cont = "" + text_outputs = self.tokenizer.batch_decode(cont, skip_special_tokens=True)[0].strip() + res.append(text_outputs) + self.cache_hook.add_partial("generate_until", (context, gen_kwargs), text_outputs) + pbar.update(1) + # reorder this group of results back to original unsorted form + res = re_ords.get_original(res) + + pbar.close() + return res From a6e258b87310cbd42e50305eaada3fbdfee66249 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Thu, 4 Apr 2024 13:56:13 +0000 Subject: [PATCH 03/17] Fix parsing --- lmms_eval/models/llava_hf.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/lmms_eval/models/llava_hf.py b/lmms_eval/models/llava_hf.py index 1709c688c..c0231ec11 100644 --- a/lmms_eval/models/llava_hf.py +++ b/lmms_eval/models/llava_hf.py @@ -204,12 +204,14 @@ def _collate(x): except Exception as e: eval_logger.error(f"Error {e} in generating") cont = "" - text_outputs = self.tokenizer.batch_decode(cont, skip_special_tokens=True)[0].strip() + text_outputs = self.tokenizer.batch_decode(cont, skip_special_tokens=True)[0]#.strip() + text_outputs = text_outputs.split("ASSISTANT:")[1].strip() res.append(text_outputs) self.cache_hook.add_partial("generate_until", (context, gen_kwargs), text_outputs) pbar.update(1) # reorder this group of results back to original unsorted form res = re_ords.get_original(res) + res = re_ords.get_foriginal(res) pbar.close() return res From 86732ebf7a406504ee822fe188329bb1005c7f25 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Fri, 5 Apr 2024 08:34:18 +0000 Subject: [PATCH 04/17] Add chat template --- lmms_eval/models/llava_hf.py | 29 +++++++++++++++++++++++------ 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/lmms_eval/models/llava_hf.py b/lmms_eval/models/llava_hf.py index c0231ec11..4f7d6e075 100644 --- a/lmms_eval/models/llava_hf.py +++ b/lmms_eval/models/llava_hf.py @@ -26,6 +26,7 @@ class LlavaHf(lmms): def __init__( self, pretrained: str = "llava-hf/llava-1.5-7b-hf", + revision=None, device: Optional[str] = "cuda", dtype: Optional[Union[str, torch.dtype]] = torch.float16, batch_size: Optional[Union[int, str]] = 1, @@ -40,8 +41,8 @@ def __init__( self._device = torch.device(f"cuda:{accelerator.local_process_index}") else: self._device = device - self._model = LlavaForConditionalGeneration.from_pretrained(pretrained, torch_dtype=dtype) - self._image_processor = AutoProcessor.from_pretrained(pretrained) + self._model = LlavaForConditionalGeneration.from_pretrained(pretrained, revision=revision, torch_dtype=dtype) + self._image_processor = AutoProcessor.from_pretrained(pretrained, revision=revision) self._tokenizer = self._image_processor.tokenizer self._config = self._model.config self._model.eval() @@ -180,8 +181,20 @@ def _collate(x): raise ValueError(f"Expected `gen_kwargs['until']` to be of type Union[str,list] but got {type(until)}") assert self.batch_size_per_gpu == 1, "Do not support batch_size_per_gpu > 1 for now" context = contexts[0] - context = f"USER: \n{context}\nASSISTANT:" - inputs = self._image_processor(images=visuals, text=context, return_tensors="pt").to(self._device, torch.float16) + + # Some benchmarks like MME do not contain image tokens, so we prepend them to the prompt. + if "" not in context: + context = f"\n{context}" + if self.tokenizer.chat_template is not None: + messages = [{"role": "user", "content": context}] + text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + else: + text = f"USER: {context}\nASSISTANT:" + + if self.accelerator.is_main_process and doc_id[0] % 100 == 0: + eval_logger.info(f"Prompt for doc ID {doc_id[0]}:\n\n{text}\n") + + inputs = self._image_processor(images=visuals, text=text, return_tensors="pt").to(self._device, torch.float16) gen_kwargs["image_sizes"] = [visuals[idx].size for idx in range(len(visuals))] if "max_new_tokens" not in gen_kwargs: @@ -204,14 +217,18 @@ def _collate(x): except Exception as e: eval_logger.error(f"Error {e} in generating") cont = "" - text_outputs = self.tokenizer.batch_decode(cont, skip_special_tokens=True)[0]#.strip() + text_outputs = self.tokenizer.batch_decode(cont, skip_special_tokens=True)[0] text_outputs = text_outputs.split("ASSISTANT:")[1].strip() + + if self.accelerator.is_main_process and doc_id[0] % 100 == 0: + eval_logger.info(f"Generated text for doc ID {doc_id[0]}:\n\n{text_outputs}\n") + res.append(text_outputs) self.cache_hook.add_partial("generate_until", (context, gen_kwargs), text_outputs) pbar.update(1) # reorder this group of results back to original unsorted form res = re_ords.get_original(res) - res = re_ords.get_foriginal(res) + res = re_ords.get_original(res) pbar.close() return res From c4955df8c99a3594983245c3a9823940d090c4b9 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Fri, 5 Apr 2024 13:03:17 +0000 Subject: [PATCH 05/17] Revert --- lmms_eval/models/llava.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lmms_eval/models/llava.py b/lmms_eval/models/llava.py index 3edd8d76b..b3cb8a66d 100644 --- a/lmms_eval/models/llava.py +++ b/lmms_eval/models/llava.py @@ -52,7 +52,7 @@ def __init__( batch_size: Optional[Union[int, str]] = 1, trust_remote_code: Optional[bool] = False, revision=None, - use_flash_attention_2=False, # True, + use_flash_attention_2=True, device_map="", conv_template="vicuna_v1", use_cache=True, From 78fc6175b384797c3728fb8d390bf90e318bfd14 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Mon, 8 Apr 2024 13:19:19 +0000 Subject: [PATCH 06/17] Remove duplicate code --- lmms_eval/models/llava_hf.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lmms_eval/models/llava_hf.py b/lmms_eval/models/llava_hf.py index 4f7d6e075..926168cba 100644 --- a/lmms_eval/models/llava_hf.py +++ b/lmms_eval/models/llava_hf.py @@ -228,7 +228,6 @@ def _collate(x): pbar.update(1) # reorder this group of results back to original unsorted form res = re_ords.get_original(res) - res = re_ords.get_original(res) pbar.close() return res From ca20b8cd8d9eab606ee6474dc384e2c41802d27e Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Mon, 8 Apr 2024 15:44:46 +0000 Subject: [PATCH 07/17] Add chat templates --- lmms_eval/constants.py | 2 + lmms_eval/models/llava_hf.py | 81 +++++++++++++++++++++++++++++------- 2 files changed, 67 insertions(+), 16 deletions(-) create mode 100644 lmms_eval/constants.py diff --git a/lmms_eval/constants.py b/lmms_eval/constants.py new file mode 100644 index 000000000..675c30fbc --- /dev/null +++ b/lmms_eval/constants.py @@ -0,0 +1,2 @@ +DEFAULT_IMAGE_TOKEN = "" +DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}{% if message['role'] == 'user' %}USER: {{ message['content'] }}\n{% else %}ASSISTANT: {{ message['content'] }}\n{% endif %}{% endfor %}" \ No newline at end of file diff --git a/lmms_eval/models/llava_hf.py b/lmms_eval/models/llava_hf.py index 926168cba..fd08c1e44 100644 --- a/lmms_eval/models/llava_hf.py +++ b/lmms_eval/models/llava_hf.py @@ -5,6 +5,7 @@ from lmms_eval.api.instance import Instance from lmms_eval.api.model import lmms from lmms_eval.api.registry import register_model +from lmms_eval.constants import DEFAULT_IMAGE_TOKEN, DEFAULT_CHAT_TEMPLATE from accelerate import Accelerator, DistributedType from accelerate.state import AcceleratorState from typing import List, Optional, Union, Tuple @@ -20,7 +21,8 @@ @register_model("llava_hf") class LlavaHf(lmms): """ - Llava HF Model + Llava Model for Hugging Face Transformers + https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/llava """ def __init__( @@ -28,8 +30,10 @@ def __init__( pretrained: str = "llava-hf/llava-1.5-7b-hf", revision=None, device: Optional[str] = "cuda", - dtype: Optional[Union[str, torch.dtype]] = torch.float16, + dtype: Optional[Union[str, torch.dtype]] = "auto", batch_size: Optional[Union[int, str]] = 1, + device_map="", + chat_template: Optional[str] = None, **kwargs, ) -> None: super().__init__() @@ -37,18 +41,23 @@ def __init__( assert kwargs == {}, f"Unexpected kwargs: {kwargs}" accelerator = Accelerator() - if accelerator.num_processes > 1: + if accelerator.num_processes > 1 and device_map == "": self._device = torch.device(f"cuda:{accelerator.local_process_index}") + self.device_map = f"cuda:{accelerator.local_process_index}" else: - self._device = device - self._model = LlavaForConditionalGeneration.from_pretrained(pretrained, revision=revision, torch_dtype=dtype) + self._device = torch.device(device) + self.device_map = device_map + if isinstance(dtype, str) and dtype != "auto": + dtype = getattr(torch, dtype) + self.dtype = dtype + self._model = LlavaForConditionalGeneration.from_pretrained(pretrained, revision=revision, torch_dtype=self.dtype, device_map=self.device_map) self._image_processor = AutoProcessor.from_pretrained(pretrained, revision=revision) + # Pad from left for batched generation: https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/llava#usage-tips + self._image_processor.tokenizer.padding_side = "left" self._tokenizer = self._image_processor.tokenizer self._config = self._model.config - self._model.eval() - self._model.tie_weights() self.batch_size_per_gpu = int(batch_size) - if accelerator.num_processes > 1: + if accelerator.num_processes > 1 and device_map == "": assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported." # If you want to use DistributedType.DEEPSPEED, you have to run accelerate config before using the model # Also, you have to select zero stage 0 (equivalent to DDP) in order to make the prepare model works @@ -69,7 +78,12 @@ def __init__( eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism") self._rank = self.accelerator.local_process_index self._world_size = self.accelerator.num_processes + elif accelerator.num_processes == 1 and device_map == "auto": + eval_logger.info(f"Using {accelerator.num_processes} devices with pipeline parallelism") + self._rank = 0 + self._word_size = 1 else: + eval_logger.info(f"Using single device: {self._device}") self.model.to(self._device) self._rank = 0 self._word_size = 1 @@ -129,8 +143,42 @@ def tok_decode(self, tokens): return self.tokenizer.decode(tokens) def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: - # TODO - assert False, "We have not implemented this function for LlavaHf yet" + res = [] + pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding") + + for contexts, doc_to_target, doc_to_visual, doc_id, task, split in [reg.args for reg in requests]: + # encode, pad, and truncate contexts for this batch + if type(doc_to_target) == str: + continuation = doc_to_target + else: + continuation = doc_to_target(self.task_dict[task][split][doc_id]) + visuals = [doc_to_visual(self.task_dict[task][split][doc_id])] + visuals = self.flatten(visuals) + formatted_contexts = [f"{contexts}\n"] + formatted_continuation = [f"{contexts}\n{continuation}"] + model_inputs = self.processor(text=formatted_continuation, images=visuals, device=self.device) + for k, v in model_inputs.items(): + model_inputs[k] = v.to(self.device, non_blocking=True) if isinstance(v, torch.Tensor) else [vv.to(self.device, non_blocking=True) for vv in v] + + for index in range(len(model_inputs["image_patches"])): + model_inputs["image_patches"][index] = model_inputs["image_patches"][index].to(dtype=next(self.model.parameters()).dtype) + + labels = model_inputs["input_ids"].clone() + contxt_id = self.processor(text=formatted_contexts, return_tensors="pt")["input_ids"] + labels[: len(contxt_id)] = -100 + with torch.inference_mode(): + outputs = self.model(**model_inputs, labels=labels) + loss = outputs["loss"] + logits = outputs["logits"] + greedy_tokens = logits.argmax(dim=-1) + cont_toks = model_inputs["input_ids"][:, contxt_id.shape[1] :] # [1, seq] + greedy_tokens = greedy_tokens[:, contxt_id.shape[1] : model_inputs["input_ids"].shape[1]] # [1, seq] + max_equal = (greedy_tokens == cont_toks).all() + res.append((float(loss.item()), bool(max_equal))) + pbar.update(1) + + pbar.close() + return res def flatten(self, input): new_list = [] @@ -183,18 +231,19 @@ def _collate(x): context = contexts[0] # Some benchmarks like MME do not contain image tokens, so we prepend them to the prompt. - if "" not in context: - context = f"\n{context}" + if DEFAULT_IMAGE_TOKEN not in context: + context = f"{DEFAULT_IMAGE_TOKEN}\n{context}" + messages = [{"role": "user", "content": context}] if self.tokenizer.chat_template is not None: - messages = [{"role": "user", "content": context}] text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) else: - text = f"USER: {context}\nASSISTANT:" + self.tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE + text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) if self.accelerator.is_main_process and doc_id[0] % 100 == 0: eval_logger.info(f"Prompt for doc ID {doc_id[0]}:\n\n{text}\n") - inputs = self._image_processor(images=visuals, text=text, return_tensors="pt").to(self._device, torch.float16) + inputs = self._image_processor(images=visuals, text=text, return_tensors="pt").to(self._device, self.dtype) gen_kwargs["image_sizes"] = [visuals[idx].size for idx in range(len(visuals))] if "max_new_tokens" not in gen_kwargs: @@ -218,7 +267,7 @@ def _collate(x): eval_logger.error(f"Error {e} in generating") cont = "" text_outputs = self.tokenizer.batch_decode(cont, skip_special_tokens=True)[0] - text_outputs = text_outputs.split("ASSISTANT:")[1].strip() + text_outputs = text_outputs.split("ASSISTANT:")[-1].strip() if self.accelerator.is_main_process and doc_id[0] % 100 == 0: eval_logger.info(f"Generated text for doc ID {doc_id[0]}:\n\n{text_outputs}\n") From 46e1f684c46f8d6bfa038360ffce913c96e51c42 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Mon, 8 Apr 2024 19:41:21 +0000 Subject: [PATCH 08/17] Fix chat templates --- lmms_eval/constants.py | 4 +++- lmms_eval/models/llava_hf.py | 18 ++++++++++++------ 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/lmms_eval/constants.py b/lmms_eval/constants.py index 675c30fbc..6c79b7049 100644 --- a/lmms_eval/constants.py +++ b/lmms_eval/constants.py @@ -1,2 +1,4 @@ DEFAULT_IMAGE_TOKEN = "" -DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}{% if message['role'] == 'user' %}USER: {{ message['content'] }}\n{% else %}ASSISTANT: {{ message['content'] }}\n{% endif %}{% endfor %}" \ No newline at end of file + +# Default chat templates +VICUNA_CHAT_TEMPLATE = "{% for message in messages %}{% if loop.index0 == 0 %}A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {{ message['content'] }} {% elif message['role'] == 'user' %}USER: {{ message['content'] }} {% else %} ASSISTANT: {{ message['content'] }}{{ eos_token }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'ASSISTANT:' }}{% endif %}" \ No newline at end of file diff --git a/lmms_eval/models/llava_hf.py b/lmms_eval/models/llava_hf.py index fd08c1e44..1fab60059 100644 --- a/lmms_eval/models/llava_hf.py +++ b/lmms_eval/models/llava_hf.py @@ -5,7 +5,7 @@ from lmms_eval.api.instance import Instance from lmms_eval.api.model import lmms from lmms_eval.api.registry import register_model -from lmms_eval.constants import DEFAULT_IMAGE_TOKEN, DEFAULT_CHAT_TEMPLATE +from lmms_eval.constants import DEFAULT_IMAGE_TOKEN, VICUNA_CHAT_TEMPLATE from accelerate import Accelerator, DistributedType from accelerate.state import AcceleratorState from typing import List, Optional, Union, Tuple @@ -49,14 +49,14 @@ def __init__( self.device_map = device_map if isinstance(dtype, str) and dtype != "auto": dtype = getattr(torch, dtype) - self.dtype = dtype - self._model = LlavaForConditionalGeneration.from_pretrained(pretrained, revision=revision, torch_dtype=self.dtype, device_map=self.device_map) + self._model = LlavaForConditionalGeneration.from_pretrained(pretrained, revision=revision, torch_dtype=dtype, device_map=self.device_map) self._image_processor = AutoProcessor.from_pretrained(pretrained, revision=revision) # Pad from left for batched generation: https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/llava#usage-tips self._image_processor.tokenizer.padding_side = "left" self._tokenizer = self._image_processor.tokenizer self._config = self._model.config self.batch_size_per_gpu = int(batch_size) + self.chat_template = chat_template if accelerator.num_processes > 1 and device_map == "": assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported." # If you want to use DistributedType.DEEPSPEED, you have to run accelerate config before using the model @@ -234,16 +234,22 @@ def _collate(x): if DEFAULT_IMAGE_TOKEN not in context: context = f"{DEFAULT_IMAGE_TOKEN}\n{context}" messages = [{"role": "user", "content": context}] - if self.tokenizer.chat_template is not None: + # Apply chat template if provided + if self.chat_template is not None: + self.tokenizer.chat_template = self.chat_template text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + # Else default to the tokenizer's template + elif self.tokenizer.chat_template is not None: + text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + # Finally default to the Vicuna chat template else: - self.tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE + self.tokenizer.chat_template = VICUNA_CHAT_TEMPLATE text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) if self.accelerator.is_main_process and doc_id[0] % 100 == 0: eval_logger.info(f"Prompt for doc ID {doc_id[0]}:\n\n{text}\n") - inputs = self._image_processor(images=visuals, text=text, return_tensors="pt").to(self._device, self.dtype) + inputs = self._image_processor(images=visuals, text=text, return_tensors="pt").to(self._device, self._model.dtype) gen_kwargs["image_sizes"] = [visuals[idx].size for idx in range(len(visuals))] if "max_new_tokens" not in gen_kwargs: From 2e96d4ca4b7b2f0a3cee475a1f190262626b1878 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Mon, 8 Apr 2024 20:04:52 +0000 Subject: [PATCH 09/17] Tidy --- lmms_eval/models/llava_hf.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/lmms_eval/models/llava_hf.py b/lmms_eval/models/llava_hf.py index 1fab60059..ce7355d91 100644 --- a/lmms_eval/models/llava_hf.py +++ b/lmms_eval/models/llava_hf.py @@ -233,15 +233,13 @@ def _collate(x): # Some benchmarks like MME do not contain image tokens, so we prepend them to the prompt. if DEFAULT_IMAGE_TOKEN not in context: context = f"{DEFAULT_IMAGE_TOKEN}\n{context}" + # Apply chat template messages = [{"role": "user", "content": context}] - # Apply chat template if provided if self.chat_template is not None: self.tokenizer.chat_template = self.chat_template text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) - # Else default to the tokenizer's template elif self.tokenizer.chat_template is not None: text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) - # Finally default to the Vicuna chat template else: self.tokenizer.chat_template = VICUNA_CHAT_TEMPLATE text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) From bd00350f185f8fa52325f8726e2b2588f25a9d4b Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Tue, 9 Apr 2024 12:36:42 +0000 Subject: [PATCH 10/17] Add log likelihood --- lmms_eval/models/llava_hf.py | 58 +++++++++++++++++++++++++----------- 1 file changed, 40 insertions(+), 18 deletions(-) diff --git a/lmms_eval/models/llava_hf.py b/lmms_eval/models/llava_hf.py index ce7355d91..284f57b4b 100644 --- a/lmms_eval/models/llava_hf.py +++ b/lmms_eval/models/llava_hf.py @@ -21,18 +21,21 @@ @register_model("llava_hf") class LlavaHf(lmms): """ - Llava Model for Hugging Face Transformers - https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/llava + Llava Model for Hugging Face Transformers: https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/llava + + Adapted from the InstructBLIP model in lmms_eval/models/instructblip.py """ def __init__( self, pretrained: str = "llava-hf/llava-1.5-7b-hf", - revision=None, - device: Optional[str] = "cuda", + revision: str = "main", + device: str = "cuda", dtype: Optional[Union[str, torch.dtype]] = "auto", - batch_size: Optional[Union[int, str]] = 1, - device_map="", + batch_size: Union[int, str] = 1, + trust_remote_code: Optional[bool] = False, + attn_implementation: Optional[str] = "flash_attention_2", + device_map: str ="", chat_template: Optional[str] = None, **kwargs, ) -> None: @@ -49,8 +52,8 @@ def __init__( self.device_map = device_map if isinstance(dtype, str) and dtype != "auto": dtype = getattr(torch, dtype) - self._model = LlavaForConditionalGeneration.from_pretrained(pretrained, revision=revision, torch_dtype=dtype, device_map=self.device_map) - self._image_processor = AutoProcessor.from_pretrained(pretrained, revision=revision) + self._model = LlavaForConditionalGeneration.from_pretrained(pretrained, revision=revision, torch_dtype=dtype, device_map=self.device_map, trust_remote_code=trust_remote_code, attn_implementation=attn_implementation) + self._image_processor = AutoProcessor.from_pretrained(pretrained, revision=revision, trust_remote_code=trust_remote_code) # Pad from left for batched generation: https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/llava#usage-tips self._image_processor.tokenizer.padding_side = "left" self._tokenizer = self._image_processor.tokenizer @@ -146,7 +149,7 @@ def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: res = [] pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding") - for contexts, doc_to_target, doc_to_visual, doc_id, task, split in [reg.args for reg in requests]: + for context, doc_to_target, doc_to_visual, doc_id, task, split in [reg.args for reg in requests]: # encode, pad, and truncate contexts for this batch if type(doc_to_target) == str: continuation = doc_to_target @@ -154,18 +157,35 @@ def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: continuation = doc_to_target(self.task_dict[task][split][doc_id]) visuals = [doc_to_visual(self.task_dict[task][split][doc_id])] visuals = self.flatten(visuals) - formatted_contexts = [f"{contexts}\n"] - formatted_continuation = [f"{contexts}\n{continuation}"] - model_inputs = self.processor(text=formatted_continuation, images=visuals, device=self.device) - for k, v in model_inputs.items(): - model_inputs[k] = v.to(self.device, non_blocking=True) if isinstance(v, torch.Tensor) else [vv.to(self.device, non_blocking=True) for vv in v] - for index in range(len(model_inputs["image_patches"])): - model_inputs["image_patches"][index] = model_inputs["image_patches"][index].to(dtype=next(self.model.parameters()).dtype) + image_tokens = [DEFAULT_IMAGE_TOKEN] * len(visuals) + image_tokens = " ".join(image_tokens) + context = f"{image_tokens}\n{context}" + # Apply chat template + messages = [{"role": "user", "content": context}, {"role": "assistant", "content": continuation}] + if self.chat_template is not None: + self.tokenizer.chat_template = self.chat_template + prompt = self.tokenizer.apply_chat_template(messages[:-1], tokenize=False, add_generation_prompt=True) + prompt_and_continuation = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False) + elif self.tokenizer.chat_template is not None: + prompt = self.tokenizer.apply_chat_template(messages[:-1], tokenize=False, add_generation_prompt=True) + prompt_and_continuation = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False) + else: + self.tokenizer.chat_template = VICUNA_CHAT_TEMPLATE + prompt = self.tokenizer.apply_chat_template(messages[:-1], tokenize=False, add_generation_prompt=True) + prompt_and_continuation = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False) + + formatted_contexts = [prompt] + formatted_continuation = [prompt_and_continuation] + model_inputs = self._image_processor(text=formatted_continuation, images=visuals).to(self._device, self._model.dtype) labels = model_inputs["input_ids"].clone() - contxt_id = self.processor(text=formatted_contexts, return_tensors="pt")["input_ids"] + contxt_id = self._image_processor(text=formatted_contexts, return_tensors="pt")["input_ids"] labels[: len(contxt_id)] = -100 + + if self.accelerator.is_main_process and doc_id % 100 == 0: + eval_logger.info(f"Prompt and continuation for doc ID {doc_id}:\n\n{formatted_continuation[0]}\n") + with torch.inference_mode(): outputs = self.model(**model_inputs, labels=labels) loss = outputs["loss"] @@ -240,9 +260,11 @@ def _collate(x): text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) elif self.tokenizer.chat_template is not None: text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + eval_logger.warning("Using the tokenizer's chat template to format the prompt.") else: self.tokenizer.chat_template = VICUNA_CHAT_TEMPLATE text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + eval_logger.warning("No chat template provided or set in the tokenizer. Using the default Vicuna chat template.") if self.accelerator.is_main_process and doc_id[0] % 100 == 0: eval_logger.info(f"Prompt for doc ID {doc_id[0]}:\n\n{text}\n") @@ -279,7 +301,7 @@ def _collate(x): res.append(text_outputs) self.cache_hook.add_partial("generate_until", (context, gen_kwargs), text_outputs) pbar.update(1) - # reorder this group of results back to original unsorted form + # reorder this group of results back to original unsorted form res = re_ords.get_original(res) pbar.close() From d410bea2cd39649cc6cb78f57cc71e12c48a20c8 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Tue, 9 Apr 2024 12:38:04 +0000 Subject: [PATCH 11/17] Style --- lmms_eval/constants.py | 2 +- lmms_eval/models/llava_hf.py | 11 +++++------ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/lmms_eval/constants.py b/lmms_eval/constants.py index 6c79b7049..b7b0283e8 100644 --- a/lmms_eval/constants.py +++ b/lmms_eval/constants.py @@ -1,4 +1,4 @@ DEFAULT_IMAGE_TOKEN = "" # Default chat templates -VICUNA_CHAT_TEMPLATE = "{% for message in messages %}{% if loop.index0 == 0 %}A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {{ message['content'] }} {% elif message['role'] == 'user' %}USER: {{ message['content'] }} {% else %} ASSISTANT: {{ message['content'] }}{{ eos_token }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'ASSISTANT:' }}{% endif %}" \ No newline at end of file +VICUNA_CHAT_TEMPLATE = "{% for message in messages %}{% if loop.index0 == 0 %}A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {{ message['content'] }} {% elif message['role'] == 'user' %}USER: {{ message['content'] }} {% else %} ASSISTANT: {{ message['content'] }}{{ eos_token }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'ASSISTANT:' }}{% endif %}" diff --git a/lmms_eval/models/llava_hf.py b/lmms_eval/models/llava_hf.py index 284f57b4b..e5d580229 100644 --- a/lmms_eval/models/llava_hf.py +++ b/lmms_eval/models/llava_hf.py @@ -35,7 +35,7 @@ def __init__( batch_size: Union[int, str] = 1, trust_remote_code: Optional[bool] = False, attn_implementation: Optional[str] = "flash_attention_2", - device_map: str ="", + device_map: str = "", chat_template: Optional[str] = None, **kwargs, ) -> None: @@ -59,7 +59,7 @@ def __init__( self._tokenizer = self._image_processor.tokenizer self._config = self._model.config self.batch_size_per_gpu = int(batch_size) - self.chat_template = chat_template + self.chat_template = chat_template if accelerator.num_processes > 1 and device_map == "": assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported." # If you want to use DistributedType.DEEPSPEED, you have to run accelerate config before using the model @@ -158,11 +158,10 @@ def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: visuals = [doc_to_visual(self.task_dict[task][split][doc_id])] visuals = self.flatten(visuals) - image_tokens = [DEFAULT_IMAGE_TOKEN] * len(visuals) image_tokens = " ".join(image_tokens) context = f"{image_tokens}\n{context}" - # Apply chat template + # Apply chat template messages = [{"role": "user", "content": context}, {"role": "assistant", "content": continuation}] if self.chat_template is not None: self.tokenizer.chat_template = self.chat_template @@ -253,7 +252,7 @@ def _collate(x): # Some benchmarks like MME do not contain image tokens, so we prepend them to the prompt. if DEFAULT_IMAGE_TOKEN not in context: context = f"{DEFAULT_IMAGE_TOKEN}\n{context}" - # Apply chat template + # Apply chat template messages = [{"role": "user", "content": context}] if self.chat_template is not None: self.tokenizer.chat_template = self.chat_template @@ -268,7 +267,7 @@ def _collate(x): if self.accelerator.is_main_process and doc_id[0] % 100 == 0: eval_logger.info(f"Prompt for doc ID {doc_id[0]}:\n\n{text}\n") - + inputs = self._image_processor(images=visuals, text=text, return_tensors="pt").to(self._device, self._model.dtype) gen_kwargs["image_sizes"] = [visuals[idx].size for idx in range(len(visuals))] From 80a1e7d3270db49b8c378e3e2bbc430bdf6380f3 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Tue, 9 Apr 2024 12:46:15 +0000 Subject: [PATCH 12/17] Remove loggine --- lmms_eval/models/llava_hf.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/lmms_eval/models/llava_hf.py b/lmms_eval/models/llava_hf.py index e5d580229..59968d8e9 100644 --- a/lmms_eval/models/llava_hf.py +++ b/lmms_eval/models/llava_hf.py @@ -259,11 +259,9 @@ def _collate(x): text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) elif self.tokenizer.chat_template is not None: text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) - eval_logger.warning("Using the tokenizer's chat template to format the prompt.") else: self.tokenizer.chat_template = VICUNA_CHAT_TEMPLATE text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) - eval_logger.warning("No chat template provided or set in the tokenizer. Using the default Vicuna chat template.") if self.accelerator.is_main_process and doc_id[0] % 100 == 0: eval_logger.info(f"Prompt for doc ID {doc_id[0]}:\n\n{text}\n") From 7c7b9699af057ca0d750e7cb1bc8e3731c22e852 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Tue, 9 Apr 2024 13:18:40 +0000 Subject: [PATCH 13/17] Fix llava loglikelihood --- lmms_eval/models/llava.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lmms_eval/models/llava.py b/lmms_eval/models/llava.py index b3cb8a66d..2bd6f4414 100644 --- a/lmms_eval/models/llava.py +++ b/lmms_eval/models/llava.py @@ -198,7 +198,7 @@ def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: else: image = None - prompts_input = contexts[0] + prompts_input = contexts[0] if isinstance(contexts, list) else contexts if image is not None and len(image) != 0 and DEFAULT_IMAGE_TOKEN not in prompts_input: """ @@ -209,7 +209,7 @@ def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: """ image_tokens = [DEFAULT_IMAGE_TOKEN] * len(visuals) image_tokens = " ".join(image_tokens) - prompts_input = image_tokens + "\n" + contexts[0] + prompts_input = image_tokens + "\n" + (contexts[0] if isinstance(contexts, list) else contexts) conv = conv_templates[self.conv_template].copy() conv.append_message(conv.roles[0], prompts_input) From 3f0cad909dd7f4ccfe05beaa17782e200e2c5bc6 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Tue, 9 Apr 2024 14:02:53 +0000 Subject: [PATCH 14/17] Tidy up model calls --- lmms_eval/models/llava_hf.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/lmms_eval/models/llava_hf.py b/lmms_eval/models/llava_hf.py index 59968d8e9..cdcfa83d7 100644 --- a/lmms_eval/models/llava_hf.py +++ b/lmms_eval/models/llava_hf.py @@ -37,6 +37,7 @@ def __init__( attn_implementation: Optional[str] = "flash_attention_2", device_map: str = "", chat_template: Optional[str] = None, + use_cache: bool = True, **kwargs, ) -> None: super().__init__() @@ -60,6 +61,7 @@ def __init__( self._config = self._model.config self.batch_size_per_gpu = int(batch_size) self.chat_template = chat_template + self.use_cache = use_cache if accelerator.num_processes > 1 and device_map == "": assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported." # If you want to use DistributedType.DEEPSPEED, you have to run accelerate config before using the model @@ -177,7 +179,7 @@ def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: formatted_contexts = [prompt] formatted_continuation = [prompt_and_continuation] - model_inputs = self._image_processor(text=formatted_continuation, images=visuals).to(self._device, self._model.dtype) + model_inputs = self._image_processor(text=formatted_continuation, images=visuals).to(self._device, self.model.dtype) labels = model_inputs["input_ids"].clone() contxt_id = self._image_processor(text=formatted_contexts, return_tensors="pt")["input_ids"] labels[: len(contxt_id)] = -100 @@ -266,7 +268,7 @@ def _collate(x): if self.accelerator.is_main_process and doc_id[0] % 100 == 0: eval_logger.info(f"Prompt for doc ID {doc_id[0]}:\n\n{text}\n") - inputs = self._image_processor(images=visuals, text=text, return_tensors="pt").to(self._device, self._model.dtype) + inputs = self._image_processor(images=visuals, text=text, return_tensors="pt").to(self._device, self.model.dtype) gen_kwargs["image_sizes"] = [visuals[idx].size for idx in range(len(visuals))] if "max_new_tokens" not in gen_kwargs: @@ -278,13 +280,14 @@ def _collate(x): if "num_beams" not in gen_kwargs: gen_kwargs["num_beams"] = 1 try: - cont = self._model.generate( + cont = self.model.generate( **inputs, do_sample=True if gen_kwargs["temperature"] > 0 else False, temperature=gen_kwargs["temperature"], top_p=gen_kwargs["top_p"], num_beams=gen_kwargs["num_beams"], max_new_tokens=gen_kwargs["max_new_tokens"], + use_cache=self.use_cache, ) except Exception as e: eval_logger.error(f"Error {e} in generating") From 6c2720320f58d2e706ab03bae97a65747fd264f3 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Tue, 9 Apr 2024 14:35:46 +0000 Subject: [PATCH 15/17] Add cmd --- lmms_eval/models/llava_hf.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/lmms_eval/models/llava_hf.py b/lmms_eval/models/llava_hf.py index cdcfa83d7..2b6d88210 100644 --- a/lmms_eval/models/llava_hf.py +++ b/lmms_eval/models/llava_hf.py @@ -24,6 +24,16 @@ class LlavaHf(lmms): Llava Model for Hugging Face Transformers: https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/llava Adapted from the InstructBLIP model in lmms_eval/models/instructblip.py + + Example usage: + + accelerate launch --num_processes=8 -m lmms_eval \ + --model llava_hf \ + --model_args pretrained=llava-hf/llava-1.5-7b-hf \ + --tasks mme \ + --batch_size 1 \ + --output_path ./logs/ \ + --log_samples """ def __init__( From c531bf0302d79389851b2830fbd7621b7dfb2ce7 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Tue, 9 Apr 2024 15:45:48 +0000 Subject: [PATCH 16/17] Split logging --- lmms_eval/models/llava_hf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/lmms_eval/models/llava_hf.py b/lmms_eval/models/llava_hf.py index 2b6d88210..4ec875567 100644 --- a/lmms_eval/models/llava_hf.py +++ b/lmms_eval/models/llava_hf.py @@ -195,6 +195,7 @@ def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: labels[: len(contxt_id)] = -100 if self.accelerator.is_main_process and doc_id % 100 == 0: + eval_logger.info(f"Prompt for doc ID {doc_id}:\n\n{formatted_contexts[0]}\n") eval_logger.info(f"Prompt and continuation for doc ID {doc_id}:\n\n{formatted_continuation[0]}\n") with torch.inference_mode(): From 6d08fe8713adefd29e1d248053d4927703ba81d7 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Wed, 10 Apr 2024 13:40:28 +0000 Subject: [PATCH 17/17] Refactor --- lmms_eval/constants.py | 4 ---- lmms_eval/models/llava_hf.py | 6 +++++- 2 files changed, 5 insertions(+), 5 deletions(-) delete mode 100644 lmms_eval/constants.py diff --git a/lmms_eval/constants.py b/lmms_eval/constants.py deleted file mode 100644 index b7b0283e8..000000000 --- a/lmms_eval/constants.py +++ /dev/null @@ -1,4 +0,0 @@ -DEFAULT_IMAGE_TOKEN = "" - -# Default chat templates -VICUNA_CHAT_TEMPLATE = "{% for message in messages %}{% if loop.index0 == 0 %}A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {{ message['content'] }} {% elif message['role'] == 'user' %}USER: {{ message['content'] }} {% else %} ASSISTANT: {{ message['content'] }}{{ eos_token }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'ASSISTANT:' }}{% endif %}" diff --git a/lmms_eval/models/llava_hf.py b/lmms_eval/models/llava_hf.py index 4ec875567..e5a88da60 100644 --- a/lmms_eval/models/llava_hf.py +++ b/lmms_eval/models/llava_hf.py @@ -5,7 +5,6 @@ from lmms_eval.api.instance import Instance from lmms_eval.api.model import lmms from lmms_eval.api.registry import register_model -from lmms_eval.constants import DEFAULT_IMAGE_TOKEN, VICUNA_CHAT_TEMPLATE from accelerate import Accelerator, DistributedType from accelerate.state import AcceleratorState from typing import List, Optional, Union, Tuple @@ -17,6 +16,11 @@ eval_logger = logging.getLogger("lmms-eval") +DEFAULT_IMAGE_TOKEN = "" + +# Default chat for llava-hf/llava-1.5 models: https://huggingface.co/collections/llava-hf/llava-15-65f762d5b6941db5c2ba07e0 +VICUNA_CHAT_TEMPLATE = "{% for message in messages %}{% if loop.index0 == 0 %}A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {{ message['content'] }} {% elif message['role'] == 'user' %}USER: {{ message['content'] }} {% else %} ASSISTANT: {{ message['content'] }}{{ eos_token }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'ASSISTANT:' }}{% endif %}" + @register_model("llava_hf") class LlavaHf(lmms):