Skip to content

Commit

Permalink
model generation supports input embeds
Browse files Browse the repository at this point in the history
  • Loading branch information
zongwave committed Aug 20, 2024
1 parent 8abf818 commit 7db501c
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 12 deletions.
44 changes: 42 additions & 2 deletions examples/text-generation/run_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,11 @@ def setup_parser(parser):
default="none",
help="Run multi card with the specified parallel strategy. Choices are 'tp' for Tensor Parallel Strategy or 'none'.",
)
parser.add_argument(
"--input_embeds",
action="store_true",
help="Whether to enable inputs_embeds or not.",
)

args = parser.parse_args()

Expand All @@ -331,6 +336,28 @@ def setup_parser(parser):
return args


def prepare_generation_embedding(model, model_name, input_tokens):
batch_size = input_tokens['input_ids'].size(0)

# Get text embeddings from the model
if model_name in ["meta-llama/Llama-2-7b-hf", "mistralai/Mistral-7B-Instruct-v0.2", "microsoft/phi-2"]:
inputs_embeds = model.model.embed_tokens(input_tokens['input_ids'])
elif model_name in ["gpt2", "mosaicml/mpt-7b"]:
inputs_embeds = model.transformer.wte(input_tokens['input_ids'])
elif model_name == "tiiuae/falcon-7b":
inputs_embeds = model.transformer.word_embeddings(input_tokens['input_ids'])
else:
logger.warning(f"This test does not support input embeds for model: {model_name}")
return None

# If you need to expand the embeddings for the batch size
if inputs_embeds.size(0) != batch_size:
inputs_embeds = inputs_embeds.expand(batch_size, -1, -1)

attention_mask = torch.ones_like(input_tokens['input_ids'])
return {'inputs_embeds': inputs_embeds, 'attention_mask': attention_mask}


def main():
parser = argparse.ArgumentParser()
args = setup_parser(parser)
Expand Down Expand Up @@ -428,9 +455,21 @@ def generate(size=None, reduce_recompile=False):
for t in input_tokens:
if torch.is_tensor(input_tokens[t]):
input_tokens[t] = input_tokens[t].to(args.device)

input_data = {}
if args.input_embeds:
inputs_embeds = prepare_generation_embedding(model, args.model_name_or_path, input_tokens)
if inputs_embeds is not None:
input_data.update(inputs_embeds)
else:
args.input_embeds = False
input_data.update(input_tokens)
else:
input_data.update(input_tokens)

iteration_times = []
outputs = model.generate(
**input_tokens,
**input_data,
generation_config=generation_config,
assistant_model=assistant_model,
lazy_mode=use_lazy_mode,
Expand Down Expand Up @@ -519,7 +558,8 @@ def rounder(x):
with (output_dir / "results.json").open("w", encoding="utf-8") as f:
json.dump(results, f, ensure_ascii=False, indent=4)

stats = f"Throughput (including tokenization) = {throughput} tokens/second"
stats = f"Input embeds" if args.input_embeds else "Input tokens"
stats = stats + f"\nThroughput (including tokenization) = {throughput} tokens/second"
stats = stats + f"\nNumber of HPU graphs = {count_hpu_graphs()}"
separator = "-" * len(stats)
print()
Expand Down
33 changes: 23 additions & 10 deletions optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,7 @@ def _prepare_generated_length(
"Please refer to the documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
)
if has_token_idx:
if has_token_idx and input_ids_length > 0:
generation_config.max_length = input_ids_length
else:
generation_config.max_length = generation_config.max_new_tokens + input_ids_length
Expand Down Expand Up @@ -661,7 +661,7 @@ def _prepare_generation_config(
self.generation_config.static_shapes = generation_config.static_shapes
if generation_config.ignore_eos is None:
generation_config.ignore_eos = kwargs.get("ignore_eos", kwargs.get("lazy_mode", None))
self.generation_config.ignore_eos = generation_config.ignore_eos
self.generation_config.ignore_eos = generation_config.ignore_eos
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
if self.config.model_type == "falcon" and "token_type_ids" in kwargs.keys():
for key in ["token_type_ids"]:
Expand Down Expand Up @@ -914,14 +914,23 @@ def generate(
# only pad if bucket_size < -1. If we are bucketing (bucket_size > 0), then that is taken care in greedy_search()
if not is_greedy_or_beam_and_bucket:
# token_idx is the current index in the generation process, it is incremented each time a new token is generated
token_idx = inputs_tensor.shape[-1]
token_idx = inputs_tensor.shape[1]
model_kwargs["token_idx"] = torch.tensor(token_idx, device=inputs_tensor.device)
model_kwargs["token_idx_cpu"] = token_idx
if generation_config.max_new_tokens is None:
generation_config.max_new_tokens = generation_config.max_length - token_idx
inputs_tensor = torch.nn.functional.pad(
inputs_tensor, (0, generation_config.max_new_tokens), value=generation_config.pad_token_id
)
if model_input_name == "inputs_embeds" and model_kwargs.get("inputs_embeds") is not None:
generation_config.reuse_cache = False
inputs_tensor = torch.nn.functional.pad(
inputs_tensor, (0, 0, 0, generation_config.max_new_tokens), value=generation_config.pad_token_id
)
model_kwargs["inputs_embeds"] = torch.nn.functional.pad(
model_kwargs["inputs_embeds"], (0, 0, 0, generation_config.max_new_tokens), value=0
)
else:
inputs_tensor = torch.nn.functional.pad(
inputs_tensor, (0, generation_config.max_new_tokens), value=generation_config.pad_token_id
)
for other_inputs in ["attention_mask", "token_type_ids"]:
if model_kwargs.get(other_inputs) is not None:
model_kwargs[other_inputs] = torch.nn.functional.pad(
Expand Down Expand Up @@ -963,6 +972,10 @@ def generate(
)
else:
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
if model_input_name == "inputs_embeds" and generation_config.static_shapes:
input_ids = torch.nn.functional.pad(
input_ids, (0, generation_config.max_new_tokens), value=generation_config.pad_token_id
)

if generation_config.token_healing:
input_ids = self.heal_tokens(input_ids, tokenizer)
Expand All @@ -971,7 +984,7 @@ def generate(
streamer.put(input_ids.cpu())

# 6. Prepare `max_length` depending on other stopping criteria.
input_ids_length = input_ids.shape[-1]
input_ids_length = input_ids.shape[1]
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
generation_config = self._prepare_generated_length(
Expand Down Expand Up @@ -1084,9 +1097,9 @@ def generate(
model_kwargs["num_virtual_tokens"] = num_virtual_tokens

if not self.config.is_encoder_decoder:
calculated_max_length = input_ids.shape[-1] + num_virtual_tokens
calculated_max_length = input_ids.shape[1] + num_virtual_tokens
if not generation_config.static_shapes and generation_config.max_new_tokens is not None:
calculated_max_length = input_ids.shape[-1] + generation_config.max_new_tokens + num_virtual_tokens
calculated_max_length = input_ids.shape[1] + generation_config.max_new_tokens + num_virtual_tokens
if generation_config.use_cache and generation_config.reuse_cache:
bs, _ = input_ids.shape
if not is_greedy_or_beam_and_bucket:
Expand Down Expand Up @@ -2016,7 +2029,7 @@ def _contrastive_search(
unfinished_sequences = unfinished_sequences & ~stopping_criteria(
input_ids,
scores,
token_idx=cur_len,
token_idx=None if "inputs_embeds" in model_kwargs else cur_len,
ignore_eos=ignore_eos,
eos_token_id=generation_config.eos_token_id,
)
Expand Down

0 comments on commit 7db501c

Please sign in to comment.