Skip to content

Commit

Permalink
Merge branch 'main' into fix/num_cpus
Browse files Browse the repository at this point in the history
  • Loading branch information
OlivierDehaene authored Sep 17, 2024
2 parents 14385eb + fef77b0 commit 0c77456
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 33 deletions.
21 changes: 17 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ ENV HUGGINGFACE_HUB_CACHE=/data \
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
ca-certificates \
libssl-dev \
curl \
&& rm -rf /var/lib/apt/lists/*


Expand Down
6 changes: 6 additions & 0 deletions Dockerfile-cuda
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,12 @@ ENV HUGGINGFACE_HUB_CACHE=/data \
PORT=80 \
USE_FLASH_ATTENTION=$DEFAULT_USE_FLASH_ATTENTION

RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
ca-certificates \
libssl-dev \
curl \
&& rm -rf /var/lib/apt/lists/*

FROM base AS grpc

COPY --from=grpc-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router
Expand Down
6 changes: 6 additions & 0 deletions Dockerfile-cuda-all
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,12 @@ ENV HUGGINGFACE_HUB_CACHE=/data \
PORT=80 \
USE_FLASH_ATTENTION=$DEFAULT_USE_FLASH_ATTENTION

RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
ca-certificates \
libssl-dev \
curl \
&& rm -rf /var/lib/apt/lists/*

COPY --from=builder /usr/src/target/release/text-embeddings-router-75 /usr/local/bin/text-embeddings-router-75
COPY --from=builder /usr/src/target/release/text-embeddings-router-80 /usr/local/bin/text-embeddings-router-80
COPY --from=builder /usr/src/target/release/text-embeddings-router-90 /usr/local/bin/text-embeddings-router-90
Expand Down
2 changes: 1 addition & 1 deletion backends/ort/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ anyhow = { workspace = true }
nohash-hasher = { workspace = true }
ndarray = "0.15.6"
num_cpus = { workspace = true }
ort = { version = "2.0.0-rc.2", default-features = false, features = ["download-binaries", "half", "onednn", "ndarray"] }
ort = { version = "2.0.0-rc.4", default-features = false, features = ["download-binaries", "half", "onednn", "ndarray"] }
text-embeddings-backend-core = { path = "../core" }
tracing = { workspace = true }
thiserror = { workspace = true }
Expand Down
66 changes: 38 additions & 28 deletions router/src/http/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1634,9 +1634,12 @@ pub async fn run(
}
});

let prom_handle = prom_builder
.install_recorder()
.context("failed to install metrics recorder")?;
// See: https://github.com/metrics-rs/metrics/issues/467#issuecomment-2022755151
let (recorder, _) = prom_builder
.build()
.context("failed to build prometheus recorder")?;
let prom_handle = recorder.handle();
metrics::set_global_recorder(recorder).context("Failed to set global recorder")?;

// CORS layer
let allow_origin = allow_origin.unwrap_or(AllowOrigin::any());
Expand Down Expand Up @@ -1666,9 +1669,7 @@ pub async fn run(
ApiDoc::openapi()
};

// Create router
let mut app = Router::new()
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", doc))
let mut routes = Router::new()
// Base routes
.route("/info", get(get_model_info))
.route("/embed", post(embed))
Expand All @@ -1683,74 +1684,72 @@ pub async fn run(
.route("/embeddings", post(openai_embed))
.route("/v1/embeddings", post(openai_embed))
// Vertex compat route
.route("/vertex", post(vertex_compatibility))
.route("/vertex", post(vertex_compatibility));

#[allow(unused_mut)]
let mut public_routes = Router::new()
// Base Health route
.route("/health", get(health))
// Inference API health route
.route("/", get(health))
// AWS Sagemaker health route
.route("/ping", get(health))
// Prometheus metrics route
.route("/metrics", get(metrics))
// Update payload limit
.layer(DefaultBodyLimit::max(payload_limit));
.route("/metrics", get(metrics));

#[cfg(feature = "google")]
{
tracing::info!("Built with `google` feature");

if let Ok(env_predict_route) = std::env::var("AIP_PREDICT_ROUTE") {
tracing::info!("Serving Vertex compatible route on {env_predict_route}");
app = app.route(&env_predict_route, post(vertex_compatibility));
routes = routes.route(&env_predict_route, post(vertex_compatibility));
}

if let Ok(env_health_route) = std::env::var("AIP_HEALTH_ROUTE") {
tracing::info!("Serving Vertex compatible health route on {env_health_route}");
app = app.route(&env_health_route, get(health));
public_routes = public_routes.route(&env_health_route, get(health));
}
}
#[cfg(not(feature = "google"))]
{
// Set default routes
app = match &info.model_type {
routes = match &info.model_type {
ModelType::Classifier(_) => {
app.route("/", post(predict))
routes
.route("/", post(predict))
// AWS Sagemaker route
.route("/invocations", post(predict))
}
ModelType::Reranker(_) => {
app.route("/", post(rerank))
routes
.route("/", post(rerank))
// AWS Sagemaker route
.route("/invocations", post(rerank))
}
ModelType::Embedding(model) => {
if std::env::var("TASK").ok() == Some("sentence-similarity".to_string()) {
app.route("/", post(similarity))
routes
.route("/", post(similarity))
// AWS Sagemaker route
.route("/invocations", post(similarity))
} else if model.pooling == "splade" {
app.route("/", post(embed_sparse))
routes
.route("/", post(embed_sparse))
// AWS Sagemaker route
.route("/invocations", post(embed_sparse))
} else {
app.route("/", post(embed))
routes
.route("/", post(embed))
// AWS Sagemaker route
.route("/invocations", post(embed))
}
}
};
}

app = app
.layer(Extension(infer))
.layer(Extension(info))
.layer(Extension(prom_handle.clone()))
.layer(OtelAxumLayer::default())
.layer(cors_layer);

if let Some(api_key) = api_key {
let mut prefix = "Bearer ".to_string();
prefix.push_str(&api_key);
let prefix = format!("Bearer {}", api_key);

// Leak to allow FnMut
let api_key: &'static str = prefix.leak();
Expand All @@ -1767,9 +1766,20 @@ pub async fn run(
}
};

app = app.layer(axum::middleware::from_fn(auth));
routes = routes.layer(axum::middleware::from_fn(auth));
}

let app = Router::new()
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", doc))
.merge(routes)
.merge(public_routes)
.layer(Extension(infer))
.layer(Extension(info))
.layer(Extension(prom_handle.clone()))
.layer(OtelAxumLayer::default())
.layer(DefaultBodyLimit::max(payload_limit))
.layer(cors_layer);

// Run server
let listener = tokio::net::TcpListener::bind(&addr)
.await
Expand Down
1 change: 1 addition & 0 deletions router/src/http/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,7 @@ pub(crate) struct SimilarityInput {
pub(crate) struct SimilarityParameters {
#[schema(default = "false", example = "false", nullable = true)]
pub truncate: Option<bool>,
#[serde(default)]
#[schema(default = "right", example = "right")]
pub truncation_direction: TruncationDirection,
/// The name of the prompt that should be used by for encoding. If not set, no prompt
Expand Down

0 comments on commit 0c77456

Please sign in to comment.