diff --git a/lmms_eval/models/gpt4v.py b/lmms_eval/models/gpt4v.py index d2ec2025d..68d6d20b9 100644 --- a/lmms_eval/models/gpt4v.py +++ b/lmms_eval/models/gpt4v.py @@ -12,6 +12,7 @@ from lmms_eval.api.model import lmms from lmms_eval.api.registry import register_model from lmms_eval import utils +from lmms_eval.api.samplers import Context from PIL import Image @@ -65,17 +66,19 @@ def generate_until(self, requests) -> List[str]: for contexts, gen_kwargs, doc_to_visual, doc_id, task, split in [reg.args for reg in requests]: # encode, pad, and truncate contexts for this batch - visuals = [doc_to_visual(self.task_dict[task][split][doc_id])] - visuals = self.flatten(visuals) - imgs = [] - for visual in visuals: - img = self.encode_image(visual) - imgs.append(img) + # visuals = [doc_to_visual(self.task_dict[task][split][doc_id])] + # visuals = contexts.get_visions() + # visuals = self.flatten(visuals) + # imgs = [] + # for visual in visuals: + # img = self.encode_image(visual) + # imgs.append(img) payload = {"model": "gpt-4-vision-preview", "messages": []} response_json = {"role": "user", "content": []} # When there is no image token in the context, append the image to the text - if self.image_token not in contexts: + image_token_in_context = contexts.already_have_image_token(self.image_token) + if image_token_in_context: payload["messages"].append(deepcopy(response_json)) payload["messages"][0]["content"].append({"type": "text", "text": contexts}) for img in imgs: diff --git a/lmms_eval/models/llava.py b/lmms_eval/models/llava.py index 80134af78..3c2f9e9f0 100644 --- a/lmms_eval/models/llava.py +++ b/lmms_eval/models/llava.py @@ -73,6 +73,12 @@ def __init__( self.device_map = device_map self._tokenizer, self._model, self._image_processor, self._max_length = load_pretrained_model(pretrained, None, get_model_name_from_path(pretrained), device_map=self.device_map, use_flash_attention_2=use_flash_attention_2) + if self._image_processor is None: + vision_tower = self._model.get_vision_tower() + if not vision_tower.is_loaded: + vision_tower.load_model() + vision_tower.to(device=device, dtype=torch.float16) + self._image_processor = vision_tower.image_processor self._config = self._model.config self.model.eval() self.model.tie_weights()