Skip to content

Commit

Permalink
fix: evaluator challenge use correct answer
Browse files Browse the repository at this point in the history
  • Loading branch information
lkaesberg committed Nov 27, 2024
1 parent db3fb68 commit 25e38b5
Showing 1 changed file with 17 additions and 8 deletions.
25 changes: 17 additions & 8 deletions mallm/evaluation/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,11 @@ def _initialize_metrics(metrics: Optional[list[str]]) -> list[Any]:
return selected_metrics

def calculate_scores(
self, answer: str, references: list[str], metric_alteration: str = "", dataset_id: Optional[str] = None
self,
answer: str,
references: list[str],
metric_alteration: str = "",
dataset_id: Optional[str] = None,
) -> dict[str, Any]:
metrics = []
if references:
Expand All @@ -113,12 +117,16 @@ def calculate_scores(
}

def add_scores(self) -> None:
for item in tqdm(self.data, desc=f"Calculating scores of {self.input_file_path}: "):
for item in tqdm(
self.data, desc=f"Calculating scores of {self.input_file_path}: "
):
main_answer = item.get("finalAnswer", "")
references = item.get("references", [])
dataset_id = item.get("datasetId", None)
if main_answer:
item["scores"] = self.calculate_scores(main_answer, references, "", dataset_id)
item["scores"] = self.calculate_scores(
main_answer, references, "", dataset_id
)

votes_each_turn = item.get("votesEachTurn", None)
if votes_each_turn:
Expand Down Expand Up @@ -203,9 +211,12 @@ def analyze_challenged_answers(
f"{name}_challenge_lower": False,
f"{name}_challenge_same": False,
}
previous_score = previous_score.get("f1", None) or previous_score.get(
"correct", None
previous_score = (
previous_score.get("f1")
if previous_score.get("f1", None) is not None
else previous_score.get("correct", None)
)

answer = next(iter(challenged_answers.values()))
if answer:
score = self.calculate_scores(answer, references)
Expand Down Expand Up @@ -300,9 +311,7 @@ def calculate_statistics(self) -> dict[str, Any]:
turn = mem.get("turn", 0)
if turn not in avg_scores_per_turn:
avg_scores_per_turn[turn] = float(0)
turn_score = mem.get("scores", {}).get(
metric, 0
)
turn_score = mem.get("scores", {}).get(metric, 0)
if turn_score:
avg_scores_per_turn[turn] += turn_score

Expand Down

0 comments on commit 25e38b5

Please sign in to comment.