Skip to content

Commit

Permalink
add linting
Browse files Browse the repository at this point in the history
  • Loading branch information
jordiclive committed Jun 12, 2024
1 parent 0df3809 commit 0f35c8d
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 31 deletions.
104 changes: 81 additions & 23 deletions mteb/evaluation/evaluators/RerankingEvaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(
use_batched_encoding: bool = True,
limit: int | None = None,
k_values: list[int] = [1, 3, 5, 10, 20, 100, 1000],
evaluator_type: str = "standard"
evaluator_type: str = "standard",
**kwargs,
):
super().__init__(**kwargs)
Expand Down Expand Up @@ -116,21 +116,19 @@ def compute_metrics_batched(self, model: Encoder | EncoderWithQueryCorpusEncode)
raise ValueError(
f"Query must be a string or a list of strings but is {type(self.samples[0]['query'])}"
)

if self.evaluator_type == "standard":
results = self.encode_candidates(all_query_embs,encode_corpus_func,True)
results = self.encode_candidates(all_query_embs, encode_corpus_func, True)
elif self.evaluator_type == "miracl":
results = self.encode_candidates_miracl(all_query_embs, encode_corpus_func)
return results

def compute_metrics_individual(self, model):
"""Embeds every (query, positive, negative) tuple individually.
Is slower than the batched version, but saves memory as only the
embeddings for one tuple are needed. Useful when you have
a really large test set
"""


# using encode_queries and encode_corpus functions if they exists,
# which can be defined by users to add different instructions for query and passage conveniently
encode_queries_func = (
Expand All @@ -140,20 +138,41 @@ def compute_metrics_individual(self, model):
model.encode_corpus if hasattr(model, "encode_corpus") else model.encode
)
if self.evaluator_type == "standard":
results = self.encode_candidates(encode_queries_func, encode_corpus_func,False,encode_corpus_func=encode_corpus_func)
results = self.encode_candidates(
encode_queries_func,
encode_corpus_func,
False,
encode_corpus_func=encode_corpus_func,
)
elif self.evaluator_type == "miracl":
results = self.encode_candidates_miracl_individual(encode_queries_func,encode_corpus_func)
results = self.encode_candidates_miracl_individual(
encode_queries_func, encode_corpus_func
)
return results

def encode_candidates(self, all_query_embs,encode_corpus_func,batched,encode_queries_func=None):
def encode_candidates(
self, all_query_embs, encode_corpus_func, batched, encode_queries_func=None
):
all_mrr_scores = []
all_ap_scores = []
all_conf_scores = []
logger.info("Encoding candidates...")
if batched:
self.encode_candidates_batched(all_query_embs, encode_corpus_func,all_mrr_scores, all_ap_scores, all_conf_scores)
self.encode_candidates_batched(
all_query_embs,
encode_corpus_func,
all_mrr_scores,
all_ap_scores,
all_conf_scores,
)
else:
self.encode_candidates_individual(encode_queries_func, encode_corpus_func,all_mrr_scores, all_ap_scores, all_conf_scores)
self.encode_candidates_individual(
encode_queries_func,
encode_corpus_func,
all_mrr_scores,
all_ap_scores,
all_conf_scores,
)
mean_ap = np.mean(all_ap_scores)
mean_mrr = np.mean(all_mrr_scores)

Expand All @@ -163,14 +182,23 @@ def encode_candidates(self, all_query_embs,encode_corpus_func,batched,encode_que

return {**{"map": mean_ap, "mrr": mean_mrr}, **naucs_map, **naucs_mrr}


def encode_candidates_batched(self, all_query_embs, encode_corpus_func,all_mrr_scores, all_ap_scores, all_conf_scores):
def encode_candidates_batched(
self,
all_query_embs,
encode_corpus_func,
all_mrr_scores,
all_ap_scores,
all_conf_scores,
):
all_docs = []
for sample in self.samples:
all_docs.extend(sample["positive"])
all_docs.extend(sample["negative"])

all_docs_embs = self._encode_unique_texts(all_docs, encode_corpus_func,)
all_docs_embs = self._encode_unique_texts(
all_docs,
encode_corpus_func,
)

# Compute scores and confidence scores
logger.info("Evaluating...")
Expand All @@ -190,10 +218,23 @@ def encode_candidates_batched(self, all_query_embs, encode_corpus_func,all_mrr_s
if num_pos == 0 or num_neg == 0:
continue
is_relevant = [True] * num_pos + [False] * num_neg
self.apply_sim_scores(query_emb, docs_emb, is_relevant, all_mrr_scores, all_ap_scores, all_conf_scores)


def encode_candidates_individual(self, encode_queries_func, encode_corpus_func,all_mrr_scores, all_ap_scores, all_conf_scores):
self.apply_sim_scores(
query_emb,
docs_emb,
is_relevant,
all_mrr_scores,
all_ap_scores,
all_conf_scores,
)

def encode_candidates_individual(
self,
encode_queries_func,
encode_corpus_func,
all_mrr_scores,
all_ap_scores,
all_conf_scores,
):
for instance in tqdm.tqdm(self.samples, desc="Samples"):
query = instance["query"]
positive = list(instance["positive"])
Expand All @@ -212,17 +253,32 @@ def encode_candidates_individual(self, encode_queries_func, encode_corpus_func,a
encode_queries_func(query, batch_size=self.batch_size)
)
docs_emb = np.asarray(encode_corpus_func(docs, batch_size=self.batch_size))
self.apply_sim_scores(query_emb, docs_emb, is_relevant, all_mrr_scores, all_ap_scores, all_conf_scores)
self.apply_sim_scores(
query_emb,
docs_emb,
is_relevant,
all_mrr_scores,
all_ap_scores,
all_conf_scores,
)

def apply_sim_scores(self,query_emb, docs_emb, is_relevant, all_mrr_scores, all_ap_scores, all_conf_scores):
def apply_sim_scores(
self,
query_emb,
docs_emb,
is_relevant,
all_mrr_scores,
all_ap_scores,
all_conf_scores,
):
sim_scores = self._compute_sim_scores_instance(query_emb, docs_emb)
scores = self._compute_metrics_instance(sim_scores, is_relevant)
conf_scores = self.conf_scores(sim_scores.tolist())

all_mrr_scores.append(scores["mrr"])
all_ap_scores.append(scores["ap"])
all_conf_scores.append(conf_scores)

def encode_candidates_miracl(self, all_query_embs, encode_corpus_func):
all_docs = []
for sample in self.samples:
Expand Down Expand Up @@ -258,7 +314,9 @@ def encode_candidates_miracl(self, all_query_embs, encode_corpus_func):
scores_miracl = self.collect_miracl_results(results, qrels)
return scores_miracl

def encode_candidates_miracl_individual(self, encode_queries_func, encode_corpus_func):
def encode_candidates_miracl_individual(
self, encode_queries_func, encode_corpus_func
):
results, qrels = {}, {}
for i, instance in enumerate(tqdm.tqdm(self.samples, desc="Samples")):
query = instance["query"]
Expand All @@ -282,7 +340,7 @@ def encode_candidates_miracl_individual(self, encode_queries_func, encode_corpus

scores_miracl = self.collect_miracl_results(results, qrels)
return scores_miracl

def collect_miracl_results(self, results, qrels):
ndcg, _map, recall, precision, naucs = RetrievalEvaluator.evaluate(
qrels=qrels,
Expand Down
9 changes: 1 addition & 8 deletions mteb/tasks/Reranking/multilingual/MIRACLReranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,11 @@
import logging
from typing import Any

import numpy as np
import torch
import tqdm
from datasets import Dataset

from mteb.abstasks.TaskMetadata import TaskMetadata
from mteb.encoder_interface import Encoder, EncoderWithQueryCorpusEncode
from mteb.evaluation.evaluators import RerankingEvaluator
from mteb.evaluation.evaluators.RetrievalEvaluator import RetrievalEvaluator
from mteb.evaluation.evaluators.utils import cos_sim
from mteb.MTEBResults import ScoresDict

from ....abstasks import MultilingualTask
Expand Down Expand Up @@ -89,10 +84,8 @@ def _evaluate_subset(
data_split: Dataset,
**kwargs: Any,
) -> ScoresDict:
evaluator = RerankingEvaluator(data_split, evaluator_type='miracl',**kwargs)
evaluator = RerankingEvaluator(data_split, evaluator_type="miracl", **kwargs)
scores = evaluator(model)

self._add_main_score(scores)
return scores


0 comments on commit 0f35c8d

Please sign in to comment.