Skip to content

Commit

Permalink
[Pipelines] Refactor next_token to return dict[str,
Browse files Browse the repository at this point in the history
TextGenerationResponse]

This PR refactors the `next_token` interface to enable variable token
length responses from the pipeline on a per-request basis.

Instead of returning a `list[dict[str, TextResponse]]` and implicitly
identifying request completion based on the keys returned in the
dictionary. This PR refactors `next_token` to return a `dict[str,
TextGenerationResponse]`, in which the dictionary keys align with the
request ids provided.

The newly introduced `TextGenerationResponse` includes variable token
length arrays, and explicitly states the final status
(`TextGenerationStatus`: ACTIVE, MAXIMUM_LENGTH, or
END_OF_SEQUENCE).This
hardens our interface to completion tracking, enables more complex
decoding strategies, and provides the server the opportunity to provide
a corrected `finish_reason`.

MODULAR_ORIG_COMMIT_REV_ID: fa8c7ff5af9470d25d3f86c966c633f783cc37ec
  • Loading branch information
KCaverly authored and modularbot committed Feb 27, 2025
1 parent 98333f2 commit e9be43e
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 50 deletions.
30 changes: 15 additions & 15 deletions src/max/entrypoints/cli/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,22 +76,22 @@ async def stream_text_to_console(
num_steps=num_steps,
)

for response in responses:
if req_id not in response:
# next_token is expected to omit the return if
# it encounters eos.
for request_idx, response in responses.items():
if response.is_done:
generate_again = False
break

encoded_text = response[req_id].next_token
response_text = await tokenizer.decode(context, encoded_text)
if metrics:
if first_token:
first_token = False
metrics.signpost("first_token")
metrics.new_token()
if print_tokens:
print(response_text, end="", flush=True)

for text_response in response.tokens:
encoded_text = text_response.next_token
response_text = await tokenizer.decode(
context, encoded_text
)
if metrics:
if first_token:
first_token = False
metrics.signpost("first_token")
metrics.new_token()
if print_tokens:
print(response_text, end="", flush=True)

# Yield to the event loop. If at no other point (e.g.
# tokenizer.decode which we await earlier does not yield to the
Expand Down
29 changes: 21 additions & 8 deletions src/max/pipelines/hf_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
from .interfaces import (
EmbeddingsGenerator,
EmbeddingsResponse,
TextGenerationResponse,
TextGenerationStatus,
TextResponse,
TokenGenerator,
)
Expand Down Expand Up @@ -119,7 +121,7 @@ def __init__(

def next_token(
self, batch: dict[str, TextContext], num_steps: int
) -> list[dict[str, TextResponse]]:
) -> dict[str, TextGenerationResponse]:
"""Provided a batch, process batch inputs, execute the graph for num_steps in a multi-step scenario,
then decode the tokens holistically and return the list of decoded tokens.
"""
Expand Down Expand Up @@ -183,8 +185,10 @@ def next_token(
generated_tokens = generated_tokens.cpu()

# Prepare the response, pruning away completed requests as we go.
res: list[dict[str, TextResponse]] = [{} for i in range(num_steps)]
res: dict[str, TextGenerationResponse] = {}
for batch_idx, (request_id, context) in enumerate(batch.items()):
status = TextGenerationStatus.ACTIVE
res[request_id] = TextGenerationResponse([], status)
for step in range(num_steps):
next_token_id = generated_tokens[batch_idx, step].item()

Expand All @@ -200,13 +204,22 @@ def next_token(
if context.max_length is None
else context.max_length
)
if (
next_token_id in self._eos_token_id
or context.current_length > max_length
):
break

res[step][request_id] = TextResponse(next_token)
if next_token_id in self._eos_token_id:
status = TextGenerationStatus.END_OF_SEQUENCE
res[request_id].update_status(status)
elif context.current_length > max_length:
status = TextGenerationStatus.MAXIMUM_LENGTH
res[request_id].update_status(status)
elif context.current_length == max_length:
res[request_id].append_token(TextResponse(next_token))
status = TextGenerationStatus.MAXIMUM_LENGTH
res[request_id].update_status(status)
else:
res[request_id].append_token(TextResponse(next_token))

if status.is_done:
break

return res

Expand Down
20 changes: 16 additions & 4 deletions src/max/pipelines/interfaces/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,23 @@ class TextGenerationResponse:
def __init__(
self, tokens: list[TextResponse], final_status: TextGenerationStatus
) -> None:
self.tokens = tokens
self.final_status = final_status
self._tokens = tokens
self._final_status = final_status

@property
def is_done(self) -> bool:
return self.final_status.is_done
return self._final_status.is_done

@property
def tokens(self) -> list[TextResponse]:
return self._tokens

@property
def final_status(self) -> TextGenerationStatus:
return self._final_status

def append_token(self, token: TextResponse) -> None:
self.tokens.append(token)
self._tokens.append(token)

def update_status(self, status: TextGenerationStatus) -> None:
self._final_status = status
4 changes: 2 additions & 2 deletions src/max/pipelines/interfaces/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
runtime_checkable,
)

from .response import TextResponse
from .response import TextGenerationResponse


class TokenGeneratorRequestFunction(TypedDict):
Expand Down Expand Up @@ -258,7 +258,7 @@ class TokenGenerator(Generic[TokenGeneratorContext], Protocol):

def next_token(
self, batch: dict[str, TokenGeneratorContext], num_steps: int
) -> list[dict[str, TextResponse]]:
) -> dict[str, TextGenerationResponse]:
"""Computes the next token response for a single batch.
Args:
Expand Down
56 changes: 35 additions & 21 deletions src/max/pipelines/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,13 @@

from .config import PipelineConfig
from .context import InputContext
from .interfaces import LogProbabilities, TextResponse, TokenGenerator
from .interfaces import (
LogProbabilities,
TextGenerationResponse,
TextGenerationStatus,
TextResponse,
TokenGenerator,
)
from .kv_cache import KVCacheManager, KVCacheParams
from .sampling import token_sampler

Expand Down Expand Up @@ -565,7 +571,7 @@ def next_token(
self,
batch: dict[str, T],
num_steps: int,
) -> list[dict[str, TextResponse]]:
) -> dict[str, TextGenerationResponse]:
"""Provided a batch, process batch inputs, execute the graph for num_steps in a multi-step scenario,
then decode the tokens holistically and return the list of decoded tokens.
"""
Expand Down Expand Up @@ -695,11 +701,12 @@ def next_token(
tracer.pop() # pops kv_manager.step

# Prepare the response, pruning away completed requests as we go.
res: list[dict[str, TextResponse]] = [{} for _ in range(num_steps)]
res: dict[str, TextGenerationResponse] = {}
tracer.push("prepare_response")
for batch_index, (request_id, context) in enumerate(batch.items()):
step = 0
while step < num_steps:
status = TextGenerationStatus.ACTIVE
res[request_id] = TextGenerationResponse([], status)
for step in range(num_steps):
# Convert to a Python scalar to improve serialization performance.
next_token = int(generated_tokens_host[batch_index, step])

Expand All @@ -715,30 +722,37 @@ def next_token(
default=context.max_length,
)

# The current length is incremented above, during context.update
# As such, if we are already at the max length, exiting here
# would cause us to miss updating the request.
# As such, we overrun here by 1, ensuring that the context object
# tracks special tokens like eos_token_id appropriately for benchmarking
# and other uses, but that they are not returned in the request.
if (
next_token in self._eos_token_id
or context.current_length > max_length
):
step += 1
break

# Set up TextResponse
log_probs: Optional[LogProbabilities] = None
if compute_log_probabilities and (
log_probs_for_step := batch_log_probabilities[step]
):
log_probs = log_probs_for_step[batch_index]

# Removing the positional arguments here, go about 100us faster.
res[step][request_id] = TextResponse(next_token, log_probs)
# Update status
# If its eos, dont add it to the token array.
if next_token in self._eos_token_id:
status = TextGenerationStatus.END_OF_SEQUENCE
res[request_id].update_status(status)
elif context.current_length == max_length:
status = TextGenerationStatus.MAXIMUM_LENGTH
res[request_id].append_token(
TextResponse(next_token, log_probs)
)
res[request_id].update_status(status)
# This practically, should not be hit, as once the context object
# reaches the max_length, we should break from this current loop.
# TODO: Explore cleaning up max length checks.
elif context.current_length > max_length:
status = TextGenerationStatus.MAXIMUM_LENGTH
res[request_id].update_status(status)
else:
res[request_id].append_token(
TextResponse(next_token, log_probs)
)

step += 1
if status.is_done:
break

return res

Expand Down

0 comments on commit e9be43e

Please sign in to comment.