Skip to content

Commit

Permalink
removing MIRACLevaluator
Browse files Browse the repository at this point in the history
  • Loading branch information
jordiclive committed Jun 12, 2024
1 parent 615dbbb commit 0df3809
Showing 1 changed file with 1 addition and 187 deletions.
188 changes: 1 addition & 187 deletions mteb/tasks/Reranking/multilingual/MIRACLReranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,196 +89,10 @@ def _evaluate_subset(
data_split: Dataset,
**kwargs: Any,
) -> ScoresDict:
evaluator = MIRACLRerankingEvaluator(data_split, **kwargs)
evaluator = RerankingEvaluator(data_split, evaluator_type='miracl',**kwargs)
scores = evaluator(model)

self._add_main_score(scores)
return scores


class MIRACLRerankingEvaluator(RerankingEvaluator):
"""This class evaluates a SentenceTransformer model for the task of re-ranking.
MIRACLRerankingEvaluator differs from RerankingEvaluator in two ways:
1. it uses the pytrec_eval via RetrievalEvaluator instead of the metrics provided by sklearn;
2. it reranks the top-k `candidates` from previous-stage retrieval which may not include all ground-truth `positive` documents
"""

def __init__(
self,
samples: list[dict],
mrr_at_k: int = 10,
name: str = "",
similarity_fct=cos_sim,
batch_size: int = 512,
use_batched_encoding: bool = True,
limit: int | None = None,
k_values: list[int] = [1, 3, 5, 10, 20, 100, 1000],
**kwargs,
):
"""Args:
k_values: ranking cutoff threshold when applicable
"""
super().__init__(
samples,
mrr_at_k,
name,
similarity_fct,
batch_size,
use_batched_encoding,
limit,
**kwargs,
)
self.k_values = k_values

def rerank(
self, query_emb: torch.Tensor, docs_emb: torch.Tensor
) -> dict[str, float]:
"""Rerank documents (docs_emb) given the query (query_emb)
Args:
query_emb: Query embedding of shape `(num_queries, hidden_size)`)
if `num_queries` > 0: we take the closest document to any of the queries
docs_emb: Candidates documents embeddings of shape `(num_pos+num_neg, hidden_size)`)
Returns:
similarity_scores:
"""
if not query_emb.shape[0]:
raise ValueError("Empty query embedding")

if not docs_emb.shape[0]:
return {"empty-docid": 0}

pred_scores = self.similarity_fct(query_emb, docs_emb)
if len(pred_scores.shape) > 1:
pred_scores = torch.amax(pred_scores, dim=0)

return {
str(i): score.detach().numpy().item() for i, score in enumerate(pred_scores)
}

def compute_metrics_batched(self, model: Encoder | EncoderWithQueryCorpusEncode):
"""Computes the metrices in a batched way, by batching all queries and
all documents together
"""
# using encode_queries and encode_corpus functions if they exists,
# which can be defined by users to add different instructions for query and passage conveniently
encode_queries_func = (
model.encode_queries
if isinstance(model, EncoderWithQueryCorpusEncode)
else model.encode
)
encode_corpus_func = (
model.encode_corpus
if isinstance(model, EncoderWithQueryCorpusEncode)
else model.encode
)

logger.info("Encoding queries...")
if isinstance(self.samples[0]["query"], str):
all_query_embs = np.asarray(
encode_queries_func(
[sample["query"] for sample in self.samples],
batch_size=self.batch_size,
)
)
elif isinstance(self.samples[0]["query"], list):
# In case the query is a list of strings, we get the most similar embedding to any of the queries
all_query_flattened = [
q for sample in self.samples for q in sample["query"]
]
all_query_embs = np.asarray(
encode_queries_func(all_query_flattened, batch_size=self.batch_size)
)
else:
raise ValueError(
f"Query must be a string or a list of strings but is {type(self.samples[0]['query'])}"
)

logger.info("Encoding candidates...")
all_docs = []
for sample in self.samples:
all_docs.extend(sample["candidates"])

all_docs_embs = np.asarray(
encode_corpus_func(all_docs, batch_size=self.batch_size)
)

# Compute scores
logger.info("Evaluating...")
query_idx, docs_idx = 0, 0
results, qrels = {}, {}
for instance in self.samples:
num_subqueries = (
len(instance["query"]) if isinstance(instance["query"], list) else 1
)
query_emb = all_query_embs[query_idx : query_idx + num_subqueries]
query_idx += num_subqueries

positive = instance["positive"]
docs = instance["candidates"]
num_doc = len(docs)
docs_emb = all_docs_embs[docs_idx : docs_idx + num_doc]
docs_idx += num_doc

fake_qid = str(query_idx)
results[fake_qid] = self.rerank(query_emb, docs_emb)
qrels[fake_qid] = {
str(i): 1 if doc in positive else 0 for i, doc in enumerate(docs)
}

ndcg, _map, recall, precision, naucs = RetrievalEvaluator.evaluate(
qrels=qrels,
results=results,
k_values=self.k_values,
ignore_identical_ids=False,
)
scores = {**ndcg, **_map, **recall, **precision, **naucs}
scores_miracl = {f"{k}(MIRACL)": v for k, v in scores.items()}
return scores_miracl

def compute_metrics_individual(self, model):
"""Embeds every (query, positive, negative) tuple individually.
Is slower than the batched version, but saves memory as only the
embeddings for one tuple are needed. Useful when you have
a really large test set
"""
# using encode_queries and encode_corpus functions if they exists,
# which can be defined by users to add different instructions for query and passage conveniently
encode_queries_func = (
model.encode_queries if hasattr(model, "encode_queries") else model.encode
)
encode_corpus_func = (
model.encode_corpus if hasattr(model, "encode_corpus") else model.encode
)

results, qrels = {}, {}
for i, instance in enumerate(tqdm.tqdm(self.samples, desc="Samples")):
query = instance["query"]
positive = set(instance["positive"])
docs = list(instance["candidates"])

if isinstance(query, str):
# .encoding interface requires List[str] as input
query_emb = np.asarray(
encode_queries_func([query], batch_size=self.batch_size)
)
docs_emb = np.asarray(
encode_corpus_func(docs, batch_size=self.batch_size)
)

fake_qid = str(i)
results[fake_qid] = self.rerank(query_emb, docs_emb)
qrels[fake_qid] = {
str(i): 1 if doc in positive else 0 for i, doc in enumerate(docs)
}

ndcg, _map, recall, precision, naucs = RetrievalEvaluator.evaluate(
qrels=qrels,
results=results,
k_values=self.k_values,
ignore_identical_ids=False,
)
scores = {**ndcg, **_map, **recall, **precision, **naucs}
scores_miracl = {f"{k}(MIRACL)": v for k, v in scores.items()}
return scores_miracl

0 comments on commit 0df3809

Please sign in to comment.