Skip to content

Commit

Permalink
Update binary quantization test
Browse files Browse the repository at this point in the history
  • Loading branch information
enp1s0 committed Feb 24, 2025
1 parent 6808368 commit 98077a6
Showing 1 changed file with 26 additions and 11 deletions.
37 changes: 26 additions & 11 deletions cpp/tests/preprocessing/binary_quantization.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <raft/linalg/transpose.cuh>
#include <raft/matrix/init.cuh>
#include <raft/stats/stddev.cuh>
#include <raft/util/itertools.hpp>
#include <thrust/reduce.h>

namespace cuvs::preprocessing::quantize::binary {
Expand All @@ -28,12 +29,21 @@ template <typename T>
struct BinaryQuantizationInputs {
int rows;
int cols;
cuvs::preprocessing::quantize::binary::set_bit_threshold threshold;
};

template <typename T>
std::ostream& operator<<(std::ostream& os, const BinaryQuantizationInputs<T>& inputs)
{
return os << "> rows:" << inputs.rows << " cols:" << inputs.cols;
os << "> dataset_size:" << inputs.rows << " dataset_dim:" << inputs.cols;
os << " threshold: ";
switch (inputs.threshold) {
case set_bit_threshold::zero: os << "zero"; break;
case set_bit_threshold::mean: os << "mean"; break;
case set_bit_threshold::sampling_median: os << "sampling_median"; break;
default: os << "unknown"; break;
}
return os;
}

template <typename T, typename QuantI>
Expand All @@ -49,6 +59,7 @@ class BinaryQuantizationTest : public ::testing::TestWithParam<BinaryQuantizatio
protected:
void testBinaryQuantization()
{
if (std::is_same_v<T, half> && params_.threshold == set_bit_threshold::mean) { GTEST_SKIP(); }
// dataset identical on host / device
auto dataset = raft::make_device_matrix_view<const T, int64_t, raft::row_major>(
(const T*)(input_.data()), rows_, cols_);
Expand All @@ -59,6 +70,7 @@ class BinaryQuantizationTest : public ::testing::TestWithParam<BinaryQuantizatio
static_assert(std::is_same_v<QuantI, uint8_t>);

cuvs::preprocessing::quantize::binary::params params;
params.threshold = params_.threshold;

const auto col_quantized = raft::div_rounding_up_safe(cols_, 8);
auto quantized_input_h = raft::make_host_matrix<QuantI, int64_t>(rows_, cols_);
Expand Down Expand Up @@ -107,13 +119,16 @@ class BinaryQuantizationTest : public ::testing::TestWithParam<BinaryQuantizatio
};

template <typename T>
const std::vector<BinaryQuantizationInputs<T>> inputs = {
{5, 5},
{100, 7},
{100, 128},
{100, 1999},
{1000, 1999},
};
const std::vector<BinaryQuantizationInputs<T>> generate_inputs()
{
const auto inputs = raft::util::itertools::product<BinaryQuantizationInputs<T>>(
{5, 100, 1000},
{7, 128, 1999},
{cuvs::preprocessing::quantize::binary::set_bit_threshold::zero,
cuvs::preprocessing::quantize::binary::set_bit_threshold::mean,
cuvs::preprocessing::quantize::binary::set_bit_threshold::sampling_median});
return inputs;
}

typedef BinaryQuantizationTest<float, uint8_t> QuantizationTest_float_uint8t;
TEST_P(QuantizationTest_float_uint8t, BinaryQuantizationTest) { this->testBinaryQuantization(); }
Expand All @@ -126,12 +141,12 @@ TEST_P(QuantizationTest_half_uint8t, BinaryQuantizationTest) { this->testBinaryQ

INSTANTIATE_TEST_CASE_P(BinaryQuantizationTest,
QuantizationTest_float_uint8t,
::testing::ValuesIn(inputs<float>));
::testing::ValuesIn(generate_inputs<float>()));
INSTANTIATE_TEST_CASE_P(BinaryQuantizationTest,
QuantizationTest_double_uint8t,
::testing::ValuesIn(inputs<double>));
::testing::ValuesIn(generate_inputs<double>()));
INSTANTIATE_TEST_CASE_P(BinaryQuantizationTest,
QuantizationTest_half_uint8t,
::testing::ValuesIn(inputs<half>));
::testing::ValuesIn(generate_inputs<half>()));

} // namespace cuvs::preprocessing::quantize::binary

0 comments on commit 98077a6

Please sign in to comment.