Skip to content

Commit 4aafd76

Browse files
author
jetstream authors
committed
Merge pull request #268 from AI-Hypercomputer:yuyan-prefix-cache-benchmark
PiperOrigin-RevId: 755894145
2 parents 219e5a1 + bbfb5bd commit 4aafd76

File tree

3 files changed

+238
-1
lines changed

3 files changed

+238
-1
lines changed

benchmarks/README.md

+21-1
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ python benchmark_serving.py \
7575
```
7676

7777
## Benchmark with openorca dataset (openorca is used by MLPerf inference for LLaMA2 models)
78+
7879
```
7980
python JetStream/benchmarks/benchmark_serving.py \
8081
--tokenizer ~/maxtext/assets/tokenizer.llama2 \
@@ -93,6 +94,7 @@ python JetStream/benchmarks/benchmark_serving.py \
9394
The benchmark has better performance if it first conducts a warmup of the JetStream server. We currently support `sampled` and `full` warmup modes. `full` mode would warmup up the JetStream server with all the input requests. `sampled` mode would warmup up the JetStream server with a sampling of the input requests across different bucket sizes of input lengths.
9495

9596
Example to run benchmark with `full` warmup mode:
97+
9698
```
9799
python JetStream/benchmarks/benchmark_serving.py \
98100
--tokenizer ~/maxtext/assets/tokenizer.llama2 \
@@ -115,7 +117,25 @@ python eval_accuracy.py outputs.json
115117
```
116118

117119
With openorca dataset and llama2-chat models (used by MLPerf), here are the reference accuracy numbers:
120+
118121
```
119122
llama2-7b-chat {'rouge1': 42.0706, 'rouge2': 19.8021, 'rougeL': 26.8474, 'rougeLsum': 39.5952, 'gen_len': 1146679, 'gen_num': 998}
120123
llama2-70b-chat {'rouge1': 44.4312, 'rouge2': 22.0352, 'rougeL': 28.6162}
121-
```
124+
```
125+
126+
## Benchmark prefix cache
127+
128+
Benchmark with mock input requests that share common prefix. Use to test prefix caching.
129+
130+
All prompts length is `max-input-length`, and share common prefix mean at length `--prefix-cache-test-common-len` with normal distribution.
131+
132+
```
133+
python JetStream/benchmarks/benchmark_serving.py \
134+
--tokenizer prefix_cache_test \
135+
--dataset prefix_cache_test
136+
--warmup-mode full \
137+
--num-prompts 100 \
138+
--max-input-length 16000 \
139+
--prefix-cache-test-common-len 9000\
140+
--max-output-length 50 \
141+
```

benchmarks/benchmark_prefix_cache.sh

+89
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
#!/bin/bash
2+
# Copyright 2025 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
set -e
17+
18+
NUM_PROMPTS=${NUM_PROMPTS:-100}
19+
MAX_OUTPUT_LENGTH=${MAX_OUTPUT_LENGTH:-50}
20+
21+
# Test combination from lengths and common prefix lengths.
22+
# The length should be shorter than max_input_length minus 1 for bos.
23+
BENCHMARK_PROMPT_LENGTHS=${BENCHMARK_PROMPT_LENGTHS:-8000,16000}
24+
BENCHMARK_PROMPT_COMMON_PREFIX_LENGTHS=${BENCHMARK_PROMPT_COMMON_PREFIX_LENGTHS:-4000,6000,8000,10000,12000,14000,16000}
25+
26+
benchmark_serving_with_prefix_cache() {
27+
echo "Starting prefix cache benchmark..."
28+
echo "Benchmark serving script: ${BENCHMARK_SERVING_SCRIPT_PATH}"
29+
echo "Prompt lengths to test: ${BENCHMARK_PROMPT_LENGTHS}"
30+
echo "Common prefix lengths to test: ${BENCHMARK_PROMPT_COMMON_PREFIX_LENGTHS}"
31+
echo "Number of prompts per run: ${NUM_PROMPTS}"
32+
echo "Max output length per prompt: ${MAX_OUTPUT_LENGTH}"
33+
echo "Base output directory for results: ${OUTPUTS_DIR_BASE}"
34+
echo "Warmup mode: ${WARMUP_MODE}"
35+
36+
# Convert comma-separated strings to arrays for iteration
37+
IFS=',' read -r -a prompt_lengths_arr <<< "$BENCHMARK_PROMPT_LENGTHS"
38+
IFS=',' read -r -a common_prefix_lengths_arr <<< "$BENCHMARK_PROMPT_COMMON_PREFIX_LENGTHS"
39+
40+
for prompt_len in "${prompt_lengths_arr[@]}"; do
41+
for common_len in "${common_prefix_lengths_arr[@]}"; do
42+
if [ "${common_len}" -gt "${prompt_len}" ]; then
43+
echo "Skipping: Common prefix length ${common_len} is greater than prompt length ${prompt_len}."
44+
continue
45+
fi
46+
47+
echo "----------------------------------------------------------------------"
48+
echo "Running benchmark: Prompt Length=${prompt_len}, Common Prefix Length=${common_len}"
49+
echo "----------------------------------------------------------------------"
50+
echo "Warm up twice"
51+
echo "----------------------------------------------------------------------"
52+
53+
# With warmup-mode full, it will run twice
54+
python3 ./benchmark_serving.py \
55+
--tokenizer "prefix_cache_test" \
56+
--dataset "prefix_cache_test" \
57+
--num-prompts 10 \
58+
--max-output-length "${MAX_OUTPUT_LENGTH}" \
59+
--warmup-mode "full" \
60+
--max-input-length "${prompt_len}" \
61+
--prefix-cache-test-common-len "${common_len}"
62+
63+
echo "Warm up done"
64+
echo "----------------------------------------------------------------------"
65+
66+
python3 ./benchmark_serving.py \
67+
--tokenizer "prefix_cache_test" \
68+
--dataset "prefix_cache_test" \
69+
--num-prompts "${NUM_PROMPTS}" \
70+
--max-output-length "${MAX_OUTPUT_LENGTH}" \
71+
--warmup-mode "none" \
72+
--max-input-length "${prompt_len}" \
73+
--prefix-cache-test-common-len "${common_len}"
74+
75+
echo "Benchmark finished for Prompt Length=${prompt_len}, Common Prefix Length=${common_len}"
76+
echo "----------------------------------------------------------------------"
77+
echo
78+
done
79+
done
80+
echo "All benchmark runs completed."
81+
}
82+
83+
main() {
84+
benchmark_serving_with_prefix_cache
85+
echo "Script finished."
86+
exit 0
87+
}
88+
89+
main "$@"

benchmarks/benchmark_serving.py

+128
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,25 @@ def to_dict(self):
209209
}
210210

211211

212+
class PrefixCacheTestTokenizer:
213+
"""A simple tokenizer for testing prefix caching.
214+
215+
This tokenizer converts each character in a string to its integer ordinal
216+
value during encoding, and converts a list of integer ordinals back to
217+
a string during decoding. It's designed for testing scenarios, particularly
218+
those involving prefix caching, where a basic, predictable tokenizer is
219+
needed.
220+
"""
221+
222+
def encode(self, s: str, **kwargs) -> list[int]:
223+
del kwargs
224+
return [ord(c) for c in s]
225+
226+
def decode(self, token_ids: list[int], **kwargs) -> str:
227+
del kwargs
228+
return "".join([chr(token_id) for token_id in token_ids])
229+
230+
212231
def get_tokenizer(
213232
model_id: str,
214233
tokenizer_name: str,
@@ -219,6 +238,9 @@ def get_tokenizer(
219238
if tokenizer_name == "test":
220239
print("Using test tokenizer")
221240
return "test"
241+
elif tokenizer_name == "prefix_cache_test":
242+
print("Using prefix_cache_test tokenizer")
243+
return PrefixCacheTestTokenizer()
222244
elif use_hf_tokenizer:
223245
# Please accept agreement to access private/gated models in HF, and
224246
# follow up instructions below to set up access token
@@ -329,6 +351,98 @@ def load_mmlu_dataset_csv(dataset_path: str) -> tuple[Any, dict[str, str]]:
329351
return combined_dataset, prompts_per_subject
330352

331353

354+
def load_mock_prefix_cache_test_input_requests(
355+
prompt_len: int,
356+
output_len: int,
357+
common_prefix_len: int,
358+
num_samples: int,
359+
) -> list[InputRequest]:
360+
"""Generates a mock dataset for testing prefix cache.
361+
362+
The prefix part of each prompt is a sub-string of a single master string.
363+
The length of this prefix part for each sample is drawn from a normal
364+
distribution with its mean set to `common_prefix_len`, and values are
365+
clipped to the range [0, `prompt_len`].
366+
The tokenizer is assumed to treat each character as a token.
367+
368+
Args:
369+
prompt_len: The total length of each generated prompt string.
370+
output_len: The length of each generated output string.
371+
common_prefix_len: The target mean for the length of the prefix part
372+
of each prompt. These prefixes are derived from a
373+
shared master string.
374+
num_samples: The number of (prompt, output) pairs to generate.
375+
376+
Returns:
377+
A list of InputRequest objects.
378+
"""
379+
if not 0 <= common_prefix_len <= prompt_len:
380+
raise ValueError(
381+
"Target mean common_prefix_len must be between 0 and prompt_len,"
382+
f" inclusive. Got common_prefix_len={common_prefix_len}, "
383+
f"prompt_len={prompt_len}"
384+
)
385+
if any(arg <= 0 for arg in [prompt_len, output_len, num_samples]):
386+
raise ValueError(
387+
"prompt_len, output_len, and num_samples cannot be 0 or negative."
388+
)
389+
390+
input_requests: list[InputRequest] = []
391+
392+
# Generate a master string from which all prefixes will be derived.
393+
# This ensures that prefixes of the same length are identical,
394+
# and shorter prefixes are actual prefixes of longer ones.
395+
master_potential_prefix = "".join(
396+
random.choices("ABCDEFGHIJKLMNOPQRSTUVWXYZ", k=prompt_len)
397+
)
398+
399+
# Generate prefix lengths for each sample from a normal distribution
400+
scale = prompt_len / 3.0 # Standard deviation for the normal distribution
401+
402+
generated_prefix_lengths = np.random.normal(
403+
loc=common_prefix_len, scale=scale, size=num_samples
404+
)
405+
generated_prefix_lengths = (
406+
np.clip(generated_prefix_lengths, 0, prompt_len).round().astype(int)
407+
)
408+
409+
for idx in range(num_samples):
410+
current_actual_prefix_len = generated_prefix_lengths[idx]
411+
412+
actual_prefix_for_sample = master_potential_prefix[
413+
:current_actual_prefix_len
414+
]
415+
416+
current_unique_len = prompt_len - current_actual_prefix_len
417+
# This should not happen if generated_prefix_lengths is clipped correctly
418+
if current_unique_len < 0:
419+
current_unique_len = 0 # Safeguard
420+
current_actual_prefix_len = prompt_len
421+
actual_prefix_for_sample = master_potential_prefix[
422+
:current_actual_prefix_len
423+
]
424+
425+
unique_suffix_str = "".join(
426+
random.choices(
427+
"abcdefghijklmnopqrstuvwxyz0123456789", k=current_unique_len
428+
)
429+
)
430+
431+
prompt_str = actual_prefix_for_sample + unique_suffix_str
432+
433+
output_str = "".join(random.choices("!@#$%^&*()_+", k=output_len))
434+
435+
request = InputRequest(
436+
prompt=prompt_str,
437+
prompt_len=len(prompt_str),
438+
output=output_str,
439+
output_len=len(output_str),
440+
sample_idx=idx,
441+
)
442+
input_requests.append(request)
443+
return input_requests
444+
445+
332446
def gen_mmlu_qa(data: Any, mmlu_method: str = "") -> str:
333447

334448
output = ""
@@ -893,6 +1007,7 @@ def parse_args() -> argparse.Namespace:
8931007
"mmlu",
8941008
"math500",
8951009
"longcontext",
1010+
"prefix_cache_test",
8961011
],
8971012
help="The dataset name.",
8981013
)
@@ -1086,6 +1201,12 @@ def parse_args() -> argparse.Namespace:
10861201
choices=["HELM", "Harness", ""],
10871202
help="mmlu method/format to generate shots",
10881203
)
1204+
parser.add_argument(
1205+
"--prefix-cache-test-common-len",
1206+
type=int,
1207+
default=64,
1208+
help="Common prefix length for the prefix cache test dataset.",
1209+
)
10891210
return parser.parse_args()
10901211

10911212

@@ -1112,6 +1233,13 @@ def main(args: argparse.Namespace):
11121233
input_requests = mock_requests(
11131234
args.total_mock_requests
11141235
) # e.g. [("AB", 2, "AB", 3)]
1236+
elif args.dataset == "prefix_cache_test":
1237+
input_requests = load_mock_prefix_cache_test_input_requests(
1238+
prompt_len=args.max_input_length,
1239+
output_len=args.max_output_length,
1240+
common_prefix_len=args.prefix_cache_test_common_len,
1241+
num_samples=args.num_prompts,
1242+
)
11151243
else:
11161244
dataset = []
11171245
if args.dataset == "openorca":

0 commit comments

Comments
 (0)