@@ -209,6 +209,25 @@ def to_dict(self):
209
209
}
210
210
211
211
212
+ class PrefixCacheTestTokenizer :
213
+ """A simple tokenizer for testing prefix caching.
214
+
215
+ This tokenizer converts each character in a string to its integer ordinal
216
+ value during encoding, and converts a list of integer ordinals back to
217
+ a string during decoding. It's designed for testing scenarios, particularly
218
+ those involving prefix caching, where a basic, predictable tokenizer is
219
+ needed.
220
+ """
221
+
222
+ def encode (self , s : str , ** kwargs ) -> list [int ]:
223
+ del kwargs
224
+ return [ord (c ) for c in s ]
225
+
226
+ def decode (self , token_ids : list [int ], ** kwargs ) -> str :
227
+ del kwargs
228
+ return "" .join ([chr (token_id ) for token_id in token_ids ])
229
+
230
+
212
231
def get_tokenizer (
213
232
model_id : str ,
214
233
tokenizer_name : str ,
@@ -219,6 +238,9 @@ def get_tokenizer(
219
238
if tokenizer_name == "test" :
220
239
print ("Using test tokenizer" )
221
240
return "test"
241
+ elif tokenizer_name == "prefix_cache_test" :
242
+ print ("Using prefix_cache_test tokenizer" )
243
+ return PrefixCacheTestTokenizer ()
222
244
elif use_hf_tokenizer :
223
245
# Please accept agreement to access private/gated models in HF, and
224
246
# follow up instructions below to set up access token
@@ -329,6 +351,98 @@ def load_mmlu_dataset_csv(dataset_path: str) -> tuple[Any, dict[str, str]]:
329
351
return combined_dataset , prompts_per_subject
330
352
331
353
354
+ def load_mock_prefix_cache_test_input_requests (
355
+ prompt_len : int ,
356
+ output_len : int ,
357
+ common_prefix_len : int ,
358
+ num_samples : int ,
359
+ ) -> list [InputRequest ]:
360
+ """Generates a mock dataset for testing prefix cache.
361
+
362
+ The prefix part of each prompt is a sub-string of a single master string.
363
+ The length of this prefix part for each sample is drawn from a normal
364
+ distribution with its mean set to `common_prefix_len`, and values are
365
+ clipped to the range [0, `prompt_len`].
366
+ The tokenizer is assumed to treat each character as a token.
367
+
368
+ Args:
369
+ prompt_len: The total length of each generated prompt string.
370
+ output_len: The length of each generated output string.
371
+ common_prefix_len: The target mean for the length of the prefix part
372
+ of each prompt. These prefixes are derived from a
373
+ shared master string.
374
+ num_samples: The number of (prompt, output) pairs to generate.
375
+
376
+ Returns:
377
+ A list of InputRequest objects.
378
+ """
379
+ if not 0 <= common_prefix_len <= prompt_len :
380
+ raise ValueError (
381
+ "Target mean common_prefix_len must be between 0 and prompt_len,"
382
+ f" inclusive. Got common_prefix_len={ common_prefix_len } , "
383
+ f"prompt_len={ prompt_len } "
384
+ )
385
+ if any (arg <= 0 for arg in [prompt_len , output_len , num_samples ]):
386
+ raise ValueError (
387
+ "prompt_len, output_len, and num_samples cannot be 0 or negative."
388
+ )
389
+
390
+ input_requests : list [InputRequest ] = []
391
+
392
+ # Generate a master string from which all prefixes will be derived.
393
+ # This ensures that prefixes of the same length are identical,
394
+ # and shorter prefixes are actual prefixes of longer ones.
395
+ master_potential_prefix = "" .join (
396
+ random .choices ("ABCDEFGHIJKLMNOPQRSTUVWXYZ" , k = prompt_len )
397
+ )
398
+
399
+ # Generate prefix lengths for each sample from a normal distribution
400
+ scale = prompt_len / 3.0 # Standard deviation for the normal distribution
401
+
402
+ generated_prefix_lengths = np .random .normal (
403
+ loc = common_prefix_len , scale = scale , size = num_samples
404
+ )
405
+ generated_prefix_lengths = (
406
+ np .clip (generated_prefix_lengths , 0 , prompt_len ).round ().astype (int )
407
+ )
408
+
409
+ for idx in range (num_samples ):
410
+ current_actual_prefix_len = generated_prefix_lengths [idx ]
411
+
412
+ actual_prefix_for_sample = master_potential_prefix [
413
+ :current_actual_prefix_len
414
+ ]
415
+
416
+ current_unique_len = prompt_len - current_actual_prefix_len
417
+ # This should not happen if generated_prefix_lengths is clipped correctly
418
+ if current_unique_len < 0 :
419
+ current_unique_len = 0 # Safeguard
420
+ current_actual_prefix_len = prompt_len
421
+ actual_prefix_for_sample = master_potential_prefix [
422
+ :current_actual_prefix_len
423
+ ]
424
+
425
+ unique_suffix_str = "" .join (
426
+ random .choices (
427
+ "abcdefghijklmnopqrstuvwxyz0123456789" , k = current_unique_len
428
+ )
429
+ )
430
+
431
+ prompt_str = actual_prefix_for_sample + unique_suffix_str
432
+
433
+ output_str = "" .join (random .choices ("!@#$%^&*()_+" , k = output_len ))
434
+
435
+ request = InputRequest (
436
+ prompt = prompt_str ,
437
+ prompt_len = len (prompt_str ),
438
+ output = output_str ,
439
+ output_len = len (output_str ),
440
+ sample_idx = idx ,
441
+ )
442
+ input_requests .append (request )
443
+ return input_requests
444
+
445
+
332
446
def gen_mmlu_qa (data : Any , mmlu_method : str = "" ) -> str :
333
447
334
448
output = ""
@@ -893,6 +1007,7 @@ def parse_args() -> argparse.Namespace:
893
1007
"mmlu" ,
894
1008
"math500" ,
895
1009
"longcontext" ,
1010
+ "prefix_cache_test" ,
896
1011
],
897
1012
help = "The dataset name." ,
898
1013
)
@@ -1086,6 +1201,12 @@ def parse_args() -> argparse.Namespace:
1086
1201
choices = ["HELM" , "Harness" , "" ],
1087
1202
help = "mmlu method/format to generate shots" ,
1088
1203
)
1204
+ parser .add_argument (
1205
+ "--prefix-cache-test-common-len" ,
1206
+ type = int ,
1207
+ default = 64 ,
1208
+ help = "Common prefix length for the prefix cache test dataset." ,
1209
+ )
1089
1210
return parser .parse_args ()
1090
1211
1091
1212
@@ -1112,6 +1233,13 @@ def main(args: argparse.Namespace):
1112
1233
input_requests = mock_requests (
1113
1234
args .total_mock_requests
1114
1235
) # e.g. [("AB", 2, "AB", 3)]
1236
+ elif args .dataset == "prefix_cache_test" :
1237
+ input_requests = load_mock_prefix_cache_test_input_requests (
1238
+ prompt_len = args .max_input_length ,
1239
+ output_len = args .max_output_length ,
1240
+ common_prefix_len = args .prefix_cache_test_common_len ,
1241
+ num_samples = args .num_prompts ,
1242
+ )
1115
1243
else :
1116
1244
dataset = []
1117
1245
if args .dataset == "openorca" :
0 commit comments