Skip to content

Commit

Permalink
Add beam_search ability to sharktank LLM model
Browse files Browse the repository at this point in the history
  • Loading branch information
stbaione committed Feb 21, 2025
1 parent 888a98a commit c28bf2a
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 4 deletions.
11 changes: 7 additions & 4 deletions sharktank/sharktank/examples/export_paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def main():
block_seq_stride=args.block_seq_stride,
activation_dtype=args.activation_dtype,
attention_dtype=args.attention_dtype,
n_beams=args.n_beams,
)
llama_config.fake_quant = args.fake_quant

Expand Down Expand Up @@ -112,6 +113,7 @@ def generate_params_json(
"prefill_batch_sizes": prefill_bs,
"decode_batch_sizes": decode_bs,
"transformer_block_count": hp.block_count,
"n_beams": llama_config.n_beams,
"paged_kv_cache": {
"attention_head_count_kv": hp.attention_head_count_kv,
"block_seq_stride": llama_config.block_seq_stride,
Expand Down Expand Up @@ -251,14 +253,15 @@ def generate_batch_decode(bs: int):
block_dim_min = 2
block_dim_max = ceildiv(hp.context_length, llama_config.block_seq_stride) - 1
block_dim = torch.export.Dim("block", min=block_dim_min, max=block_dim_max)
batch_dim_size = bs * llama_config.n_beams
tokens = torch.empty(
bs,
batch_dim_size,
1,
dtype=torch.int64,
)
seq_lens = torch.empty(bs, dtype=torch.int64)
start_positions = torch.ones(bs, dtype=torch.int64)
seq_block_ids = torch.empty(bs, block_dim_min, dtype=torch.int64)
seq_lens = torch.empty(batch_dim_size, dtype=torch.int64)
start_positions = torch.ones(batch_dim_size, dtype=torch.int64)
seq_block_ids = torch.empty(batch_dim_size, block_dim_min, dtype=torch.int64)

(
cache_state,
Expand Down
5 changes: 5 additions & 0 deletions sharktank/sharktank/layers/configs/llm_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,11 @@ class LlamaModelConfig:
# the program and not.
static_tables: bool = True

# The number of beams to use when generating tokens for a given prompt.
# When n_beams == 1, `greedy` selection is used,
# when n_beams > 1, `beam search` is used.
n_beams: int = 1


@dataclass
class T5Config:
Expand Down
6 changes: 6 additions & 0 deletions sharktank/sharktank/utils/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,12 @@ def add_model_options(parser: argparse.ArgumentParser):
type=int,
default=512,
)
parser.add_argument(
"--n-beams",
help="Number of beams to use when generating tokens.",
type=int,
default=1,
)


def add_quantization_options(parser: argparse.ArgumentParser):
Expand Down

0 comments on commit c28bf2a

Please sign in to comment.