From c2c984cfa92084cb3e0d3c172b19ec20938bbce2 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Thu, 21 Mar 2024 11:08:26 +0100 Subject: [PATCH] use the model type to match safely --- router/src/http/server.rs | 137 +++++++++++++++++++------------------- router/src/http/types.rs | 21 ++---- router/src/lib.rs | 5 +- 3 files changed, 76 insertions(+), 87 deletions(-) diff --git a/router/src/http/server.rs b/router/src/http/server.rs index cc8943fa..d09fce37 100644 --- a/router/src/http/server.rs +++ b/router/src/http/server.rs @@ -4,7 +4,7 @@ use crate::http::types::{ EmbedSparseResponse, Input, OpenAICompatEmbedding, OpenAICompatErrorResponse, OpenAICompatRequest, OpenAICompatResponse, OpenAICompatUsage, PredictInput, PredictRequest, PredictResponse, Prediction, Rank, RerankRequest, RerankResponse, Sequence, SimpleToken, - SparseValue, TokenizeRequest, TokenizeResponse, VertexInstance, VertexRequest, VertexResponse, + SparseValue, TokenizeRequest, TokenizeResponse, VertexRequest, VertexResponse, VertexResponseInstance, }; use crate::{ @@ -1180,11 +1180,6 @@ async fn vertex_compatibility( let result = embed(infer, info, Json(req)).await?; Ok(VertexResponseInstance::Embed(result.1 .0)) }; - let embed_all_future = - move |infer: Extension, info: Extension, req: EmbedAllRequest| async move { - let result = embed_all(infer, info, Json(req)).await?; - Ok(VertexResponseInstance::EmbedAll(result.1 .0)) - }; let embed_sparse_future = move |infer: Extension, info: Extension, req: EmbedSparseRequest| async move { let result = embed_sparse(infer, info, Json(req)).await?; @@ -1200,45 +1195,44 @@ async fn vertex_compatibility( let result = rerank(infer, info, Json(req)).await?; Ok(VertexResponseInstance::Rerank(result.1 .0)) }; - let tokenize_future = - move |infer: Extension, info: Extension, req: TokenizeRequest| async move { - let result = tokenize(infer, info, Json(req)).await?; - Ok(VertexResponseInstance::Tokenize(result.0)) - }; let mut futures = Vec::with_capacity(req.instances.len()); for instance in req.instances { let local_infer = infer.clone(); let local_info = info.clone(); - match instance { - VertexInstance::Embed(req) => { - futures.push(embed_future(local_infer, local_info, req).boxed()); - } - VertexInstance::EmbedAll(req) => { - futures.push(embed_all_future(local_infer, local_info, req).boxed()); - } - VertexInstance::EmbedSparse(req) => { - futures.push(embed_sparse_future(local_infer, local_info, req).boxed()); - } - VertexInstance::Predict(req) => { - futures.push(predict_future(local_infer, local_info, req).boxed()); - } - VertexInstance::Rerank(req) => { - futures.push(rerank_future(local_infer, local_info, req).boxed()); + // Rerank is the only payload that can me matched safely + if let Ok(instance) = serde_json::from_value::(instance.clone()) { + futures.push(rerank_future(local_infer, local_info, instance).boxed()); + continue; + } + + match info.model_type { + ModelType::Classifier(_) | ModelType::Reranker(_) => { + let instance = serde_json::from_value::(instance) + .map_err(ErrorResponse::from)?; + futures.push(predict_future(local_infer, local_info, instance).boxed()); } - VertexInstance::Tokenize(req) => { - futures.push(tokenize_future(local_infer, local_info, req).boxed()); + ModelType::Embedding(_) => { + if infer.is_splade() { + let instance = serde_json::from_value::(instance) + .map_err(ErrorResponse::from)?; + futures.push(embed_sparse_future(local_infer, local_info, instance).boxed()); + } else { + let instance = serde_json::from_value::(instance) + .map_err(ErrorResponse::from)?; + futures.push(embed_future(local_infer, local_info, instance).boxed()); + } } } } - let results = join_all(futures) + let predictions = join_all(futures) .await .into_iter() .collect::, (StatusCode, Json)>>()?; - Ok(Json(VertexResponse(results))) + Ok(Json(VertexResponse { predictions })) } /// Prometheus metrics scrape endpoint @@ -1350,12 +1344,7 @@ pub async fn run( #[derive(OpenApi)] #[openapi( paths(vertex_compatibility), - components(schemas( - VertexInstance, - VertexRequest, - VertexResponse, - VertexResponseInstance - )) + components(schemas(VertexRequest, VertexResponse, VertexResponseInstance)) )] struct VertextApiDoc; @@ -1394,43 +1383,42 @@ pub async fn run( let mut app = Router::new().merge(base_routes); - // Set default routes - app = match &info.model_type { - ModelType::Classifier(_) => { - app.route("/", post(predict)) - // AWS Sagemaker route - .route("/invocations", post(predict)) - } - ModelType::Reranker(_) => { - app.route("/", post(rerank)) - // AWS Sagemaker route - .route("/invocations", post(rerank)) - } - ModelType::Embedding(model) => { - if model.pooling == "splade" { - app.route("/", post(embed_sparse)) - // AWS Sagemaker route - .route("/invocations", post(embed_sparse)) - } else { - app.route("/", post(embed)) - // AWS Sagemaker route - .route("/invocations", post(embed)) - } - } - }; - #[cfg(feature = "google")] { tracing::info!("Built with `google` feature"); - tracing::info!( - "Environment variables `AIP_PREDICT_ROUTE` and `AIP_HEALTH_ROUTE` will be respected." - ); - if let Ok(env_predict_route) = std::env::var("AIP_PREDICT_ROUTE") { - app = app.route(&env_predict_route, post(vertex_compatibility)); - } - if let Ok(env_health_route) = std::env::var("AIP_HEALTH_ROUTE") { - app = app.route(&env_health_route, get(health)); - } + let env_predict_route = std::env::var("AIP_PREDICT_ROUTE") + .context("`AIP_PREDICT_ROUTE` env var must be set for Google Vertex deployments")?; + app = app.route(&env_predict_route, post(vertex_compatibility)); + let env_health_route = std::env::var("AIP_HEALTH_ROUTE") + .context("`AIP_HEALTH_ROUTE` env var must be set for Google Vertex deployments")?; + app = app.route(&env_health_route, get(health)); + } + #[cfg(not(feature = "google"))] + { + // Set default routes + app = match &info.model_type { + ModelType::Classifier(_) => { + app.route("/", post(predict)) + // AWS Sagemaker route + .route("/invocations", post(predict)) + } + ModelType::Reranker(_) => { + app.route("/", post(rerank)) + // AWS Sagemaker route + .route("/invocations", post(rerank)) + } + ModelType::Embedding(model) => { + if model.pooling == "splade" { + app.route("/", post(embed_sparse)) + // AWS Sagemaker route + .route("/invocations", post(embed_sparse)) + } else { + app.route("/", post(embed)) + // AWS Sagemaker route + .route("/invocations", post(embed)) + } + } + }; } let app = app @@ -1485,3 +1473,12 @@ impl From for (StatusCode, Json) { (StatusCode::from(&err.error_type), Json(err.into())) } } + +impl From for ErrorResponse { + fn from(err: serde_json::Error) -> Self { + ErrorResponse { + error: err.to_string(), + error_type: ErrorType::Validation, + } + } +} diff --git a/router/src/http/types.rs b/router/src/http/types.rs index fbaa7be4..30c9e728 100644 --- a/router/src/http/types.rs +++ b/router/src/http/types.rs @@ -382,32 +382,21 @@ pub(crate) struct SimpleToken { #[schema(example = json!([[{"id": 0, "text": "test", "special": false, "start": 0, "stop": 2}]]))] pub(crate) struct TokenizeResponse(pub Vec>); -#[derive(Deserialize, ToSchema)] -#[serde(tag = "type", rename_all = "snake_case")] -pub(crate) enum VertexInstance { - Embed(EmbedRequest), - EmbedAll(EmbedAllRequest), - EmbedSparse(EmbedSparseRequest), - Predict(PredictRequest), - Rerank(RerankRequest), - Tokenize(TokenizeRequest), -} - #[derive(Deserialize, ToSchema)] pub(crate) struct VertexRequest { - pub instances: Vec, + pub instances: Vec, } #[derive(Serialize, ToSchema)] -#[serde(tag = "type", content = "result", rename_all = "snake_case")] +#[serde(untagged)] pub(crate) enum VertexResponseInstance { Embed(EmbedResponse), - EmbedAll(EmbedAllResponse), EmbedSparse(EmbedSparseResponse), Predict(PredictResponse), Rerank(RerankResponse), - Tokenize(TokenizeResponse), } #[derive(Serialize, ToSchema)] -pub(crate) struct VertexResponse(pub Vec); +pub(crate) struct VertexResponse { + pub predictions: Vec, +} diff --git a/router/src/lib.rs b/router/src/lib.rs index 8e369184..9b69559a 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -244,7 +244,7 @@ pub async fn run( std::env::var("AIP_HTTP_PORT") .ok() .and_then(|p| p.parse().ok()) - .context("Invalid or unset AIP_HTTP_PORT")? + .context("`AIP_HTTP_PORT` env var must be set for Google Vertex deployments")? } else { port }; @@ -262,6 +262,9 @@ pub async fn run( #[cfg(all(feature = "grpc", feature = "http"))] compile_error!("Features `http` and `grpc` cannot be enabled at the same time."); + #[cfg(all(feature = "grpc", feature = "google"))] + compile_error!("Features `http` and `google` cannot be enabled at the same time."); + #[cfg(not(any(feature = "http", feature = "grpc")))] compile_error!("Either feature `http` or `grpc` must be enabled.");