From 986a7a54ce24650d6d92934c772f816c13435797 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Thu, 21 Mar 2024 17:06:09 +0100 Subject: [PATCH] feat: add `/decode` route --- core/src/infer.rs | 16 ++++++ core/src/tokenization.rs | 57 ++++++++++++++++++++ docs/openapi.json | 96 ++++++++++++++++++++++++++++++++- proto/tei.proto | 11 ++++ router/src/grpc/server.rs | 108 +++++++++++++++++++++++++++++++++++++- router/src/http/server.rs | 86 +++++++++++++++++++++++++++--- router/src/http/types.rs | 23 ++++++++ 7 files changed, 388 insertions(+), 9 deletions(-) diff --git a/core/src/infer.rs b/core/src/infer.rs index e2cef5d5..54f755d9 100644 --- a/core/src/infer.rs +++ b/core/src/infer.rs @@ -74,6 +74,22 @@ impl Infer { }) } + #[instrument(skip(self))] + pub async fn decode( + &self, + ids: Vec, + skip_special_tokens: bool, + ) -> Result { + self.tokenization + .decode(ids, skip_special_tokens) + .await + .map_err(|err| { + metrics::increment_counter!("te_request_failure", "err" => "tokenization"); + tracing::error!("{err}"); + err + }) + } + #[instrument(skip(self))] pub fn try_acquire_permit(&self) -> Result { // Limit concurrent requests by acquiring a permit from the semaphore diff --git a/core/src/tokenization.rs b/core/src/tokenization.rs index 787a9e81..42073b32 100644 --- a/core/src/tokenization.rs +++ b/core/src/tokenization.rs @@ -120,6 +120,37 @@ impl Tokenization { // Unwrap is safe here response_receiver.await.expect("Tokenization background task dropped the sender without sending a response. This is a bug.") } + + #[instrument(skip_all)] + pub async fn decode( + &self, + ids: Vec, + skip_special_tokens: bool, + ) -> Result { + // Check if inputs is empty + if ids.is_empty() { + return Err(TextEmbeddingsError::Validation( + "`input_ids` cannot be empty".to_string(), + )); + } + + // Create response channel + let (response_sender, response_receiver) = oneshot::channel(); + // Send request to the background validation task + // Unwrap is safe here + self.sender + .send(TokenizerRequest::Decode( + ids, + skip_special_tokens, + response_sender, + Span::current(), + )) + .expect("Tokenization background task dropped the receiver. This is a bug."); + + // Await on response channel + // Unwrap is safe here + response_receiver.await.expect("Tokenization background task dropped the sender without sending a response. This is a bug.") + } } /// Start tokenization workers @@ -161,10 +192,30 @@ fn tokenizer_worker( } }) } + TokenizerRequest::Decode(ids, skip_special_tokens, response_tx, parent_span) => { + parent_span.in_scope(|| { + if !response_tx.is_closed() { + // It's possible that the user dropped its request resulting in a send error. + // We just discard the error + let _ = + response_tx.send(decode_ids(ids, skip_special_tokens, &mut tokenizer)); + } + }) + } } } } +fn decode_ids( + ids: Vec, + skip_special_tokens: bool, + tokenizer: &mut Tokenizer, +) -> Result { + Ok(tokenizer + .with_truncation(None)? + .decode(&ids, skip_special_tokens)?) +} + fn tokenize_input( inputs: EncodingInput, add_special_tokens: bool, @@ -263,4 +314,10 @@ enum TokenizerRequest { oneshot::Sender>, Span, ), + Decode( + Vec, + bool, + oneshot::Sender>, + Span, + ), } diff --git a/docs/openapi.json b/docs/openapi.json index c9377087..35bfc4c1 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -12,6 +12,52 @@ "version": "1.1.0" }, "paths": { + "/decode": { + "post": { + "tags": [ + "Text Embeddings Inference" + ], + "summary": "Decode input ids", + "description": "Decode input ids", + "operationId": "decode", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/DecodeRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Decoded ids", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/DecodeResponse" + } + } + } + }, + "422": { + "description": "Tokenization error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "message": "Tokenization error", + "type": "tokenizer" + } + } + } + } + } + } + }, "/embed": { "post": { "tags": [ @@ -647,7 +693,7 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/OpenAICompatErrorResponse" + "$ref": "#/components/schemas/ErrorResponse" }, "example": { "message": "Tokenization error", @@ -690,6 +736,31 @@ } } }, + "DecodeRequest": { + "type": "object", + "required": [ + "ids" + ], + "properties": { + "ids": { + "$ref": "#/components/schemas/InputIds" + }, + "skip_special_tokens": { + "type": "boolean", + "default": "true", + "example": "true" + } + } + }, + "DecodeResponse": { + "type": "array", + "items": { + "type": "string" + }, + "example": [ + "test" + ] + }, "EmbedAllRequest": { "type": "object", "required": [ @@ -922,6 +993,29 @@ } ] }, + "InputIds": { + "oneOf": [ + { + "type": "array", + "items": { + "type": "integer", + "format": "int32", + "minimum": 0 + } + }, + { + "type": "array", + "items": { + "type": "array", + "items": { + "type": "integer", + "format": "int32", + "minimum": 0 + } + } + } + ] + }, "ModelType": { "oneOf": [ { diff --git a/proto/tei.proto b/proto/tei.proto index 87205129..51ad0002 100644 --- a/proto/tei.proto +++ b/proto/tei.proto @@ -30,6 +30,8 @@ service Rerank { service Tokenize { rpc Tokenize (EncodeRequest) returns (EncodeResponse); rpc TokenizeStream (stream EncodeRequest) returns (stream EncodeResponse); + rpc Decode (DecodeRequest) returns (DecodeResponse); + rpc DecodeStream (stream DecodeRequest) returns (stream DecodeResponse); } message InfoRequest {} @@ -166,3 +168,12 @@ message SimpleToken { message EncodeResponse { repeated SimpleToken tokens = 1; } + +message DecodeRequest { + repeated uint32 ids = 1; + bool skip_special_tokens = 2; +} + +message DecodeResponse { + string text = 1; +} diff --git a/router/src/grpc/server.rs b/router/src/grpc/server.rs index eefe58ce..1fa16bce 100644 --- a/router/src/grpc/server.rs +++ b/router/src/grpc/server.rs @@ -3,8 +3,8 @@ use crate::grpc::pb::tei::v1::{ EncodeResponse, RerankStreamRequest, SimpleToken, SparseValue, TokenEmbedding, }; use crate::grpc::{ - EmbedRequest, EmbedResponse, InfoRequest, InfoResponse, PredictRequest, PredictResponse, - Prediction, Rank, RerankRequest, RerankResponse, + DecodeRequest, DecodeResponse, EmbedRequest, EmbedResponse, InfoRequest, InfoResponse, + PredictRequest, PredictResponse, Prediction, Rank, RerankRequest, RerankResponse, }; use crate::ResponseMetadata; use crate::{grpc, shutdown, ErrorResponse, ErrorType, Info, ModelType}; @@ -331,6 +331,17 @@ impl TextEmbeddingsService { .collect(); Ok(EncodeResponse { tokens }) } + + #[instrument(skip_all)] + async fn decode_inner(&self, request: DecodeRequest) -> Result { + let ids = request.ids; + let text = self + .infer + .decode(ids, request.skip_special_tokens) + .await + .map_err(ErrorResponse::from)?; + Ok(DecodeResponse { text }) + } } #[tonic::async_trait] @@ -1327,6 +1338,99 @@ impl grpc::tokenize_server::Tokenize for TextEmbeddingsService { response_receiver, ))) } + + async fn decode( + &self, + request: Request, + ) -> Result, Status> { + let request = request.into_inner(); + let tokens = self.decode_inner(request).await?; + Ok(Response::new(tokens)) + } + + type DecodeStreamStream = UnboundedReceiverStream>; + + async fn decode_stream( + &self, + request: Request>, + ) -> Result, Status> { + let mut request_stream = request.into_inner(); + + // Create bounded channel to have an upper bound of spawned tasks + // We will have at most `max_parallel_stream_requests` messages from this stream in the queue + let (encode_sender, mut encode_receiver) = mpsc::channel::<( + DecodeRequest, + oneshot::Sender>, + )>(self.max_parallel_stream_requests); + + // Required for the async move below + let local = self.clone(); + + // Background task that uses the bounded channel + tokio::spawn(async move { + while let Some((request, mut sender)) = encode_receiver.recv().await { + // Required for the async move below + let task_local = local.clone(); + + // Create async task for this specific input + tokio::spawn(async move { + // Select on closed to cancel work if the stream was closed + tokio::select! { + response = task_local.decode_inner(request) => { + let _ = sender.send(response); + } + _ = sender.closed() => {} + } + }); + } + }); + + // Intermediate channels + // Required to keep the order of the requests + let (intermediate_sender, mut intermediate_receiver) = mpsc::unbounded_channel(); + + tokio::spawn(async move { + // Iterate on input + while let Some(request) = request_stream.next().await { + // Create return channel + let (result_sender, result_receiver) = oneshot::channel(); + // Push to intermediate channel and preserve ordering + intermediate_sender + .send(result_receiver) + .expect("`intermediate_receiver` was dropped. This is a bug."); + + match request { + Ok(request) => encode_sender + .send((request, result_sender)) + .await + .expect("`encode_receiver` was dropped. This is a bug."), + Err(status) => { + // Request is malformed + let _ = result_sender.send(Err(status)); + } + }; + } + }); + + // Final channel for the outputs + let (response_sender, response_receiver) = mpsc::unbounded_channel(); + + tokio::spawn(async move { + while let Some(result_receiver) = intermediate_receiver.recv().await { + // Select on closed to cancel work if the stream was closed + tokio::select! { + response = result_receiver => { + let _ = response_sender.send(response.expect("`result_sender` was dropped. This is a bug.")); + } + _ = response_sender.closed() => {} + } + } + }); + + Ok(Response::new(UnboundedReceiverStream::new( + response_receiver, + ))) + } } pub async fn run( diff --git a/router/src/http/server.rs b/router/src/http/server.rs index 7da0e6a6..68512153 100644 --- a/router/src/http/server.rs +++ b/router/src/http/server.rs @@ -1,10 +1,10 @@ /// HTTP Server logic use crate::http::types::{ - EmbedAllRequest, EmbedAllResponse, EmbedRequest, EmbedResponse, EmbedSparseRequest, - EmbedSparseResponse, Input, OpenAICompatEmbedding, OpenAICompatErrorResponse, - OpenAICompatRequest, OpenAICompatResponse, OpenAICompatUsage, PredictInput, PredictRequest, - PredictResponse, Prediction, Rank, RerankRequest, RerankResponse, Sequence, SimpleToken, - SparseValue, TokenizeRequest, TokenizeResponse, VertexRequest, + DecodeRequest, DecodeResponse, EmbedAllRequest, EmbedAllResponse, EmbedRequest, EmbedResponse, + EmbedSparseRequest, EmbedSparseResponse, Input, InputIds, OpenAICompatEmbedding, + OpenAICompatErrorResponse, OpenAICompatRequest, OpenAICompatResponse, OpenAICompatUsage, + PredictInput, PredictRequest, PredictResponse, Prediction, Rank, RerankRequest, RerankResponse, + Sequence, SimpleToken, SparseValue, TokenizeRequest, TokenizeResponse, VertexRequest, }; use crate::{ shutdown, ClassifierModel, EmbeddingModel, ErrorResponse, ErrorType, Info, ModelType, @@ -1056,7 +1056,7 @@ path = "/tokenize", request_body = TokenizeRequest, responses( (status = 200, description = "Tokenized ids", body = TokenizeResponse), -(status = 422, description = "Tokenization error", body = OpenAICompatErrorResponse, +(status = 422, description = "Tokenization error", body = ErrorResponse, example = json ! ({"message": "Tokenization error", "type": "tokenizer"})), ) )] @@ -1150,6 +1150,75 @@ async fn tokenize( Ok(Json(TokenizeResponse(tokens))) } +/// Decode input ids +#[utoipa::path( +post, +tag = "Text Embeddings Inference", +path = "/decode", +request_body = DecodeRequest, +responses( +(status = 200, description = "Decoded ids", body = DecodeResponse), +(status = 422, description = "Tokenization error", body = ErrorResponse, +example = json ! ({"message": "Tokenization error", "type": "tokenizer"})), +) +)] +#[instrument(skip_all)] +async fn decode( + infer: Extension, + info: Extension, + Json(req): Json, +) -> Result, (StatusCode, Json)> { + let decode_inner = move |ids: Vec, skip_special_tokens: bool, infer: Infer| async move { + let text = infer + .decode(ids, skip_special_tokens) + .await + .map_err(ErrorResponse::from)?; + Ok::(text) + }; + + let texts = match req.ids { + InputIds::Single(ids) => vec![decode_inner(ids, req.skip_special_tokens, infer.0).await?], + InputIds::Batch(ids) => { + if ids.is_empty() { + let message = "`ids` cannot be empty".to_string(); + tracing::error!("{message}"); + let err = ErrorResponse { + error: message, + error_type: ErrorType::Validation, + }; + metrics::increment_counter!("te_request_failure", "err" => "validation"); + Err(err)?; + } + + let batch_size = ids.len(); + if batch_size > info.max_client_batch_size { + let message = format!( + "batch size {batch_size} > maximum allowed batch size {}", + info.max_client_batch_size + ); + tracing::error!("{message}"); + let err = ErrorResponse { + error: message, + error_type: ErrorType::Validation, + }; + metrics::increment_counter!("te_request_failure", "err" => "batch_size"); + Err(err)?; + } + + let mut futures = Vec::with_capacity(batch_size); + for ids in ids { + futures.push(decode_inner(ids, req.skip_special_tokens, infer.0.clone())); + } + + join_all(futures) + .await + .into_iter() + .collect::, ErrorResponse>>()? + } + }; + Ok(Json(DecodeResponse(texts))) +} + /// Generate embeddings from a Vertex request #[utoipa::path( post, @@ -1278,6 +1347,7 @@ pub async fn run( embed_sparse, openai_embed, tokenize, + decode, metrics, ), components( @@ -1310,6 +1380,9 @@ pub async fn run( TokenizeRequest, TokenizeResponse, SimpleToken, + InputIds, + DecodeRequest, + DecodeResponse, ErrorType, ) ), @@ -1382,6 +1455,7 @@ pub async fn run( .route("/predict", post(predict)) .route("/rerank", post(rerank)) .route("/tokenize", post(tokenize)) + .route("/decode", post(decode)) // OpenAI compat route .route("/embeddings", post(openai_embed)) // Vertex compat route diff --git a/router/src/http/types.rs b/router/src/http/types.rs index a638fb29..aa3f3751 100644 --- a/router/src/http/types.rs +++ b/router/src/http/types.rs @@ -382,6 +382,29 @@ pub(crate) struct SimpleToken { #[schema(example = json!([[{"id": 0, "text": "test", "special": false, "start": 0, "stop": 2}]]))] pub(crate) struct TokenizeResponse(pub Vec>); +#[derive(Deserialize, ToSchema)] +#[serde(untagged)] +pub(crate) enum InputIds { + Single(Vec), + Batch(Vec>), +} + +#[derive(Deserialize, ToSchema)] +pub(crate) struct DecodeRequest { + pub ids: InputIds, + #[serde(default = "default_skip_special_tokens")] + #[schema(default = "true", example = "true")] + pub skip_special_tokens: bool, +} + +fn default_skip_special_tokens() -> bool { + true +} + +#[derive(Serialize, ToSchema)] +#[schema(example = json!(["test"]))] +pub(crate) struct DecodeResponse(pub Vec); + #[derive(Clone, Deserialize, ToSchema)] pub(crate) struct VertexInstance { #[schema(example = "What is Deep Learning?")]