Skip to content

Commit

Permalink
Refactor image token handling in LMMS evaluation code
Browse files Browse the repository at this point in the history
  • Loading branch information
pufanyi committed Mar 31, 2024
1 parent d8dcd0a commit 3e2b24f
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 15 deletions.
7 changes: 7 additions & 0 deletions lmms_eval/api/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,13 @@ def add_question(self, doc, data_frame=None, index=None):
def get_text(self, *, image_tokens="<image>", lazy=True):
texts = []
vision = []
already_have_images = False
for context in self.contexts:
if isinstance(context, str) and image_tokens in context:
already_have_images = True
break
if already_have_images:
image_tokens = ""
for context in self.contexts:
if isinstance(context, LazyLoadedImages):
if isinstance(image_tokens, str):
Expand Down
31 changes: 16 additions & 15 deletions lmms_eval/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def _collate(x):
task = task[0]
split = split[0]
# batched_visuals = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id] # [B, N]
contexts_texts, batched_visuals = zip(*[context.get_text(lazy=False) for context in contexts]) # [B, N]
contexts_texts, batched_visuals = zip(*[context.get_text(image_tokens=DEFAULT_IMAGE_TOKEN ,lazy=False) for context in contexts]) # [B, N]
flattened_visuals = self.flatten(batched_visuals) # [B*N]
# batched_visuals = context.get_visions() # [B, N]
# flattened_visuals = contexts[0].get_visions() # [B*N]
Expand Down Expand Up @@ -317,20 +317,21 @@ def _collate(x):
question_input = []

for visual, context in zip(batched_visuals, contexts_texts):
if image_tensor is not None and len(image_tensor) != 0 and DEFAULT_IMAGE_TOKEN not in context:
"""
Three senarios:
1. No image, and there for, no image token should be added.
2. image token is already specified in the context, so we don't need to add it.
3. image token is not specified in the context and there is image inputs, so we need to add it. In this case, we add the image token at the beginning of the context and add a new line.
"""
image_tokens = [DEFAULT_IMAGE_TOKEN] * len(visual) if isinstance(visual, list) else [DEFAULT_IMAGE_TOKEN]
image_tokens = " ".join(image_tokens)
if isinstance(context, list):
context = "".join(context)
question = image_tokens + "\n" + context
else:
question = context
question = context
# if image_tensor is not None and len(image_tensor) != 0 and DEFAULT_IMAGE_TOKEN not in context:
# """
# Three senarios:
# 1. No image, and there for, no image token should be added.
# 2. image token is already specified in the context, so we don't need to add it.
# 3. image token is not specified in the context and there is image inputs, so we need to add it. In this case, we add the image token at the beginning of the context and add a new line.
# """
# image_tokens = [DEFAULT_IMAGE_TOKEN] * len(visual) if isinstance(visual, list) else [DEFAULT_IMAGE_TOKEN]
# image_tokens = " ".join(image_tokens)
# if isinstance(context, list):
# context = "".join(context)
# question = image_tokens + "\n" + context
# else:
# question = context

conv = conv_templates[self.conv_template].copy()
conv.append_message(conv.roles[0], question)
Expand Down

0 comments on commit 3e2b24f

Please sign in to comment.