diff --git a/lmms_eval/api/task.py b/lmms_eval/api/task.py index 1e05c084e..bbe728d36 100644 --- a/lmms_eval/api/task.py +++ b/lmms_eval/api/task.py @@ -348,7 +348,7 @@ def build_all_requests(self, limit=None, rank=None, world_size=None) -> None: doc_id_iterator = utils.create_iterator([i for i in range(len(docs))], rank, world_size, limit) doc_id_iterator, doc_id_iterator_counting = itertools.tee(doc_id_iterator) total_docs = sum(1 for _ in doc_id_iterator_counting) - pbar = tqdm(total=total_docs, desc="Building context") + pbar = tqdm(total=total_docs, desc=f"Building context {rank}", position=rank) for doc_id in doc_id_iterator: # sample fewshot context #TODO: need to offset doc_id by rank now! fewshot_ctx = self.fewshot_context(doc_id, 0 if self.config.num_fewshot is None else self.config.num_fewshot, self.config.training_split if self.has_training_docs() else split) diff --git a/lmms_eval/tasks/docvqa/docvqa.yaml b/lmms_eval/tasks/docvqa/docvqa.yaml new file mode 100644 index 000000000..f441a82ea --- /dev/null +++ b/lmms_eval/tasks/docvqa/docvqa.yaml @@ -0,0 +1,22 @@ +task: docvqa +dataset_path: lmms-lab/DocVQA +test_split: test +output_type: generate_until +doc_to_visual: !function utils.vqav2_doc_to_visual +doc_to_text: !function utils.vqav2_doc_to_text +doc_to_target: "answer" +generation_kwargs: + until: + - "ASSISTANT:" +metric_list: + - metric: exact_match + aggregation: mean + higher_is_better: true + ignore_case: true + ignore_punctuation: true + - metric: submission + aggregation: !function utils.vqav2_aggreate_submissions + higher_is_better: true +metadata: + - version: 0.0 +process_results: !function utils.vqav2_process_results diff --git a/lmms_eval/tasks/vqav2/utils.py b/lmms_eval/tasks/vqav2_test/utils.py similarity index 98% rename from lmms_eval/tasks/vqav2/utils.py rename to lmms_eval/tasks/vqav2_test/utils.py index 7e10c7103..695e4cb23 100644 --- a/lmms_eval/tasks/vqav2/utils.py +++ b/lmms_eval/tasks/vqav2_test/utils.py @@ -265,7 +265,7 @@ def vqav2_doc_to_text(doc): def vqav2_aggreate_submissions(results): now_date_time = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") os.makedirs("./submissions", exist_ok=True) - submission_file_name = f"./submissions/vqav2-submission-{now_date_time}.json" + submission_file_name = f"./submissions/vqav2-test-submission-{now_date_time}.json" path = os.path.abspath(submission_file_name) with open(path, "w") as f: json.dump(results, f) diff --git a/lmms_eval/tasks/vqav2/vqav2.yaml b/lmms_eval/tasks/vqav2_test/vqav2_test.yaml similarity index 94% rename from lmms_eval/tasks/vqav2/vqav2.yaml rename to lmms_eval/tasks/vqav2_test/vqav2_test.yaml index 3f82c84fc..e29fb4303 100644 --- a/lmms_eval/tasks/vqav2/vqav2.yaml +++ b/lmms_eval/tasks/vqav2_test/vqav2_test.yaml @@ -1,8 +1,8 @@ -task: "vqav2" +task: "vqav2_test" dataset_path: lmms-lab/VQAv2 dataset_kwargs: token: True -test_split: testdev +test_split: test output_type: generate_until doc_to_visual: !function utils.vqav2_doc_to_visual doc_to_text: !function utils.vqav2_doc_to_text diff --git a/lmms_eval/tasks/vqav2_val/utils.py b/lmms_eval/tasks/vqav2_val/utils.py new file mode 100644 index 000000000..ec28bf004 --- /dev/null +++ b/lmms_eval/tasks/vqav2_val/utils.py @@ -0,0 +1,272 @@ +import re +import os +import json +import datetime +import statistics + + +def vqav2_doc_to_visual(doc): + return [doc["image"].convert("RGB")] + + +class EvalAIAnswerProcessor: + """ + Processes an answer similar to Eval AI + copied from + https://github.com/facebookresearch/mmf/blob/c46b3b3391275b4181567db80943473a89ab98ab/pythia/tasks/processors.py#L897 + """ + + CONTRACTIONS = { + "aint": "ain't", + "arent": "aren't", + "cant": "can't", + "couldve": "could've", + "couldnt": "couldn't", + "couldn'tve": "couldn't've", + "couldnt've": "couldn't've", + "didnt": "didn't", + "doesnt": "doesn't", + "dont": "don't", + "hadnt": "hadn't", + "hadnt've": "hadn't've", + "hadn'tve": "hadn't've", + "hasnt": "hasn't", + "havent": "haven't", + "hed": "he'd", + "hed've": "he'd've", + "he'dve": "he'd've", + "hes": "he's", + "howd": "how'd", + "howll": "how'll", + "hows": "how's", + "Id've": "I'd've", + "I'dve": "I'd've", + "Im": "I'm", + "Ive": "I've", + "isnt": "isn't", + "itd": "it'd", + "itd've": "it'd've", + "it'dve": "it'd've", + "itll": "it'll", + "let's": "let's", + "maam": "ma'am", + "mightnt": "mightn't", + "mightnt've": "mightn't've", + "mightn'tve": "mightn't've", + "mightve": "might've", + "mustnt": "mustn't", + "mustve": "must've", + "neednt": "needn't", + "notve": "not've", + "oclock": "o'clock", + "oughtnt": "oughtn't", + "ow's'at": "'ow's'at", + "'ows'at": "'ow's'at", + "'ow'sat": "'ow's'at", + "shant": "shan't", + "shed've": "she'd've", + "she'dve": "she'd've", + "she's": "she's", + "shouldve": "should've", + "shouldnt": "shouldn't", + "shouldnt've": "shouldn't've", + "shouldn'tve": "shouldn't've", + "somebody'd": "somebodyd", + "somebodyd've": "somebody'd've", + "somebody'dve": "somebody'd've", + "somebodyll": "somebody'll", + "somebodys": "somebody's", + "someoned": "someone'd", + "someoned've": "someone'd've", + "someone'dve": "someone'd've", + "someonell": "someone'll", + "someones": "someone's", + "somethingd": "something'd", + "somethingd've": "something'd've", + "something'dve": "something'd've", + "somethingll": "something'll", + "thats": "that's", + "thered": "there'd", + "thered've": "there'd've", + "there'dve": "there'd've", + "therere": "there're", + "theres": "there's", + "theyd": "they'd", + "theyd've": "they'd've", + "they'dve": "they'd've", + "theyll": "they'll", + "theyre": "they're", + "theyve": "they've", + "twas": "'twas", + "wasnt": "wasn't", + "wed've": "we'd've", + "we'dve": "we'd've", + "weve": "we've", + "werent": "weren't", + "whatll": "what'll", + "whatre": "what're", + "whats": "what's", + "whatve": "what've", + "whens": "when's", + "whered": "where'd", + "wheres": "where's", + "whereve": "where've", + "whod": "who'd", + "whod've": "who'd've", + "who'dve": "who'd've", + "wholl": "who'll", + "whos": "who's", + "whove": "who've", + "whyll": "why'll", + "whyre": "why're", + "whys": "why's", + "wont": "won't", + "wouldve": "would've", + "wouldnt": "wouldn't", + "wouldnt've": "wouldn't've", + "wouldn'tve": "wouldn't've", + "yall": "y'all", + "yall'll": "y'all'll", + "y'allll": "y'all'll", + "yall'd've": "y'all'd've", + "y'alld've": "y'all'd've", + "y'all'dve": "y'all'd've", + "youd": "you'd", + "youd've": "you'd've", + "you'dve": "you'd've", + "youll": "you'll", + "youre": "you're", + "youve": "you've", + } + + NUMBER_MAP = { + "none": "0", + "zero": "0", + "one": "1", + "two": "2", + "three": "3", + "four": "4", + "five": "5", + "six": "6", + "seven": "7", + "eight": "8", + "nine": "9", + "ten": "10", + } + ARTICLES = ["a", "an", "the"] + PERIOD_STRIP = re.compile(r"(?!<=\d)(\.)(?!\d)") + COMMA_STRIP = re.compile(r"(?<=\d)(\,)+(?=\d)") + PUNCTUATIONS = [ + ";", + r"/", + "[", + "]", + '"', + "{", + "}", + "(", + ")", + "=", + "+", + "\\", + "_", + "-", + ">", + "<", + "@", + "`", + ",", + "?", + "!", + ] + + def __init__(self, *args, **kwargs): + pass + + def word_tokenize(self, word): + word = word.lower() + word = word.replace(",", "").replace("?", "").replace("'s", " 's") + return word.strip() + + def process_punctuation(self, in_text): + out_text = in_text + for p in self.PUNCTUATIONS: + if (p + " " in in_text or " " + p in in_text) or (re.search(self.COMMA_STRIP, in_text) is not None): + out_text = out_text.replace(p, "") + else: + out_text = out_text.replace(p, " ") + out_text = self.PERIOD_STRIP.sub("", out_text, re.UNICODE) + return out_text + + def process_digit_article(self, in_text): + out_text = [] + temp_text = in_text.lower().split() + for word in temp_text: + word = self.NUMBER_MAP.setdefault(word, word) + if word not in self.ARTICLES: + out_text.append(word) + else: + pass + for word_id, word in enumerate(out_text): + if word in self.CONTRACTIONS: + out_text[word_id] = self.CONTRACTIONS[word] + out_text = " ".join(out_text) + return out_text + + def __call__(self, item): + item = self.word_tokenize(item) + item = item.replace("\n", " ").replace("\t", " ").strip() + item = self.process_punctuation(item) + item = self.process_digit_article(item) + return item + + +def vqav2_process_results(doc, result): + eval_ai_processor = EvalAIAnswerProcessor() + assert len(result) == 1, f"The result should be a list of length 1, but got {len(result)}." + resAns = eval_ai_processor(result[0]) + accuracy = 0 + + if "answers" in doc and doc["answers"] is not None: + for ansDic in doc["answers"]: + ansDic["answer"] = ansDic["answer"].replace("\n", " ") + ansDic["answer"] = ansDic["answer"].replace("\t", " ") + ansDic["answer"] = ansDic["answer"].strip() + gtAcc = [] + gtAnswers = [ans["answer"] for ans in doc["answers"]] + + if len(set(gtAnswers)) > 1: + for ansDic in doc["answers"]: + ansDic["answer"] = eval_ai_processor.process_punctuation(ansDic["answer"]) + ansDic["answer"] = eval_ai_processor.process_digit_article(ansDic["answer"]) + resAns = eval_ai_processor.process_punctuation(resAns) + resAns = eval_ai_processor.process_digit_article(resAns) + + for gtAnsDatum in doc["answers"]: + otherGTAns = [item for item in doc["answers"] if item != gtAnsDatum] + matchingAns = [item for item in otherGTAns if item["answer"] == resAns] + acc = min(1, float(len(matchingAns)) / 3) + gtAcc.append(acc) + accuracy = statistics.mean(gtAcc) + + return { + "exact_match": accuracy, + "submission": { + "question_id": doc["question_id"], + "answer": resAns, + }, + } + + +def vqav2_doc_to_text(doc): + return f"{doc['question']}\nAnswer the question using a single word or phrase." + + +def vqav2_aggreate_submissions(results): + now_date_time = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + submission_file_name = f"vqav2-val-submission-{now_date_time}.json" + path = os.path.abspath(submission_file_name) + with open(path, "w") as f: + json.dump(results, f) + print(f"Submission file saved to {path}") + return 0 diff --git a/lmms_eval/tasks/vqav2_val/vqav2_val.yaml b/lmms_eval/tasks/vqav2_val/vqav2_val.yaml new file mode 100644 index 000000000..6e8ba0c32 --- /dev/null +++ b/lmms_eval/tasks/vqav2_val/vqav2_val.yaml @@ -0,0 +1,24 @@ +task: "vqav2_val" +dataset_path: lmms-lab/VQAv2 +dataset_kwargs: + token: True +test_split: validation +output_type: generate_until +doc_to_visual: !function utils.vqav2_doc_to_visual +doc_to_text: !function utils.vqav2_doc_to_text +doc_to_target: "answer" +generation_kwargs: + until: + - "ASSISTANT:" +metric_list: + - metric: exact_match + aggregation: mean + higher_is_better: true + ignore_case: true + ignore_punctuation: true + - metric: submission + aggregation: !function utils.vqav2_aggreate_submissions + higher_is_better: true +metadata: + - version: 0.0 +process_results: !function utils.vqav2_process_results