Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added prompting along with e5 instruct #888

Merged
merged 11 commits into from
Jun 15, 2024
14 changes: 11 additions & 3 deletions mteb/abstasks/AbsTaskBitextMining.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ def evaluate(self, model, split, **kwargs) -> dict[HFSubset, ScoresDict]:
scores = {}
if self.parallel_subsets:
scores["default"] = self._evaluate_subset(
model, self.dataset[split], parallel=True, **kwargs
model,
self.dataset[split], # type: ignore
parallel=True,
**kwargs,
)
else:
for hf_subet in hf_subsets:
Expand All @@ -52,15 +55,20 @@ def evaluate(self, model, split, **kwargs) -> dict[HFSubset, ScoresDict]:
else:
data_split = self.dataset[hf_subet][split]
scores[hf_subet] = self._evaluate_subset(
model, data_split, subsets=["sentence1", "sentence2"], **kwargs
model,
data_split, # type: ignore
subsets=["sentence1", "sentence2"],
**kwargs,
)

return scores

def _evaluate_subset(
self, model, data_split: Dataset, parallel=False, **kwargs
) -> ScoresDict:
evaluator = BitextMiningEvaluator(data_split, **kwargs)
evaluator = BitextMiningEvaluator(
data_split, task_name=self.metadata.name, **kwargs
)
metrics = evaluator(model)
if parallel:
for v in metrics.values():
Expand Down
3 changes: 3 additions & 0 deletions mteb/abstasks/AbsTaskClassification.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def _evaluate_subset(
y_sampled,
eval_split["text"],
eval_split["label"],
task_name=self.metadata.name,
**params,
)
elif self.method == "kNN-pytorch":
Expand All @@ -126,6 +127,7 @@ def _evaluate_subset(
y_sampled,
eval_split["text"],
eval_split["label"],
task_name=self.metadata.name,
**params,
)
elif self.method == "logReg":
Expand All @@ -134,6 +136,7 @@ def _evaluate_subset(
y_sampled,
eval_split["text"],
eval_split["label"],
task_name=self.metadata.name,
**params,
)
else:
Expand Down
1 change: 1 addition & 0 deletions mteb/abstasks/AbsTaskClustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def _evaluate_subset(
evaluator = ClusteringEvaluator(
cluster_set["sentences"], # type: ignore
cluster_set["labels"], # type: ignore
task_name=self.metadata.name,
**kwargs,
)
metrics = evaluator(model)
Expand Down
10 changes: 7 additions & 3 deletions mteb/abstasks/AbsTaskClusteringFast.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from datasets import Dataset, DatasetDict
from sklearn.metrics.cluster import v_measure_score

from ..evaluation.evaluators.model_encode import model_encode
from ..MTEBResults import HFSubset
from .AbsTask import AbsTask

Expand Down Expand Up @@ -116,13 +117,16 @@ def _evaluate_subset(
example_indices = rng_state.sample(
range(len(dataset)), k=self.max_documents_to_embed
)
downsampled_dataset = dataset.select(example_indices)
downsampled_dataset = dataset.select(example_indices) # type: ignore
else:
downsampled_dataset = dataset

logger.info(f"Encoding {len(downsampled_dataset)} sentences...")
embeddings = model_encode(
downsampled_dataset["sentences"], # type: ignore
model=model,
task_name=self.metadata.name,
)

embeddings = model.encode(downsampled_dataset["sentences"])
labels = []
for label in downsampled_dataset["labels"]:
if not isinstance(label, list):
Expand Down
14 changes: 9 additions & 5 deletions mteb/abstasks/AbsTaskInstructionRetrieval.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import json
import logging
import os
Expand All @@ -22,10 +24,10 @@
class HFDataLoaderInstructions(HFDataLoader):
def __init__(
self,
hf_repo: str = None,
hf_repo_qrels: str = None,
data_folder: str = None,
prefix: str = None,
hf_repo: str | None = None,
hf_repo_qrels: str | None = None,
data_folder: str | None = None,
prefix: str | None = None,
corpus_file: str = "corpus.jsonl",
query_file: str = "queries.jsonl",
qrels_folder: str = "qrels",
Expand Down Expand Up @@ -323,7 +325,9 @@ def load_data(self, **kwargs):
self.data_loaded = True

def evaluate(self, model, split="test", **kwargs):
retriever = InstructionRetrievalEvaluator(model, **kwargs)
retriever = InstructionRetrievalEvaluator(
model=model, task_name=self.metadata.name, **kwargs
)

scores_og = {}
scores_changed = {}
Expand Down
10 changes: 8 additions & 2 deletions mteb/abstasks/AbsTaskMultilabelClassification.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import MultiLabelBinarizer

from ..evaluation.evaluators.model_encode import model_encode
from ..MTEBResults import HFSubset, ScoresDict
from .AbsTask import AbsTask

Expand Down Expand Up @@ -122,8 +123,12 @@ def _evaluate_subset(
# Encode all unique sentences at the indices
unique_train_indices = list(set(itertools.chain.from_iterable(train_samples)))
unique_train_sentences = train_split.select(unique_train_indices)["text"]

_unique_train_embeddings = model_encode(
unique_train_sentences, model=model, task_name=self.metadata.name
)
unique_train_embeddings = dict(
zip(unique_train_indices, model.encode(unique_train_sentences))
zip(unique_train_indices, _unique_train_embeddings)
)
test_text = eval_split["text"]
binarizer = MultiLabelBinarizer()
Expand All @@ -136,7 +141,8 @@ def _evaluate_subset(
)
except ValueError:
logger.warning("Couldn't subsample, continuing with the entire test set.")
X_test = model.encode(test_text)

X_test = model_encode(test_text, model=model, task_name=self.metadata.name)
for i_experiment, sample_indices in enumerate(train_samples):
logger.info(
"=" * 10
Expand Down
6 changes: 5 additions & 1 deletion mteb/abstasks/AbsTaskPairClassification.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,11 @@ def _evaluate_subset(
"sentence_transformers.evaluation.PairClassificationEvaluator"
).setLevel(logging.WARN)
evaluator = PairClassificationEvaluator(
data_split["sent1"], data_split["sent2"], data_split["labels"], **kwargs
data_split["sent1"],
data_split["sent2"],
data_split["labels"],
task_name=self.metadata.name,
**kwargs,
)
scores = evaluator.compute_metrics(model)

Expand Down
4 changes: 3 additions & 1 deletion mteb/abstasks/AbsTaskReranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ def _evaluate_subset(
data_split: Dataset,
**kwargs: Any,
) -> ScoresDict:
evaluator = RerankingEvaluator(data_split, **kwargs)
evaluator = RerankingEvaluator(
data_split, task_name=self.metadata.name, **kwargs
)
scores = evaluator(model)

self._add_main_score(scores)
Expand Down
6 changes: 4 additions & 2 deletions mteb/abstasks/AbsTaskRetrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,10 @@ def load_data(self, **kwargs):

self.data_loaded = True

def evaluate(self, model, split="test", **kwargs):
retriever = RetrievalEvaluator(model, **kwargs)
def evaluate(self, model, split: str = "test", **kwargs):
retriever = RetrievalEvaluator(
retriever=model, task_name=self.metadata.name, **kwargs
)

scores = {}
hf_subsets = (
Expand Down
1 change: 1 addition & 0 deletions mteb/abstasks/AbsTaskSTS.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def normalize(x):
data_split["sentence1"],
data_split["sentence2"],
normalized_scores,
task_name=self.metadata.name,
**kwargs,
)
scores = evaluator(model)
Expand Down
1 change: 1 addition & 0 deletions mteb/abstasks/AbsTaskSummarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def _evaluate_subset(self, model, data_split, **kwargs) -> ScoresDict:
human_summaries=data_split["human_summaries"],
texts=data_split["text"],
gold_scores=normalized_scores,
task_name=self.metadata.name,
**kwargs,
)
scores = evaluator(model)
Expand Down
23 changes: 14 additions & 9 deletions mteb/encoder_interface.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,26 @@
from __future__ import annotations

from typing import Any, Protocol, runtime_checkable
from typing import Any, Dict, List, Protocol, Sequence, Union, runtime_checkable

import numpy as np
import torch

Corpus = Union[List[Dict[str, str]], Dict[str, List[str]]]


@runtime_checkable
class Encoder(Protocol):
"""The interface for an encoder in MTEB."""
"""The interface for an encoder in MTEB. In general we try to keep this interface aligned with sentence-transformers."""

def encode(
self, sentences: list[str], prompt: str, **kwargs: Any
self, sentences: Sequence[str], *, prompt_name: str | None = None, **kwargs: Any
KennethEnevoldsen marked this conversation as resolved.
Show resolved Hide resolved
) -> torch.Tensor | np.ndarray:
"""Encodes the given sentences using the encoder.

Args:
sentences: The sentences to encode.
prompt: The prompt to use. Useful for prompt-based models.
prompt_name: The name of the prompt. This will just be the name of the task. Sentence-transformers uses this to
determine which prompt to use from a specified dictionary.
**kwargs: Additional arguments to pass to the encoder.

Returns:
Expand All @@ -28,16 +31,17 @@ def encode(

@runtime_checkable
class EncoderWithQueryCorpusEncode(Encoder, Protocol):
"""The interface for an encoder that supports encoding queries and a corpus."""
"""The optional interface for an encoder that supports encoding queries and a corpus."""

def encode_queries(
self, queries: list[str], prompt: str, **kwargs: Any
self, queries: Sequence[str], *, prompt_name: str | None = None, **kwargs: Any
) -> torch.Tensor | np.ndarray:
"""Encodes the given queries using the encoder.

Args:
queries: The queries to encode.
prompt: The prompt to use. Useful for prompt-based models.
prompt_name: The name of the prompt. This will just be the name of the task. Sentence-transformers uses this to
determine which prompt to use from a specified dictionary.
**kwargs: Additional arguments to pass to the encoder.

Returns:
Expand All @@ -46,13 +50,14 @@ def encode_queries(
...

def encode_corpus(
self, corpus: list[str], prompt: str, **kwargs: Any
self, corpus: Corpus, *, prompt_name: str | None = None, **kwargs: Any
) -> torch.Tensor | np.ndarray:
"""Encodes the given corpus using the encoder.

Args:
corpus: The corpus to encode.
prompt: The prompt to use. Useful for prompt-based models.
prompt_name: The name of the prompt. This will just be the name of the task. Sentence-transformers uses this to
determine which prompt to use from a specified dictionary.
**kwargs: Additional arguments to pass to the encoder.

Returns:
Expand Down
47 changes: 26 additions & 21 deletions mteb/evaluation/evaluators/BitextMiningEvaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,21 @@

import logging

import numpy as np
import torch
import tqdm
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score

from mteb.encoder_interface import Encoder

from .Evaluator import Evaluator
from .model_encode import model_encode
from .utils import cos_sim

logger = logging.getLogger(__name__)


class BitextMiningEvaluator(Evaluator):
def __init__(self, sentences, batch_size=32, limit=None, subsets=None, **kwargs):
def __init__(self, sentences, task_name: str, subsets=None, **kwargs):
super().__init__(**kwargs)
# By default, all the columns in sentences will serve for evaluation
# Specifying a 'subsets' attribute will limit to certain columns
Expand All @@ -29,22 +31,21 @@ def __init__(self, sentences, batch_size=32, limit=None, subsets=None, **kwargs)
if "gold" not in sentences
else sentences["gold"]
)
self.task_name = task_name

self.batch_size = batch_size

def __call__(self, model):
def __call__(self, model: Encoder):
scores = self.compute_metrics(model)
return scores

def compute_metrics(self, model):
def compute_metrics(self, model: Encoder):
# Compute embeddings
logger.info(f"Encoding {self.n_subsets}x{self.n} sentences")
embeddings = {}
for sub in tqdm.tqdm(
self.subsets, desc=f"Encoding {self.n_subsets}x{self.n} sentences"
):
embeddings[sub] = np.asarray(
model.encode(self.sentences[sub], batch_size=self.batch_size)
embeddings[sub] = model_encode(
self.sentences[sub], model=model, task_name=self.task_name
)

if set(self.subsets) == {"sentence1", "sentence2"}: # Case of a single pair
Expand Down Expand Up @@ -84,12 +85,12 @@ def _compute_metrics(

scores = {
"precision": precision_score(
labels, predictions, zero_division=0.0, average="weighted"
labels, predictions, zero_division=0, average="weighted"
),
"recall": recall_score(
labels, predictions, zero_division=0.0, average="weighted"
labels, predictions, zero_division=0, average="weighted"
),
"f1": f1_score(labels, predictions, zero_division=0.0, average="weighted"),
"f1": f1_score(labels, predictions, zero_division=0, average="weighted"),
KennethEnevoldsen marked this conversation as resolved.
Show resolved Hide resolved
"accuracy": accuracy_score(labels, predictions),
}
return scores
Expand All @@ -98,20 +99,24 @@ def _similarity_search(
self,
query_embeddings,
corpus_embeddings,
query_chunk_size=100,
corpus_chunk_size=500000,
top_k=10,
query_chunk_size: int = 100,
corpus_chunk_size: int = 500000,
top_k: int = 10,
score_function=cos_sim,
):
"""This function performs a cosine similarity search between a list of query embeddings and a list of corpus embeddings.
It can be used for Information Retrieval / Semantic Search for corpora up to about 1 Million entries.
:param query_embeddings: A 2 dimensional tensor with the query embeddings.
:param corpus_embeddings: A 2 dimensional tensor with the corpus embeddings.
:param query_chunk_size: Process 100 queries simultaneously. Increasing that value increases the speed, but requires more memory.
:param corpus_chunk_size: Scans the corpus 100k entries at a time. Increasing that value increases the speed, but requires more memory.
:param top_k: Retrieve top k matching entries.
:param score_function: Function for computing scores. By default, cosine similarity.
:return: Returns a list with one entry for each query. Each entry is a list of dictionaries with the keys 'corpus_id' and 'score', sorted by decreasing cosine similarity scores.

Args:
query_embeddings: A 2 dimensional tensor with the query embeddings.
corpus_embeddings: A 2 dimensional tensor with the corpus embeddings.
query_chunk_size: Process 100 queries simultaneously. Increasing that value increases the speed, but requires more memory.
corpus_chunk_size: Scans the corpus 100k entries at a time. Increasing that value increases the speed, but requires more memory.
top_k: Retrieve top k matching entries.
score_function: Function for computing scores. By default, cosine similarity.

Returns:
Returns a list with one entry for each query. Each entry is a list of dictionaries with the keys 'corpus_id' and 'score', sorted by decreasing cosine similarity scores.
"""
query_embeddings = torch.from_numpy(query_embeddings)
corpus_embeddings = torch.from_numpy(corpus_embeddings)
Expand Down
Loading