Skip to content

Commit

Permalink
feat: add instruct wrapper (#1768)
Browse files Browse the repository at this point in the history
* add instruct wrapper

* use get_task_instruction

* add logging messages

* apply based on PromptType

* update description

* change example model

* move nvembed

* Update mteb/models/instruct_wrapper.py

Co-authored-by: Isaac Chung <chungisaac1217@gmail.com>

* update docstrings

* add instruction to docs

* Apply suggestions from code review

Co-authored-by: Isaac Chung <chungisaac1217@gmail.com>

* lint

---------

Co-authored-by: Isaac Chung <chungisaac1217@gmail.com>
  • Loading branch information
Samoed and isaac-chung authored Jan 25, 2025
1 parent dfba463 commit ee0f15a
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 59 deletions.
15 changes: 15 additions & 0 deletions docs/adding_a_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,18 @@ The leaderboard [automatically refreshes daily](https://github.com/embeddings-be
###### Instantiating the Model with Prompts

If you are unable to directly add the prompts in the model configuration, you can instantiate the model using the `sentence_transformers_loader` and pass `prompts` as an argument. For more details, see the `mteb/models/bge_models.py` file.

##### Adding instruction models

Models that use instructions can use the [`InstructSentenceTransformerWrapper`](../mteb/models/instruct_wrapper.py). For example:
```python
model = ModelMeta(
loader=partial(
InstructSentenceTransformerWrapper,
model="nvidia/NV-Embed-v1",
revision="7604d305b621f14095a1aa23d351674c2859553a",
instruction_template="Instruct: {instruction}\nQuery: ",
),
...
)
```
85 changes: 85 additions & 0 deletions mteb/models/instruct_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import numpy as np
import torch
from sentence_transformers import SentenceTransformer

from mteb.encoder_interface import PromptType

Expand Down Expand Up @@ -78,3 +79,87 @@ def encode(
return embeddings

return InstructWrapper(model_name_or_path, mode, instruction_template, **kwargs)


class InstructSentenceTransformerWrapper(Wrapper):
def __init__(
self,
model_name: str,
revision: str,
instruction_template: str | Callable[[str], str] | None = None,
max_seq_length: int | None = None,
apply_instruction_to_passages: bool = True,
padding_side: str | None = None,
add_eos_token: bool = False,
**kwargs: Any,
):
"""Instruct Sentence Transformer Wrapper. Wrapper that passes instructions to the Sentence Transformer model.
Applied for models like NV-Embed, gte-Qwen, e5-mistral, etc.
Arguments:
model_name: Model name of the sentence transformers model.
revision: Revision of the sentence transformers model.
instruction_template: Model template. Should contain the string '{instruction}'.
max_seq_length: Maximum sequence length. If None, the maximum sequence length will be read from the model config.
apply_instruction_to_passages: Whether to apply the instruction template to the passages.
padding_side: Padding side. If None, the padding side will be read from the model config.
add_eos_token: Whether to add the eos token to each input example.
**kwargs: Kwargs for Sentence Transformer model.
"""
if (
isinstance(instruction_template, str)
and "{instruction}" not in instruction_template
):
raise ValueError(
"Instruction template must contain the string '{instruction}'."
)
if instruction_template is None:
logger.warning(
"No instruction template provided. Instructions will be used as-is."
)

self.model_name = model_name
self.model = SentenceTransformer(model_name, revision=revision, **kwargs)
self.instruction_template = instruction_template
self.apply_instruction_to_passages = apply_instruction_to_passages
self.add_eos_token = add_eos_token
if max_seq_length is not None:
self.model.max_seq_length = max_seq_length
if padding_side is not None:
self.model.tokenizer.padding_side = padding_side

def encode(
self,
sentences: Sequence[str],
*,
task_name: str,
prompt_type: PromptType | None = None,
**kwargs: Any,
) -> np.ndarray:
if self.add_eos_token:
sentences = [
example + self.model.tokenizer.eos_token for example in sentences
]

instruction = self.get_task_instruction(task_name, prompt_type)

# to passage prompts won't be applied to passages
if not self.apply_instruction_to_passages and prompt_type == PromptType.passage:
instruction = None
logger.info(
f"No instruction used, because prompt type = {prompt_type.passage}"
)

if instruction:
logger.info(f"Using instruction: '{instruction}' for task: '{task_name}'")

embeddings = self.model.encode(
sentences,
prompt=instruction,
**kwargs,
)

if isinstance(embeddings, torch.Tensor):
# sometimes in kwargs can be return_tensors=True
embeddings = embeddings.cpu().detach().float().numpy()
return embeddings
75 changes: 16 additions & 59 deletions mteb/models/nvidia_models.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,11 @@
from __future__ import annotations

import logging
from collections.abc import Sequence
from functools import partial
from typing import Any

import numpy as np
import torch
from sentence_transformers import CrossEncoder, SentenceTransformer

from mteb.encoder_interface import PromptType
from mteb.model_meta import ModelMeta
from mteb.models.sentence_transformer_wrapper import SentenceTransformerWrapper
from mteb.models.instruct_wrapper import InstructSentenceTransformerWrapper

logger = logging.getLogger(__name__)

Expand All @@ -22,56 +16,6 @@ def instruction_template(
return f"Instruct: {instruction}\nQuery: " if instruction else ""


class NvEmbedWrapper(SentenceTransformerWrapper):
def __init__(
self,
model: str | SentenceTransformer | CrossEncoder,
revision: str | None = None,
model_prompts: dict[str, str] | None = None,
**kwargs,
) -> None:
super().__init__(model, revision, model_prompts, **kwargs)
self.model.max_seq_length = 32768
self.model.tokenizer.padding_side = "right"
logger.warning(
"Instructions are used in both query and docs, which may cause performance discrepancies from the original implementation."
)

def encode(
self,
sentences: Sequence[str],
*,
task_name: str,
prompt_type: PromptType | None = None,
**kwargs: Any,
) -> np.ndarray:
# Add eos token to each input example
sentences = [example + self.model.tokenizer.eos_token for example in sentences]

instruction = ""
if prompt_type == PromptType.query:
instruction = self.get_instruction(task_name, prompt_type)

prompt = instruction_template(instruction)

if prompt:
logger.info(f"Using {prompt=} for task={task_name} {prompt_type=}")
else:
logger.info(f"No model prompts found for task={task_name} {prompt_type=}")

logger.info(f"Encoding {len(sentences)} sentences.")

embeddings = self.model.encode(
sentences,
prompt=prompt,
normalize_embeddings=True,
**kwargs,
)
if isinstance(embeddings, torch.Tensor):
embeddings = embeddings.cpu().detach().float().numpy()
return embeddings


nvidia_training_datasets = {
# source: https://arxiv.org/pdf/2405.17428
"ArguAna": ["train"],
Expand Down Expand Up @@ -120,11 +64,18 @@ def encode(
"STSBenchmark": ["train"],
"STSBenchmarkMultilingualSTS": ["train"], # translated, not trained on
}

NV_embed_v2 = ModelMeta(
loader=partial( # type: ignore
NvEmbedWrapper,
InstructSentenceTransformerWrapper,
model="nvidia/NV-Embed-v2",
revision="7604d305b621f14095a1aa23d351674c2859553a",
instruction_template=instruction_template,
trust_remote_code=True,
max_seq_length=32768,
padding_side="right",
# for nv-embed, we add eos token to each input example
add_eos_token=True,
),
name="nvidia/NV-Embed-v2",
languages=["eng_Latn"],
Expand All @@ -146,9 +97,15 @@ def encode(

NV_embed_v1 = ModelMeta(
loader=partial( # type: ignore
NvEmbedWrapper,
InstructSentenceTransformerWrapper,
model="nvidia/NV-Embed-v1",
revision="7604d305b621f14095a1aa23d351674c2859553a",
instruction_template=instruction_template,
trust_remote_code=True,
max_seq_length=32768,
padding_side="right",
# for nv-embed, we add eos token to each input example
add_eos_token=True,
),
name="nvidia/NV-Embed-v1",
languages=["eng_Latn"],
Expand Down

0 comments on commit ee0f15a

Please sign in to comment.