Skip to content

Commit

Permalink
[V1] Optimize handling of sampling metadata and req_ids list (vllm-pr…
Browse files Browse the repository at this point in the history
…oject#13244)

Signed-off-by: Nick Hill <nhill@redhat.com>
  • Loading branch information
njhill authored Feb 18, 2025
1 parent a4d577b commit 30172b4
Show file tree
Hide file tree
Showing 15 changed files with 255 additions and 298 deletions.
9 changes: 2 additions & 7 deletions tests/v1/sample/test_rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,13 @@ def create_logits_tensor(token_ids: List[int],
def create_sampling_metadata(spec_tokens: List[List[int]]) -> SamplingMetadata:
batch_size = len(spec_tokens)
return SamplingMetadata(
temperature=0.0,
temperature=torch.tensor([]),
all_greedy=True,
all_random=False,
rejection_sampling=True,
spec_token_ids=spec_tokens,
top_p=None,
top_k=None,
no_top_p=False,
no_top_k=False,
min_p=torch.empty(batch_size, ),
no_min_p=True,
generators={},
max_num_logprobs=0,
no_penalties=False,
Expand All @@ -45,8 +41,7 @@ def create_sampling_metadata(spec_tokens: List[List[int]]) -> SamplingMetadata:
presence_penalties=torch.tensor([]),
repetition_penalties=torch.tensor([]),
output_token_ids=[],
min_tokens=[],
stop_token_ids=[],
min_tokens={},
logit_bias=[None] * batch_size,
)

Expand Down
44 changes: 18 additions & 26 deletions tests/v1/sample/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,25 +77,20 @@ def _create_default_sampling_metadata(
temperature=torch.full((batch_size, ), 0.0),
all_greedy=True,
all_random=False,
rejection_sampling=False,
top_p=torch.empty(batch_size, ),
top_k=torch.empty(batch_size, ),
no_top_p=True,
no_top_k=True,
min_p=torch.empty(batch_size, ),
no_min_p=True,
top_p=None,
top_k=None,
min_p=None,
generators={},
max_num_logprobs=0,
prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids,
vocab_size, device),
output_token_ids=output_token_ids,
spec_token_ids=[],
spec_token_ids=None,
frequency_penalties=_create_penalty_tensor(batch_size, 0.0, device),
presence_penalties=_create_penalty_tensor(batch_size, 0.0, device),
repetition_penalties=_create_penalty_tensor(batch_size, 1.0, device),
no_penalties=True,
min_tokens=[],
stop_token_ids=[],
min_tokens={},
logit_bias=[None] * batch_size,
)
return fake_sampling_metadata
Expand All @@ -104,33 +99,30 @@ def _create_default_sampling_metadata(
def _generate_min_token_penalties_and_stop_tokens(
num_output_tokens: int, batch_size: int, vocab_size: int,
batch_indices_for_min_token_penalty: List[int]
) -> Tuple[List[int], List[Set[int]]]:
) -> Dict[int, Tuple[int, Set[int]]]:
"""
Generates and returns a list of minimum token penalties (`min_tokens`)
and a corresponding list of stop token IDs (`stop_token_ids`) for each
Generates and returns a dict of minimum token penalties and
corresponding stop token IDs (`min_tokens`, `stop_token_ids`) for each
batch.
If a batch index is included in `batch_indices_for_min_token_penalty`,
a higher `min_tokens` value is assigned (within a randomized range),
and a random set of stop token IDs is created. Otherwise, a lower
`min_tokens` value is assigned, and the stop token IDs set is empty.
"""
stop_token_ids: List[Set[int]] = []
min_tokens: List[int] = []
min_tokens: Dict[int, Tuple[int, Set[int]]] = {}
for index in range(batch_size):
if index in batch_indices_for_min_token_penalty:
min_tokens.append(
min_tokens[index] = (
np.random.randint(num_output_tokens + 1,
2 * num_output_tokens))
stop_token_ids.append(
2 * num_output_tokens),
set(
np.random.randint(0, vocab_size - 1)
for _ in range(np.random.randint(0, vocab_size))))

else:
min_tokens.append(np.random.randint(0, num_output_tokens))
stop_token_ids.append(set())
return (min_tokens, stop_token_ids)
min_tokens[index] = (np.random.randint(0,
num_output_tokens), set())
return min_tokens


def _create_weighted_output_token_list(
Expand Down Expand Up @@ -165,7 +157,7 @@ def _create_weighted_output_token_list(
output_token_ids_for_batch.extend(
[token_id for _ in range(index + 1)])
output_token_ids.append(output_token_ids_for_batch)
return (output_token_ids, sorted_token_ids_in_output)
return output_token_ids, sorted_token_ids_in_output


@pytest.mark.parametrize("device", CUDA_DEVICES)
Expand All @@ -182,17 +174,17 @@ def test_sampler_min_tokens_penalty(device: str, batch_size: int):
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
batch_indices_for_min_token_penalty = np.random.randint(
0, batch_size - 1, size=np.random.randint(0, batch_size)).tolist()
min_tokens, stop_token_ids = _generate_min_token_penalties_and_stop_tokens(
min_tokens = _generate_min_token_penalties_and_stop_tokens(
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE,
batch_indices_for_min_token_penalty)
sampling_metadata.min_tokens = min_tokens
sampling_metadata.stop_token_ids = stop_token_ids
sampler = Sampler()
logits = sampler.apply_penalties(fake_logits, sampling_metadata)
logits = logits.cpu()
for batch_idx in range(batch_size):
for token_id in range(VOCAB_SIZE):
if token_id in stop_token_ids[batch_idx]:
_, stop_token_ids = min_tokens.get(batch_idx, (0, set()))
if token_id in stop_token_ids:
assert logits[batch_idx][token_id] == -float("inf")
else:
assert logits[batch_idx][token_id] != -float("inf")
Expand Down
47 changes: 21 additions & 26 deletions tests/v1/worker/test_gpu_input_batch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0

from typing import Dict, List, Set, Tuple
from typing import Dict, List, Optional, Set, Tuple

import numpy as np
import pytest
Expand Down Expand Up @@ -41,7 +41,7 @@ def _remove_requests(
for index in req_indices_to_remove:
input_batch.remove_request(reqs[index].req_id)
req_ids_to_remove.add(reqs[index].req_id)
return (req_ids_to_remove, req_indices_to_remove_list)
return req_ids_to_remove, req_indices_to_remove_list


def _construct_expected_sampling_metadata(
Expand All @@ -64,8 +64,7 @@ def _construct_expected_sampling_metadata(
top_p = [0.0 for _ in range(num_reqs)]
min_p = [0.0 for _ in range(num_reqs)]
temperature = [0.0 for _ in range(num_reqs)]
stop_token_ids: List[Set[int]] = [set() for _ in range(num_reqs)]
min_tokens = [0 for _ in range(num_reqs)]
min_tokens = {}
logit_bias = [None] * num_reqs
for req in reqs:
if req.req_id not in req_ids_retained:
Expand All @@ -83,22 +82,21 @@ def _construct_expected_sampling_metadata(
top_p[index_in_input_batch] = req.sampling_params.top_p
min_p[index_in_input_batch] = req.sampling_params.min_p
temperature[index_in_input_batch] = req.sampling_params.temperature
stop_token_ids[
index_in_input_batch] = req.sampling_params.all_stop_token_ids
min_tokens[index_in_input_batch] = req.sampling_params.min_tokens
min_tokens[index_in_input_batch] = (
req.sampling_params.min_tokens,
req.sampling_params.all_stop_token_ids)
logit_bias[index_in_input_batch] = req.sampling_params.logit_bias
return SamplingMetadata(
temperature=torch.tensor(temperature, dtype=torch.float,
device=device),
all_greedy=False,
all_random=True,
rejection_sampling=False,
top_p=torch.tensor(top_p, dtype=torch.float, device=device),
top_k=torch.tensor(top_k, dtype=torch.int, device=device),
no_top_p=all(x == 1.0 for x in top_p),
no_top_k=all(x == 0 for x in top_k),
min_p=torch.tensor(min_p, dtype=torch.float, device=device),
no_min_p=all(x == 0.0 for x in min_p),
top_p=None if all(x == 1.0 for x in top_p) else torch.tensor(
top_p, dtype=torch.float, device=device),
top_k=None if all(x == 0 for x in top_k) else torch.tensor(
top_k, dtype=torch.int, device=device),
min_p=None if all(x == 0.0 for x in min_p) else torch.tensor(
min_p, dtype=torch.float, device=device),
generators={},
max_num_logprobs=0,
prompt_token_ids=make_tensor_with_pad(
Expand All @@ -117,9 +115,8 @@ def _construct_expected_sampling_metadata(
dtype=torch.float,
device=device),
output_token_ids=output_token_ids,
spec_token_ids=[],
spec_token_ids=None,
min_tokens=min_tokens,
stop_token_ids=stop_token_ids,
no_penalties=(all(x == 0 for x in presence_penalties)
and all(x == 0 for x in frequency_penalties)
and all(x == 1 for x in repetition_penalties)),
Expand Down Expand Up @@ -206,8 +203,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
input_batch.condense(req_indices_to_remove)

# Generate the sampling metadata
sampling_metadata = input_batch.make_sampling_metadata(
req_id_output_token_ids, req_id_to_spec_token_ids={}, skip_copy=False)
sampling_metadata = input_batch._make_sampling_metadata()

# Create expected output.
expected_sampling_metadata = _construct_expected_sampling_metadata(
Expand All @@ -216,13 +212,16 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
input_batch.req_id_to_index,
device=torch.device(device))

def same(t1: Optional[torch.Tensor], t2: Optional[torch.Tensor]) -> bool:
return (t1 is None
and t2 is None) or (t1 is not None and t2 is not None
and torch.allclose(t1, t2))

# Assert the actual and expected output.
assert torch.allclose(expected_sampling_metadata.temperature,
sampling_metadata.temperature)
assert torch.allclose(expected_sampling_metadata.top_p,
sampling_metadata.top_p)
assert torch.allclose(expected_sampling_metadata.top_k,
sampling_metadata.top_k)
assert same(expected_sampling_metadata.top_p, sampling_metadata.top_p)
assert same(expected_sampling_metadata.top_k, sampling_metadata.top_k)
assert torch.allclose(
expected_sampling_metadata.frequency_penalties,
sampling_metadata.frequency_penalties,
Expand All @@ -240,10 +239,6 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
assert (expected_sampling_metadata.output_token_ids ==
sampling_metadata.output_token_ids)
assert expected_sampling_metadata.min_tokens == sampling_metadata.min_tokens
assert expected_sampling_metadata.stop_token_ids == \
sampling_metadata.stop_token_ids
assert expected_sampling_metadata.no_penalties == \
sampling_metadata.no_penalties
assert expected_sampling_metadata.no_top_p == sampling_metadata.no_top_p
assert expected_sampling_metadata.no_top_k == sampling_metadata.no_top_k
assert expected_sampling_metadata.logit_bias == sampling_metadata.logit_bias
33 changes: 22 additions & 11 deletions tests/v1/worker/test_gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from vllm.sampling_params import SamplingParams
from vllm.v1.core.scheduler_output import (CachedRequestData, NewRequestData,
SchedulerOutput)
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu_model_runner import GPUModelRunner


Expand Down Expand Up @@ -82,14 +83,21 @@ def _is_req_added(model_runner, req_id: str) -> bool:
return req_id in model_runner.requests


def _is_sampling_metadata_changed(model_runner,
sampling_metadata_before: SamplingMetadata):
return model_runner.input_batch.sampling_metadata is not (
sampling_metadata_before)


def test_update_states_new_request(model_runner):
req_id = "req_0"

# new req
scheduler_output = _schedule_new_request(req_id)

batch_changed = model_runner._update_states(scheduler_output)
assert batch_changed is True
metadata_before = model_runner.input_batch.sampling_metadata
model_runner._update_states(scheduler_output)
assert _is_sampling_metadata_changed(model_runner, metadata_before)
assert _is_req_added(model_runner, req_id)
assert _is_req_scheduled(model_runner, req_id)

Expand Down Expand Up @@ -117,8 +125,9 @@ def test_update_states_request_finished(model_runner):
free_encoder_input_ids=[],
)

batch_changed = model_runner._update_states(scheduler_output)
assert batch_changed is True
metadata_before = model_runner.input_batch.sampling_metadata
model_runner._update_states(scheduler_output)
assert _is_sampling_metadata_changed(model_runner, metadata_before)
assert not _is_req_added(model_runner, req_id)
assert not _is_req_scheduled(model_runner, req_id)

Expand All @@ -142,7 +151,7 @@ def test_update_states_request_resumed(model_runner):
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids={},
finished_req_ids=set(),
free_encoder_input_ids=[],
)

Expand Down Expand Up @@ -171,8 +180,9 @@ def test_update_states_request_resumed(model_runner):
free_encoder_input_ids=[],
)

batch_changed = model_runner._update_states(scheduler_output)
assert batch_changed is True
metadata_before = model_runner.input_batch.sampling_metadata
model_runner._update_states(scheduler_output)
assert _is_sampling_metadata_changed(model_runner, metadata_before)
assert _is_req_added(model_runner, req_id)
assert _is_req_scheduled(model_runner, req_id)

Expand Down Expand Up @@ -200,8 +210,9 @@ def test_update_states_no_changes(model_runner):
free_encoder_input_ids=[],
)

batch_changed = model_runner._update_states(scheduler_output)
assert batch_changed is False
metadata_before = model_runner.input_batch.sampling_metadata
model_runner._update_states(scheduler_output)
assert not _is_sampling_metadata_changed(model_runner, metadata_before)
assert _is_req_added(model_runner, req_id)
assert _is_req_scheduled(model_runner, req_id)

Expand Down Expand Up @@ -233,8 +244,8 @@ def test_update_states_request_unscheduled(model_runner):
free_encoder_input_ids=[],
)

batch_changed = model_runner._update_states(scheduler_output)
assert batch_changed is True
metadata_before = model_runner._update_states(scheduler_output)
assert _is_sampling_metadata_changed(model_runner, metadata_before)

assert _is_req_added(model_runner, req_ids[0])
assert _is_req_scheduled(model_runner, req_ids[0])
Expand Down
6 changes: 3 additions & 3 deletions vllm/model_executor/layers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,14 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
vocab_size, num_seqs)
output_bin_counts, output_mask = get_token_bin_counts_and_mask(
output_tokens_tensor, vocab_size, num_seqs)
repetition_penalties = repetition_penalties.unsqueeze_(dim=1).repeat(
repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat(
1, vocab_size)
logits[logits > 0] /= torch.where(prompt_mask | output_mask,
repetition_penalties, 1.0)[logits > 0]
logits[logits <= 0] *= torch.where(prompt_mask | output_mask,
repetition_penalties, 1.0)[logits <= 0]
# We follow the definition in OpenAI API.
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts
logits -= presence_penalties.unsqueeze_(dim=1) * output_mask
logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts
logits -= presence_penalties.unsqueeze(dim=1) * output_mask
return logits
6 changes: 4 additions & 2 deletions vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,10 @@ def schedule(self) -> "SchedulerOutput":
request.num_computed_tokens -
request.num_tokens)
if num_scheduled_spec_tokens > 0:
# Trim spec_token_ids list to num_scheduled_spec_tokens.
del request.spec_token_ids[num_scheduled_spec_tokens:]
scheduled_spec_decode_tokens[request.request_id] = (
request.spec_token_ids[:num_scheduled_spec_tokens])
request.spec_token_ids)

# Encoder-related.
if encoder_inputs_to_schedule:
Expand Down Expand Up @@ -567,7 +569,7 @@ def update_from_output(
outputs.append(
EngineCoreOutput(
request_id=req_id,
new_token_ids=new_token_ids or [],
new_token_ids=new_token_ids,
finish_reason=request.get_finished_reason(),
new_logprobs=new_logprobs,
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
Expand Down
Loading

0 comments on commit 30172b4

Please sign in to comment.