Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add beam_search ability to sharktank LLM model #991

Closed
wants to merge 4 commits into from

Conversation

stbaione
Copy link
Contributor

@stbaione stbaione commented Feb 21, 2025

Description

Pretty simple changes to enable beam_search in sharktank. Still working on the server implementation, but figured we can unblock any potential beamsearch work on the sharktank side.

The idea is pretty simple:

  1. Run normal prefill. We return the raw logits from the sharktank LLM model, and it seemed like the best move to keep it that way.
  2. In shortfin (subsequent PR), greedily select token from prefill, and expand token inputs for decode from [bs, 1] to [bs * k, 1]. Also replicate the cache for each batch, k times. I saw online that you can also take an approach of adding a beam dimension to the inputs. (i.e. tokens: [bs, k, 1]). However, my research indicated that you can process the prompts faster in parallel by expanding along the bs dimension, and avoid potential reshapes that may be needed otherwise. This also seemed to be the most used implementation, and keeps changes minimal. From decode's perspective, everything else is the same, with just one of the dimensions scaled.
So, if k == 2, bs == 1 and our prompt fits within one cache page (seq_l, start_positions, and seq_block_ids also scale the same way as tokens):

Prefill inputs: 

tokens: [bs, seq_l] => [1, seq_l]
cache_pages: [511]

Decode Inputs:

tokens: [bs * k, 1] => [2, 1]
cache_pages: [511, 510], where 510 is a replica of 511.
  1. Run decode. The work for applying log_softmax, selecting k tokens, tracking cumulative log probs, tracking beams, etc. will be contained on the shortfin side.

@stbaione stbaione requested a review from rsuderman February 21, 2025 18:02
Copy link
Contributor

@rsuderman rsuderman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should be logic in exploring multiple hypothesis? Beam search is a mechanism for how sequential proposed logits are combined and won't require any modelling changes.

tokens = torch.empty(
bs,
batch_dim_size,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Decoding behavior should not need to change the model. The entire beam search behavior should be externally controlled by shortfin as the invocation gives a ranking per token.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My thinking with this change was that it would make beam search more efficient on the shortfin side, by allowing all beams to be processed in parallel.

So, if I had a batch of 4 prompts, with n_beams == 3, I could run decode on all 12 beams at once.

If that's not expected behavior, I can definitely drop these sharktank changes all together, which will probably simplify the shortfin code too

@stbaione
Copy link
Contributor Author

There should be logic in exploring multiple hypothesis? Beam search is a mechanism for how sequential proposed logits are combined and won't require any modelling changes.

I was going to put up a follow up PR with the logic for that on the shortfin side. I'm still working on that code. I thought this was a necessary change on the sharktank side, which is why I pushed it separately.

@stbaione stbaione closed this Feb 21, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants