Skip to content

Commit

Permalink
update to trace2cq of resnet using imagenet data
Browse files Browse the repository at this point in the history
  • Loading branch information
ssinghalTT committed Mar 6, 2025
1 parent 5b95f51 commit e8b33dc
Show file tree
Hide file tree
Showing 4 changed files with 256 additions and 2 deletions.
118 changes: 118 additions & 0 deletions models/demos/ttnn_resnet/tests/resnet50_performant_imagenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import pytest
import torch
import ttnn
from models.utility_functions import (
is_wormhole_b0,
profiler,
)
from models.demos.ttnn_resnet.tests.resnet50_test_infra import create_test_infra
from models.demos.ttnn_resnet.tests.demo_utils import get_data, get_data_loader, get_batch

try:
from tracy import signpost

use_signpost = True
except ModuleNotFoundError:
use_signpost = False


def buffer_address(tensor):
addr = []
for ten in ttnn.get_device_tensors(tensor):
addr.append(ten.buffer_address())
return addr


# TODO: Create ttnn apis for this
ttnn.buffer_address = buffer_address


class ResNet50Trace2CQ:
def __init__(self):
...

def initialize_resnet50_trace_2cqs_inference(
self,
device,
device_batch_size=1,
act_dtype=ttnn.bfloat16,
weight_dtype=ttnn.bfloat16,
):
self.test_infra = create_test_infra(
device,
device_batch_size,
act_dtype,
weight_dtype,
ttnn.MathFidelity.LoFi,
True,
dealloc_input=True,
final_output_mem_config=ttnn.L1_MEMORY_CONFIG,
)
self.device = device
self.tt_inputs_host, sharded_mem_config_DRAM, self.input_mem_config = self.test_infra.setup_dram_sharded_input(
device
)
self.tt_image_res = self.tt_inputs_host.to(device, sharded_mem_config_DRAM)
# self.op_event = ttnn.create_event(device)
# self.write_event = ttnn.create_event(device)
self.op_event = ttnn.record_event(device, 0)
self.write_event = ttnn.record_event(device, 1)
# Initialize the op event so we can write
# ttnn.record_event(0, self.op_event)

# First run configures convs JIT
ttnn.wait_for_event(1, self.op_event)
ttnn.copy_host_to_device_tensor(self.tt_inputs_host, self.tt_image_res, 1)
self.write_event = ttnn.record_event(device, 1)
ttnn.wait_for_event(0, self.write_event)
self.test_infra.input_tensor = ttnn.to_memory_config(self.tt_image_res, self.input_mem_config)
spec = self.test_infra.input_tensor.spec
self.op_event = ttnn.record_event(device, 0)
self.test_infra.run()
self.test_infra.validate()
self.test_infra.output_tensor.deallocate(force=True)

# Optimized run
ttnn.wait_for_event(1, self.op_event)
ttnn.copy_host_to_device_tensor(self.tt_inputs_host, self.tt_image_res, 1)
self.write_event = ttnn.record_event(device, 1)
ttnn.wait_for_event(0, self.write_event)
self.test_infra.input_tensor = ttnn.to_memory_config(self.tt_image_res, self.input_mem_config)
self.op_event = ttnn.record_event(device, 0)
self.test_infra.run()
self.test_infra.validate()

# Capture
ttnn.wait_for_event(1, self.op_event)
ttnn.copy_host_to_device_tensor(self.tt_inputs_host, self.tt_image_res, 1)
self.write_event = ttnn.record_event(device, 1)
ttnn.wait_for_event(0, self.write_event)
self.test_infra.input_tensor = ttnn.to_memory_config(self.tt_image_res, self.input_mem_config)
self.op_event = ttnn.record_event(device, 0)
self.test_infra.output_tensor.deallocate(force=True)
trace_input_addr = ttnn.buffer_address(self.test_infra.input_tensor)
self.tid = ttnn.begin_trace_capture(device, cq_id=0)
self.test_infra.run()
self.input_tensor = ttnn.allocate_tensor_on_device(spec, device)
ttnn.end_trace_capture(device, self.tid, cq_id=0)
assert trace_input_addr == ttnn.buffer_address(self.input_tensor)

def execute_resnet50_trace_2cqs_inference(self, tt_inputs_host=None):
ttnn.wait_for_event(1, self.op_event)
ttnn.copy_host_to_device_tensor(tt_inputs_host, self.tt_image_res, 1)
self.write_event = ttnn.record_event(self.device, 1)
ttnn.wait_for_event(0, self.write_event)
# TODO: Add in place support to ttnn to_memory_config
self.input_tensor = ttnn.reshard(self.tt_image_res, self.input_mem_config, self.input_tensor)
self.op_event = ttnn.record_event(self.device, 0)
ttnn.execute_trace(self.device, self.tid, cq_id=0, blocking=False)
outputs = ttnn.from_device(self.test_infra.output_tensor, blocking=True)

return outputs

def release_resnet50_trace_2cqs_inference(self):
ttnn.release_trace(self.device, self.tid)
4 changes: 2 additions & 2 deletions models/demos/wormhole/resnet50/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ This demo includes preprocessing, postprocessing and inference time for batch si

+ Our second demo is designed to run ImageNet dataset, run this with
```python
WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest --disable-warnings models/demos/wormhole/resnet50/demo/demo.py::test_demo_imagenet
WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest --disable-warnings models/demos/wormhole/resnet50/tests/test_perf_e2e_resnet50.py::test_perf_trace_2cqs_with_imagenet[True-16-act_dtype0-weight_dtype0-device_params0]
```
- The 16 refers to batch size here and 100 is the number of iterations(batches), hence the model will process 100 batches of size 16, total of 1600 images.
- The 16 refers to batch size here and 500 is the number of iterations(batches), hence the model will process 500 batches of size 16, total of 8000 images.
- Note that the first time the model is run, ImageNet images must be downloaded from huggingface and stored in `models/demos/ttnn_resnet/demo/images/`; therefore you need to login to huggingface using your token: `huggingface-cli login` or by setting the token with the command `export HF_TOKEN=<token>`
- To obtain a huggingface token visit: https://huggingface.co/docs/hub/security-tokens

Expand Down
35 changes: 35 additions & 0 deletions models/demos/wormhole/resnet50/tests/test_perf_e2e_resnet50.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@

from models.utility_functions import run_for_wormhole_b0
from models.demos.ttnn_resnet.tests.perf_e2e_resnet50 import run_perf_resnet
from models.demos.wormhole.resnet50.tests.test_resnet50_performant_imagenet import (
test_run_resnet50_trace_2cqs_inference,
)
import ttnn


@run_for_wormhole_b0()
Expand Down Expand Up @@ -120,3 +124,34 @@ def test_perf_trace_2cqs(
"resnet50_trace_2cqs",
model_location_generator,
)


@run_for_wormhole_b0()
@pytest.mark.parametrize(
"device_params", [{"l1_small_size": 24576, "trace_region_size": 1605632, "num_command_queues": 2}], indirect=True
)
@pytest.mark.parametrize(
"batch_size, act_dtype, weight_dtype",
((16, ttnn.bfloat8_b, ttnn.bfloat8_b),),
)
@pytest.mark.parametrize("enable_async_mode", (False, True), indirect=True)
def test_perf_trace_2cqs_with_imagenet(
device,
use_program_cache,
batch_size,
imagenet_label_dict,
act_dtype,
weight_dtype,
enable_async_mode,
model_location_generator,
):
test_run_resnet50_trace_2cqs_inference(
device,
use_program_cache,
batch_size,
imagenet_label_dict,
act_dtype,
weight_dtype,
enable_async_mode,
model_location_generator,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import pytest
import ttnn
import time
import torch
from loguru import logger

from models.utility_functions import run_for_wormhole_b0
from models.demos.ttnn_resnet.tests.resnet50_performant_imagenet import ResNet50Trace2CQ
from models.demos.ttnn_resnet.tests.demo_utils import get_data, get_data_loader, get_batch
from transformers import AutoImageProcessor
from models.utility_functions import (
profiler,
)
from models.perf.perf_utils import prep_perf_report
from pdb import set_trace as bp


@run_for_wormhole_b0()
@pytest.mark.parametrize(
"device_params", [{"l1_small_size": 24576, "trace_region_size": 1605632, "num_command_queues": 2}], indirect=True
)
@pytest.mark.parametrize(
"batch_size, act_dtype, weight_dtype",
((16, ttnn.bfloat8_b, ttnn.bfloat8_b),),
)
@pytest.mark.parametrize("enable_async_mode", (False, True), indirect=True)
def test_run_resnet50_trace_2cqs_inference(
device,
use_program_cache,
batch_size,
imagenet_label_dict,
act_dtype,
weight_dtype,
enable_async_mode,
model_location_generator,
):
profiler.clear()
with torch.no_grad():
resnet50_trace_2cq = ResNet50Trace2CQ()

profiler.start(f"compile")
resnet50_trace_2cq.initialize_resnet50_trace_2cqs_inference(
device,
batch_size,
act_dtype,
weight_dtype,
)
profiler.end(f"compile")
model_version = "microsoft/resnet-50"
iterations = 500
image_processor = AutoImageProcessor.from_pretrained(model_version)
input_loc = str(model_location_generator("ImageNet_data"))
data_loader = get_data_loader(input_loc, batch_size, iterations)
correct = 0

profiler.start(f"run")
# inputs, labels = get_batch(data_loader, image_processor)
# tt_inputs_host = ttnn.from_torch(inputs, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT)
for iter in range(iterations):
predictions = []
### TODO optimize input streamer for better e2e performance
inputs, labels = get_batch(data_loader, image_processor)
tt_inputs_host = ttnn.from_torch(inputs, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT)
output = resnet50_trace_2cq.execute_resnet50_trace_2cqs_inference(tt_inputs_host).to_torch().to(torch.float)
prediction = output[:, 0, 0, :].argmax(dim=-1)
for i in range(batch_size):
predictions.append(imagenet_label_dict[prediction[i].item()])
# logger.info(
# f"Iter: {iter} Sample: {i} - Expected Label: {imagenet_label_dict[labels[i]]} -- Predicted Label: {predictions[-1]}"
# )
if imagenet_label_dict[labels[i]] == predictions[-1]:
correct += 1
profiler.end(f"run")
resnet50_trace_2cq.release_resnet50_trace_2cqs_inference()
accuracy = correct / (batch_size * iterations)
logger.info(f"=============")
logger.info(f"Accuracy for {batch_size}x{iterations} inputs: {accuracy}")

first_iter_time = profiler.get(f"compile")
# ensuring inference time fluctuations is not noise
inference_time_avg = profiler.get("run") / (iterations)

compile_time = first_iter_time - 2 * inference_time_avg
prep_perf_report(
model_name=f"ttnn_{model_version}_batch_size{batch_size}",
batch_size=batch_size,
inference_and_compile_time=first_iter_time,
inference_time=inference_time_avg,
expected_compile_time=30,
expected_inference_time=0.004,
comments="tests",
)

logger.info(
f"ttnn_{model_version}_batch_size{batch_size} tests inference time (avg): {inference_time_avg}, FPS: {batch_size/inference_time_avg}"
)
logger.info(f"ttnn_{model_version}_batch_size{batch_size} compile time: {compile_time}")

0 comments on commit e8b33dc

Please sign in to comment.