From 7852b039b800b39749106f3e870d01f8efbd87f9 Mon Sep 17 00:00:00 2001 From: Diego Gomez Date: Wed, 5 Mar 2025 15:29:38 +0000 Subject: [PATCH] Changed to use the ttnn operation instead of invoke --- .../test_graph_capture_arguments_morehdot.cpp | 42 ++++++++++++++----- ...test_graph_capture_arguments_transpose.cpp | 31 +++++++++----- 2 files changed, 52 insertions(+), 21 deletions(-) diff --git a/tests/ttnn/unit_tests/gtests/test_graph_capture_arguments_morehdot.cpp b/tests/ttnn/unit_tests/gtests/test_graph_capture_arguments_morehdot.cpp index 8a425b88ce0..b862df915a2 100644 --- a/tests/ttnn/unit_tests/gtests/test_graph_capture_arguments_morehdot.cpp +++ b/tests/ttnn/unit_tests/gtests/test_graph_capture_arguments_morehdot.cpp @@ -14,21 +14,41 @@ namespace ttnn::graph::arguments::test { -class TestGraphCaptureArgumentsMorehDot : public TTNNFixtureWithTensor { -protected: - tt::tt_metal::IGraphProcessor::RunMode Mode = tt::tt_metal::IGraphProcessor::RunMode::NORMAL; -}; +class TestGraphCaptureArgumentsMorehDot : public TTNNFixtureWithTensor {}; TEST_P(TestGraphCaptureArgumentsMorehDot, MorehDot) { auto tt_input1 = CreateTensor(); auto tt_input2 = CreateTensor(); - ttnn::graph::GraphProcessor::begin_graph_capture(Mode); - ttnn::operations::moreh::moreh_dot::MorehDot::invoke( - tt_input1, tt_input2, std::nullopt, DataType::BFLOAT16, std::nullopt, std::nullopt); + ttnn::graph::GraphProcessor::begin_graph_capture(tt::tt_metal::IGraphProcessor::RunMode::NORMAL); + ttnn::moreh_dot(tt_input1, tt_input2, std::nullopt, DataType::BFLOAT16, std::nullopt, std::nullopt); auto trace = ttnn::graph::GraphProcessor::end_graph_capture(); auto operations = ttnn::graph::extract_arguments(trace); - auto operation1 = operations[0]; + auto operation0 = operations[0]; + EXPECT_EQ(operation0.operation_name, "ttnn::moreh_dot"); + EXPECT_EQ(operation0.arguments.size(), 6); + EXPECT_EQ( + operation0.arguments[0], + "Tensor(storage=DeviceStorage(memory_config=MemoryConfig(memory_layout=TensorMemoryLayout::INTERLEAVED,buffer_" + "type=BufferType::L1,shard_spec=std::nullopt)),tensor_spec=TensorSpec(logical_shape=Shape([1, 1, 1, " + "32]),tensor_layout=TensorLayout(dtype=BFLOAT16,page_config=PageConfig(config=TilePageConfig(tile=Tile(tile_" + "shape={32, 32},face_shape={16, " + "16},num_faces=4))),memory_config=MemoryConfig(memory_layout=TensorMemoryLayout::INTERLEAVED,buffer_type=" + "BufferType::L1,shard_spec=std::nullopt),alignment=Alignment([32, 32]))))"); + EXPECT_EQ( + operation0.arguments[1], + "Tensor(storage=DeviceStorage(memory_config=MemoryConfig(memory_layout=TensorMemoryLayout::INTERLEAVED,buffer_" + "type=BufferType::L1,shard_spec=std::nullopt)),tensor_spec=TensorSpec(logical_shape=Shape([1, 1, 1, " + "32]),tensor_layout=TensorLayout(dtype=BFLOAT16,page_config=PageConfig(config=TilePageConfig(tile=Tile(tile_" + "shape={32, 32},face_shape={16, " + "16},num_faces=4))),memory_config=MemoryConfig(memory_layout=TensorMemoryLayout::INTERLEAVED,buffer_type=" + "BufferType::L1,shard_spec=std::nullopt),alignment=Alignment([32, 32]))))"); + EXPECT_EQ(operation0.arguments[2], "[ unsupported type , std::__1::reference_wrapper]"); + EXPECT_EQ(operation0.arguments[3], "BFLOAT16"); + EXPECT_EQ(operation0.arguments[4], "[ unsupported type , std::__1::reference_wrapper]"); + EXPECT_EQ(operation0.arguments[5], "[ unsupported type , std::__1::reference_wrapper]"); + + auto operation1 = operations[1]; EXPECT_EQ(operation1.operation_name, "ttnn::prim::moreh_dot"); EXPECT_EQ(operation1.arguments.size(), 6); EXPECT_EQ( @@ -56,7 +76,7 @@ TEST_P(TestGraphCaptureArgumentsMorehDot, MorehDot) { "std::__1::reference_wrapper> const>]"); - auto operation2 = operations[1]; + auto operation2 = operations[2]; EXPECT_EQ(operation2.operation_name, "MorehDotOperation"); EXPECT_EQ(operation2.arguments.size(), 2); EXPECT_EQ( @@ -69,7 +89,7 @@ TEST_P(TestGraphCaptureArgumentsMorehDot, MorehDot) { "[ unsupported type , " "std::__1::reference_wrapper]"); - auto operation3 = operations[2]; + auto operation3 = operations[3]; EXPECT_EQ(operation3.operation_name, "tt::tt_metal::create_device_tensor"); EXPECT_EQ(operation3.arguments.size(), 5); EXPECT_EQ(operation3.arguments[0], "Shape([1, 1, 1, 1])"); @@ -86,7 +106,7 @@ INSTANTIATE_TEST_SUITE_P( TestGraphCaptureArgumentsMorehDot_MorehDot, TestGraphCaptureArgumentsMorehDot, ::testing::Values(CreateTensorParameters{ - .input_shape = ttnn::Shape(tt::tt_metal::Array4D{1, 1, 1, 32}), + .input_shape = ttnn::Shape({1, 1, 1, 32}), .dtype = DataType::BFLOAT16, .layout = TILE_LAYOUT, .mem_cfg = L1_MEMORY_CONFIG})); diff --git a/tests/ttnn/unit_tests/gtests/test_graph_capture_arguments_transpose.cpp b/tests/ttnn/unit_tests/gtests/test_graph_capture_arguments_transpose.cpp index e94f4845d4f..ed1ab12ab98 100644 --- a/tests/ttnn/unit_tests/gtests/test_graph_capture_arguments_transpose.cpp +++ b/tests/ttnn/unit_tests/gtests/test_graph_capture_arguments_transpose.cpp @@ -14,20 +14,31 @@ namespace ttnn::graph::arguments::test { -class TestGraphCaptureArgumentsTranspose : public TTNNFixtureWithTensor { -protected: - tt::tt_metal::IGraphProcessor::RunMode Mode = tt::tt_metal::IGraphProcessor::RunMode::NORMAL; -}; +class TestGraphCaptureArgumentsTranspose : public TTNNFixtureWithTensor {}; TEST_P(TestGraphCaptureArgumentsTranspose, Transpose) { auto tt_input = CreateTensor(); tt_input.reshape(ttnn::Shape{1, 2048, 4, 128}); - ttnn::graph::GraphProcessor::begin_graph_capture(Mode); - ttnn::operations::data_movement::ExecuteTranspose::invoke(tt_input, 1, 2); + ttnn::graph::GraphProcessor::begin_graph_capture(tt::tt_metal::IGraphProcessor::RunMode::NORMAL); + ttnn::transpose(tt_input, 1, 2); auto trace = ttnn::graph::GraphProcessor::end_graph_capture(); auto operations = ttnn::graph::extract_arguments(trace); - auto operation1 = operations[0]; + auto operation0 = operations[0]; + EXPECT_EQ(operation0.operation_name, "ttnn::transpose"); + EXPECT_EQ(operation0.arguments.size(), 3); + EXPECT_EQ( + operation0.arguments[0], + "Tensor(storage=DeviceStorage(memory_config=MemoryConfig(memory_layout=TensorMemoryLayout::INTERLEAVED,buffer_" + "type=BufferType::L1,shard_spec=std::nullopt)),tensor_spec=TensorSpec(logical_shape=Shape([1, 1, 2048, " + "512]),tensor_layout=TensorLayout(dtype=BFLOAT16,page_config=PageConfig(config=RowMajorPageConfig(tile=Tile(" + "tile_shape={32, 32},face_shape={16, " + "16},num_faces=4))),memory_config=MemoryConfig(memory_layout=TensorMemoryLayout::INTERLEAVED,buffer_type=" + "BufferType::L1,shard_spec=std::nullopt),alignment=Alignment([1]))))"); + EXPECT_EQ(operation0.arguments[1], "1"); + EXPECT_EQ(operation0.arguments[2], "2"); + + auto operation1 = operations[1]; EXPECT_EQ(operation1.operation_name, "ttnn::prim::permute"); EXPECT_EQ(operation1.arguments.size(), 5); EXPECT_EQ( @@ -46,7 +57,7 @@ TEST_P(TestGraphCaptureArgumentsTranspose, Transpose) { EXPECT_EQ(operation1.arguments[3], "[ unsupported type , std::__1::reference_wrapper]"); EXPECT_EQ(operation1.arguments[4], "0"); - auto operation2 = operations[1]; + auto operation2 = operations[2]; EXPECT_EQ(operation2.operation_name, "PermuteDeviceOperation"); EXPECT_EQ(operation2.arguments.size(), 2); EXPECT_EQ( @@ -59,7 +70,7 @@ TEST_P(TestGraphCaptureArgumentsTranspose, Transpose) { "[ unsupported type , " "std::__1::reference_wrapper]"); - auto operation3 = operations[2]; + auto operation3 = operations[3]; EXPECT_EQ(operation3.operation_name, "tt::tt_metal::create_device_tensor"); EXPECT_EQ(operation3.arguments.size(), 5); EXPECT_EQ(operation3.arguments[0], "Shape([1, 2048, 1, 512])"); @@ -76,7 +87,7 @@ INSTANTIATE_TEST_SUITE_P( TestGraphCaptureArgumentsTranspose_Transpose, TestGraphCaptureArgumentsTranspose, ::testing::Values(CreateTensorParameters{ - .input_shape = ttnn::Shape(tt::tt_metal::Array4D{1, 1, 2048, 512}), + .input_shape = ttnn::Shape({1, 1, 2048, 512}), .dtype = DataType::BFLOAT16, .layout = ROW_MAJOR_LAYOUT, .mem_cfg = L1_MEMORY_CONFIG}));