diff --git a/sharktank/sharktank/examples/export_paged_llm_v1.py b/sharktank/sharktank/examples/export_paged_llm_v1.py index 7c12f4e20..a05b4de1b 100644 --- a/sharktank/sharktank/examples/export_paged_llm_v1.py +++ b/sharktank/sharktank/examples/export_paged_llm_v1.py @@ -83,6 +83,7 @@ def main(): block_seq_stride=args.block_seq_stride, activation_dtype=args.activation_dtype, attention_dtype=args.attention_dtype, + n_beams=args.n_beams, ) llama_config.fake_quant = args.fake_quant @@ -112,6 +113,7 @@ def generate_params_json( "prefill_batch_sizes": prefill_bs, "decode_batch_sizes": decode_bs, "transformer_block_count": hp.block_count, + "n_beams": llama_config.n_beams, "paged_kv_cache": { "attention_head_count_kv": hp.attention_head_count_kv, "block_seq_stride": llama_config.block_seq_stride, @@ -251,14 +253,15 @@ def generate_batch_decode(bs: int): block_dim_min = 2 block_dim_max = ceildiv(hp.context_length, llama_config.block_seq_stride) - 1 block_dim = torch.export.Dim("block", min=block_dim_min, max=block_dim_max) + batch_dim_size = bs * llama_config.n_beams tokens = torch.empty( - bs, + batch_dim_size, 1, dtype=torch.int64, ) - seq_lens = torch.empty(bs, dtype=torch.int64) - start_positions = torch.ones(bs, dtype=torch.int64) - seq_block_ids = torch.empty(bs, block_dim_min, dtype=torch.int64) + seq_lens = torch.empty(batch_dim_size, dtype=torch.int64) + start_positions = torch.ones(batch_dim_size, dtype=torch.int64) + seq_block_ids = torch.empty(batch_dim_size, block_dim_min, dtype=torch.int64) ( cache_state, diff --git a/sharktank/sharktank/layers/configs/llm_configs.py b/sharktank/sharktank/layers/configs/llm_configs.py index 7d1e506a0..a89b7ab38 100644 --- a/sharktank/sharktank/layers/configs/llm_configs.py +++ b/sharktank/sharktank/layers/configs/llm_configs.py @@ -193,6 +193,11 @@ class LlamaModelConfig: # the program and not. static_tables: bool = True + # The number of beams to use when generating tokens for a given prompt. + # When n_beams == 1, `greedy` selection is used, + # when n_beams > 1, `beam search` is used. + n_beams: int = 1 + @dataclass class T5Config: diff --git a/sharktank/sharktank/utils/cli.py b/sharktank/sharktank/utils/cli.py index e3dba31fa..3d65a42cf 100644 --- a/sharktank/sharktank/utils/cli.py +++ b/sharktank/sharktank/utils/cli.py @@ -120,6 +120,12 @@ def add_model_options(parser: argparse.ArgumentParser): type=int, default=512, ) + parser.add_argument( + "--n-beams", + help="Number of beams to use when generating tokens.", + type=int, + default=1, + ) def add_quantization_options(parser: argparse.ArgumentParser): diff --git a/shortfin/python/shortfin_apps/llm/components/config_struct.py b/shortfin/python/shortfin_apps/llm/components/config_struct.py index 6ea6802e3..3a953fd2d 100644 --- a/shortfin/python/shortfin_apps/llm/components/config_struct.py +++ b/shortfin/python/shortfin_apps/llm/components/config_struct.py @@ -141,6 +141,9 @@ class ModelParams: # Cache parameters. paged_kv_cache: PagedKVCacheParams | None = None + # Number of beams to use during token generation. + n_beams: int = 1 + # Size in bytes of the KV cache dtype. @property def attn_dtype_size(self) -> int: