Skip to content

Commit

Permalink
Merge pull request #2 from Qiulin-W/main
Browse files Browse the repository at this point in the history
add UNIAA training set and UNIAA-Bench
  • Loading branch information
Qiulin-W authored Sep 27, 2024
2 parents c2053f6 + 9b89033 commit 13f2294
Show file tree
Hide file tree
Showing 5 changed files with 415 additions and 1 deletion.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ The UNIAA-Bench overview. (a) UNIAA-QA contains 5354 Image-Question-Answer sampl


## Release

- [9/25] 🔥 Our [UNIAA](https://huggingface.co/datasets/zkzhou/UNIAA) data is released! The corresponding fine-tuning and evaluation code can be found in the GitHub repository folder.
- [4/15] 🔥 We build the page of UNIAA!


Expand Down
112 changes: 112 additions & 0 deletions UNIAA-Bench/aesthetic_assessment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import argparse
import torch

import os

from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria

from PIL import Image
from tqdm import tqdm

import requests
from PIL import Image
from io import BytesIO


def load_image(image_file):
if image_file.startswith('http') or image_file.startswith('https'):
response = requests.get(image_file)
image = Image.open(BytesIO(response.content)).convert('RGB')
else:
image = Image.open(image_file).convert('RGB')
return image

def eval_model(args):
disable_torch_init()
model_name = get_model_name_from_path(args.model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name)
qs = args.query
if model.config.mm_use_im_start_end:
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
else:
qs = DEFAULT_IMAGE_TOKEN + '\n' + qs

if 'llama-2' in model_name.lower():
conv_mode = "llava_llama_2"
elif "v1" in model_name.lower():
conv_mode = "llava_v1"
elif "mpt" in model_name.lower():
conv_mode = "mpt"
else:
conv_mode = "llava_v0"
if args.conv_mode is not None and conv_mode != args.conv_mode:
print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))
else:
args.conv_mode = conv_mode

conv = conv_templates[args.conv_mode].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
prompt += " The aesthetic quality is"

import json

with open("ava_scores.json") as f:
data = json.load(f)

out_lines = []
for i, llddata in enumerate(tqdm(data)):
image_path = llddata["image"]
iaa_score = llddata["iaa_score"]

image = load_image(image_path)
image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda()

input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()

stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)

with torch.inference_mode():
output_logits = model(input_ids,
images=image_tensor)["logits"][:,-1]

probs, inds = output_logits.sort(dim=-1, descending=True)
#print(probs[0, 0:100], inds[0, 0:100], tokenizer.convert_ids_to_tokens(inds[0, 0:100]))
# 1781: good, 6460: poor
# 1880: high, 4482: low
# 15129: excellent, 1781: good, 6534: fair, 6460: poor, 4319: bad
'''
lgood, lpoor = output_logits[0,1781].item(), output_logits[0,6460].item()
lhigh, llow = output_logits[0,1880].item(), output_logits[0,4482].item()
llddata["logit_good"] = lgood
llddata["logit_poor"] = lpoor
llddata["logit_high"] = lhigh
llddata["logit_low"] = llow
out_lines.append(" ".join([image_path, str(iaa_score),
str(float(llddata["logit_good"])), str(float(llddata["logit_poor"])),
str(float(llddata["logit_high"])), str(float(llddata["logit_low"])),
]) + "\n")
'''
lexcel, lgood, lfair, lpoor, lbad = output_logits[0,15129].item(), output_logits[0,1781].item(), output_logits[0,6534].item(), output_logits[0,6460].item(), output_logits[0,4319].item()
out_lines.append(" ".join([image_path, str(iaa_score),
str(float(lexcel)), str(float(lgood)), str(float(lfair)), str(float(lpoor)), str(float(lbad))])
+ "\n")
f = open(os.path.join(args.model_path, "zero-shot-iaa-score-ava-5gear.txt"), "w")
f.writelines(out_lines)
f.close()

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, default="")
parser.add_argument("--model-base", type=str, default="liuhaotian/llava-v1.5-7b")
parser.add_argument("--query", type=str, default="Rate this image from an aesthetic perspective.")
parser.add_argument("--conv-mode", type=str, default="llava_v1")
args = parser.parse_args()
eval_model(args)
109 changes: 109 additions & 0 deletions UNIAA-Bench/aesthetic_describe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import argparse
import torch
from tqdm import tqdm
import json
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle, Conversation
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
from PIL import Image
import requests
from PIL import Image
from io import BytesIO


device = "cuda" if torch.cuda.is_available() else "cpu"
def load_image(image_file):
if image_file.startswith('http') or image_file.startswith('https'):
response = requests.get(image_file)
image = Image.open(BytesIO(response.content)).convert('RGB')
else:
image = Image.open(image_file).convert('RGB')
return image


disable_torch_init()
model_name = "llava-v1.5"
tokenizer, model, image_processor, context_len = load_pretrained_model(checkpoint, None, model_name)
question_file = "UNIAA_Describe.json"
output_file = "UNIAA_Describe_output.json"
output_dir = 'UNIAA-Bench/'
output_file = output_dir + '/' + output_file


llava_model_path = [
checkpoint + '/pytorch_model-00001-of-00003.bin',
checkpoint + '/pytorch_model-00002-of-00003.bin',
checkpoint + '/pytorch_model-00003-of-00003.bin'
]
state_dict = {}
for weight_file in llava_model_path:
weight_file_path = weight_file
state_dict_part = torch.load(weight_file_path, map_location=torch.device('cpu'))
new_state_dict_part = {}
for k, v in state_dict_part.items():
new_state_dict_part[k] = v
state_dict.update(new_state_dict_part)

encoder_dict = {k:v for k, v in state_dict.items() if k.startswith('model.vision_tower.vision_tower.vision_model')}
for old_key in list(encoder_dict.keys()):
new_key = old_key.replace('model.vision_tower.vision_tower.vision_model.', 'vision_model.')
encoder_dict[new_key] = encoder_dict.pop(old_key)
vision_tower = model.get_vision_tower()
vision_tower.vision_tower.load_state_dict(encoder_dict, strict=True)



with open(question_file) as f:
description_data = json.load(f)

for i, data in enumerate(tqdm(description_data)):
filename = data["img_path"]
qs = data["question"]
if model.config.mm_use_im_start_end:
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
else:
qs = DEFAULT_IMAGE_TOKEN + '\n' + qs

if 'llama-2' in model_name.lower():
conv_mode = "llava_llama_2"
elif "v1" in model_name.lower():
conv_mode = "llava_v1"
elif "mpt" in model_name.lower():
conv_mode = "mpt"
else:
conv_mode = "llava_v0"

image = load_image(filename)
image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda()
conv = conv_templates[conv_mode].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=image_tensor,
# num_beams=1,
do_sample=True,
temperature=0.1,
top_p=0.7,
max_new_tokens=1024,
use_cache=True,
stopping_criteria=[stopping_criteria])

input_token_len = input_ids.shape[1]
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
outputs = outputs.strip()
if outputs.endswith(stop_str):
outputs = outputs[:-len(stop_str)]
outputs = outputs.strip()
data["response"] = outputs

with open(output_file, "a") as wf:
json.dump(data, wf)
156 changes: 156 additions & 0 deletions UNIAA-Bench/aesthetic_perception.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import argparse
import torch
from tqdm import tqdm
import json
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle, Conversation
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
from PIL import Image
import requests
from PIL import Image
from io import BytesIO


def load_image(image_file):
if image_file.startswith('http') or image_file.startswith('https'):
response = requests.get(image_file)
image = Image.open(BytesIO(response.content)).convert('RGB')
else:
image = Image.open(image_file).convert('RGB')
return image


def eval_model(args):
# Model
disable_torch_init()
model_version = args.model_path.split("/")[-1]
answers_file = args.answers_file
model_name = get_model_name_from_path(args.model_name)
tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name)
# Please add the bin file of the download UNIAA-LLaVA model bellows
llava_model_path = [
args.absolute_model_path + '/pytorch_model-00001-of-00003.bin',
args.absolute_model_path + '/pytorch_model-00002-of-00003.bin',
args.absolute_model_path + '/pytorch_model-00003-of-00003.bin',
]
state_dict = {}
for weight_file in llava_model_path:
weight_file_path = weight_file
state_dict_part = torch.load(weight_file_path, map_location=torch.device('cpu'))
new_state_dict_part = {}
for k, v in state_dict_part.items():
new_state_dict_part[k] = v
state_dict.update(new_state_dict_part)

encoder_dict = {k:v for k, v in state_dict.items() if k.startswith('model.vision_tower.vision_tower.vision_model')}
for old_key in list(encoder_dict.keys()):
new_key = old_key.replace('model.vision_tower.vision_tower.vision_model.', 'vision_model.')
encoder_dict[new_key] = encoder_dict.pop(old_key)
vision_tower = model.get_vision_tower()
vision_tower.vision_tower.load_state_dict(encoder_dict, strict=True)


with open(args.questions_file) as f:
llvqa_data = json.load(f)

for i, llddata in enumerate(tqdm(llvqa_data)):
filename = llddata["img_path"]
if args.lang == "en":
message = llddata["question"] + "\nChoose between one of the options as follows:\n"
elif args.lang == "zh":
message = llddata["question"] + "\在下列选项中选择一个:\n"
else:
raise NotImplementedError("IAA-Bench does not support languages other than English (en) and Chinese (zh) yet. Contact us (https://github.com/VQAssessment/Q-Bench/) to convert Q-Bench into more languages.")
if len(llddata["candidates"]) == 4:
for choice, ans in zip(["A.", "B.", "C.", "D."], llddata["candidates"]):
message += f"{choice} {ans}\n"
elif len(llddata["candidates"]) == 2:
for choice, ans in zip(["A.", "B."], llddata["candidates"]):
message += f"{choice} {ans}\n"
elif len(llddata["candidates"]) == 3:
for choice, ans in zip(["A.", "B.", "C."], llddata["candidates"]):
message += f"{choice} {ans}\n"
elif len(llddata["candidates"]) == 5:
for choice, ans in zip(["A.", "B.", "C.", "D.", "E."], llddata["candidates"]):
message += f"{choice} {ans}\n"
qs = message

if model.config.mm_use_im_start_end:
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
else:
qs = DEFAULT_IMAGE_TOKEN + '\n' + qs

if 'llama-2' in model_name.lower():
conv_mode = "llava_llama_2"
elif "v1" in model_name.lower():
conv_mode = "llava_v1"
elif "mpt" in model_name.lower():
conv_mode = "mpt"
else:
conv_mode = "llava_v0"

if args.conv_mode is not None and conv_mode != args.conv_mode:
print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))
else:
args.conv_mode = conv_mode

conv = conv_templates[args.conv_mode].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()

image = load_image(filename)
image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda()

input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()

stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)


with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=image_tensor,
num_beams=1,
do_sample=False,
temperature=0,
max_new_tokens=1024,
use_cache=True,
stopping_criteria=[stopping_criteria])

input_token_len = input_ids.shape[1]
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
if n_diff_input_output > 0:
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
outputs = outputs.strip()
if outputs.endswith(stop_str):
outputs = outputs[:-len(stop_str)]
outputs = outputs.strip()
llddata["response"] = outputs

with open(answers_file, "a") as wf:
json.dump(llddata, wf)

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-name", type=str, default="llava-v1.5")
parser.add_argument("--model-path", type=str, default="zhouzhaokun/UNIAA-LLaVA")
parser.add_argument("--absolute_model-path", type=str, default="path-to-UNIAA-LLaVA")

parser.add_argument("--model-base", type=str, default=None)
parser.add_argument("--questions-file", type=str,
default="UNIAA_QA.json")
parser.add_argument("--conv-mode", type=str, default="llava_v1")
parser.add_argument("--lang", type=str, default="en")
args = parser.parse_args(--answers-file, type=str, default="UNIAA_QA_answers.json")
eval_model(args)




Loading

0 comments on commit 13f2294

Please sign in to comment.