Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Python pooling #442

Merged
merged 2 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion backends/python/server/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ unit-tests:

gen-server:
# Compile protos
pip install grpcio-tools==1.51.1 mypy-protobuf==3.4.0 'types-protobuf>=3.20.4' --no-cache-dir
pip install grpcio-tools==1.62.2 mypy-protobuf==3.6.0 'types-protobuf' --no-cache-dir
mkdir text_embeddings_server/pb || true
python -m grpc_tools.protoc -I../../proto --python_out=text_embeddings_server/pb \
--grpc_python_out=text_embeddings_server/pb --mypy_out=text_embeddings_server/pb ../../proto/embed.proto
Expand All @@ -15,6 +15,7 @@ gen-server:

install: gen-server
pip install pip --upgrade
pip install torch==2.5.1
pip install -r requirements.txt
pip install -e .

Expand Down
2,219 changes: 1,641 additions & 578 deletions backends/python/server/poetry.lock

Large diffs are not rendered by default.

13 changes: 7 additions & 6 deletions backends/python/server/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,19 @@ python-text-embeddings-server = 'text_embeddings_server.cli:app'

[tool.poetry.dependencies]
python = ">=3.9,<3.13"
protobuf = "^4.21.7"
protobuf = ">=4.25.3,<6"
grpcio = "^1.51.1"
grpcio-status = "^1.51.1"
grpcio-reflection = "^1.51.1"
grpc-interceptor = "^0.15.0"
typer = "^0.6.1"
safetensors = "^0.3.2"
safetensors = "^0.4"
loguru = "^0.6.0"
opentelemetry-api = "^1.15.0"
opentelemetry-exporter-otlp = "^1.15.0"
opentelemetry-instrumentation-grpc = "^0.36b0"
torch = { version = "^2.0.1" }
opentelemetry-api = "^1.25.0"
opentelemetry-exporter-otlp = "^1.25.0"
opentelemetry-instrumentation-grpc = "^0.46b0"
sentence-transformers = "^3.3.1"
torch = "^2.5.1"

[tool.poetry.extras]

Expand Down
99 changes: 62 additions & 37 deletions backends/python/server/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,43 +1,68 @@
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
certifi==2023.7.22 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "3.13"
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.4.0 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.12.3 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2023.9.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.58.0 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.58.0 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.58.0 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.16.4 ; python_version >= "3.9" and python_version < "3.13"
idna==3.4 ; python_version >= "3.9" and python_version < "3.13"
jinja2==3.1.2 ; python_version >= "3.9" and python_version < "3.13"
deprecated==1.2.15 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.16.1 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2024.10.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.66.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.68.0 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.26.2 ; python_version >= "3.9" and python_version < "3.13"
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
jinja2==3.1.4 ; python_version >= "3.9" and python_version < "3.13"
joblib==1.4.2 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
markupsafe==2.1.3 ; python_version >= "3.9" and python_version < "3.13"
markupsafe==3.0.2 ; python_version >= "3.9" and python_version < "3.13"
mpmath==1.3.0 ; python_version >= "3.9" and python_version < "3.13"
networkx==3.1 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
packaging==23.1 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.24.3 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.3.3 ; python_version >= "3.9" and python_version < "3.13"
setuptools==68.2.0 ; python_version >= "3.9" and python_version < "3.13"
sympy==1.12 ; python_version >= "3.9" and python_version < "3.13"
torch==2.0.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
networkx==3.2.1 ; python_version >= "3.9" and python_version < "3.13"
numpy==2.0.2 ; python_version >= "3.9" and python_version < "3.13"
nvidia-cublas-cu12==12.4.5.8 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
nvidia-cuda-cupti-cu12==12.4.127 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
nvidia-cuda-nvrtc-cu12==12.4.127 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
nvidia-cuda-runtime-cu12==12.4.127 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
nvidia-cudnn-cu12==9.1.0.70 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
nvidia-cufft-cu12==11.2.1.3 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
nvidia-curand-cu12==10.3.5.147 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
nvidia-cusolver-cu12==11.6.1.9 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
nvidia-cusparse-cu12==12.3.1.170 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
nvidia-nccl-cu12==2.21.5 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
nvidia-nvjitlink-cu12==12.4.127 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
nvidia-nvtx-cu12==12.4.127 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-grpc==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-http==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation-grpc==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-proto==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-sdk==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
packaging==24.2 ; python_version >= "3.9" and python_version < "3.13"
pillow==11.0.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.5 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.11.6 ; python_version >= "3.9" and python_version < "3.13"
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
scikit-learn==1.5.2 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
sentence-transformers==3.3.1 ; python_version >= "3.9" and python_version < "3.13"
setuptools==75.6.0 ; python_version >= "3.9" and python_version < "3.13"
sympy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
threadpoolctl==3.5.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.20.3 ; python_version >= "3.9" and python_version < "3.13"
torch==2.5.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.67.1 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.46.3 ; python_version >= "3.9" and python_version < "3.13"
triton==3.1.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version < "3.13" and python_version >= "3.9"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.7.1 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.0.4 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
wrapt==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
wrapt==1.17.0 ; python_version >= "3.9" and python_version < "3.13"
zipp==3.21.0 ; python_version >= "3.9" and python_version < "3.13"
3 changes: 2 additions & 1 deletion backends/python/server/text_embeddings_server/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def serve(
json_output: bool = False,
otlp_endpoint: Optional[str] = None,
otlp_service_name: str = "text-embeddings-inference.server",
pool: str = "cls",
):
# Remove default handler
logger.remove()
Expand All @@ -48,7 +49,7 @@ def serve(
# Downgrade enum into str for easier management later on
dtype = None if dtype is None else dtype.value

server.serve(model_path, dtype, uds_path)
server.serve(model_path, dtype, uds_path, pool)


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
__all__.append(FlashBert)


def get_model(model_path: Path, dtype: Optional[str]):
def get_model(model_path: Path, dtype: Optional[str], pool: str):
if dtype == "float32":
dtype = torch.float32
elif dtype == "float16":
Expand All @@ -38,8 +38,6 @@ def get_model(model_path: Path, dtype: Optional[str]):
if torch.cuda.is_available():
device = torch.device("cuda")
else:
if dtype != torch.float32:
raise ValueError("CPU device only supports float32 dtype")
device = torch.device("cpu")

config = AutoConfig.from_pretrained(model_path)
Expand All @@ -52,8 +50,10 @@ def get_model(model_path: Path, dtype: Optional[str]):
and dtype in [torch.float16, torch.bfloat16]
and FLASH_ATTENTION
):
if pool != "cls":
raise ValueError("FlashBert only supports cls pooling")
return FlashBert(model_path, device, dtype)
else:
return DefaultModel(model_path, device, dtype)
return DefaultModel(model_path, device, dtype, pool)

raise NotImplementedError
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Type, List
from transformers import AutoModel
from opentelemetry import trace
from sentence_transformers.models import Pooling

from text_embeddings_server.models import Model
from text_embeddings_server.models.types import PaddedBatch, Embedding
Expand All @@ -13,9 +14,12 @@


class DefaultModel(Model):
def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype):
def __init__(
self, model_path: Path, device: torch.device, dtype: torch.dtype, pool: str
):
model = AutoModel.from_pretrained(model_path).to(dtype).to(device)
self.hidden_size = model.config.hidden_size
self.pooling = Pooling(self.hidden_size, pooling_mode=pool)

self.has_position_ids = (
inspect.signature(model.forward).parameters.get("position_ids", None)
Expand All @@ -41,7 +45,13 @@ def embed(self, batch: PaddedBatch) -> List[Embedding]:
kwargs["position_ids"] = batch.position_ids

output = self.model(**kwargs)
embedding = output[0][:, 0]

pooling_features = {
"token_embeddings": output[0],
"attention_mask": batch.attention_mask,
}
embedding = self.pooling.forward(pooling_features)["sentence_embedding"]

cpu_results = embedding.view(-1).tolist()

return [
Expand Down
4 changes: 2 additions & 2 deletions backends/python/server/text_embeddings_server/server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import asyncio
import torch

from grpc import aio
from loguru import logger

Expand Down Expand Up @@ -37,6 +36,7 @@ def serve(
model_path: Path,
dtype: Optional[str],
uds_path: Path,
pool: str,
):
async def serve_inner(
model_path: Path,
Expand All @@ -45,7 +45,7 @@ async def serve_inner(
unix_socket = f"unix://{uds_path}"

try:
model = get_model(model_path, dtype)
model = get_model(model_path, dtype, pool)
except Exception:
logger.exception("Error when initializing model")
raise
Expand Down
12 changes: 4 additions & 8 deletions backends/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use backend_grpc_client::Client;
use nohash_hasher::BuildNoHashHasher;
use std::collections::HashMap;
use text_embeddings_backend_core::{
Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Pool, Predictions,
Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Predictions,
};
use tokio::runtime::Runtime;

Expand All @@ -24,18 +24,13 @@ impl PythonBackend {
otlp_endpoint: Option<String>,
otlp_service_name: String,
) -> Result<Self, BackendError> {
match model_type {
let pool = match model_type {
ModelType::Classifier => {
return Err(BackendError::Start(
"`classifier` model type is not supported".to_string(),
))
}
ModelType::Embedding(pool) => {
if pool != Pool::Cls {
return Err(BackendError::Start(format!("{pool:?} is not supported")));
}
pool
}
ModelType::Embedding(pool) => pool,
};

let backend_process = management::BackendProcess::new(
Expand All @@ -44,6 +39,7 @@ impl PythonBackend {
&uds_path,
otlp_endpoint,
otlp_service_name,
pool,
)?;
let tokio_runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
Expand Down
14 changes: 13 additions & 1 deletion backends/python/src/management.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use std::sync::mpsc;
use std::thread::sleep;
use std::time::{Duration, Instant};
use std::{env, fs, io, thread};
use text_embeddings_backend_core::BackendError;
use text_embeddings_backend_core::{BackendError, Pool};

#[derive(Debug)]
pub(crate) struct BackendProcess {
Expand All @@ -22,6 +22,7 @@ impl BackendProcess {
uds_path: &str,
otlp_endpoint: Option<String>,
otlp_service_name: String,
pool: Pool,
) -> Result<Self, BackendError> {
// Get UDS path
let uds = Path::new(uds_path);
Expand All @@ -31,6 +32,15 @@ impl BackendProcess {
fs::remove_file(uds).expect("could not remove UDS file");
}

let pool = match pool {
Pool::Cls => "cls",
Pool::Mean => "mean",
Pool::LastToken => "lasttoken",
Pool::Splade => {
return Err(BackendError::Start(format!("{pool:?} is not supported")));
}
};

// Process args
let mut python_server_args = vec![
model_path,
Expand All @@ -41,6 +51,8 @@ impl BackendProcess {
"--logger-level".to_owned(),
"INFO".to_owned(),
"--json-output".to_owned(),
"--pool".to_owned(),
pool.to_owned(),
];

// OpenTelemetry
Expand Down
Loading