Skip to content

Commit

Permalink
add vila
Browse files Browse the repository at this point in the history
  • Loading branch information
pufanyi committed Apr 3, 2024
1 parent d5ac624 commit 89513b7
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 7 deletions.
17 changes: 10 additions & 7 deletions lmms_eval/models/gpt4v.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from lmms_eval.api.model import lmms
from lmms_eval.api.registry import register_model
from lmms_eval import utils
from lmms_eval.api.samplers import Context

from PIL import Image

Expand Down Expand Up @@ -65,17 +66,19 @@ def generate_until(self, requests) -> List[str]:

for contexts, gen_kwargs, doc_to_visual, doc_id, task, split in [reg.args for reg in requests]:
# encode, pad, and truncate contexts for this batch
visuals = [doc_to_visual(self.task_dict[task][split][doc_id])]
visuals = self.flatten(visuals)
imgs = []
for visual in visuals:
img = self.encode_image(visual)
imgs.append(img)
# visuals = [doc_to_visual(self.task_dict[task][split][doc_id])]
# visuals = contexts.get_visions()
# visuals = self.flatten(visuals)
# imgs = []
# for visual in visuals:
# img = self.encode_image(visual)
# imgs.append(img)

payload = {"model": "gpt-4-vision-preview", "messages": []}
response_json = {"role": "user", "content": []}
# When there is no image token in the context, append the image to the text
if self.image_token not in contexts:
image_token_in_context = contexts.already_have_image_token(self.image_token)
if image_token_in_context:
payload["messages"].append(deepcopy(response_json))
payload["messages"][0]["content"].append({"type": "text", "text": contexts})
for img in imgs:
Expand Down
6 changes: 6 additions & 0 deletions lmms_eval/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ def __init__(
self.device_map = device_map

self._tokenizer, self._model, self._image_processor, self._max_length = load_pretrained_model(pretrained, None, get_model_name_from_path(pretrained), device_map=self.device_map, use_flash_attention_2=use_flash_attention_2)
if self._image_processor is None:
vision_tower = self._model.get_vision_tower()
if not vision_tower.is_loaded:
vision_tower.load_model()
vision_tower.to(device=device, dtype=torch.float16)
self._image_processor = vision_tower.image_processor
self._config = self._model.config
self.model.eval()
self.model.tie_weights()
Expand Down

0 comments on commit 89513b7

Please sign in to comment.