-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2 from Qiulin-W/main
add UNIAA training set and UNIAA-Bench
- Loading branch information
Showing
5 changed files
with
415 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
import argparse | ||
import torch | ||
|
||
import os | ||
|
||
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN | ||
from llava.conversation import conv_templates, SeparatorStyle | ||
from llava.model.builder import load_pretrained_model | ||
from llava.utils import disable_torch_init | ||
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria | ||
|
||
from PIL import Image | ||
from tqdm import tqdm | ||
|
||
import requests | ||
from PIL import Image | ||
from io import BytesIO | ||
|
||
|
||
def load_image(image_file): | ||
if image_file.startswith('http') or image_file.startswith('https'): | ||
response = requests.get(image_file) | ||
image = Image.open(BytesIO(response.content)).convert('RGB') | ||
else: | ||
image = Image.open(image_file).convert('RGB') | ||
return image | ||
|
||
def eval_model(args): | ||
disable_torch_init() | ||
model_name = get_model_name_from_path(args.model_path) | ||
tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name) | ||
qs = args.query | ||
if model.config.mm_use_im_start_end: | ||
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs | ||
else: | ||
qs = DEFAULT_IMAGE_TOKEN + '\n' + qs | ||
|
||
if 'llama-2' in model_name.lower(): | ||
conv_mode = "llava_llama_2" | ||
elif "v1" in model_name.lower(): | ||
conv_mode = "llava_v1" | ||
elif "mpt" in model_name.lower(): | ||
conv_mode = "mpt" | ||
else: | ||
conv_mode = "llava_v0" | ||
if args.conv_mode is not None and conv_mode != args.conv_mode: | ||
print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode)) | ||
else: | ||
args.conv_mode = conv_mode | ||
|
||
conv = conv_templates[args.conv_mode].copy() | ||
conv.append_message(conv.roles[0], qs) | ||
conv.append_message(conv.roles[1], None) | ||
prompt = conv.get_prompt() | ||
prompt += " The aesthetic quality is" | ||
|
||
import json | ||
|
||
with open("ava_scores.json") as f: | ||
data = json.load(f) | ||
|
||
out_lines = [] | ||
for i, llddata in enumerate(tqdm(data)): | ||
image_path = llddata["image"] | ||
iaa_score = llddata["iaa_score"] | ||
|
||
image = load_image(image_path) | ||
image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda() | ||
|
||
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() | ||
|
||
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 | ||
keywords = [stop_str] | ||
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) | ||
|
||
with torch.inference_mode(): | ||
output_logits = model(input_ids, | ||
images=image_tensor)["logits"][:,-1] | ||
|
||
probs, inds = output_logits.sort(dim=-1, descending=True) | ||
#print(probs[0, 0:100], inds[0, 0:100], tokenizer.convert_ids_to_tokens(inds[0, 0:100])) | ||
# 1781: good, 6460: poor | ||
# 1880: high, 4482: low | ||
# 15129: excellent, 1781: good, 6534: fair, 6460: poor, 4319: bad | ||
''' | ||
lgood, lpoor = output_logits[0,1781].item(), output_logits[0,6460].item() | ||
lhigh, llow = output_logits[0,1880].item(), output_logits[0,4482].item() | ||
llddata["logit_good"] = lgood | ||
llddata["logit_poor"] = lpoor | ||
llddata["logit_high"] = lhigh | ||
llddata["logit_low"] = llow | ||
out_lines.append(" ".join([image_path, str(iaa_score), | ||
str(float(llddata["logit_good"])), str(float(llddata["logit_poor"])), | ||
str(float(llddata["logit_high"])), str(float(llddata["logit_low"])), | ||
]) + "\n") | ||
''' | ||
lexcel, lgood, lfair, lpoor, lbad = output_logits[0,15129].item(), output_logits[0,1781].item(), output_logits[0,6534].item(), output_logits[0,6460].item(), output_logits[0,4319].item() | ||
out_lines.append(" ".join([image_path, str(iaa_score), | ||
str(float(lexcel)), str(float(lgood)), str(float(lfair)), str(float(lpoor)), str(float(lbad))]) | ||
+ "\n") | ||
f = open(os.path.join(args.model_path, "zero-shot-iaa-score-ava-5gear.txt"), "w") | ||
f.writelines(out_lines) | ||
f.close() | ||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--model-path", type=str, default="") | ||
parser.add_argument("--model-base", type=str, default="liuhaotian/llava-v1.5-7b") | ||
parser.add_argument("--query", type=str, default="Rate this image from an aesthetic perspective.") | ||
parser.add_argument("--conv-mode", type=str, default="llava_v1") | ||
args = parser.parse_args() | ||
eval_model(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
import argparse | ||
import torch | ||
from tqdm import tqdm | ||
import json | ||
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN | ||
from llava.conversation import conv_templates, SeparatorStyle, Conversation | ||
from llava.model.builder import load_pretrained_model | ||
from llava.utils import disable_torch_init | ||
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria | ||
from PIL import Image | ||
import requests | ||
from PIL import Image | ||
from io import BytesIO | ||
|
||
|
||
device = "cuda" if torch.cuda.is_available() else "cpu" | ||
def load_image(image_file): | ||
if image_file.startswith('http') or image_file.startswith('https'): | ||
response = requests.get(image_file) | ||
image = Image.open(BytesIO(response.content)).convert('RGB') | ||
else: | ||
image = Image.open(image_file).convert('RGB') | ||
return image | ||
|
||
|
||
disable_torch_init() | ||
model_name = "llava-v1.5" | ||
tokenizer, model, image_processor, context_len = load_pretrained_model(checkpoint, None, model_name) | ||
question_file = "UNIAA_Describe.json" | ||
output_file = "UNIAA_Describe_output.json" | ||
output_dir = 'UNIAA-Bench/' | ||
output_file = output_dir + '/' + output_file | ||
|
||
|
||
llava_model_path = [ | ||
checkpoint + '/pytorch_model-00001-of-00003.bin', | ||
checkpoint + '/pytorch_model-00002-of-00003.bin', | ||
checkpoint + '/pytorch_model-00003-of-00003.bin' | ||
] | ||
state_dict = {} | ||
for weight_file in llava_model_path: | ||
weight_file_path = weight_file | ||
state_dict_part = torch.load(weight_file_path, map_location=torch.device('cpu')) | ||
new_state_dict_part = {} | ||
for k, v in state_dict_part.items(): | ||
new_state_dict_part[k] = v | ||
state_dict.update(new_state_dict_part) | ||
|
||
encoder_dict = {k:v for k, v in state_dict.items() if k.startswith('model.vision_tower.vision_tower.vision_model')} | ||
for old_key in list(encoder_dict.keys()): | ||
new_key = old_key.replace('model.vision_tower.vision_tower.vision_model.', 'vision_model.') | ||
encoder_dict[new_key] = encoder_dict.pop(old_key) | ||
vision_tower = model.get_vision_tower() | ||
vision_tower.vision_tower.load_state_dict(encoder_dict, strict=True) | ||
|
||
|
||
|
||
with open(question_file) as f: | ||
description_data = json.load(f) | ||
|
||
for i, data in enumerate(tqdm(description_data)): | ||
filename = data["img_path"] | ||
qs = data["question"] | ||
if model.config.mm_use_im_start_end: | ||
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs | ||
else: | ||
qs = DEFAULT_IMAGE_TOKEN + '\n' + qs | ||
|
||
if 'llama-2' in model_name.lower(): | ||
conv_mode = "llava_llama_2" | ||
elif "v1" in model_name.lower(): | ||
conv_mode = "llava_v1" | ||
elif "mpt" in model_name.lower(): | ||
conv_mode = "mpt" | ||
else: | ||
conv_mode = "llava_v0" | ||
|
||
image = load_image(filename) | ||
image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda() | ||
conv = conv_templates[conv_mode].copy() | ||
conv.append_message(conv.roles[0], qs) | ||
conv.append_message(conv.roles[1], None) | ||
prompt = conv.get_prompt() | ||
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() | ||
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 | ||
keywords = [stop_str] | ||
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) | ||
with torch.inference_mode(): | ||
output_ids = model.generate( | ||
input_ids, | ||
images=image_tensor, | ||
# num_beams=1, | ||
do_sample=True, | ||
temperature=0.1, | ||
top_p=0.7, | ||
max_new_tokens=1024, | ||
use_cache=True, | ||
stopping_criteria=[stopping_criteria]) | ||
|
||
input_token_len = input_ids.shape[1] | ||
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] | ||
outputs = outputs.strip() | ||
if outputs.endswith(stop_str): | ||
outputs = outputs[:-len(stop_str)] | ||
outputs = outputs.strip() | ||
data["response"] = outputs | ||
|
||
with open(output_file, "a") as wf: | ||
json.dump(data, wf) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
import argparse | ||
import torch | ||
from tqdm import tqdm | ||
import json | ||
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN | ||
from llava.conversation import conv_templates, SeparatorStyle, Conversation | ||
from llava.model.builder import load_pretrained_model | ||
from llava.utils import disable_torch_init | ||
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria | ||
from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig | ||
from PIL import Image | ||
import requests | ||
from PIL import Image | ||
from io import BytesIO | ||
|
||
|
||
def load_image(image_file): | ||
if image_file.startswith('http') or image_file.startswith('https'): | ||
response = requests.get(image_file) | ||
image = Image.open(BytesIO(response.content)).convert('RGB') | ||
else: | ||
image = Image.open(image_file).convert('RGB') | ||
return image | ||
|
||
|
||
def eval_model(args): | ||
# Model | ||
disable_torch_init() | ||
model_version = args.model_path.split("/")[-1] | ||
answers_file = args.answers_file | ||
model_name = get_model_name_from_path(args.model_name) | ||
tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name) | ||
# Please add the bin file of the download UNIAA-LLaVA model bellows | ||
llava_model_path = [ | ||
args.absolute_model_path + '/pytorch_model-00001-of-00003.bin', | ||
args.absolute_model_path + '/pytorch_model-00002-of-00003.bin', | ||
args.absolute_model_path + '/pytorch_model-00003-of-00003.bin', | ||
] | ||
state_dict = {} | ||
for weight_file in llava_model_path: | ||
weight_file_path = weight_file | ||
state_dict_part = torch.load(weight_file_path, map_location=torch.device('cpu')) | ||
new_state_dict_part = {} | ||
for k, v in state_dict_part.items(): | ||
new_state_dict_part[k] = v | ||
state_dict.update(new_state_dict_part) | ||
|
||
encoder_dict = {k:v for k, v in state_dict.items() if k.startswith('model.vision_tower.vision_tower.vision_model')} | ||
for old_key in list(encoder_dict.keys()): | ||
new_key = old_key.replace('model.vision_tower.vision_tower.vision_model.', 'vision_model.') | ||
encoder_dict[new_key] = encoder_dict.pop(old_key) | ||
vision_tower = model.get_vision_tower() | ||
vision_tower.vision_tower.load_state_dict(encoder_dict, strict=True) | ||
|
||
|
||
with open(args.questions_file) as f: | ||
llvqa_data = json.load(f) | ||
|
||
for i, llddata in enumerate(tqdm(llvqa_data)): | ||
filename = llddata["img_path"] | ||
if args.lang == "en": | ||
message = llddata["question"] + "\nChoose between one of the options as follows:\n" | ||
elif args.lang == "zh": | ||
message = llddata["question"] + "\在下列选项中选择一个:\n" | ||
else: | ||
raise NotImplementedError("IAA-Bench does not support languages other than English (en) and Chinese (zh) yet. Contact us (https://github.com/VQAssessment/Q-Bench/) to convert Q-Bench into more languages.") | ||
if len(llddata["candidates"]) == 4: | ||
for choice, ans in zip(["A.", "B.", "C.", "D."], llddata["candidates"]): | ||
message += f"{choice} {ans}\n" | ||
elif len(llddata["candidates"]) == 2: | ||
for choice, ans in zip(["A.", "B."], llddata["candidates"]): | ||
message += f"{choice} {ans}\n" | ||
elif len(llddata["candidates"]) == 3: | ||
for choice, ans in zip(["A.", "B.", "C."], llddata["candidates"]): | ||
message += f"{choice} {ans}\n" | ||
elif len(llddata["candidates"]) == 5: | ||
for choice, ans in zip(["A.", "B.", "C.", "D.", "E."], llddata["candidates"]): | ||
message += f"{choice} {ans}\n" | ||
qs = message | ||
|
||
if model.config.mm_use_im_start_end: | ||
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs | ||
else: | ||
qs = DEFAULT_IMAGE_TOKEN + '\n' + qs | ||
|
||
if 'llama-2' in model_name.lower(): | ||
conv_mode = "llava_llama_2" | ||
elif "v1" in model_name.lower(): | ||
conv_mode = "llava_v1" | ||
elif "mpt" in model_name.lower(): | ||
conv_mode = "mpt" | ||
else: | ||
conv_mode = "llava_v0" | ||
|
||
if args.conv_mode is not None and conv_mode != args.conv_mode: | ||
print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode)) | ||
else: | ||
args.conv_mode = conv_mode | ||
|
||
conv = conv_templates[args.conv_mode].copy() | ||
conv.append_message(conv.roles[0], qs) | ||
conv.append_message(conv.roles[1], None) | ||
prompt = conv.get_prompt() | ||
|
||
image = load_image(filename) | ||
image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda() | ||
|
||
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() | ||
|
||
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 | ||
keywords = [stop_str] | ||
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) | ||
|
||
|
||
with torch.inference_mode(): | ||
output_ids = model.generate( | ||
input_ids, | ||
images=image_tensor, | ||
num_beams=1, | ||
do_sample=False, | ||
temperature=0, | ||
max_new_tokens=1024, | ||
use_cache=True, | ||
stopping_criteria=[stopping_criteria]) | ||
|
||
input_token_len = input_ids.shape[1] | ||
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() | ||
if n_diff_input_output > 0: | ||
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') | ||
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] | ||
outputs = outputs.strip() | ||
if outputs.endswith(stop_str): | ||
outputs = outputs[:-len(stop_str)] | ||
outputs = outputs.strip() | ||
llddata["response"] = outputs | ||
|
||
with open(answers_file, "a") as wf: | ||
json.dump(llddata, wf) | ||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--model-name", type=str, default="llava-v1.5") | ||
parser.add_argument("--model-path", type=str, default="zhouzhaokun/UNIAA-LLaVA") | ||
parser.add_argument("--absolute_model-path", type=str, default="path-to-UNIAA-LLaVA") | ||
|
||
parser.add_argument("--model-base", type=str, default=None) | ||
parser.add_argument("--questions-file", type=str, | ||
default="UNIAA_QA.json") | ||
parser.add_argument("--conv-mode", type=str, default="llava_v1") | ||
parser.add_argument("--lang", type=str, default="en") | ||
args = parser.parse_args(--answers-file, type=str, default="UNIAA_QA_answers.json") | ||
eval_model(args) | ||
|
||
|
||
|
||
|
Oops, something went wrong.