diff --git a/README.md b/README.md index cc5f80794..5f0997a94 100644 --- a/README.md +++ b/README.md @@ -269,7 +269,7 @@ We also thank: @misc{lmms_eval2024, title={LMMs-Eval: Accelerating the Development of Large Multimoal Models}, url={https://github.com/EvolvingLMMs-Lab/lmms-eval}, - author={Bo Li*, Peiyuan Zhang*, Kaicheng Zhang*, Fanyi Pu*, Xinrun Du, Yuhao Dong, Haotian Liu, Yuanhan Zhang, Ge Zhang, Chunyuan Li and Ziwei Liu}, + author={Bo Li*, Peiyuan Zhang*, Kaichen Zhang*, Fanyi Pu*, Xinrun Du, Yuhao Dong, Haotian Liu, Yuanhan Zhang, Ge Zhang, Chunyuan Li and Ziwei Liu}, publisher = {Zenodo}, version = {v0.1.0}, month={March}, diff --git a/lmms_eval/__main__.py b/lmms_eval/__main__.py index 1f45a85e3..0c6661450 100644 --- a/lmms_eval/__main__.py +++ b/lmms_eval/__main__.py @@ -274,7 +274,7 @@ def cli_evaluate_single(args: Union[argparse.Namespace, None] = None) -> None: hash_input = f"{args.model_args}".encode("utf-8") hash_output = hashlib.sha256(hash_input).hexdigest()[:6] path = Path(args.output_path) - path = path.expanduser().resolve().joinpath(f"{args.model}").joinpath(f"model_args_{hash_output}").joinpath(f"{datetime_str}_{args.log_samples_suffix}") + path = path.expanduser().resolve().joinpath(f"{datetime_str}_{args.log_samples_suffix}_{args.model}_model_args_{hash_output}") args.output_path = path elif args.log_samples and not args.output_path: diff --git a/lmms_eval/models/llava.py b/lmms_eval/models/llava.py index c0869d48e..b3cb8a66d 100644 --- a/lmms_eval/models/llava.py +++ b/lmms_eval/models/llava.py @@ -1,14 +1,19 @@ import torch + +torch.backends.cuda.matmul.allow_tf32 = True + import logging import copy from tqdm import tqdm +from datetime import timedelta + from lmms_eval import utils from lmms_eval.api.instance import Instance from lmms_eval.api.model import lmms from lmms_eval.api.registry import register_model from lmms_eval.utils import stop_sequences_criteria -from accelerate import Accelerator, DistributedType +from accelerate import Accelerator, DistributedType, InitProcessGroupKwargs from accelerate.state import AcceleratorState from typing import List, Optional, Union, Tuple import warnings @@ -47,7 +52,8 @@ def __init__( batch_size: Optional[Union[int, str]] = 1, trust_remote_code: Optional[bool] = False, revision=None, - use_flash_attention_2=False, + use_flash_attention_2=True, + device_map="", conv_template="vicuna_v1", use_cache=True, truncate_context=False, # whether to truncate the context in generation, set it False for LLaVA-1.6 @@ -57,17 +63,16 @@ def __init__( # Do not use kwargs for now assert kwargs == {}, f"Unexpected kwargs: {kwargs}" - accelerator = Accelerator() - if accelerator.num_processes > 1: + accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52)) + accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs]) + if accelerator.num_processes > 1 and device_map == "": self._device = torch.device(f"cuda:{accelerator.local_process_index}") + self.device_map = f"cuda:{accelerator.local_process_index}" else: - self._device = device - ( - self._tokenizer, - self._model, - self._image_processor, - self._max_length, - ) = load_pretrained_model(pretrained, None, get_model_name_from_path(pretrained), device_map=self._device) + self._device = torch.device(device) + self.device_map = device_map + + self._tokenizer, self._model, self._image_processor, self._max_length = load_pretrained_model(pretrained, None, get_model_name_from_path(pretrained), device_map=self.device_map, use_flash_attention_2=use_flash_attention_2) self._config = self._model.config self.model.eval() self.model.tie_weights() @@ -77,7 +82,7 @@ def __init__( self.use_cache = use_cache self.truncate_context = truncate_context # assert self.batch_size_per_gpu == 1, "Llava currently does not support batched generation. See https://github.com/haotian-liu/LLaVA/issues/754. HF Llava also has this issue." - if accelerator.num_processes > 1: + if accelerator.num_processes > 1 and device_map == "": assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported." # If you want to use DistributedType.DEEPSPEED, you have to run accelerate config before using the model # Also, you have to select zero stage 0 (equivalent to DDP) in order to make the prepare model works @@ -89,6 +94,7 @@ def __init__( } AcceleratorState().deepspeed_plugin.deepspeed_config_process(must_match=True, **kwargs) eval_logger.info("Detected that you are using DistributedType.DEEPSPEED. Make sure you run `accelerate config` and set zero stage to 0") + if accelerator.distributed_type == DistributedType.FSDP or accelerator.distributed_type == DistributedType.DEEPSPEED: self._model = accelerator.prepare(self.model) else: @@ -98,10 +104,15 @@ def __init__( eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism") self._rank = self.accelerator.local_process_index self._world_size = self.accelerator.num_processes + elif accelerator.num_processes == 1 and device_map == "auto": + eval_logger.info(f"Using {accelerator.num_processes} devices with tensor parallelism") + self._rank = 0 + self._word_size = 1 else: + eval_logger.info(f"Using single device: {self._device}") self.model.to(self._device) self._rank = 0 - self._word_size = 1 + self._world_size = 1 @property def config(self): diff --git a/lmms_eval/tasks/llava-in-the-wild/llava-in-the-wild.yaml b/lmms_eval/tasks/llava-in-the-wild/llava-in-the-wild.yaml index c07c4b5c9..535202795 100644 --- a/lmms_eval/tasks/llava-in-the-wild/llava-in-the-wild.yaml +++ b/lmms_eval/tasks/llava-in-the-wild/llava-in-the-wild.yaml @@ -32,8 +32,8 @@ metric_list: higher_is_better: true metadata: version: 0.0 - gpt_eval_model_name: "gpt-4-0314" + gpt_eval_model_name: "gpt-4-0613" model_specific_prompt_kwargs: default: pre_prompt: "" - post_prompt: "" \ No newline at end of file + post_prompt: "" diff --git a/pyproject.toml b/pyproject.toml index 70fcbba59..de42bb346 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,7 @@ dependencies = [ "openai>=1.0.0", "pycocoevalcap", "tqdm-multiprocess", - "transformers>=4.31.0", + "transformers==4.37.2", "zstandard", "pillow", "pyyaml",