Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge miracl evaluator #906

Merged
301 changes: 245 additions & 56 deletions mteb/evaluation/evaluators/RerankingEvaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import tqdm
from sklearn.metrics import average_precision_score

from mteb.evaluation.evaluators.RetrievalEvaluator import RetrievalEvaluator

from ...encoder_interface import Encoder, EncoderWithQueryCorpusEncode
from .Evaluator import Evaluator
from .model_encode import model_encode
Expand Down Expand Up @@ -38,6 +40,8 @@ def __init__(
batch_size: int = 512,
use_batched_encoding: bool = True,
limit: int | None = None,
k_values: list[int] = [1, 3, 5, 10, 20, 100, 1000],
evaluator_type: str = "standard",
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -50,6 +54,8 @@ def __init__(
self.batch_size = batch_size
self.use_batched_encoding = use_batched_encoding
self.task_name = task_name
self.k_values = k_values
self.evaluator_type = evaluator_type

if isinstance(self.samples, dict):
self.samples = list(self.samples.values())
Expand All @@ -76,10 +82,6 @@ def compute_metrics_batched(self, model: Encoder | EncoderWithQueryCorpusEncode)
"""Computes the metrices in a batched way, by batching all queries and
all documents together
"""
all_mrr_scores = []
all_ap_scores = []
all_conf_scores = []

# using encode_queries and encode_corpus functions if they exists,
# which can be defined by users to add different instructions for query and passage conveniently
encode_queries_func = (
Expand Down Expand Up @@ -118,7 +120,84 @@ def compute_metrics_batched(self, model: Encoder | EncoderWithQueryCorpusEncode)
f"Query must be a string or a list of strings but is {type(self.samples[0]['query'])}"
)

if self.evaluator_type == "standard":
results = self._encode_candidates(
encode_queries_func=encode_queries_func,
encode_corpus_func=encode_corpus_func,
batched=True,
all_query_embs=all_query_embs,
)
elif self.evaluator_type == "miracl":
results = self._encode_candidates_miracl(
encode_queries_func=encode_queries_func,
encode_corpus_func=encode_corpus_func,
batched=True,
all_query_embs=all_query_embs,
)
return results

def compute_metrics_individual(self, model):
"""Embeds every (query, positive, negative) tuple individually.
Is slower than the batched version, but saves memory as only the
embeddings for one tuple are needed. Useful when you have
a really large test set
"""
# using encode_queries and encode_corpus functions if they exists,
# which can be defined by users to add different instructions for query and passage conveniently
encode_queries_func = (
model.encode_queries if hasattr(model, "encode_queries") else model.encode
)
encode_corpus_func = (
model.encode_corpus if hasattr(model, "encode_corpus") else model.encode
)
if self.evaluator_type == "standard":
results = self._encode_candidates(
encode_queries_func=encode_queries_func,
encode_corpus_func=encode_corpus_func,
batched=False,
)
elif self.evaluator_type == "miracl":
results = self._encode_candidates_miracl(
encode_queries_func=encode_queries_func,
encode_corpus_func=encode_corpus_func,
batched=False,
)
return results

def _encode_candidates(
self, encode_corpus_func, batched, all_query_embs=None, encode_queries_func=None
):
all_mrr_scores = []
all_ap_scores = []
all_conf_scores = []
logger.info("Encoding candidates...")
if batched:
self._encode_candidates_batched(
encode_corpus_func=encode_corpus_func,
all_query_embs=all_query_embs,
all_mrr_scores=all_mrr_scores,
all_ap_scores=all_ap_scores,
all_conf_scores=all_conf_scores,
)
else:
self._encode_candidates_individual(
encode_queries_func=encode_queries_func,
encode_corpus_func=encode_corpus_func,
all_mrr_scores=all_mrr_scores,
all_ap_scores=all_ap_scores,
all_conf_scores=all_conf_scores,
)
scores = self._collect_results(all_mrr_scores, all_ap_scores, all_conf_scores)
return scores

def _encode_candidates_batched(
self,
all_query_embs,
encode_corpus_func,
all_mrr_scores,
all_ap_scores,
all_conf_scores,
):
all_docs = []
for sample in self.samples:
all_docs.extend(sample["positive"])
Expand Down Expand Up @@ -148,45 +227,24 @@ def compute_metrics_batched(self, model: Encoder | EncoderWithQueryCorpusEncode)

if num_pos == 0 or num_neg == 0:
continue

is_relevant = [True] * num_pos + [False] * num_neg
self._apply_sim_scores(
query_emb,
docs_emb,
is_relevant,
all_mrr_scores,
all_ap_scores,
all_conf_scores,
)

sim_scores = self._compute_sim_scores_instance(query_emb, docs_emb)
scores = self._compute_metrics_instance(sim_scores, is_relevant)
conf_scores = self.conf_scores(sim_scores.tolist())

all_mrr_scores.append(scores["mrr"])
all_ap_scores.append(scores["ap"])
all_conf_scores.append(conf_scores)

mean_ap = np.mean(all_ap_scores)
mean_mrr = np.mean(all_mrr_scores)

# Compute nAUCs
naucs_map = self.nAUC_scores(all_conf_scores, all_ap_scores, "map")
naucs_mrr = self.nAUC_scores(all_conf_scores, all_mrr_scores, "mrr")

return {**{"map": mean_ap, "mrr": mean_mrr}, **naucs_map, **naucs_mrr}

def compute_metrics_individual(self, model):
"""Embeds every (query, positive, negative) tuple individually.
Is slower than the batched version, but saves memory as only the
embeddings for one tuple are needed. Useful when you have
a really large test set
"""
all_mrr_scores = []
all_ap_scores = []
all_conf_scores = []

# using encode_queries and encode_corpus functions if they exists,
# which can be defined by users to add different instructions for query and passage conveniently
encode_queries_func = (
model.encode_queries if hasattr(model, "encode_queries") else model.encode
)
encode_corpus_func = (
model.encode_corpus if hasattr(model, "encode_corpus") else model.encode
)

def _encode_candidates_individual(
self,
encode_queries_func,
encode_corpus_func,
all_mrr_scores,
all_ap_scores,
all_conf_scores,
):
for instance in tqdm.tqdm(self.samples, desc="Samples"):
query = instance["query"]
positive = list(instance["positive"])
Expand All @@ -202,24 +260,19 @@ def compute_metrics_individual(self, model):
# .encoding interface requires List[str] as input
query = [query]
query_emb = np.asarray(
encode_queries_func(
query, task_name=self.task_name, batch_size=self.batch_size
)
encode_queries_func(query, batch_size=self.batch_size)
)
docs_emb = np.asarray(
encode_corpus_func(
docs, task_name=self.task_name, batch_size=self.batch_size
)
docs_emb = np.asarray(encode_corpus_func(docs, batch_size=self.batch_size))
self._apply_sim_scores(
query_emb,
docs_emb,
is_relevant,
all_mrr_scores,
all_ap_scores,
all_conf_scores,
)

sim_scores = self._compute_sim_scores_instance(query_emb, docs_emb)
scores = self._compute_metrics_instance(sim_scores, is_relevant)
conf_scores = self.conf_scores(sim_scores.tolist())

all_mrr_scores.append(scores["mrr"])
all_ap_scores.append(scores["ap"])
all_conf_scores.append(conf_scores)

def _collect_results(self, all_mrr_scores, all_ap_scores, all_conf_scores):
mean_ap = np.mean(all_ap_scores)
mean_mrr = np.mean(all_mrr_scores)

Expand All @@ -229,6 +282,142 @@ def compute_metrics_individual(self, model):

return {**{"map": mean_ap, "mrr": mean_mrr}, **naucs_map, **naucs_mrr}

def _encode_candidates_miracl(
self,
encode_corpus_func,
encode_queries_func,
batched,
all_query_embs=None,
):
if batched:
return self._encode_candidates_miracl_batched(
all_query_embs=all_query_embs, encode_corpus_func=encode_corpus_func
)
else:
return self._encode_candidates_miracl_individual(
encode_queries_func=encode_queries_func,
encode_corpus_func=encode_corpus_func,
)

def _encode_candidates_miracl_batched(self, all_query_embs, encode_corpus_func):
all_docs = []
for sample in self.samples:
all_docs.extend(sample["candidates"])

all_docs_embs = np.asarray(
encode_corpus_func(
all_docs, task_name=self.task_name, batch_size=self.batch_size
)
)

# Compute scores
logger.info("Evaluating...")
query_idx, docs_idx = 0, 0
results, qrels = {}, {}
for instance in self.samples:
num_subqueries = (
len(instance["query"]) if isinstance(instance["query"], list) else 1
)
query_emb = all_query_embs[query_idx : query_idx + num_subqueries]
query_idx += num_subqueries

positive = instance["positive"]
docs = instance["candidates"]
num_doc = len(docs)
docs_emb = all_docs_embs[docs_idx : docs_idx + num_doc]
docs_idx += num_doc

fake_qid = str(query_idx)
results[fake_qid] = self.rerank(query_emb, docs_emb)
qrels[fake_qid] = {
str(i): 1 if doc in positive else 0 for i, doc in enumerate(docs)
}

scores_miracl = self._collect_miracl_results(results, qrels)
KennethEnevoldsen marked this conversation as resolved.
Show resolved Hide resolved
return scores_miracl

def _encode_candidates_miracl_individual(
self, encode_queries_func, encode_corpus_func
):
results, qrels = {}, {}
for i, instance in enumerate(tqdm.tqdm(self.samples, desc="Samples")):
query = instance["query"]
positive = set(instance["positive"])
docs = list(instance["candidates"])

if isinstance(query, str):
# .encoding interface requires List[str] as input
query_emb = np.asarray(
encode_queries_func([query], batch_size=self.batch_size)
)
docs_emb = np.asarray(
encode_corpus_func(docs, batch_size=self.batch_size)
)

fake_qid = str(i)
results[fake_qid] = self.rerank(query_emb, docs_emb)
qrels[fake_qid] = {
str(i): 1 if doc in positive else 0 for i, doc in enumerate(docs)
}

scores_miracl = self._collect_miracl_results(results, qrels)
return scores_miracl

def _collect_miracl_results(self, results, qrels):
ndcg, _map, recall, precision, naucs = RetrievalEvaluator.evaluate(
qrels=qrels,
results=results,
k_values=self.k_values,
ignore_identical_ids=False,
)
scores = {**ndcg, **_map, **recall, **precision, **naucs}
scores_miracl = {f"{k}(MIRACL)": v for k, v in scores.items()}
return scores_miracl

def rerank(
self, query_emb: torch.Tensor, docs_emb: torch.Tensor
) -> dict[str, float]:
"""Rerank documents (docs_emb) given the query (query_emb)

Args:
query_emb: Query embedding of shape `(num_queries, hidden_size)`)
if `num_queries` > 0: we take the closest document to any of the queries
docs_emb: Candidates documents embeddings of shape `(num_pos+num_neg, hidden_size)`)

Returns:
similarity_scores:
"""
if not query_emb.shape[0]:
raise ValueError("Empty query embedding")

if not docs_emb.shape[0]:
return {"empty-docid": 0}

pred_scores = self.similarity_fct(query_emb, docs_emb)
if len(pred_scores.shape) > 1:
pred_scores = torch.amax(pred_scores, dim=0)

return {
str(i): score.detach().numpy().item() for i, score in enumerate(pred_scores)
}

def _apply_sim_scores(
self,
query_emb,
docs_emb,
is_relevant,
all_mrr_scores,
all_ap_scores,
all_conf_scores,
):
sim_scores = self._compute_sim_scores_instance(query_emb, docs_emb)
scores = self._compute_metrics_instance(sim_scores, is_relevant)
conf_scores = self.conf_scores(sim_scores.tolist())

all_mrr_scores.append(scores["mrr"])
all_ap_scores.append(scores["ap"])
all_conf_scores.append(conf_scores)

@staticmethod
def _encode_unique_texts(
all_texts: list[str],
Expand Down
Loading
Loading