Skip to content

Commit

Permalink
solve part of the comments
Browse files Browse the repository at this point in the history
  • Loading branch information
mapleFU committed Mar 8, 2024
1 parent 97934ef commit 88a91ff
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 26 deletions.
28 changes: 14 additions & 14 deletions cpp/src/arrow/util/byte_stream_split_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ namespace arrow::util::internal {

#if defined(ARROW_HAVE_NEON) || defined(ARROW_HAVE_SSE4_2)
template <int kNumStreams>
void ByteStreamSplitDecode128B(const uint8_t* data, int64_t num_values, int64_t stride,
uint8_t* out) {
void ByteStreamSplitDecodeSimd128(const uint8_t* data, int64_t num_values, int64_t stride,
uint8_t* out) {
using simd_batch = xsimd::make_sized_batch_t<int8_t, 16>;

static_assert(kNumStreams == 4 || kNumStreams == 8, "Invalid number of streams.");
Expand Down Expand Up @@ -92,8 +92,8 @@ void ByteStreamSplitDecode128B(const uint8_t* data, int64_t num_values, int64_t
}

template <int kNumStreams>
void ByteStreamSplitEncode128B(const uint8_t* raw_values, const int64_t num_values,
uint8_t* output_buffer_raw) {
void ByteStreamSplitEncodeSimd128(const uint8_t* raw_values, const int64_t num_values,
uint8_t* output_buffer_raw) {
using simd_batch = xsimd::make_sized_batch_t<int8_t, 16>;
using simd_arch = typename simd_batch::arch_type;

Expand Down Expand Up @@ -125,10 +125,10 @@ void ByteStreamSplitEncode128B(const uint8_t* raw_values, const int64_t num_valu
// Example run for 32-bit variables:
// Step 0, copy:
// 0: ABCD ABCD ABCD ABCD 1: ABCD ABCD ABCD ABCD ...
// Step 1: simd_batch<int8_t, 8>::xip_lo and simd_batch<int8_t, 8>::xip_hi:
// Step 1: simd_batch<int8_t, 8>::zip_lo and simd_batch<int8_t, 8>::zip_hi:
// 0: AABB CCDD AABB CCDD 1: AABB CCDD AABB CCDD ...
// 0: AAAA BBBB CCCC DDDD 1: AAAA BBBB CCCC DDDD ...
// Step 3: simd_batch<int8_t, 8>::xip_lo and simd_batch<int8_t, 8>::xip_hi:
// Step 3: simd_batch<int8_t, 8>::zip_lo and simd_batch<int8_t, 8>::zip_hi:
// 0: AAAA AAAA BBBB BBBB 1: CCCC CCCC DDDD DDDD ...
// Step 4: simd_batch<int64_t, 2> and simd_batch<int64_t, 2>:
// 0: AAAA AAAA AAAA AAAA 1: BBBB BBBB BBBB BBBB ...
Expand Down Expand Up @@ -223,7 +223,7 @@ void ByteStreamSplitDecodeAvx2(const uint8_t* data, int64_t num_values, int64_t

const int64_t size = num_values * kNumStreams;
if (size < kBlockSize) // Back to SSE for small size
return ByteStreamSplitDecode128B<kNumStreams>(data, num_values, stride, out);
return ByteStreamSplitDecodeSimd128<kNumStreams>(data, num_values, stride, out);
const int64_t num_blocks = size / kBlockSize;

// First handle suffix.
Expand Down Expand Up @@ -305,13 +305,13 @@ void ByteStreamSplitEncodeAvx2(const uint8_t* raw_values, const int64_t num_valu
constexpr int kBlockSize = sizeof(__m256i) * kNumStreams;

if constexpr (kNumStreams == 8) // Back to SSE, currently no path for double.
return ByteStreamSplitEncode128B<kNumStreams>(raw_values, num_values,
output_buffer_raw);
return ByteStreamSplitEncodeSimd128<kNumStreams>(raw_values, num_values,
output_buffer_raw);

const int64_t size = num_values * kNumStreams;
if (size < kBlockSize) // Back to SSE for small size
return ByteStreamSplitEncode128B<kNumStreams>(raw_values, num_values,
output_buffer_raw);
return ByteStreamSplitEncodeSimd128<kNumStreams>(raw_values, num_values,
output_buffer_raw);
const int64_t num_blocks = size / kBlockSize;
const __m256i* raw_values_simd = reinterpret_cast<const __m256i*>(raw_values);
__m256i* output_buffer_streams[kNumStreams];
Expand Down Expand Up @@ -378,7 +378,7 @@ void inline ByteStreamSplitDecodeSimd(const uint8_t* data, int64_t num_values,
#if defined(ARROW_HAVE_AVX2)
return ByteStreamSplitDecodeAvx2<kNumStreams>(data, num_values, stride, out);
#elif defined(ARROW_HAVE_SSE4_2) || defined(ARROW_HAVE_NEON)
return ByteStreamSplitDecode128B<kNumStreams>(data, num_values, stride, out);
return ByteStreamSplitDecodeSimd128<kNumStreams>(data, num_values, stride, out);
#else
#error "ByteStreamSplitDecodeSimd not implemented"
#endif
Expand All @@ -391,8 +391,8 @@ void inline ByteStreamSplitEncodeSimd(const uint8_t* raw_values, const int64_t n
return ByteStreamSplitEncodeAvx2<kNumStreams>(raw_values, num_values,
output_buffer_raw);
#elif defined(ARROW_HAVE_SSE4_2) || defined(ARROW_HAVE_NEON)
return ByteStreamSplitEncode128B<kNumStreams>(raw_values, num_values,
output_buffer_raw);
return ByteStreamSplitEncodeSimd128<kNumStreams>(raw_values, num_values,
output_buffer_raw);
#else
#error "ByteStreamSplitEncodeSimd not implemented"
#endif
Expand Down
6 changes: 2 additions & 4 deletions cpp/src/arrow/util/byte_stream_split_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,8 @@ class TestByteStreamSplitSpecialized : public ::testing::Test {
#if defined(ARROW_HAVE_SIMD_SPLIT)
encode_funcs_.push_back({"simd", &ByteStreamSplitEncodeSimd<kWidth>});
decode_funcs_.push_back({"simd", &ByteStreamSplitDecodeSimd<kWidth>});
#endif
#if defined(ARROW_HAVE_SSE4_2)
encode_funcs_.push_back({"sse2", &ByteStreamSplitEncode128B<kWidth>});
decode_funcs_.push_back({"sse2", &ByteStreamSplitDecode128B<kWidth>});
encode_funcs_.push_back({"simd128", &ByteStreamSplitEncodeSimd128<kWidth>});
decode_funcs_.push_back({"simd128", &ByteStreamSplitDecodeSimd128<kWidth>});
#endif
#if defined(ARROW_HAVE_AVX2)
encode_funcs_.push_back({"avx2", &ByteStreamSplitEncodeAvx2<kWidth>});
Expand Down
16 changes: 8 additions & 8 deletions cpp/src/parquet/encoding_benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -417,22 +417,22 @@ BENCHMARK(BM_ByteStreamSplitEncode_Double_Scalar)->Range(MIN_RANGE, MAX_RANGE);
#if defined(ARROW_HAVE_SSE4_2)
static void BM_ByteStreamSplitDecode_Float_Sse2(benchmark::State& state) {
BM_ByteStreamSplitDecode<float>(
state, ::arrow::util::internal::ByteStreamSplitDecode128B<sizeof(float)>);
state, ::arrow::util::internal::ByteStreamSplitDecodeSimd128<sizeof(float)>);
}

static void BM_ByteStreamSplitDecode_Double_Sse2(benchmark::State& state) {
BM_ByteStreamSplitDecode<double>(
state, ::arrow::util::internal::ByteStreamSplitDecode128B<sizeof(double)>);
state, ::arrow::util::internal::ByteStreamSplitDecodeSimd128<sizeof(double)>);
}

static void BM_ByteStreamSplitEncode_Float_Sse2(benchmark::State& state) {
BM_ByteStreamSplitEncode<float>(
state, ::arrow::util::internal::ByteStreamSplitEncode128B<sizeof(float)>);
state, ::arrow::util::internal::ByteStreamSplitEncodeSimd128<sizeof(float)>);
}

static void BM_ByteStreamSplitEncode_Double_Sse2(benchmark::State& state) {
BM_ByteStreamSplitEncode<double>(
state, ::arrow::util::internal::ByteStreamSplitEncode128B<sizeof(double)>);
state, ::arrow::util::internal::ByteStreamSplitEncodeSimd128<sizeof(double)>);
}

BENCHMARK(BM_ByteStreamSplitDecode_Float_Sse2)->Range(MIN_RANGE, MAX_RANGE);
Expand Down Expand Up @@ -471,22 +471,22 @@ BENCHMARK(BM_ByteStreamSplitEncode_Double_Avx2)->Range(MIN_RANGE, MAX_RANGE);
#if defined(ARROW_HAVE_NEON)
static void BM_ByteStreamSplitDecode_Float_Neon(benchmark::State& state) {
BM_ByteStreamSplitDecode<float>(
state, ::arrow::util::internal::ByteStreamSplitDecode128B<sizeof(float)>);
state, ::arrow::util::internal::ByteStreamSplitDecodeSimd128<sizeof(float)>);
}

static void BM_ByteStreamSplitDecode_Double_Neon(benchmark::State& state) {
BM_ByteStreamSplitDecode<double>(
state, ::arrow::util::internal::ByteStreamSplitDecode128B<sizeof(double)>);
state, ::arrow::util::internal::ByteStreamSplitDecodeSimd128<sizeof(double)>);
}

static void BM_ByteStreamSplitEncode_Float_Neon(benchmark::State& state) {
BM_ByteStreamSplitEncode<float>(
state, ::arrow::util::internal::ByteStreamSplitEncode128B<sizeof(float)>);
state, ::arrow::util::internal::ByteStreamSplitEncodeSimd128<sizeof(float)>);
}

static void BM_ByteStreamSplitEncode_Double_Neon(benchmark::State& state) {
BM_ByteStreamSplitEncode<double>(
state, ::arrow::util::internal::ByteStreamSplitEncode128B<sizeof(double)>);
state, ::arrow::util::internal::ByteStreamSplitEncodeSimd128<sizeof(double)>);
}

BENCHMARK(BM_ByteStreamSplitDecode_Float_Neon)->Range(MIN_RANGE, MAX_RANGE);
Expand Down

0 comments on commit 88a91ff

Please sign in to comment.