Skip to content

Commit

Permalink
feat: add /decode route
Browse files Browse the repository at this point in the history
  • Loading branch information
OlivierDehaene committed Mar 21, 2024
1 parent 1d6f288 commit 986a7a5
Show file tree
Hide file tree
Showing 7 changed files with 388 additions and 9 deletions.
16 changes: 16 additions & 0 deletions core/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,22 @@ impl Infer {
})
}

#[instrument(skip(self))]
pub async fn decode(
&self,
ids: Vec<u32>,
skip_special_tokens: bool,
) -> Result<String, TextEmbeddingsError> {
self.tokenization
.decode(ids, skip_special_tokens)
.await
.map_err(|err| {
metrics::increment_counter!("te_request_failure", "err" => "tokenization");
tracing::error!("{err}");
err
})
}

#[instrument(skip(self))]
pub fn try_acquire_permit(&self) -> Result<OwnedSemaphorePermit, TextEmbeddingsError> {
// Limit concurrent requests by acquiring a permit from the semaphore
Expand Down
57 changes: 57 additions & 0 deletions core/src/tokenization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,37 @@ impl Tokenization {
// Unwrap is safe here
response_receiver.await.expect("Tokenization background task dropped the sender without sending a response. This is a bug.")
}

#[instrument(skip_all)]
pub async fn decode(
&self,
ids: Vec<u32>,
skip_special_tokens: bool,
) -> Result<String, TextEmbeddingsError> {
// Check if inputs is empty
if ids.is_empty() {
return Err(TextEmbeddingsError::Validation(
"`input_ids` cannot be empty".to_string(),
));
}

// Create response channel
let (response_sender, response_receiver) = oneshot::channel();
// Send request to the background validation task
// Unwrap is safe here
self.sender
.send(TokenizerRequest::Decode(
ids,
skip_special_tokens,
response_sender,
Span::current(),
))
.expect("Tokenization background task dropped the receiver. This is a bug.");

// Await on response channel
// Unwrap is safe here
response_receiver.await.expect("Tokenization background task dropped the sender without sending a response. This is a bug.")
}
}

/// Start tokenization workers
Expand Down Expand Up @@ -161,10 +192,30 @@ fn tokenizer_worker(
}
})
}
TokenizerRequest::Decode(ids, skip_special_tokens, response_tx, parent_span) => {
parent_span.in_scope(|| {
if !response_tx.is_closed() {
// It's possible that the user dropped its request resulting in a send error.
// We just discard the error
let _ =
response_tx.send(decode_ids(ids, skip_special_tokens, &mut tokenizer));
}
})
}
}
}
}

fn decode_ids(
ids: Vec<u32>,
skip_special_tokens: bool,
tokenizer: &mut Tokenizer,
) -> Result<String, TextEmbeddingsError> {
Ok(tokenizer
.with_truncation(None)?
.decode(&ids, skip_special_tokens)?)
}

fn tokenize_input(
inputs: EncodingInput,
add_special_tokens: bool,
Expand Down Expand Up @@ -263,4 +314,10 @@ enum TokenizerRequest {
oneshot::Sender<Result<RawEncoding, TextEmbeddingsError>>,
Span,
),
Decode(
Vec<u32>,
bool,
oneshot::Sender<Result<String, TextEmbeddingsError>>,
Span,
),
}
96 changes: 95 additions & 1 deletion docs/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,52 @@
"version": "1.1.0"
},
"paths": {
"/decode": {
"post": {
"tags": [
"Text Embeddings Inference"
],
"summary": "Decode input ids",
"description": "Decode input ids",
"operationId": "decode",
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/DecodeRequest"
}
}
},
"required": true
},
"responses": {
"200": {
"description": "Decoded ids",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/DecodeResponse"
}
}
}
},
"422": {
"description": "Tokenization error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"message": "Tokenization error",
"type": "tokenizer"
}
}
}
}
}
}
},
"/embed": {
"post": {
"tags": [
Expand Down Expand Up @@ -647,7 +693,7 @@
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/OpenAICompatErrorResponse"
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"message": "Tokenization error",
Expand Down Expand Up @@ -690,6 +736,31 @@
}
}
},
"DecodeRequest": {
"type": "object",
"required": [
"ids"
],
"properties": {
"ids": {
"$ref": "#/components/schemas/InputIds"
},
"skip_special_tokens": {
"type": "boolean",
"default": "true",
"example": "true"
}
}
},
"DecodeResponse": {
"type": "array",
"items": {
"type": "string"
},
"example": [
"test"
]
},
"EmbedAllRequest": {
"type": "object",
"required": [
Expand Down Expand Up @@ -922,6 +993,29 @@
}
]
},
"InputIds": {
"oneOf": [
{
"type": "array",
"items": {
"type": "integer",
"format": "int32",
"minimum": 0
}
},
{
"type": "array",
"items": {
"type": "array",
"items": {
"type": "integer",
"format": "int32",
"minimum": 0
}
}
}
]
},
"ModelType": {
"oneOf": [
{
Expand Down
11 changes: 11 additions & 0 deletions proto/tei.proto
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ service Rerank {
service Tokenize {
rpc Tokenize (EncodeRequest) returns (EncodeResponse);
rpc TokenizeStream (stream EncodeRequest) returns (stream EncodeResponse);
rpc Decode (DecodeRequest) returns (DecodeResponse);
rpc DecodeStream (stream DecodeRequest) returns (stream DecodeResponse);
}

message InfoRequest {}
Expand Down Expand Up @@ -166,3 +168,12 @@ message SimpleToken {
message EncodeResponse {
repeated SimpleToken tokens = 1;
}

message DecodeRequest {
repeated uint32 ids = 1;
bool skip_special_tokens = 2;
}

message DecodeResponse {
string text = 1;
}
108 changes: 106 additions & 2 deletions router/src/grpc/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use crate::grpc::pb::tei::v1::{
EncodeResponse, RerankStreamRequest, SimpleToken, SparseValue, TokenEmbedding,
};
use crate::grpc::{
EmbedRequest, EmbedResponse, InfoRequest, InfoResponse, PredictRequest, PredictResponse,
Prediction, Rank, RerankRequest, RerankResponse,
DecodeRequest, DecodeResponse, EmbedRequest, EmbedResponse, InfoRequest, InfoResponse,
PredictRequest, PredictResponse, Prediction, Rank, RerankRequest, RerankResponse,
};
use crate::ResponseMetadata;
use crate::{grpc, shutdown, ErrorResponse, ErrorType, Info, ModelType};
Expand Down Expand Up @@ -331,6 +331,17 @@ impl TextEmbeddingsService {
.collect();
Ok(EncodeResponse { tokens })
}

#[instrument(skip_all)]
async fn decode_inner(&self, request: DecodeRequest) -> Result<DecodeResponse, Status> {
let ids = request.ids;
let text = self
.infer
.decode(ids, request.skip_special_tokens)
.await
.map_err(ErrorResponse::from)?;
Ok(DecodeResponse { text })
}
}

#[tonic::async_trait]
Expand Down Expand Up @@ -1327,6 +1338,99 @@ impl grpc::tokenize_server::Tokenize for TextEmbeddingsService {
response_receiver,
)))
}

async fn decode(
&self,
request: Request<DecodeRequest>,
) -> Result<Response<DecodeResponse>, Status> {
let request = request.into_inner();
let tokens = self.decode_inner(request).await?;
Ok(Response::new(tokens))
}

type DecodeStreamStream = UnboundedReceiverStream<Result<DecodeResponse, Status>>;

async fn decode_stream(
&self,
request: Request<Streaming<DecodeRequest>>,
) -> Result<Response<Self::DecodeStreamStream>, Status> {
let mut request_stream = request.into_inner();

// Create bounded channel to have an upper bound of spawned tasks
// We will have at most `max_parallel_stream_requests` messages from this stream in the queue
let (encode_sender, mut encode_receiver) = mpsc::channel::<(
DecodeRequest,
oneshot::Sender<Result<DecodeResponse, Status>>,
)>(self.max_parallel_stream_requests);

// Required for the async move below
let local = self.clone();

// Background task that uses the bounded channel
tokio::spawn(async move {
while let Some((request, mut sender)) = encode_receiver.recv().await {
// Required for the async move below
let task_local = local.clone();

// Create async task for this specific input
tokio::spawn(async move {
// Select on closed to cancel work if the stream was closed
tokio::select! {
response = task_local.decode_inner(request) => {
let _ = sender.send(response);
}
_ = sender.closed() => {}
}
});
}
});

// Intermediate channels
// Required to keep the order of the requests
let (intermediate_sender, mut intermediate_receiver) = mpsc::unbounded_channel();

tokio::spawn(async move {
// Iterate on input
while let Some(request) = request_stream.next().await {
// Create return channel
let (result_sender, result_receiver) = oneshot::channel();
// Push to intermediate channel and preserve ordering
intermediate_sender
.send(result_receiver)
.expect("`intermediate_receiver` was dropped. This is a bug.");

match request {
Ok(request) => encode_sender
.send((request, result_sender))
.await
.expect("`encode_receiver` was dropped. This is a bug."),
Err(status) => {
// Request is malformed
let _ = result_sender.send(Err(status));
}
};
}
});

// Final channel for the outputs
let (response_sender, response_receiver) = mpsc::unbounded_channel();

tokio::spawn(async move {
while let Some(result_receiver) = intermediate_receiver.recv().await {
// Select on closed to cancel work if the stream was closed
tokio::select! {
response = result_receiver => {
let _ = response_sender.send(response.expect("`result_sender` was dropped. This is a bug."));
}
_ = response_sender.closed() => {}
}
}
});

Ok(Response::new(UnboundedReceiverStream::new(
response_receiver,
)))
}
}

pub async fn run(
Expand Down
Loading

0 comments on commit 986a7a5

Please sign in to comment.