Skip to content

Commit

Permalink
Merge branch 'embeddings-benchmark:main' into fix-custom-remote-code
Browse files Browse the repository at this point in the history
  • Loading branch information
henilp105 authored Jun 16, 2024
2 parents fc8525c + 3fbce3d commit a292e79
Show file tree
Hide file tree
Showing 18 changed files with 1,005 additions and 91 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -139,4 +139,5 @@ tests/results
tmp.py

# sandbox
sb.ipynb
sb.ipynb
tests/create_meta/model_card.md
14 changes: 10 additions & 4 deletions mteb/MTEBResults.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,13 +209,19 @@ def _convert_from_before_v1_11_0(cls, data: dict) -> MTEBResults:
main_score = task.metadata.main_score
for split, split_score in scores.items():
for hf_subset, hf_subset_scores in split_score.items():
if task.metadata.type == "STS":
for name, prev_name in [
("cosine", "cos_sim"),
("manhattan", "manhattan"),
("euclidean", "euclidean"),
]:
prev_name_scores = hf_subset_scores.pop(prev_name)
for k, v in prev_name_scores.items():
hf_subset_scores[f"{name}_{k}"] = v

if "main_score" not in hf_subset_scores:
if main_score in hf_subset_scores:
hf_subset_scores["main_score"] = hf_subset_scores[main_score]
elif main_score == "cosine_spearman":
hf_subset_scores["main_score"] = hf_subset_scores["cos_sim"][
"spearman"
]
else:
logger.warning(f"Main score {main_score} not found in scores")
hf_subset_scores["main_score"] = None
Expand Down
5 changes: 1 addition & 4 deletions mteb/abstasks/AbsTaskSTS.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,4 @@ def normalize(x):
return scores

def _add_main_score(self, scores: ScoresDict) -> None:
m_score = self.metadata.main_score
dist, metric = m_score.split("_")
dist_mapping = {"cosine": "cos_sim"}
scores["main_score"] = scores[dist_mapping.get(dist, dist)][metric]
scores["main_score"] = scores[self.metadata.main_score]
5 changes: 1 addition & 4 deletions mteb/abstasks/AbsTaskSummarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,4 @@ def _evaluate_subset(self, model, data_split, **kwargs) -> ScoresDict:
return scores

def _add_main_score(self, scores: ScoresDict) -> None:
m_score = self.metadata.main_score
dist, metric = m_score.split("_")
dist_mapping = {"cosine": "cos_sim"}
scores["main_score"] = scores[dist_mapping.get(dist, dist)][metric]
scores["main_score"] = scores[self.metadata.main_score]
25 changes: 17 additions & 8 deletions mteb/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,28 +245,37 @@ def create_meta(args: argparse.Namespace) -> None:
]

task_results = [MTEBResults.from_disk(path) for path in json_files]
task_results = sorted(task_results, key=lambda x: x.task_name)

yaml_results = []
for task_result in task_results:
task = mteb.get_task(task_result.task_name)

for split, hf_subset_scores in task_result.scores.items():
for hf_subset_score in hf_subset_scores:
metrics = [
{
"type": k,
"value": v,
}
for k, v in hf_subset_score.items()
if isinstance(v, (int, float))
]
if task.metadata.main_score not in hf_subset_score:
raise ValueError(
f"Main score {task.metadata.main_score} not found in metrics or is not a number."
)

yaml_result = {
"task": {"type": task.metadata.type},
"dataset": {
"type": task.metadata.dataset["path"],
"name": f"MTEB {task.metadata.name}",
"config": hf_subset_score["hf_subset"],
"name": f"MTEB {task.metadata.name} ({hf_subset_score['hf_subset']})",
"config": hf_subset_score["hf_subset"],
"split": split,
"revision": task_result.dataset_revision,
},
"metrics": [
{
"type": task.metadata.main_score,
"value": hf_subset_score["main_score"],
}
],
"metrics": metrics,
}
yaml_results.append(yaml_result)

Expand Down
53 changes: 52 additions & 1 deletion mteb/encoder_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@

@runtime_checkable
class Encoder(Protocol):
"""The interface for an encoder in MTEB. In general we try to keep this interface aligned with sentence-transformers."""
"""The interface for an encoder in MTEB.
Besides the required functions specified below, the encoder can additionally specify the the following signatures seen below.
In general the interface is kept aligned with sentence-transformers interface. In cases where exceptions occurs these are handled within MTEB.
"""

def encode(
self, sentences: Sequence[str], *, prompt_name: str | None = None, **kwargs: Any
Expand All @@ -29,6 +33,53 @@ def encode(
...


class EncoderWithSimilarity(Encoder, Protocol):
"""Besides the required functions in the Encoder interface, the encoder can additionally specify its own similiarity functions.
MTEB will by default attempt to use similarity_pairwise function first before falling back to similarity function. If the encoder does not support
similarity_pairwise function, it should simply not implement it.
"""

def similarity(
self,
embeddings1: torch.Tensor | np.ndarray,
embeddings2: torch.Tensor | np.ndarray,
) -> torch.Tensor:
"""Compute the similarity between two collections of embeddings. The output will be a matrix with the similarity scores between all embeddings
from the first parameter and all embeddings from the second parameter. This differs from similarity_pairwise which computes the similarity
between each pair of embeddings.
read more at: https://www.sbert.net/docs/package_reference/sentence_transformer/SentenceTransformer.html#sentence_transformers.SentenceTransformer.similarity
Args:
embeddings1: [num_embeddings_1, embedding_dim] or [embedding_dim]-shaped numpy array or torch tensor.
embeddings2: [num_embeddings_2, embedding_dim] or [embedding_dim]-shaped numpy array or torch tensor.
Returns:
A [num_embeddings_1, num_embeddings_2]-shaped torch tensor with similarity scores.
"""
...

def similarity_pairwise(
self,
embeddings1: torch.Tensor | np.ndarray,
embeddings2: torch.Tensor | np.ndarray,
) -> torch.Tensor:
"""Compute the similarity between two collections of embeddings. The output will be a vector with the similarity scores between each pair of
embeddings.
read more at: https://www.sbert.net/docs/package_reference/sentence_transformer/SentenceTransformer.html#sentence_transformers.SentenceTransformer.similarity_pairwise
Args:
embeddings1: [num_embeddings, embedding_dim] or [embedding_dim]-shaped numpy array or torch tensor.
embeddings2: [num_embeddings, embedding_dim] or [embedding_dim]-shaped numpy array or torch tensor.
Returns:
A [num_embeddings]-shaped torch tensor with pairwise similarity scores.
"""
...


@runtime_checkable
class EncoderWithQueryCorpusEncode(Encoder, Protocol):
"""The optional interface for an encoder that supports encoding queries and a corpus."""
Expand Down
46 changes: 30 additions & 16 deletions mteb/evaluation/evaluators/PairClassificationEvaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
paired_manhattan_distances,
)

from mteb.encoder_interface import Encoder
from mteb.encoder_interface import Encoder, EncoderWithSimilarity
from mteb.evaluation.evaluators.model_encode import model_encode

from .Evaluator import Evaluator
Expand Down Expand Up @@ -61,15 +61,15 @@ def __init__(
for label in labels:
assert label == 0 or label == 1

def __call__(self, model: Encoder):
def __call__(self, model: Encoder | EncoderWithSimilarity):
scores = self.compute_metrics(model)

# Main score is the max of Average Precision (AP)
main_score = max(scores[short_name]["ap"] for short_name in scores)
scores["main_score"] = main_score
return scores

def compute_metrics(self, model: Encoder):
def compute_metrics(self, model: Encoder | EncoderWithSimilarity):
sentences = list(set(self.sentences1 + self.sentences2))

total_sents = len(self.sentences1) + len(self.sentences2)
Expand All @@ -90,6 +90,17 @@ def compute_metrics(self, model: Encoder):
manhattan_distances = paired_manhattan_distances(embeddings1, embeddings2)
euclidean_distances = paired_euclidean_distances(embeddings1, embeddings2)

if hasattr(model, "similarity_pairwise"):
similarity_scores = model.similarity_pairwise(embeddings1, embeddings2) # type: ignore
elif hasattr(model, "similarity"):
_similarity_scores = [
float(model.similarity(e1, e2)) # type: ignore
for e1, e2 in zip(embeddings1, embeddings2)
]
similarity_scores = np.array(_similarity_scores)
else:
similarity_scores = cosine_scores # Default to cosine similarity

embeddings1_np = np.asarray(embeddings1)
embeddings2_np = np.asarray(embeddings2)
dot_scores = [
Expand All @@ -101,7 +112,8 @@ def compute_metrics(self, model: Encoder):
labels = np.asarray(self.labels)
output_scores = {}
for short_name, name, scores, reverse in [
["cos_sim", "Cosine-Similarity", cosine_scores, True],
["similarity", "Model-Specified Similarity", similarity_scores, True],
["cosine", "Cosine-Similarity", cosine_scores, True],
["manhattan", "Manhattan-Distance", manhattan_distances, False],
["euclidean", "Euclidean-Distance", euclidean_distances, False],
["dot", "Dot-Product", dot_scores, True],
Expand All @@ -111,16 +123,18 @@ def compute_metrics(self, model: Encoder):
return output_scores

@staticmethod
def _compute_metrics(scores, labels, high_score_more_similar):
def _compute_metrics(
scores: np.ndarray, labels: np.ndarray, high_score_more_similar: bool
) -> dict[str, float]:
"""Compute the metrics for the given scores and labels.
Args:
scores (`np.ndarray` of shape (n_pairs, )): The similarity/dissimilarity scores for the pairs.
labels (`np.ndarray` of shape (n_pairs, )): The labels for the pairs.
high_score_more_similar (`bool`): If true, then the higher the score, the more similar the pairs are.
scores: The similarity/dissimilarity scores for the pairs, specified as an array of shape (n_pairs, ).
labels: The labels for the pairs, specified as an array of shape (n_pairs, ).
high_score_more_similar: If true, then the higher the score, the more similar the pairs are.
Returns:
`dict`: The metrics for the given scores and labels.
The metrics for the given scores and labels.
"""
acc, acc_threshold = PairClassificationEvaluator.find_best_acc_and_threshold(
scores, labels, high_score_more_similar
Expand All @@ -135,13 +149,13 @@ def _compute_metrics(scores, labels, high_score_more_similar):
)

return {
"accuracy": acc,
"accuracy_threshold": acc_threshold,
"f1": f1,
"f1_threshold": f1_threshold,
"precision": precision,
"recall": recall,
"ap": ap,
"accuracy": float(acc),
"accuracy_threshold": float(acc_threshold),
"f1": float(f1),
"f1_threshold": float(f1_threshold),
"precision": float(precision),
"recall": float(recall),
"ap": float(ap),
}

@staticmethod
Expand Down
45 changes: 32 additions & 13 deletions mteb/evaluation/evaluators/STSEvaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@

import logging

import numpy as np
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics.pairwise import (
paired_cosine_distances,
paired_euclidean_distances,
paired_manhattan_distances,
)

from mteb.encoder_interface import Encoder, EncoderWithSimilarity

from .Evaluator import Evaluator
from .model_encode import model_encode

Expand Down Expand Up @@ -37,7 +40,7 @@ def __init__(
self.batch_size = batch_size
self.task_name = task_name

def __call__(self, model):
def __call__(self, model: Encoder | EncoderWithSimilarity):
embeddings1 = model_encode(
self.sentences1,
model=model,
Expand Down Expand Up @@ -65,17 +68,33 @@ def __call__(self, model):
euclidean_pearson, _ = pearsonr(self.gold_scores, euclidean_distances)
euclidean_spearman, _ = spearmanr(self.gold_scores, euclidean_distances)

similarity_scores = None
if hasattr(model, "similarity_pairwise"):
similarity_scores = model.similarity_pairwise(embeddings1, embeddings2) # type: ignore
elif hasattr(model, "similarity"):
_similarity_scores = [
float(model.similarity(e1, e2)) # type: ignore
for e1, e2 in zip(embeddings1, embeddings2)
]
similarity_scores = np.array(_similarity_scores)

if similarity_scores is not None:
pearson = pearsonr(self.gold_scores, similarity_scores)
spearman = spearmanr(self.gold_scores, similarity_scores)
else:
# if model does not have a similarity function, we assume the cosine similarity
pearson = cosine_pearson
spearman = cosine_spearman

return {
"cos_sim": {
"pearson": cosine_pearson,
"spearman": cosine_spearman,
},
"manhattan": {
"pearson": manhatten_pearson,
"spearman": manhatten_spearman,
},
"euclidean": {
"pearson": euclidean_pearson,
"spearman": euclidean_spearman,
},
# using the models own similarity score
"pearson": pearson,
"spearman": spearman,
# generic similarity scores
"cosine_pearson": cosine_pearson,
"cosine_spearman": cosine_spearman,
"manhattan_pearson": manhatten_pearson,
"manhattan_spearman": manhatten_spearman,
"euclidean_pearson": euclidean_pearson,
"euclidean_spearman": euclidean_spearman,
}
Loading

0 comments on commit a292e79

Please sign in to comment.