diff --git a/lmms_eval/evaluator.py b/lmms_eval/evaluator.py index e6775c9e2..9e0a697d0 100644 --- a/lmms_eval/evaluator.py +++ b/lmms_eval/evaluator.py @@ -315,7 +315,7 @@ def evaluate( # Don't use above one, this would crash if doc_iterator_for_counting contains too many objects and very slow doc_iterator_for_counting = itertools.islice(range(len(task.test_docs())), lm.rank, limit, lm.world_size) if task.has_test_docs() else itertools.islice(range(len(task.validation_docs())), lm.rank, limit, lm.world_size) total_docs = sum(1 for _ in doc_iterator_for_counting) - pbar = tqdm(total=total_docs, desc="Postprocessing") + pbar = tqdm(total=total_docs, desc="Postprocessing", position=lm.rank) for doc_id, doc in doc_iterator: # subset instances to only this document id ; sort by idx requests = list(filter(lambda x: x.doc_id == doc_id, task.instances)) diff --git a/lmms_eval/tasks/vizwizvqa/utils.py b/lmms_eval/tasks/vizwizvqa/utils.py index 2b3d69310..35a6e0fa1 100644 --- a/lmms_eval/tasks/vizwizvqa/utils.py +++ b/lmms_eval/tasks/vizwizvqa/utils.py @@ -252,14 +252,14 @@ def vizwizvqa_process_results(doc, result): return { "exact_match": accuracy, "submission": { - "question_id": doc["question_id"], + "image": f"{doc['question_id']}.jpg", "answer": resAns, }, } def vizwizvqa_doc_to_text(doc): - text = f"{doc['question'].capitalize()}\n When the provided information is insufficient, respond with 'unanswerable'. Answer the question using a single word or phrase." + text = f"{doc['question'].capitalize()}\nWhen the provided information is insufficient, respond with 'Unanswerable'.\nAnswer the question using a single word or phrase." return text diff --git a/lmms_eval/tasks/vizwizvqa_val/utils.py b/lmms_eval/tasks/vizwizvqa_val/utils.py new file mode 100644 index 000000000..35a6e0fa1 --- /dev/null +++ b/lmms_eval/tasks/vizwizvqa_val/utils.py @@ -0,0 +1,273 @@ +import re +import os +import json +import yaml +import pathlib +import logging +import datetime +import statistics + +eval_logger = logging.getLogger("lmms-eval") + +with open(pathlib.Path(__file__).parent / "vizwizvqa.yaml", "r") as f: + raw_data = f.readlines() + for i in range(len(raw_data)): + raw_data[i] = raw_data[i].replace("!function", "function") + + config = yaml.safe_load("".join(raw_data)) + + +class EvalAIAnswerProcessor: + CONTRACTIONS = { + "aint": "ain't", + "arent": "aren't", + "cant": "can't", + "couldve": "could've", + "couldnt": "couldn't", + "couldn'tve": "couldn't've", + "couldnt've": "couldn't've", + "didnt": "didn't", + "doesnt": "doesn't", + "dont": "don't", + "hadnt": "hadn't", + "hadnt've": "hadn't've", + "hadn'tve": "hadn't've", + "hasnt": "hasn't", + "havent": "haven't", + "hed": "he'd", + "hed've": "he'd've", + "he'dve": "he'd've", + "hes": "he's", + "howd": "how'd", + "howll": "how'll", + "hows": "how's", + "Id've": "I'd've", + "I'dve": "I'd've", + "Im": "I'm", + "Ive": "I've", + "isnt": "isn't", + "itd": "it'd", + "itd've": "it'd've", + "it'dve": "it'd've", + "itll": "it'll", + "let's": "let's", + "maam": "ma'am", + "mightnt": "mightn't", + "mightnt've": "mightn't've", + "mightn'tve": "mightn't've", + "mightve": "might've", + "mustnt": "mustn't", + "mustve": "must've", + "neednt": "needn't", + "notve": "not've", + "oclock": "o'clock", + "oughtnt": "oughtn't", + "ow's'at": "'ow's'at", + "'ows'at": "'ow's'at", + "'ow'sat": "'ow's'at", + "shant": "shan't", + "shed've": "she'd've", + "she'dve": "she'd've", + "she's": "she's", + "shouldve": "should've", + "shouldnt": "shouldn't", + "shouldnt've": "shouldn't've", + "shouldn'tve": "shouldn't've", + "somebody'd": "somebodyd", + "somebodyd've": "somebody'd've", + "somebody'dve": "somebody'd've", + "somebodyll": "somebody'll", + "somebodys": "somebody's", + "someoned": "someone'd", + "someoned've": "someone'd've", + "someone'dve": "someone'd've", + "someonell": "someone'll", + "someones": "someone's", + "somethingd": "something'd", + "somethingd've": "something'd've", + "something'dve": "something'd've", + "somethingll": "something'll", + "thats": "that's", + "thered": "there'd", + "thered've": "there'd've", + "there'dve": "there'd've", + "therere": "there're", + "theres": "there's", + "theyd": "they'd", + "theyd've": "they'd've", + "they'dve": "they'd've", + "theyll": "they'll", + "theyre": "they're", + "theyve": "they've", + "twas": "'twas", + "wasnt": "wasn't", + "wed've": "we'd've", + "we'dve": "we'd've", + "weve": "we've", + "werent": "weren't", + "whatll": "what'll", + "whatre": "what're", + "whats": "what's", + "whatve": "what've", + "whens": "when's", + "whered": "where'd", + "wheres": "where's", + "whereve": "where've", + "whod": "who'd", + "whod've": "who'd've", + "who'dve": "who'd've", + "wholl": "who'll", + "whos": "who's", + "whove": "who've", + "whyll": "why'll", + "whyre": "why're", + "whys": "why's", + "wont": "won't", + "wouldve": "would've", + "wouldnt": "wouldn't", + "wouldnt've": "wouldn't've", + "wouldn'tve": "wouldn't've", + "yall": "y'all", + "yall'll": "y'all'll", + "y'allll": "y'all'll", + "yall'd've": "y'all'd've", + "y'alld've": "y'all'd've", + "y'all'dve": "y'all'd've", + "youd": "you'd", + "youd've": "you'd've", + "you'dve": "you'd've", + "youll": "you'll", + "youre": "you're", + "youve": "you've", + } + + NUMBER_MAP = { + "none": "0", + "zero": "0", + "one": "1", + "two": "2", + "three": "3", + "four": "4", + "five": "5", + "six": "6", + "seven": "7", + "eight": "8", + "nine": "9", + "ten": "10", + } + ARTICLES = ["a", "an", "the"] + PERIOD_STRIP = re.compile(r"(?!<=\d)(\.)(?!\d)") + COMMA_STRIP = re.compile(r"(?<=\d)(\,)+(?=\d)") + PUNCTUATIONS = [ + ";", + r"/", + "[", + "]", + '"', + "{", + "}", + "(", + ")", + "=", + "+", + "\\", + "_", + "-", + ">", + "<", + "@", + "`", + ",", + "?", + "!", + ] + + def __init__(self, *args, **kwargs): + pass + + def word_tokenize(self, word): + word = word.lower() + word = word.replace(",", "").replace("?", "").replace("'s", " 's") + word = word.replace("\n", " ").replace("\t", " ").strip() + return word.strip() + + def process_punctuation(self, in_text): + out_text = in_text + for p in self.PUNCTUATIONS: + if (p + " " in in_text or " " + p in in_text) or (re.search(self.COMMA_STRIP, in_text) is not None): + out_text = out_text.replace(p, "") + else: + out_text = out_text.replace(p, " ") + out_text = self.PERIOD_STRIP.sub("", out_text, re.UNICODE) + return out_text + + def process_digit_article(self, in_text): + out_text = [] + temp_text = in_text.lower().split() + for word in temp_text: + word = self.NUMBER_MAP.setdefault(word, word) + if word not in self.ARTICLES: + out_text.append(word) + else: + pass + for word_id, word in enumerate(out_text): + if word in self.CONTRACTIONS: + out_text[word_id] = self.CONTRACTIONS[word] + out_text = " ".join(out_text) + return out_text + + def __call__(self, item): + item = self.word_tokenize(item) + item = self.process_punctuation(item) + item = self.process_digit_article(item) + return item + + +def vizwizvqa_doc_to_visual(doc): + return [doc["image"].convert("RGB")] + + +def vizwizvqa_process_results(doc, result): + eval_ai_processor = EvalAIAnswerProcessor() + assert len(result) == 1, f"The result should be a list of length 1, but got {len(result)}." + resAns = eval_ai_processor(result[0]) + accuracy = 0 + + if "answers" in doc and doc["answers"] is not None: + gtAcc = [] + + for i in range(len(doc["answers"])): + doc["answers"][i] = eval_ai_processor(doc["answers"][i]) + + for i in range(len(doc["answers"])): + otherGTAns = [doc["answers"][j] for j in range(len(doc["answers"])) if i != j] + matchingAns = [item for item in otherGTAns if item == resAns] + acc = min(1, float(len(matchingAns)) / 3) + gtAcc.append(acc) + if gtAcc: + accuracy = statistics.mean(gtAcc) + else: + accuracy = 0 + + return { + "exact_match": accuracy, + "submission": { + "image": f"{doc['question_id']}.jpg", + "answer": resAns, + }, + } + + +def vizwizvqa_doc_to_text(doc): + text = f"{doc['question'].capitalize()}\nWhen the provided information is insufficient, respond with 'Unanswerable'.\nAnswer the question using a single word or phrase." + return text + + +def vizwizvqa_aggreate_submissions(results): + now_date_time = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + submission_file_name = f"vizwizvqa-submission-{now_date_time}.json" + path = os.path.abspath(submission_file_name) + with open(path, "w") as f: + json.dump(results, f) + print(f"Submission file saved to {path}") + return 0 diff --git a/lmms_eval/tasks/vizwizvqa_val/vizwizvqa.yaml b/lmms_eval/tasks/vizwizvqa_val/vizwizvqa.yaml new file mode 100644 index 000000000..290d928ef --- /dev/null +++ b/lmms_eval/tasks/vizwizvqa_val/vizwizvqa.yaml @@ -0,0 +1,24 @@ +task: vizwizvqa_val +dataset_path: lmms-lab/VizWiz-VQA + token: True +test_split: val +output_type: generate_until +doc_to_visual: !function utils.vizwizvqa_doc_to_visual +doc_to_text: !function utils.vizwizvqa_doc_to_text +doc_to_target: "answer" +generation_kwargs: + until: + - "ASSISTANT:" +metric_list: + - metric: exact_match + aggregation: mean + higher_is_better: true + ignore_case: true + ignore_punctuation: true + - metric: submission + aggregation: !function utils.vizwizvqa_aggreate_submissions + higher_is_better: true +metadata: + - version: 0.0 + - have_ocr_reference: false +process_results: !function utils.vizwizvqa_process_results