From 626e8a91a4af2dd5dd774fc130cc2f4d74b2bc37 Mon Sep 17 00:00:00 2001 From: Hunter Heidenreich Date: Thu, 2 May 2024 09:31:03 -0400 Subject: [PATCH] Bugfix: WebSRC should be token-level F1 NOT character-level --- lmms_eval/tasks/websrc/utils.py | 42 +++++++++++++++++++++------------ 1 file changed, 27 insertions(+), 15 deletions(-) diff --git a/lmms_eval/tasks/websrc/utils.py b/lmms_eval/tasks/websrc/utils.py index bd5c9bc2a..e76037c2b 100644 --- a/lmms_eval/tasks/websrc/utils.py +++ b/lmms_eval/tasks/websrc/utils.py @@ -50,7 +50,7 @@ def websrc_process_results(doc, results): "websrc_squad_f1": websrc_ans, "submission": { websrc_ans['question_id']: pred, - }, + } if 'question_id' in websrc_ans else None } @@ -122,27 +122,39 @@ def _normalize_str(string): # lower it string = string.lower() - # strip non-alphanumeric characters - string = re.sub(r"[^a-zA-Z0-9]", "", string) - # strip leading and trailing whitespaces string = string.strip() return string + def _tokenize(text): + # Regex pattern to match words and isolate punctuation + pattern = r'\w+|[^\w\s]' + tokens = re.findall(pattern, text) + return tokens + + def _compute_f1(sa, sb): + sa = _normalize_str(sa) + sb = _normalize_str(sb) + + sa = _tokenize(sa) + sb = _tokenize(sb) + + sa = set(sa) + sb = set(sb) + + if len(sa) == 0 or len(sb) == 0: + return 0.0 + + comm = sa.intersection(sb) + prec = len(comm) / len(sb) + rec = len(comm) / len(sa) + f1 = 2 * prec * rec / (prec + rec) if prec + rec > 0 else 0 + return f1 + judge_list = [] for sample in samples: - gold_i = set(_normalize_str(sample["answer"])) - pred_i = set(_normalize_str( sample["parsed_pred"])) - if len(pred_i) == 0: - judge_list.append(0.0) - continue - - comm_i = gold_i.intersection(pred_i) - prec_i = len(comm_i) / len(pred_i) - rec_i = len(comm_i) / len(gold_i) - f1_i = 2 * prec_i * rec_i / (prec_i + rec_i) if prec_i + rec_i > 0 else 0 - judge_list.append(f1_i) + judge_list.append(_compute_f1(sample["answer"], sample["parsed_pred"])) f1 = np.mean(judge_list) return judge_list, {"f1": f1}