Skip to content

Commit

Permalink
#13235: topk op fix (#18211)
Browse files Browse the repository at this point in the history
### Ticket
#13235

### Problem description
few flags & input combinations (k, largest, sorted) were not supported

### What's changed
Original code improperly formatted k in the main function and in the
kernels. The code also contained a lot of repetitive code doing the same
thing. In the solution I fixed aforementioned bugs and refactored the
code into reusable functions.

### Checklist
- [x] [All post
commit](https://github.com/tenstorrent/tt-metal/actions/workflows/all-post-commit-workflows.yaml)
CI passes
tests relevant to ticket, pass but there are a few relevant to cache
size that do not:
https://github.com/tenstorrent/tt-metal/actions/runs/13653198214

This will be resolved in next ticket that references this bug for topk:
#18357

- [x] New/Existing tests provide coverage for changes
  • Loading branch information
aczajkowskiTT authored Mar 4, 2025
1 parent ac5c168 commit e03382a
Show file tree
Hide file tree
Showing 16 changed files with 738 additions and 415 deletions.
54 changes: 52 additions & 2 deletions tests/ttnn/unit_tests/operations/test_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
# SPDX-License-Identifier: Apache-2.0

import pytest

import torch

import ttnn

from tests.ttnn.utils_for_testing import assert_with_pcc
Expand Down Expand Up @@ -203,6 +201,58 @@ def test_sum_4d_tensor_dims(device, batch_size, c, h, w, dim, keepdim):
assert_with_pcc(torch_output_tensor, output_tensor, pcc=0.99)


# returns larger padded tensor instead of desired shape
@pytest.mark.parametrize("dim1", [1])
@pytest.mark.parametrize("dim2", [1])
@pytest.mark.parametrize("dim3", [8])
@pytest.mark.parametrize("dim4", [1])
@pytest.mark.parametrize("dim5", [128])
@pytest.mark.parametrize("dim6", [64])
# @pytest.mark.parametrize("dim", [0, 1, 2, 3, 4, 5]) transpose cannot handle N-D tensor for all dims
@pytest.mark.parametrize("dim", [4, 5])
@pytest.mark.parametrize("k", [50, 64])
@pytest.mark.parametrize("largest", [True])
@pytest.mark.parametrize("dtype", [ttnn.bfloat16, ttnn.bfloat8_b])
def test_6d_topk(device, dim1, dim2, dim3, dim4, dim5, dim6, dim, k, largest, dtype):
torch.manual_seed(2005)
shape = [dim1, dim2, dim3, dim4, dim5, dim6]
torch_dtype = torch.bfloat16

input = torch.randn(shape, dtype=torch_dtype)
pyt_topk_values, pyt_topk_indices = torch.topk(input, k, dim=dim, largest=largest, sorted=True)

ttnn_input = ttnn.from_torch(input, dtype, layout=ttnn.Layout.TILE, device=device)
ttnn_topk_values, ttnn_topk_indices = ttnn.topk(ttnn_input, k, dim=dim, largest=largest, sorted=True)

desired_shape = [dim1, dim2, dim3, dim4, dim5, dim6]
desired_shape[dim] = k

assert list(ttnn_topk_values.shape) == desired_shape
assert list(ttnn_topk_indices.shape) == desired_shape

ttnn_torch_values = ttnn.to_torch(ttnn_topk_values)
ttnn_torch_indices = ttnn.to_torch(ttnn_topk_indices).to(torch.int64)

if dtype == ttnn.bfloat8_b:
pcc_values = 0.99
else:
pcc_values = 1.0

# pcc is not a good measure for the raw indices
# if index 49 and index 8 are tied, the order of the indices can be different
# but the values associated with the indices should be the same
# if index 7 and 8 are tied, but swapped, the pcc will be better than if index 49 and 8 are tied but swapped
# rounding may also cause more ties than expected
# the bigger we get, the tighter the distribution of the top K elements, so the pcc will be worse as stability/rounding will cause more ties
# use cosine similarity on the gathered indices as this will show the top elements are all about the same
ttnn_torch_gather_from_indices = torch.gather(input, dim, ttnn_torch_indices.to(torch.int64))
cosine = torch.nn.CosineSimilarity(dim=dim)
ttnn_torch_cosine = torch.mean(cosine(pyt_topk_values, ttnn_torch_gather_from_indices))

assert ttnn_torch_cosine > 0.99, "Cosine similarity between topk values and gather from indices is less than 0.99"
assert_with_pcc(pyt_topk_values, ttnn_torch_values, pcc_values)


@pytest.mark.parametrize("c", [3])
@pytest.mark.parametrize("h", [31])
@pytest.mark.parametrize("w", [32])
Expand Down
74 changes: 45 additions & 29 deletions tests/ttnn/unit_tests/operations/test_topk.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,42 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import pytest

import torch

import ttnn

from tests.ttnn.utils_for_testing import assert_with_pcc
from models.utility_functions import skip_for_grayskull


def run_topk_test(N, C, H, W, k, largest, dtype, device):
def run_topk_test(N, C, H, W, k, dtype, dim, sorted, largest, device):
torch.manual_seed(2005)
shape = [N, C, H, W]
torch_dtype = torch.bfloat16

input = torch.randn(shape, dtype=torch_dtype)
pyt_topk_values, pyt_topk_indices = torch.topk(input, k, dim=-1, largest=largest, sorted=True)

input = torch.randn(shape, dtype=torch_dtype) * 0.9
pyt_topk_values, pyt_topk_indices = torch.topk(input, k, dim=dim, largest=largest, sorted=True)
ttnn_input = ttnn.from_torch(input, dtype, layout=ttnn.Layout.TILE, device=device)
ttnn_topk_values, ttnn_topk_indices = ttnn.topk(ttnn_input, k, dim=-1, largest=largest, sorted=True)

assert list(ttnn_topk_values.padded_shape) == [N, C, H, k]
assert list(ttnn_topk_indices.padded_shape) == [N, C, H, k]

ttnn_topk_values, ttnn_topk_indices = ttnn.topk(ttnn_input, k, dim=dim, largest=largest, sorted=sorted)
desired_shape = [N, C, H, W]
desired_shape[dim] = k
assert list(ttnn_topk_values.shape) == desired_shape
assert list(ttnn_topk_indices.shape) == desired_shape
ttnn_torch_values = ttnn.to_torch(ttnn_topk_values)
ttnn_torch_indices = ttnn.to_torch(ttnn_topk_indices).to(torch.int64)

if dtype == ttnn.bfloat8_b:
pcc_values = 0.99
else:
pcc_values = 1.0

# pcc is not a good measure for the raw indices
# if index 49 and index 8 are tied, the order of the indices can be different
# but the values associated with the indices should be the same
# if index 7 and 8 are tied, but swapped, the pcc will be better than if index 49 and 8 are tied but swapped
# rounding may also cause more ties than expected
# the bigger we get, the tighter the distribution of the top 32 elements, so the pcc will be worse as stability/rounding will cause more ties
# the bigger we get, the tighter the distribution of the top K elements, so the pcc will be worse as stability/rounding will cause more ties
# use cosine similarity on the gathered indices as this will show the top elements are all about the same
ttnn_torch_gather_from_indices = torch.gather(input, -1, ttnn_torch_indices.to(torch.int64))
cosine = torch.nn.CosineSimilarity(dim=-1)
ttnn_torch_gather_from_indices = torch.gather(input, dim, ttnn_torch_indices.to(torch.int64))
cosine = torch.nn.CosineSimilarity(dim=dim)
ttnn_torch_cosine = torch.mean(cosine(pyt_topk_values, ttnn_torch_gather_from_indices))

assert ttnn_torch_cosine > 0.99, "Cosine similarity between topk values and gather from indices is less than 0.99"
assert_with_pcc(pyt_topk_values, ttnn_torch_values, pcc_values)

Expand All @@ -64,15 +56,39 @@ def run_topk_test(N, C, H, W, k, largest, dtype, device):
],
)
@pytest.mark.parametrize(
"N, C, H, W, k,",
"N, C, H, W, dim, k",
(
(1, 1, 32, 64, 32),
(1, 1, 32, 8192, 32),
(1, 1, 2048, 64, 32),
(1, 1, 32, 32768, 32),
(1, 1, 8192, 64, 32),
(1, 1, 32, 8192, 3, 50), # passed
(1, 1, 64, 64, 2, 32), # passed
(1, 1, 64, 64, 2, 64), # passed
(1, 1, 32, 8192, 3, 50), # passed
(1, 2048, 1, 64, 1, 32), # skipped
(1, 1, 32, 64, 3, 2), # passed
(1, 1, 32, 64, 3, 4), # passed
(1, 1, 32, 8192, 3, 6), # passed
(1, 2048, 1, 64, 1, 8), # passed
(1, 1, 32, 32768, 3, 16), # passed
),
)
@pytest.mark.parametrize("largest", (True, False))
def test_topk(N, C, H, W, k, largest, dtype, device):
run_topk_test(N, C, H, W, k, largest, dtype, device)
@pytest.mark.parametrize(
"sorted",
[
True,
False,
],
)
@pytest.mark.parametrize(
"largest",
[
True,
False,
],
)
def test_topk(N, C, H, W, dim, k, dtype, sorted, largest, device):
if dim == 0 or dim == 1:
# As of now, when we try to get top-k for dim = 0 or 1, we get following error from transpose_op.cpp's validate():
# input_tensor.get_dtype() == DataType::BFLOAT16 || input_tensor.get_dtype() == DataType::FLOAT32
# this is because, transpose.cpp always typecasts bf8 to bf16
# and when dim = 0 or 1, transpose converts it into TransposeOpDim::HC & this dim doesnt support bf16 or fp32
pytest.skip()
run_topk_test(N, C, H, W, k, dtype, dim, sorted, largest, device)
1 change: 1 addition & 0 deletions ttnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,7 @@ set(TTNN_OP_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/reduction/prod/device/prod_op_all.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/reduction/prod/prod.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/reduction/topk/device/topk_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/reduction/topk/topk.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/sliding_window/halo/halo.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/sliding_window/halo/device/halo_device_operation.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/sliding_window/sliding_window.cpp
Expand Down
4 changes: 2 additions & 2 deletions ttnn/cpp/ttnn/operations/reduction/reduction_pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
#include "ttnn/operations/reduction/generic/generic_reductions.hpp"
#include "ttnn/operations/reduction/generic/generic_reductions_pybind.hpp"
#include "ttnn/operations/reduction/argmax/argmax_pybind.hpp"
#include "ttnn/operations/reduction/topk/topk_pybind.hpp"
#include "ttnn/operations/reduction/moe/moe_pybind.hpp"
#include "ttnn/operations/reduction/prod/prod_pybind.hpp"
#include "ttnn/operations/reduction/sampling/sampling_pybind.hpp"
#include "ttnn/operations/reduction/topk/topk_pybind.hpp"

namespace ttnn::operations::reduction {

Expand All @@ -29,10 +29,10 @@ void py_module(py::module& module) {

// Special reductions
detail::bind_reduction_argmax_operation(module);
detail::bind_reduction_topk_operation(module);
detail::bind_reduction_moe_operation(module);
detail::bind_reduction_prod_operation(module, ttnn::prod);
detail::bind_reduction_sampling_operation(module);
detail::bind_reduction_topk_operation(module);
}

} // namespace ttnn::operations::reduction
Loading

0 comments on commit e03382a

Please sign in to comment.