Skip to content

Commit

Permalink
use the model type to match safely
Browse files Browse the repository at this point in the history
  • Loading branch information
OlivierDehaene committed Mar 21, 2024
1 parent 7441c45 commit c2c984c
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 87 deletions.
137 changes: 67 additions & 70 deletions router/src/http/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::http::types::{
EmbedSparseResponse, Input, OpenAICompatEmbedding, OpenAICompatErrorResponse,
OpenAICompatRequest, OpenAICompatResponse, OpenAICompatUsage, PredictInput, PredictRequest,
PredictResponse, Prediction, Rank, RerankRequest, RerankResponse, Sequence, SimpleToken,
SparseValue, TokenizeRequest, TokenizeResponse, VertexInstance, VertexRequest, VertexResponse,
SparseValue, TokenizeRequest, TokenizeResponse, VertexRequest, VertexResponse,
VertexResponseInstance,
};
use crate::{
Expand Down Expand Up @@ -1180,11 +1180,6 @@ async fn vertex_compatibility(
let result = embed(infer, info, Json(req)).await?;
Ok(VertexResponseInstance::Embed(result.1 .0))
};
let embed_all_future =
move |infer: Extension<Infer>, info: Extension<Info>, req: EmbedAllRequest| async move {
let result = embed_all(infer, info, Json(req)).await?;
Ok(VertexResponseInstance::EmbedAll(result.1 .0))
};
let embed_sparse_future =
move |infer: Extension<Infer>, info: Extension<Info>, req: EmbedSparseRequest| async move {
let result = embed_sparse(infer, info, Json(req)).await?;
Expand All @@ -1200,45 +1195,44 @@ async fn vertex_compatibility(
let result = rerank(infer, info, Json(req)).await?;
Ok(VertexResponseInstance::Rerank(result.1 .0))
};
let tokenize_future =
move |infer: Extension<Infer>, info: Extension<Info>, req: TokenizeRequest| async move {
let result = tokenize(infer, info, Json(req)).await?;
Ok(VertexResponseInstance::Tokenize(result.0))
};

let mut futures = Vec::with_capacity(req.instances.len());
for instance in req.instances {
let local_infer = infer.clone();
let local_info = info.clone();

match instance {
VertexInstance::Embed(req) => {
futures.push(embed_future(local_infer, local_info, req).boxed());
}
VertexInstance::EmbedAll(req) => {
futures.push(embed_all_future(local_infer, local_info, req).boxed());
}
VertexInstance::EmbedSparse(req) => {
futures.push(embed_sparse_future(local_infer, local_info, req).boxed());
}
VertexInstance::Predict(req) => {
futures.push(predict_future(local_infer, local_info, req).boxed());
}
VertexInstance::Rerank(req) => {
futures.push(rerank_future(local_infer, local_info, req).boxed());
// Rerank is the only payload that can me matched safely
if let Ok(instance) = serde_json::from_value::<RerankRequest>(instance.clone()) {
futures.push(rerank_future(local_infer, local_info, instance).boxed());
continue;
}

match info.model_type {
ModelType::Classifier(_) | ModelType::Reranker(_) => {
let instance = serde_json::from_value::<PredictRequest>(instance)
.map_err(ErrorResponse::from)?;
futures.push(predict_future(local_infer, local_info, instance).boxed());
}
VertexInstance::Tokenize(req) => {
futures.push(tokenize_future(local_infer, local_info, req).boxed());
ModelType::Embedding(_) => {
if infer.is_splade() {
let instance = serde_json::from_value::<EmbedSparseRequest>(instance)
.map_err(ErrorResponse::from)?;
futures.push(embed_sparse_future(local_infer, local_info, instance).boxed());
} else {
let instance = serde_json::from_value::<EmbedRequest>(instance)
.map_err(ErrorResponse::from)?;
futures.push(embed_future(local_infer, local_info, instance).boxed());
}
}
}
}

let results = join_all(futures)
let predictions = join_all(futures)
.await
.into_iter()
.collect::<Result<Vec<VertexResponseInstance>, (StatusCode, Json<ErrorResponse>)>>()?;

Ok(Json(VertexResponse(results)))
Ok(Json(VertexResponse { predictions }))
}

/// Prometheus metrics scrape endpoint
Expand Down Expand Up @@ -1350,12 +1344,7 @@ pub async fn run(
#[derive(OpenApi)]
#[openapi(
paths(vertex_compatibility),
components(schemas(
VertexInstance,
VertexRequest,
VertexResponse,
VertexResponseInstance
))
components(schemas(VertexRequest, VertexResponse, VertexResponseInstance))
)]
struct VertextApiDoc;

Expand Down Expand Up @@ -1394,43 +1383,42 @@ pub async fn run(

let mut app = Router::new().merge(base_routes);

// Set default routes
app = match &info.model_type {
ModelType::Classifier(_) => {
app.route("/", post(predict))
// AWS Sagemaker route
.route("/invocations", post(predict))
}
ModelType::Reranker(_) => {
app.route("/", post(rerank))
// AWS Sagemaker route
.route("/invocations", post(rerank))
}
ModelType::Embedding(model) => {
if model.pooling == "splade" {
app.route("/", post(embed_sparse))
// AWS Sagemaker route
.route("/invocations", post(embed_sparse))
} else {
app.route("/", post(embed))
// AWS Sagemaker route
.route("/invocations", post(embed))
}
}
};

#[cfg(feature = "google")]
{
tracing::info!("Built with `google` feature");
tracing::info!(
"Environment variables `AIP_PREDICT_ROUTE` and `AIP_HEALTH_ROUTE` will be respected."
);
if let Ok(env_predict_route) = std::env::var("AIP_PREDICT_ROUTE") {
app = app.route(&env_predict_route, post(vertex_compatibility));
}
if let Ok(env_health_route) = std::env::var("AIP_HEALTH_ROUTE") {
app = app.route(&env_health_route, get(health));
}
let env_predict_route = std::env::var("AIP_PREDICT_ROUTE")
.context("`AIP_PREDICT_ROUTE` env var must be set for Google Vertex deployments")?;
app = app.route(&env_predict_route, post(vertex_compatibility));
let env_health_route = std::env::var("AIP_HEALTH_ROUTE")
.context("`AIP_HEALTH_ROUTE` env var must be set for Google Vertex deployments")?;
app = app.route(&env_health_route, get(health));
}
#[cfg(not(feature = "google"))]
{
// Set default routes
app = match &info.model_type {
ModelType::Classifier(_) => {
app.route("/", post(predict))
// AWS Sagemaker route
.route("/invocations", post(predict))
}
ModelType::Reranker(_) => {
app.route("/", post(rerank))
// AWS Sagemaker route
.route("/invocations", post(rerank))
}
ModelType::Embedding(model) => {
if model.pooling == "splade" {
app.route("/", post(embed_sparse))
// AWS Sagemaker route
.route("/invocations", post(embed_sparse))
} else {
app.route("/", post(embed))
// AWS Sagemaker route
.route("/invocations", post(embed))
}
}
};
}

let app = app
Expand Down Expand Up @@ -1485,3 +1473,12 @@ impl From<ErrorResponse> for (StatusCode, Json<OpenAICompatErrorResponse>) {
(StatusCode::from(&err.error_type), Json(err.into()))
}
}

impl From<serde_json::Error> for ErrorResponse {
fn from(err: serde_json::Error) -> Self {
ErrorResponse {
error: err.to_string(),
error_type: ErrorType::Validation,
}
}
}
21 changes: 5 additions & 16 deletions router/src/http/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -382,32 +382,21 @@ pub(crate) struct SimpleToken {
#[schema(example = json!([[{"id": 0, "text": "test", "special": false, "start": 0, "stop": 2}]]))]
pub(crate) struct TokenizeResponse(pub Vec<Vec<SimpleToken>>);

#[derive(Deserialize, ToSchema)]
#[serde(tag = "type", rename_all = "snake_case")]
pub(crate) enum VertexInstance {
Embed(EmbedRequest),
EmbedAll(EmbedAllRequest),
EmbedSparse(EmbedSparseRequest),
Predict(PredictRequest),
Rerank(RerankRequest),
Tokenize(TokenizeRequest),
}

#[derive(Deserialize, ToSchema)]
pub(crate) struct VertexRequest {
pub instances: Vec<VertexInstance>,
pub instances: Vec<serde_json::Value>,
}

#[derive(Serialize, ToSchema)]
#[serde(tag = "type", content = "result", rename_all = "snake_case")]
#[serde(untagged)]
pub(crate) enum VertexResponseInstance {
Embed(EmbedResponse),
EmbedAll(EmbedAllResponse),
EmbedSparse(EmbedSparseResponse),
Predict(PredictResponse),
Rerank(RerankResponse),
Tokenize(TokenizeResponse),
}

#[derive(Serialize, ToSchema)]
pub(crate) struct VertexResponse(pub Vec<VertexResponseInstance>);
pub(crate) struct VertexResponse {
pub predictions: Vec<VertexResponseInstance>,
}
5 changes: 4 additions & 1 deletion router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ pub async fn run(
std::env::var("AIP_HTTP_PORT")
.ok()
.and_then(|p| p.parse().ok())
.context("Invalid or unset AIP_HTTP_PORT")?
.context("`AIP_HTTP_PORT` env var must be set for Google Vertex deployments")?
} else {
port
};
Expand All @@ -262,6 +262,9 @@ pub async fn run(
#[cfg(all(feature = "grpc", feature = "http"))]
compile_error!("Features `http` and `grpc` cannot be enabled at the same time.");

#[cfg(all(feature = "grpc", feature = "google"))]
compile_error!("Features `http` and `google` cannot be enabled at the same time.");

#[cfg(not(any(feature = "http", feature = "grpc")))]
compile_error!("Either feature `http` or `grpc` must be enabled.");

Expand Down

0 comments on commit c2c984c

Please sign in to comment.