Skip to content

Commit

Permalink
test_llama_chunked_generation.py bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
ipotkonjak-tt committed Mar 6, 2025
1 parent 6b9b81b commit e3df4df
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
4 changes: 2 additions & 2 deletions models/demos/llama3/tests/test_llama_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def get_accuracy_thresholds(base_model_name: str, device_name: str, optimization
"optimizations",
[
pytest.param(LlamaOptimizations.accuracy, id="accuracy"),
# pytest.param(LlamaOptimizations.performance, id="performance"),
pytest.param(LlamaOptimizations.performance, id="performance"),
],
)
@pytest.mark.parametrize(
Expand All @@ -102,7 +102,7 @@ def get_accuracy_thresholds(base_model_name: str, device_name: str, optimization
"use_reference_file",
[
pytest.param(True, id="reference_file"),
# pytest.param(False, id="reference_text"),
pytest.param(False, id="reference_text"),
],
)
def test_tt_model_acc(
Expand Down
6 changes: 4 additions & 2 deletions models/demos/llama3/tests/test_llama_chunked_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def test_chunked_prefill_single_user(
weight_cache_path=model_args.weight_cache_path(dtype),
paged_attention_config=paged_attention_config,
)
generator = LlamaGenerator(tt_model, model_args, mesh_device)
generator = LlamaGenerator([tt_model], [model_args], mesh_device)

logger.info("Model and caches loaded.")

Expand All @@ -150,13 +150,15 @@ def test_chunked_prefill_single_user(
logger.info("Running TT model")
for last_token_idx in range(prefill_chunk_size - 10, seq_len, prefill_chunk_size):
logger.info(f"Running TT model for last_token_idx: {last_token_idx}")
tt_output_torch = generator.prefill_forward_single_user_text(
tt_output_device = generator.prefill_forward_single_user_text(
tt_prefill_input,
page_table=static_page_table,
user_id=0,
last_token_idx=last_token_idx,
kv_cache=tt_kv_cache,
)

tt_output_torch = tt_model.process_output_prefill(tt_output_device, last_token_idx=(last_token_idx % 32))
tt_output_torch = tt_output_torch.reshape(batch_size, 1, -1)

ref_output_slice = ref_output[:, last_token_idx : last_token_idx + 1, :]
Expand Down

0 comments on commit e3df4df

Please sign in to comment.