Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add beam_search ability to sharktank LLM model #991

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions sharktank/sharktank/examples/export_paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def main():
block_seq_stride=args.block_seq_stride,
activation_dtype=args.activation_dtype,
attention_dtype=args.attention_dtype,
n_beams=args.n_beams,
)
llama_config.fake_quant = args.fake_quant

Expand Down Expand Up @@ -112,6 +113,7 @@ def generate_params_json(
"prefill_batch_sizes": prefill_bs,
"decode_batch_sizes": decode_bs,
"transformer_block_count": hp.block_count,
"n_beams": llama_config.n_beams,
"paged_kv_cache": {
"attention_head_count_kv": hp.attention_head_count_kv,
"block_seq_stride": llama_config.block_seq_stride,
Expand Down Expand Up @@ -251,14 +253,15 @@ def generate_batch_decode(bs: int):
block_dim_min = 2
block_dim_max = ceildiv(hp.context_length, llama_config.block_seq_stride) - 1
block_dim = torch.export.Dim("block", min=block_dim_min, max=block_dim_max)
batch_dim_size = bs * llama_config.n_beams
tokens = torch.empty(
bs,
batch_dim_size,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Decoding behavior should not need to change the model. The entire beam search behavior should be externally controlled by shortfin as the invocation gives a ranking per token.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My thinking with this change was that it would make beam search more efficient on the shortfin side, by allowing all beams to be processed in parallel.

So, if I had a batch of 4 prompts, with n_beams == 3, I could run decode on all 12 beams at once.

If that's not expected behavior, I can definitely drop these sharktank changes all together, which will probably simplify the shortfin code too

1,
dtype=torch.int64,
)
seq_lens = torch.empty(bs, dtype=torch.int64)
start_positions = torch.ones(bs, dtype=torch.int64)
seq_block_ids = torch.empty(bs, block_dim_min, dtype=torch.int64)
seq_lens = torch.empty(batch_dim_size, dtype=torch.int64)
start_positions = torch.ones(batch_dim_size, dtype=torch.int64)
seq_block_ids = torch.empty(batch_dim_size, block_dim_min, dtype=torch.int64)

(
cache_state,
Expand Down
5 changes: 5 additions & 0 deletions sharktank/sharktank/layers/configs/llm_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,11 @@ class LlamaModelConfig:
# the program and not.
static_tables: bool = True

# The number of beams to use when generating tokens for a given prompt.
# When n_beams == 1, `greedy` selection is used,
# when n_beams > 1, `beam search` is used.
n_beams: int = 1


@dataclass
class T5Config:
Expand Down
6 changes: 6 additions & 0 deletions sharktank/sharktank/utils/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,12 @@ def add_model_options(parser: argparse.ArgumentParser):
type=int,
default=512,
)
parser.add_argument(
"--n-beams",
help="Number of beams to use when generating tokens.",
type=int,
default=1,
)


def add_quantization_options(parser: argparse.ArgumentParser):
Expand Down
3 changes: 3 additions & 0 deletions shortfin/python/shortfin_apps/llm/components/config_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ class ModelParams:
# Cache parameters.
paged_kv_cache: PagedKVCacheParams | None = None

# Number of beams to use during token generation.
n_beams: int = 1

# Size in bytes of the KV cache dtype.
@property
def attn_dtype_size(self) -> int:
Expand Down
Loading