Skip to content

Commit

Permalink
Changed to use the ttnn operation instead of invoke
Browse files Browse the repository at this point in the history
  • Loading branch information
dgomezTT committed Mar 5, 2025
1 parent 32bb5ce commit 7852b03
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,41 @@

namespace ttnn::graph::arguments::test {

class TestGraphCaptureArgumentsMorehDot : public TTNNFixtureWithTensor {
protected:
tt::tt_metal::IGraphProcessor::RunMode Mode = tt::tt_metal::IGraphProcessor::RunMode::NORMAL;
};
class TestGraphCaptureArgumentsMorehDot : public TTNNFixtureWithTensor {};

TEST_P(TestGraphCaptureArgumentsMorehDot, MorehDot) {
auto tt_input1 = CreateTensor();
auto tt_input2 = CreateTensor();
ttnn::graph::GraphProcessor::begin_graph_capture(Mode);
ttnn::operations::moreh::moreh_dot::MorehDot::invoke(
tt_input1, tt_input2, std::nullopt, DataType::BFLOAT16, std::nullopt, std::nullopt);
ttnn::graph::GraphProcessor::begin_graph_capture(tt::tt_metal::IGraphProcessor::RunMode::NORMAL);
ttnn::moreh_dot(tt_input1, tt_input2, std::nullopt, DataType::BFLOAT16, std::nullopt, std::nullopt);
auto trace = ttnn::graph::GraphProcessor::end_graph_capture();
auto operations = ttnn::graph::extract_arguments(trace);

auto operation1 = operations[0];
auto operation0 = operations[0];
EXPECT_EQ(operation0.operation_name, "ttnn::moreh_dot");
EXPECT_EQ(operation0.arguments.size(), 6);
EXPECT_EQ(
operation0.arguments[0],
"Tensor(storage=DeviceStorage(memory_config=MemoryConfig(memory_layout=TensorMemoryLayout::INTERLEAVED,buffer_"
"type=BufferType::L1,shard_spec=std::nullopt)),tensor_spec=TensorSpec(logical_shape=Shape([1, 1, 1, "
"32]),tensor_layout=TensorLayout(dtype=BFLOAT16,page_config=PageConfig(config=TilePageConfig(tile=Tile(tile_"
"shape={32, 32},face_shape={16, "
"16},num_faces=4))),memory_config=MemoryConfig(memory_layout=TensorMemoryLayout::INTERLEAVED,buffer_type="
"BufferType::L1,shard_spec=std::nullopt),alignment=Alignment([32, 32]))))");
EXPECT_EQ(
operation0.arguments[1],
"Tensor(storage=DeviceStorage(memory_config=MemoryConfig(memory_layout=TensorMemoryLayout::INTERLEAVED,buffer_"
"type=BufferType::L1,shard_spec=std::nullopt)),tensor_spec=TensorSpec(logical_shape=Shape([1, 1, 1, "
"32]),tensor_layout=TensorLayout(dtype=BFLOAT16,page_config=PageConfig(config=TilePageConfig(tile=Tile(tile_"
"shape={32, 32},face_shape={16, "
"16},num_faces=4))),memory_config=MemoryConfig(memory_layout=TensorMemoryLayout::INTERLEAVED,buffer_type="
"BufferType::L1,shard_spec=std::nullopt),alignment=Alignment([32, 32]))))");
EXPECT_EQ(operation0.arguments[2], "[ unsupported type , std::__1::reference_wrapper<std::__1::nullopt_t const>]");
EXPECT_EQ(operation0.arguments[3], "BFLOAT16");
EXPECT_EQ(operation0.arguments[4], "[ unsupported type , std::__1::reference_wrapper<std::__1::nullopt_t const>]");
EXPECT_EQ(operation0.arguments[5], "[ unsupported type , std::__1::reference_wrapper<std::__1::nullopt_t const>]");

auto operation1 = operations[1];
EXPECT_EQ(operation1.operation_name, "ttnn::prim::moreh_dot");
EXPECT_EQ(operation1.arguments.size(), 6);
EXPECT_EQ(
Expand Down Expand Up @@ -56,7 +76,7 @@ TEST_P(TestGraphCaptureArgumentsMorehDot, MorehDot) {
"std::__1::reference_wrapper<std::__1::optional<std::__1::variant<ttnn::GrayskullComputeKernelConfig, "
"ttnn::WormholeComputeKernelConfig>> const>]");

auto operation2 = operations[1];
auto operation2 = operations[2];
EXPECT_EQ(operation2.operation_name, "MorehDotOperation");
EXPECT_EQ(operation2.arguments.size(), 2);
EXPECT_EQ(
Expand All @@ -69,7 +89,7 @@ TEST_P(TestGraphCaptureArgumentsMorehDot, MorehDot) {
"[ unsupported type , "
"std::__1::reference_wrapper<ttnn::operations::moreh::moreh_dot::MorehDotOperation::tensor_args_t const>]");

auto operation3 = operations[2];
auto operation3 = operations[3];
EXPECT_EQ(operation3.operation_name, "tt::tt_metal::create_device_tensor");
EXPECT_EQ(operation3.arguments.size(), 5);
EXPECT_EQ(operation3.arguments[0], "Shape([1, 1, 1, 1])");
Expand All @@ -86,7 +106,7 @@ INSTANTIATE_TEST_SUITE_P(
TestGraphCaptureArgumentsMorehDot_MorehDot,
TestGraphCaptureArgumentsMorehDot,
::testing::Values(CreateTensorParameters{
.input_shape = ttnn::Shape(tt::tt_metal::Array4D{1, 1, 1, 32}),
.input_shape = ttnn::Shape({1, 1, 1, 32}),
.dtype = DataType::BFLOAT16,
.layout = TILE_LAYOUT,
.mem_cfg = L1_MEMORY_CONFIG}));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,31 @@

namespace ttnn::graph::arguments::test {

class TestGraphCaptureArgumentsTranspose : public TTNNFixtureWithTensor {
protected:
tt::tt_metal::IGraphProcessor::RunMode Mode = tt::tt_metal::IGraphProcessor::RunMode::NORMAL;
};
class TestGraphCaptureArgumentsTranspose : public TTNNFixtureWithTensor {};

TEST_P(TestGraphCaptureArgumentsTranspose, Transpose) {
auto tt_input = CreateTensor();
tt_input.reshape(ttnn::Shape{1, 2048, 4, 128});
ttnn::graph::GraphProcessor::begin_graph_capture(Mode);
ttnn::operations::data_movement::ExecuteTranspose::invoke(tt_input, 1, 2);
ttnn::graph::GraphProcessor::begin_graph_capture(tt::tt_metal::IGraphProcessor::RunMode::NORMAL);
ttnn::transpose(tt_input, 1, 2);
auto trace = ttnn::graph::GraphProcessor::end_graph_capture();
auto operations = ttnn::graph::extract_arguments(trace);

auto operation1 = operations[0];
auto operation0 = operations[0];
EXPECT_EQ(operation0.operation_name, "ttnn::transpose");
EXPECT_EQ(operation0.arguments.size(), 3);
EXPECT_EQ(
operation0.arguments[0],
"Tensor(storage=DeviceStorage(memory_config=MemoryConfig(memory_layout=TensorMemoryLayout::INTERLEAVED,buffer_"
"type=BufferType::L1,shard_spec=std::nullopt)),tensor_spec=TensorSpec(logical_shape=Shape([1, 1, 2048, "
"512]),tensor_layout=TensorLayout(dtype=BFLOAT16,page_config=PageConfig(config=RowMajorPageConfig(tile=Tile("
"tile_shape={32, 32},face_shape={16, "
"16},num_faces=4))),memory_config=MemoryConfig(memory_layout=TensorMemoryLayout::INTERLEAVED,buffer_type="
"BufferType::L1,shard_spec=std::nullopt),alignment=Alignment([1]))))");
EXPECT_EQ(operation0.arguments[1], "1");
EXPECT_EQ(operation0.arguments[2], "2");

auto operation1 = operations[1];
EXPECT_EQ(operation1.operation_name, "ttnn::prim::permute");
EXPECT_EQ(operation1.arguments.size(), 5);
EXPECT_EQ(
Expand All @@ -46,7 +57,7 @@ TEST_P(TestGraphCaptureArgumentsTranspose, Transpose) {
EXPECT_EQ(operation1.arguments[3], "[ unsupported type , std::__1::reference_wrapper<std::__1::nullopt_t const>]");
EXPECT_EQ(operation1.arguments[4], "0");

auto operation2 = operations[1];
auto operation2 = operations[2];
EXPECT_EQ(operation2.operation_name, "PermuteDeviceOperation");
EXPECT_EQ(operation2.arguments.size(), 2);
EXPECT_EQ(
Expand All @@ -59,7 +70,7 @@ TEST_P(TestGraphCaptureArgumentsTranspose, Transpose) {
"[ unsupported type , "
"std::__1::reference_wrapper<ttnn::operations::data_movement::PermuteDeviceOperation::tensor_args_t const>]");

auto operation3 = operations[2];
auto operation3 = operations[3];
EXPECT_EQ(operation3.operation_name, "tt::tt_metal::create_device_tensor");
EXPECT_EQ(operation3.arguments.size(), 5);
EXPECT_EQ(operation3.arguments[0], "Shape([1, 2048, 1, 512])");
Expand All @@ -76,7 +87,7 @@ INSTANTIATE_TEST_SUITE_P(
TestGraphCaptureArgumentsTranspose_Transpose,
TestGraphCaptureArgumentsTranspose,
::testing::Values(CreateTensorParameters{
.input_shape = ttnn::Shape(tt::tt_metal::Array4D{1, 1, 2048, 512}),
.input_shape = ttnn::Shape({1, 1, 2048, 512}),
.dtype = DataType::BFLOAT16,
.layout = ROW_MAJOR_LAYOUT,
.mem_cfg = L1_MEMORY_CONFIG}));
Expand Down

0 comments on commit 7852b03

Please sign in to comment.