From 801301ee22ce802fd000f9f4b919abb47ae1d6c3 Mon Sep 17 00:00:00 2001 From: Jonathan Keane Date: Fri, 16 Aug 2024 14:40:56 -0700 Subject: [PATCH 01/10] GH-43633: [R] Add tests for packages that might be tricky to roundtrip data to Tables + Parquet files (#43634) ### Rationale for this change Add coverage for objects that might have issues roundtripping to Arrow Tables or Parquet files ### What changes are included in this PR? A new test file + a crossbow job that ensures these other packages are installed so the tests run. ### Are these changes tested? The changes are tests ### Are there any user-facing changes? No * GitHub Issue: #43633 Authored-by: Jonathan Keane Signed-off-by: Jonathan Keane --- dev/tasks/r/github.linux.extra.packages.yml | 53 +++++++++ dev/tasks/tasks.yml | 4 + .../testthat/test-extra-package-roundtrip.R | 105 ++++++++++++++++++ 3 files changed, 162 insertions(+) create mode 100644 dev/tasks/r/github.linux.extra.packages.yml create mode 100644 r/tests/testthat/test-extra-package-roundtrip.R diff --git a/dev/tasks/r/github.linux.extra.packages.yml b/dev/tasks/r/github.linux.extra.packages.yml new file mode 100644 index 0000000000000..bb486c72a06a9 --- /dev/null +++ b/dev/tasks/r/github.linux.extra.packages.yml @@ -0,0 +1,53 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +{% import 'macros.jinja' as macros with context %} + +{{ macros.github_header() }} + +jobs: + extra-packages: + name: "extra package roundtrip tests" + runs-on: ubuntu-latest + strategy: + fail-fast: false + env: + ARROW_R_DEV: "FALSE" + ARROW_R_FORCE_EXTRA_PACKAGE_TESTS: TRUE + steps: + {{ macros.github_checkout_arrow()|indent }} + + - uses: r-lib/actions/setup-r@v2 + with: + use-public-rspm: true + - uses: r-lib/actions/setup-pandoc@v2 + - uses: r-lib/actions/setup-r-dependencies@v2 + with: + working-directory: 'arrow/r' + extra-packages: | + any::data.table + any::rcmdcheck + any::readr + any::units + - name: Build arrow package + run: | + R CMD build --no-build-vignettes arrow/r + R CMD INSTALL --install-tests --no-test-load --no-byte-compile arrow_*.tar.gz + - name: run tests + run: | + testthat::test_package("arrow", filter = "extra-package-roundtrip") + shell: Rscript {0} diff --git a/dev/tasks/tasks.yml b/dev/tasks/tasks.yml index 6e1f7609a980f..a9da7eb2889a0 100644 --- a/dev/tasks/tasks.yml +++ b/dev/tasks/tasks.yml @@ -1309,6 +1309,10 @@ tasks: ci: github template: r/github.linux.rchk.yml + test-r-extra-packages: + ci: github + template: r/github.linux.extra.packages.yml + test-r-linux-as-cran: ci: github template: r/github.linux.cran.yml diff --git a/r/tests/testthat/test-extra-package-roundtrip.R b/r/tests/testthat/test-extra-package-roundtrip.R new file mode 100644 index 0000000000000..09a87ef19d561 --- /dev/null +++ b/r/tests/testthat/test-extra-package-roundtrip.R @@ -0,0 +1,105 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +skip_on_cran() + +# Any additional package that we test here that is not already in DESCRIPTION should be +# added to dev/tasks/r/github.linux.extra.packages.yml in the r-lib/actions/setup-r-dependencies@v2 +# step so that they are installed + available in that CI job. + +# So that we can force these in CI +load_or_skip <- function(pkg) { + if (identical(tolower(Sys.getenv("ARROW_R_FORCE_EXTRA_PACKAGE_TESTS")), "true")) { + # because of this indirection on the package name we also avoid a CHECK note and + # we don't otherwise need to Suggest this + requireNamespace(pkg, quietly = TRUE) + } else { + skip_if(!requireNamespace(pkg, quietly = TRUE)) + } + attachNamespace(pkg) +} + +library(dplyr) + +test_that("readr read csvs roundtrip", { + load_or_skip("readr") + + tbl <- example_data[, c("dbl", "lgl", "false", "chr")] + + tf <- tempfile() + on.exit(unlink(tf)) + write.csv(tbl, tf, row.names = FALSE) + + # we should still be able to turn this into a table + new_df <- read_csv(tf, show_col_types = FALSE) + expect_equal(new_df, as_tibble(arrow_table(new_df))) + + # we should still be able to turn this into a table + new_df <- read_csv(tf, show_col_types = FALSE, lazy = TRUE) + expect_equal(new_df, as_tibble(arrow_table(new_df))) + + # and can roundtrip to a parquet file + pq_tmp_file <- tempfile() + write_parquet(new_df, pq_tmp_file) + new_df_read <- read_parquet(pq_tmp_file) + + # we should still be able to turn this into a table + expect_equal(new_df, new_df_read) +}) + +test_that("data.table objects roundtrip", { + load_or_skip("data.table") + + # https://github.com/Rdatatable/data.table/blob/83fd2c05ce2d8555ceb8ba417833956b1b574f7e/R/cedta.R#L25-L27 + .datatable.aware=TRUE + + DT <- as.data.table(example_data) + + # Table -> collect which is what writing + reading to parquet uses under the hood to roundtrip + tab <- as_arrow_table(DT) + DT_read <- collect(tab) + + # we should still be able to turn this into a table + # the .internal.selfref attribute is automatically ignored by testthat/waldo + expect_equal(DT, DT_read) + + # and we can set keys + indices + create new columns + setkey(DT, chr) + setindex(DT, dbl) + DT[, dblshift := data.table::shift(dbl, 1)] + + # Table -> collect + tab <- as_arrow_table(DT) + DT_read <- collect(tab) + + # we should still be able to turn this into a table + expect_equal(DT, DT_read) +}) + +test_that("units roundtrip", { + load_or_skip("units") + + tbl <- example_data + units(tbl$dbl) <- "s" + + # Table -> collect which is what writing + reading to parquet uses under the hood to roundtrip + tab <- as_arrow_table(tbl) + tbl_read <- collect(tab) + + # we should still be able to turn this into a table + expect_equal(tbl, tbl_read) +}) From 8836535785ba3dd4ba335818a34e0479929b70e6 Mon Sep 17 00:00:00 2001 From: Sutou Kouhei Date: Sat, 17 Aug 2024 11:20:16 +0900 Subject: [PATCH 02/10] GH-43702: [C++][FS][Azure] Use the latest Azurite and update the bundled Azure SDK for C++ to azure-identity_1.9.0 (#43723) ### Rationale for this change Some our CI jobs (such as conda based jobs) use recent Azure SDK for C++ and they require latest Azurite. We need to update Azurite for these jobs. I wanted to use the latest Azurite on all environments but I didn't. Because I want to keep using `apt install nodejs` on old Ubuntu for easy to maintain. ### What changes are included in this PR? * Use the latest Azurite if possible * Use `--skipApiVersionCheck` for old Azurite * Update the bundled Azure SDK for C++ * This is not required. It's for detecting this problem in many CI jobs. ### Are these changes tested? Yes. ### Are there any user-facing changes? No. * GitHub Issue: fix #41505 * GitHub Issue: #43702 Authored-by: Sutou Kouhei Signed-off-by: Sutou Kouhei --- ci/scripts/install_azurite.sh | 24 ++++++++++++++++++------ cpp/src/arrow/filesystem/azurefs_test.cc | 5 ++++- cpp/thirdparty/versions.txt | 4 ++-- python/pyarrow/tests/conftest.py | 3 +++ 4 files changed, 27 insertions(+), 9 deletions(-) diff --git a/ci/scripts/install_azurite.sh b/ci/scripts/install_azurite.sh index dda5e99405b7f..b8b1618bed314 100755 --- a/ci/scripts/install_azurite.sh +++ b/ci/scripts/install_azurite.sh @@ -19,20 +19,32 @@ set -e -# Pin azurite to 3.29.0 due to https://github.com/apache/arrow/issues/41505 +node_version="$(node --version)" +echo "node version = ${node_version}" + +case "${node_version}" in + v12*) + # Pin azurite to 3.29.0 due to https://github.com/apache/arrow/issues/41505 + azurite_version=v3.29.0 + ;; + *) + azurite_version=latest + ;; +esac + case "$(uname)" in Darwin) - npm install -g azurite@v3.29.0 + npm install -g azurite@${azurite_version} which azurite ;; MINGW*) choco install nodejs.install - npm install -g azurite@v3.29.0 + npm install -g azurite@${azurite_version} ;; Linux) - npm install -g azurite@v3.29.0 + npm install -g azurite@${azurite_version} which azurite ;; esac -echo "node version = $(node --version)" -echo "azurite version = $(azurite --version)" \ No newline at end of file + +echo "azurite version = $(azurite --version)" diff --git a/cpp/src/arrow/filesystem/azurefs_test.cc b/cpp/src/arrow/filesystem/azurefs_test.cc index 36646f417cbe1..5ff241b17ff58 100644 --- a/cpp/src/arrow/filesystem/azurefs_test.cc +++ b/cpp/src/arrow/filesystem/azurefs_test.cc @@ -198,7 +198,10 @@ class AzuriteEnv : public AzureEnvImpl { self->temp_dir_->path().Join("debug.log")); auto server_process = bp::child( boost::this_process::environment(), exe_path, "--silent", "--location", - self->temp_dir_->path().ToString(), "--debug", self->debug_log_path_.ToString()); + self->temp_dir_->path().ToString(), "--debug", self->debug_log_path_.ToString(), + // For old Azurite. We can't install the latest Azurite with + // old Node.js on old Ubuntu. + "--skipApiVersionCheck"); if (!server_process.valid() || !server_process.running()) { server_process.terminate(); server_process.wait(); diff --git a/cpp/thirdparty/versions.txt b/cpp/thirdparty/versions.txt index 16689c17fba22..30fa24a209482 100644 --- a/cpp/thirdparty/versions.txt +++ b/cpp/thirdparty/versions.txt @@ -54,8 +54,8 @@ ARROW_AWS_LC_BUILD_SHA256_CHECKSUM=ae96a3567161552744fc0cae8b4d68ed88b1ec0f3d3c9 ARROW_AWSSDK_BUILD_VERSION=1.10.55 ARROW_AWSSDK_BUILD_SHA256_CHECKSUM=2d552fb1a84bef4a9b65e34aa7031851ed2aef5319e02cc6e4cb735c48aa30de # Despite the confusing version name this is still the whole Azure SDK for C++ including core, keyvault, storage-common, etc. -ARROW_AZURE_SDK_BUILD_VERSION=azure-core_1.10.3 -ARROW_AZURE_SDK_BUILD_SHA256_CHECKSUM=dd624c2f86adf474d2d0a23066be6e27af9cbd7e3f8d9d8fd7bf981e884b7b48 +ARROW_AZURE_SDK_BUILD_VERSION=azure-identity_1.9.0 +ARROW_AZURE_SDK_BUILD_SHA256_CHECKSUM=97065bfc971ac8df450853ce805f820f52b59457bd7556510186a1569502e4a1 ARROW_BOOST_BUILD_VERSION=1.81.0 ARROW_BOOST_BUILD_SHA256_CHECKSUM=9e0ffae35528c35f90468997bc8d99500bf179cbae355415a89a600c38e13574 ARROW_BROTLI_BUILD_VERSION=v1.0.9 diff --git a/python/pyarrow/tests/conftest.py b/python/pyarrow/tests/conftest.py index 343b602995db6..e1919497b5116 100644 --- a/python/pyarrow/tests/conftest.py +++ b/python/pyarrow/tests/conftest.py @@ -263,6 +263,9 @@ def azure_server(tmpdir_factory): tmpdir = tmpdir_factory.getbasetemp() # We only need blob service emulator, not queue or table. args = ['azurite-blob', "--location", tmpdir, "--blobPort", str(port)] + # For old Azurite. We can't install the latest Azurite with old + # Node.js on old Ubuntu. + args += ["--skipApiVersionCheck"] proc = None try: proc = subprocess.Popen(args, env=env) From 49be60f5c424cca40bbc5a6d1948ad7e800afaab Mon Sep 17 00:00:00 2001 From: Sutou Kouhei Date: Sat, 17 Aug 2024 11:50:46 +0900 Subject: [PATCH 03/10] GH-43175: [C++] Skip not Emscripten ready tests in CSV tests (#43724) ### Rationale for this change We can't use thread nor `%z` on Emacripten. Some CSV tests use them. ### What changes are included in this PR? Skip CSV tests that use thread or `%z`. ### Are these changes tested? Yes. ### Are there any user-facing changes? No. * GitHub Issue: #43175 Authored-by: Sutou Kouhei Signed-off-by: Sutou Kouhei --- ci/scripts/cpp_test.sh | 2 +- cpp/src/arrow/csv/column_decoder_test.cc | 11 +++++++++++ cpp/src/arrow/csv/converter_test.cc | 5 +++++ 3 files changed, 17 insertions(+), 1 deletion(-) diff --git a/ci/scripts/cpp_test.sh b/ci/scripts/cpp_test.sh index 2c640f2c1fb6a..7912bf23e491c 100755 --- a/ci/scripts/cpp_test.sh +++ b/ci/scripts/cpp_test.sh @@ -80,7 +80,7 @@ case "$(uname)" in ;; esac -if [ "${ARROW_EMSCRIPTEN:-OFF}" = "ON" ]; then +if [ "${ARROW_EMSCRIPTEN:-OFF}" = "ON" ]; then n_jobs=1 # avoid spurious fails on emscripten due to loading too many big executables fi diff --git a/cpp/src/arrow/csv/column_decoder_test.cc b/cpp/src/arrow/csv/column_decoder_test.cc index ebac7a3da2fcf..567732647179e 100644 --- a/cpp/src/arrow/csv/column_decoder_test.cc +++ b/cpp/src/arrow/csv/column_decoder_test.cc @@ -175,6 +175,9 @@ class NullColumnDecoderTest : public ColumnDecoderTest { } void TestThreaded() { +#ifndef ARROW_ENABLE_THREADING + GTEST_SKIP() << "Test requires threading support"; +#endif constexpr int NITERS = 10; auto type = int32(); MakeDecoder(type); @@ -257,6 +260,10 @@ class TypedColumnDecoderTest : public ColumnDecoderTest { } void TestThreaded() { +#ifndef ARROW_ENABLE_THREADING + GTEST_SKIP() << "Test requires threading support"; +#endif + constexpr int NITERS = 10; auto type = uint32(); MakeDecoder(type, default_options); @@ -305,6 +312,10 @@ class InferringColumnDecoderTest : public ColumnDecoderTest { } void TestThreaded() { +#ifndef ARROW_ENABLE_THREADING + GTEST_SKIP() << "Test requires threading support"; +#endif + constexpr int NITERS = 10; auto type = float64(); MakeDecoder(default_options); diff --git a/cpp/src/arrow/csv/converter_test.cc b/cpp/src/arrow/csv/converter_test.cc index ea4e171d57e71..657e8d813ca1b 100644 --- a/cpp/src/arrow/csv/converter_test.cc +++ b/cpp/src/arrow/csv/converter_test.cc @@ -625,6 +625,11 @@ TEST(TimestampConversion, UserDefinedParsers) { } TEST(TimestampConversion, UserDefinedParsersWithZone) { +#ifdef __EMSCRIPTEN__ + GTEST_SKIP() << "Test temporarily disabled due to emscripten bug " + "https://github.com/emscripten-core/emscripten/issues/20467"; +#endif + auto options = ConvertOptions::Defaults(); auto type = timestamp(TimeUnit::SECOND, "America/Phoenix"); From fbac12c353cb6ead58a5ee765b37bd1bc46cd672 Mon Sep 17 00:00:00 2001 From: Jonathan Keane Date: Sat, 17 Aug 2024 17:16:39 -0500 Subject: [PATCH 04/10] MINOR: [R] Fix a package namespace warning (#43737) Oops, I should have caught this in #43633 Removes `data.table::` since the namespace is loaded. Also fix some linting errors and free up space on the force tests run. Authored-by: Jonathan Keane Signed-off-by: Jonathan Keane --- .github/workflows/r.yml | 3 +++ r/tests/testthat/test-extra-package-roundtrip.R | 16 ++++++++-------- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/.github/workflows/r.yml b/.github/workflows/r.yml index c4899ddcc49e5..bf7eb99e7e990 100644 --- a/.github/workflows/r.yml +++ b/.github/workflows/r.yml @@ -133,6 +133,9 @@ jobs: with: fetch-depth: 0 submodules: recursive + - name: Free up disk space + run: | + ci/scripts/util_free_space.sh - name: Cache Docker Volumes uses: actions/cache@13aacd865c20de90d75de3b17ebe84f7a17d57d2 # v4.0.0 with: diff --git a/r/tests/testthat/test-extra-package-roundtrip.R b/r/tests/testthat/test-extra-package-roundtrip.R index 09a87ef19d561..092288dffb955 100644 --- a/r/tests/testthat/test-extra-package-roundtrip.R +++ b/r/tests/testthat/test-extra-package-roundtrip.R @@ -24,7 +24,7 @@ skip_on_cran() # So that we can force these in CI load_or_skip <- function(pkg) { if (identical(tolower(Sys.getenv("ARROW_R_FORCE_EXTRA_PACKAGE_TESTS")), "true")) { - # because of this indirection on the package name we also avoid a CHECK note and + # because of this indirection on the package name we also avoid a CHECK note and # we don't otherwise need to Suggest this requireNamespace(pkg, quietly = TRUE) } else { @@ -46,11 +46,11 @@ test_that("readr read csvs roundtrip", { # we should still be able to turn this into a table new_df <- read_csv(tf, show_col_types = FALSE) - expect_equal(new_df, as_tibble(arrow_table(new_df))) + expect_equal(new_df, as_tibble(arrow_table(new_df))) # we should still be able to turn this into a table new_df <- read_csv(tf, show_col_types = FALSE, lazy = TRUE) - expect_equal(new_df, as_tibble(arrow_table(new_df))) + expect_equal(new_df, as_tibble(arrow_table(new_df))) # and can roundtrip to a parquet file pq_tmp_file <- tempfile() @@ -65,11 +65,11 @@ test_that("data.table objects roundtrip", { load_or_skip("data.table") # https://github.com/Rdatatable/data.table/blob/83fd2c05ce2d8555ceb8ba417833956b1b574f7e/R/cedta.R#L25-L27 - .datatable.aware=TRUE + .datatable.aware <- TRUE DT <- as.data.table(example_data) - # Table -> collect which is what writing + reading to parquet uses under the hood to roundtrip + # Table to collect which is what writing + reading to parquet uses under the hood to roundtrip tab <- as_arrow_table(DT) DT_read <- collect(tab) @@ -80,9 +80,9 @@ test_that("data.table objects roundtrip", { # and we can set keys + indices + create new columns setkey(DT, chr) setindex(DT, dbl) - DT[, dblshift := data.table::shift(dbl, 1)] + DT[, dblshift := shift(dbl, 1)] - # Table -> collect + # Table to collect tab <- as_arrow_table(DT) DT_read <- collect(tab) @@ -96,7 +96,7 @@ test_that("units roundtrip", { tbl <- example_data units(tbl$dbl) <- "s" - # Table -> collect which is what writing + reading to parquet uses under the hood to roundtrip + # Table to collect which is what writing + reading to parquet uses under the hood to roundtrip tab <- as_arrow_table(tbl) tbl_read <- collect(tab) From b7e618f088540a45e2ddab39696ce3d543821763 Mon Sep 17 00:00:00 2001 From: Sutou Kouhei Date: Sun, 18 Aug 2024 10:42:53 +0900 Subject: [PATCH 05/10] GH-43738: [GLib] Add `GArrowAzureFileSytem` (#43739) ### Rationale for this change The bindings for `arrow::fs::AzureFileSytem` is missing. ### What changes are included in this PR? Add the bindings for `arrow::fs::AzureFileSytem`. ### Are these changes tested? Yes. ### Are there any user-facing changes? Yes. * GitHub Issue: #43738 Authored-by: Sutou Kouhei Signed-off-by: Sutou Kouhei --- c_glib/arrow-glib/file-system.cpp | 16 ++++++++++++++++ c_glib/arrow-glib/file-system.h | 12 ++++++++++++ 2 files changed, 28 insertions(+) diff --git a/c_glib/arrow-glib/file-system.cpp b/c_glib/arrow-glib/file-system.cpp index b6efa2b872635..9ba494e405957 100644 --- a/c_glib/arrow-glib/file-system.cpp +++ b/c_glib/arrow-glib/file-system.cpp @@ -56,6 +56,8 @@ G_BEGIN_DECLS * #GArrowS3FileSystem is a class for S3-backed file system. * * #GArrowGCSFileSystem is a class for GCS-backed file system. + * + * #GArrowAzureFileSystem is a class for Azure-backed file system. */ /* arrow::fs::FileInfo */ @@ -1561,6 +1563,18 @@ garrow_gcs_file_system_class_init(GArrowGCSFileSystemClass *klass) { } +G_DEFINE_TYPE(GArrowAzureFileSystem, garrow_azure_file_system, GARROW_TYPE_FILE_SYSTEM) + +static void +garrow_azure_file_system_init(GArrowAzureFileSystem *file_system) +{ +} + +static void +garrow_azure_file_system_class_init(GArrowAzureFileSystemClass *klass) +{ +} + G_END_DECLS GArrowFileInfo * @@ -1592,6 +1606,8 @@ garrow_file_system_new_raw(std::shared_ptr *arrow_file_sy file_system_type = GARROW_TYPE_S3_FILE_SYSTEM; } else if (type_name == "gcs") { file_system_type = GARROW_TYPE_GCS_FILE_SYSTEM; + } else if (type_name == "abfs") { + file_system_type = GARROW_TYPE_AZURE_FILE_SYSTEM; } else if (type_name == "mock") { file_system_type = GARROW_TYPE_MOCK_FILE_SYSTEM; } diff --git a/c_glib/arrow-glib/file-system.h b/c_glib/arrow-glib/file-system.h index 2e500672e145c..9a903c6af68cf 100644 --- a/c_glib/arrow-glib/file-system.h +++ b/c_glib/arrow-glib/file-system.h @@ -337,4 +337,16 @@ struct _GArrowGCSFileSystemClass GArrowFileSystemClass parent_class; }; +#define GARROW_TYPE_AZURE_FILE_SYSTEM (garrow_azure_file_system_get_type()) +GARROW_AVAILABLE_IN_18_0 +G_DECLARE_DERIVABLE_TYPE(GArrowAzureFileSystem, + garrow_azure_file_system, + GARROW, + AZURE_FILE_SYSTEM, + GArrowFileSystem) +struct _GArrowAzureFileSystemClass +{ + GArrowFileSystemClass parent_class; +}; + G_END_DECLS From 5ef7e01053c526389acefddd6f961bf1fd9d274b Mon Sep 17 00:00:00 2001 From: Jin Chengcheng Date: Sun, 18 Aug 2024 15:28:52 +0800 Subject: [PATCH 06/10] GH-43506: [Java] Fix TestFragmentScanOptions result not match (#43639) ### Rationale for this change JNI test was not tested in CI. So the test failed but passed the CI. The parseChar function should return char but return bool, a typo error. ### What changes are included in this PR? ### Are these changes tested? Yes ### Are there any user-facing changes? No * GitHub Issue: #43506 Authored-by: Chengcheng Jin Signed-off-by: David Li --- java/dataset/src/main/cpp/jni_wrapper.cc | 2 +- .../dataset/TestFragmentScanOptions.java | 80 ++++++++++++------- 2 files changed, 52 insertions(+), 30 deletions(-) diff --git a/java/dataset/src/main/cpp/jni_wrapper.cc b/java/dataset/src/main/cpp/jni_wrapper.cc index 63b8dd73f4720..49cc85251c8e9 100644 --- a/java/dataset/src/main/cpp/jni_wrapper.cc +++ b/java/dataset/src/main/cpp/jni_wrapper.cc @@ -368,7 +368,7 @@ std::shared_ptr LoadArrowBufferFromByteBuffer(JNIEnv* env, jobjec inline bool ParseBool(const std::string& value) { return value == "true" ? true : false; } -inline bool ParseChar(const std::string& key, const std::string& value) { +inline char ParseChar(const std::string& key, const std::string& value) { if (value.size() != 1) { JniThrow("Option " + key + " should be a char, but is " + value); } diff --git a/java/dataset/src/test/java/org/apache/arrow/dataset/TestFragmentScanOptions.java b/java/dataset/src/test/java/org/apache/arrow/dataset/TestFragmentScanOptions.java index d598190528811..ed6344f0f9cb7 100644 --- a/java/dataset/src/test/java/org/apache/arrow/dataset/TestFragmentScanOptions.java +++ b/java/dataset/src/test/java/org/apache/arrow/dataset/TestFragmentScanOptions.java @@ -51,6 +51,16 @@ public class TestFragmentScanOptions { + private CsvFragmentScanOptions create( + ArrowSchema cSchema, + Map convertOptionsMap, + Map readOptions, + Map parseOptions) { + CsvConvertOptions convertOptions = new CsvConvertOptions(convertOptionsMap); + convertOptions.setArrowSchema(cSchema); + return new CsvFragmentScanOptions(convertOptions, readOptions, parseOptions); + } + @Test public void testCsvConvertOptions() throws Exception { final Schema schema = @@ -63,24 +73,29 @@ public void testCsvConvertOptions() throws Exception { String path = "file://" + getClass().getResource("/").getPath() + "/data/student.csv"; BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); try (ArrowSchema cSchema = ArrowSchema.allocateNew(allocator); + ArrowSchema cSchema2 = ArrowSchema.allocateNew(allocator); CDataDictionaryProvider provider = new CDataDictionaryProvider()) { Data.exportSchema(allocator, schema, provider, cSchema); - CsvConvertOptions convertOptions = new CsvConvertOptions(ImmutableMap.of("delimiter", ";")); - convertOptions.setArrowSchema(cSchema); - CsvFragmentScanOptions fragmentScanOptions = - new CsvFragmentScanOptions(convertOptions, ImmutableMap.of(), ImmutableMap.of()); + Data.exportSchema(allocator, schema, provider, cSchema2); + CsvFragmentScanOptions fragmentScanOptions1 = + create(cSchema, ImmutableMap.of(), ImmutableMap.of(), ImmutableMap.of("delimiter", ";")); + CsvFragmentScanOptions fragmentScanOptions2 = + create(cSchema2, ImmutableMap.of(), ImmutableMap.of(), ImmutableMap.of("delimiter", ";")); ScanOptions options = new ScanOptions.Builder(/*batchSize*/ 32768) .columns(Optional.empty()) - .fragmentScanOptions(fragmentScanOptions) + .fragmentScanOptions(fragmentScanOptions1) .build(); try (DatasetFactory datasetFactory = new FileSystemDatasetFactory( - allocator, NativeMemoryPool.getDefault(), FileFormat.CSV, path); + allocator, + NativeMemoryPool.getDefault(), + FileFormat.CSV, + path, + Optional.of(fragmentScanOptions2)); Dataset dataset = datasetFactory.finish(); Scanner scanner = dataset.newScan(options); ArrowReader reader = scanner.scanBatches()) { - assertEquals(schema.getFields(), reader.getVectorSchemaRoot().getSchema().getFields()); int rowCount = 0; while (reader.loadNextBatch()) { @@ -106,30 +121,38 @@ public void testCsvConvertOptionsDelimiterNotSet() throws Exception { String path = "file://" + getClass().getResource("/").getPath() + "/data/student.csv"; BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); try (ArrowSchema cSchema = ArrowSchema.allocateNew(allocator); + ArrowSchema cSchema2 = ArrowSchema.allocateNew(allocator); CDataDictionaryProvider provider = new CDataDictionaryProvider()) { Data.exportSchema(allocator, schema, provider, cSchema); - CsvConvertOptions convertOptions = new CsvConvertOptions(ImmutableMap.of()); - convertOptions.setArrowSchema(cSchema); - CsvFragmentScanOptions fragmentScanOptions = - new CsvFragmentScanOptions(convertOptions, ImmutableMap.of(), ImmutableMap.of()); + Data.exportSchema(allocator, schema, provider, cSchema2); + CsvFragmentScanOptions fragmentScanOptions1 = + create(cSchema, ImmutableMap.of(), ImmutableMap.of(), ImmutableMap.of()); + CsvFragmentScanOptions fragmentScanOptions2 = + create(cSchema2, ImmutableMap.of(), ImmutableMap.of(), ImmutableMap.of()); ScanOptions options = new ScanOptions.Builder(/*batchSize*/ 32768) .columns(Optional.empty()) - .fragmentScanOptions(fragmentScanOptions) + .fragmentScanOptions(fragmentScanOptions1) .build(); try (DatasetFactory datasetFactory = new FileSystemDatasetFactory( - allocator, NativeMemoryPool.getDefault(), FileFormat.CSV, path); + allocator, + NativeMemoryPool.getDefault(), + FileFormat.CSV, + path, + Optional.of(fragmentScanOptions2)); Dataset dataset = datasetFactory.finish(); Scanner scanner = dataset.newScan(options); ArrowReader reader = scanner.scanBatches()) { - - assertEquals(schema.getFields(), reader.getVectorSchemaRoot().getSchema().getFields()); int rowCount = 0; while (reader.loadNextBatch()) { - final ValueIterableVector idVector = - (ValueIterableVector) reader.getVectorSchemaRoot().getVector("Id"); - assertThat(idVector.getValueIterable(), IsIterableContainingInOrder.contains(1, 2, 3)); + final ValueIterableVector idVector = + (ValueIterableVector) + reader.getVectorSchemaRoot().getVector("Id;Name;Language"); + assertThat( + idVector.getValueIterable(), + IsIterableContainingInOrder.contains( + new Text("1;Juno;Java"), new Text("2;Peter;Python"), new Text("3;Celin;C++"))); rowCount += reader.getVectorSchemaRoot().getRowCount(); } assertEquals(3, rowCount); @@ -157,13 +180,12 @@ public void testCsvConvertOptionsNoOption() throws Exception { assertEquals(schema.getFields(), reader.getVectorSchemaRoot().getSchema().getFields()); int rowCount = 0; while (reader.loadNextBatch()) { - final ValueIterableVector idVector = - (ValueIterableVector) - reader.getVectorSchemaRoot().getVector("Id;Name;Language"); + final ValueIterableVector idVector = + (ValueIterableVector) reader.getVectorSchemaRoot().getVector("Id;Name;Language"); assertThat( idVector.getValueIterable(), IsIterableContainingInOrder.contains( - "1;Juno;Java\n" + "2;Peter;Python\n" + "3;Celin;C++")); + new Text("1;Juno;Java"), new Text("2;Peter;Python"), new Text("3;Celin;C++"))); rowCount += reader.getVectorSchemaRoot().getRowCount(); } assertEquals(3, rowCount); @@ -174,7 +196,10 @@ public void testCsvConvertOptionsNoOption() throws Exception { public void testCsvReadParseAndReadOptions() throws Exception { final Schema schema = new Schema( - Collections.singletonList(Field.nullable("Id;Name;Language", new ArrowType.Utf8())), + Arrays.asList( + Field.nullable("Id", new ArrowType.Int(64, true)), + Field.nullable("Name", new ArrowType.Utf8()), + Field.nullable("Language", new ArrowType.Utf8())), null); String path = "file://" + getClass().getResource("/").getPath() + "/data/student.csv"; BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); @@ -202,12 +227,9 @@ public void testCsvReadParseAndReadOptions() throws Exception { assertEquals(schema.getFields(), reader.getVectorSchemaRoot().getSchema().getFields()); int rowCount = 0; while (reader.loadNextBatch()) { - final ValueIterableVector idVector = - (ValueIterableVector) reader.getVectorSchemaRoot().getVector("Id;Name;Language"); - assertThat( - idVector.getValueIterable(), - IsIterableContainingInOrder.contains( - new Text("2;Peter;Python"), new Text("3;Celin;C++"))); + final ValueIterableVector idVector = + (ValueIterableVector) reader.getVectorSchemaRoot().getVector("Id"); + assertThat(idVector.getValueIterable(), IsIterableContainingInOrder.contains(2L, 3L)); rowCount += reader.getVectorSchemaRoot().getRowCount(); } assertEquals(2, rowCount); From 1ae38d0d42c1ae5800e42b613f22593673b7370c Mon Sep 17 00:00:00 2001 From: Jonathan Keane Date: Sun, 18 Aug 2024 08:48:55 -0500 Subject: [PATCH 07/10] GH-43735: [R] AWS SDK fails to build on one of CRAN's M1 builders (#43736) Trying to replicate the issue's on CRAN's M1 machine so that we can fix them. * GitHub Issue: #43735 Lead-authored-by: Jonathan Keane Co-authored-by: Sutou Kouhei Signed-off-by: Jonathan Keane --- cpp/cmake_modules/ThirdpartyToolchain.cmake | 12 +++ dev/tasks/r/github.macos.cran.yml | 82 +++++++++++++++++++++ dev/tasks/tasks.yml | 4 + 3 files changed, 98 insertions(+) create mode 100644 dev/tasks/r/github.macos.cran.yml diff --git a/cpp/cmake_modules/ThirdpartyToolchain.cmake b/cpp/cmake_modules/ThirdpartyToolchain.cmake index bc3a3a2249d13..63e2c036c9a6f 100644 --- a/cpp/cmake_modules/ThirdpartyToolchain.cmake +++ b/cpp/cmake_modules/ThirdpartyToolchain.cmake @@ -4965,8 +4965,20 @@ macro(build_awssdk) set(AWSSDK_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/awssdk_ep-install") set(AWSSDK_INCLUDE_DIR "${AWSSDK_PREFIX}/include") + # The AWS SDK has a few warnings around shortening lengths + set(AWS_C_FLAGS "${EP_C_FLAGS}") + set(AWS_CXX_FLAGS "${EP_CXX_FLAGS}") + if(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" OR CMAKE_CXX_COMPILER_ID STREQUAL + "Clang") + # Negate warnings that AWS SDK cannot build under + string(APPEND AWS_C_FLAGS " -Wno-error=shorten-64-to-32") + string(APPEND AWS_CXX_FLAGS " -Wno-error=shorten-64-to-32") + endif() + set(AWSSDK_COMMON_CMAKE_ARGS ${EP_COMMON_CMAKE_ARGS} + -DCMAKE_C_FLAGS=${AWS_C_FLAGS} + -DCMAKE_CXX_FLAGS=${AWS_CXX_FLAGS} -DCPP_STANDARD=${CMAKE_CXX_STANDARD} -DCMAKE_INSTALL_PREFIX=${AWSSDK_PREFIX} -DCMAKE_PREFIX_PATH=${AWSSDK_PREFIX} diff --git a/dev/tasks/r/github.macos.cran.yml b/dev/tasks/r/github.macos.cran.yml new file mode 100644 index 0000000000000..33965988e213a --- /dev/null +++ b/dev/tasks/r/github.macos.cran.yml @@ -0,0 +1,82 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +{% import 'macros.jinja' as macros with context %} + +{{ macros.github_header() }} + +jobs: + macos-cran: + name: "macOS similar to CRAN" + runs-on: macOS-latest + strategy: + fail-fast: false + + steps: + {{ macros.github_checkout_arrow()|indent }} + + - name: Configure dependencies (macos) + run: | + brew install openssl + # disable sccache on macos as it times out for unknown reasons + # see GH-33721 + # brew install sccache + # remove cmake so that we can test our cmake downloading abilities + brew uninstall cmake + - uses: r-lib/actions/setup-r@v2 + with: + use-public-rspm: true + # CRAN builders have the entire bin here added to the path. This sometimes + # includes things like GNU libtool which name-collide with what we expect + - name: Add R.framework/Resources/bin to the path + run: echo "/Library/Frameworks/R.framework/Resources/bin" >> $GITHUB_PATH + - name : Check whether libtool in R is used + run: | + if [ "$(which libtool)" != "/Library/Frameworks/R.framework/Resources/bin/libtool" ]; then + echo "libtool provided by R isn't found: $(which libtool)" + exit 1 + fi + - name: Install dependencies + uses: r-lib/actions/setup-r-dependencies@v2 + with: + cache: false # cache does not work on across branches + working-directory: arrow/r + extra-packages: | + any::rcmdcheck + any::sys + - name: Install + env: + _R_CHECK_CRAN_INCOMING_: false + CXX: "clang++ -mmacos-version-min=14.6" + CFLAGS: "-falign-functions=8 -g -O2 -Wall -pedantic -Wconversion -Wno-sign-conversion -Wstrict-prototypes" + CXXFLAGS: "-g -O2 -Wall -pedantic -Wconversion -Wno-sign-conversion" + NOT_CRAN: false + run: | + sccache --start-server || echo 'sccache not found' + cd arrow/r + R CMD INSTALL . --install-tests + - name: Run the tests + run: R -e 'if(tools::testInstalledPackage("arrow") != 0L) stop("There was a test failure.")' + - name: Dump test logs + run: cat arrow-tests/testthat.Rout* + if: failure() + - name: Save the test output + uses: actions/upload-artifact@v2 + with: + name: test-output + path: arrow-tests/testthat.Rout* + if: always() diff --git a/dev/tasks/tasks.yml b/dev/tasks/tasks.yml index a9da7eb2889a0..fe02fe9ce68b2 100644 --- a/dev/tasks/tasks.yml +++ b/dev/tasks/tasks.yml @@ -1319,6 +1319,10 @@ tasks: params: MATRIX: {{ "${{ matrix.r_image }}" }} + test-r-macos-as-cran: + ci: github + template: r/github.macos.cran.yml + test-r-arrow-backwards-compatibility: ci: github template: r/github.linux.arrow.version.back.compat.yml From 5e68513d62b0d216e916de6a1ad2db04f5d1a7bf Mon Sep 17 00:00:00 2001 From: Rossi Sun Date: Mon, 19 Aug 2024 18:39:05 +0800 Subject: [PATCH 08/10] GH-43495: [C++][Compute] Widen the row offset of the row table to 64-bit (#43389) ### Rationale for this change The row table uses `uint32_t` as the row offset within the row data buffer, effectively limiting the row data from growing beyond 4GB. This is quite restrictive, and the impact is described in more detail in #43495. This PR proposes to widen the row offset from 32-bit to 64-bit to address this limitation. #### Benefits Currently, the row table has three major limitations: 1. The overall data size cannot exceed 4GB. 2. The size of a single row cannot exceed 4GB. 3. The number of rows cannot exceed 2^32. This enhancement will eliminate the first limitation. Meanwhile, the second and third limitations are less likely to occur. Thus, this change will enable a significant range of use cases that are currently unsupported. #### Overhead Of course, this will introduce some overhead: 1. An extra 4 bytes of memory consumption for each row due to the offset size difference from 32-bit to 64-bit. 2. A wider offset type requires a few more SIMD instructions in each 8-row processing iteration. In my opinion, this overhead is justified by the benefits listed above. ### What changes are included in this PR? Change the row offset of the row table from 32-bit to 64-bit. Relative code in row comparison/encoding and swiss join has been updated accordingly. ### Are these changes tested? Test included. ### Are there any user-facing changes? Users could potentially see higher memory consumption when using acero's hash join and hash aggregation. However, on the other hand, certain use cases used to fail are now able to complete. * GitHub Issue: #43495 Authored-by: Ruoxi Sun Signed-off-by: Antoine Pitrou --- cpp/src/arrow/acero/hash_join_node_test.cc | 192 ++++++++++ cpp/src/arrow/acero/swiss_join.cc | 26 +- cpp/src/arrow/acero/swiss_join_avx2.cc | 126 +++++-- cpp/src/arrow/compute/row/compare_internal.cc | 39 +- cpp/src/arrow/compute/row/compare_internal.h | 27 +- .../compute/row/compare_internal_avx2.cc | 172 ++++----- cpp/src/arrow/compute/row/compare_test.cc | 333 +++++++++++++----- cpp/src/arrow/compute/row/encode_internal.cc | 47 ++- cpp/src/arrow/compute/row/encode_internal.h | 7 +- .../arrow/compute/row/encode_internal_avx2.cc | 10 +- cpp/src/arrow/compute/row/row_internal.cc | 38 +- cpp/src/arrow/compute/row/row_internal.h | 37 +- cpp/src/arrow/compute/row/row_test.cc | 66 ++-- cpp/src/arrow/testing/random.cc | 19 +- cpp/src/arrow/testing/random.h | 6 + 15 files changed, 802 insertions(+), 343 deletions(-) diff --git a/cpp/src/arrow/acero/hash_join_node_test.cc b/cpp/src/arrow/acero/hash_join_node_test.cc index f7b442cc3c624..88f9a9e71b768 100644 --- a/cpp/src/arrow/acero/hash_join_node_test.cc +++ b/cpp/src/arrow/acero/hash_join_node_test.cc @@ -30,6 +30,7 @@ #include "arrow/compute/kernels/test_util.h" #include "arrow/compute/light_array_internal.h" #include "arrow/testing/extension_type.h" +#include "arrow/testing/generator.h" #include "arrow/testing/gtest_util.h" #include "arrow/testing/matchers.h" #include "arrow/testing/random.h" @@ -40,6 +41,10 @@ using testing::UnorderedElementsAreArray; namespace arrow { +using arrow::gen::Constant; +using arrow::random::kSeedMax; +using arrow::random::RandomArrayGenerator; +using compute::and_; using compute::call; using compute::default_exec_context; using compute::ExecBatchBuilder; @@ -3253,5 +3258,192 @@ TEST(HashJoin, ManyJoins) { ASSERT_OK_AND_ASSIGN(std::ignore, DeclarationToTable(std::move(root))); } +namespace { + +void AssertRowCountEq(Declaration source, int64_t expected) { + Declaration count{"aggregate", + {std::move(source)}, + AggregateNodeOptions{/*aggregates=*/{{"count_all", "count(*)"}}}}; + ASSERT_OK_AND_ASSIGN(auto batches, DeclarationToExecBatches(std::move(count))); + ASSERT_EQ(batches.batches.size(), 1); + ASSERT_EQ(batches.batches[0].values.size(), 1); + ASSERT_TRUE(batches.batches[0].values[0].is_scalar()); + ASSERT_EQ(batches.batches[0].values[0].scalar()->type->id(), Type::INT64); + ASSERT_TRUE(batches.batches[0].values[0].scalar_as().is_valid); + ASSERT_EQ(batches.batches[0].values[0].scalar_as().value, expected); +} + +} // namespace + +// GH-43495: Test that both the key and the payload of the right side (the build side) are +// fixed length and larger than 4GB, and the 64-bit offset in the hash table can handle it +// correctly. +TEST(HashJoin, LARGE_MEMORY_TEST(BuildSideOver4GBFixedLength)) { + constexpr int64_t k5GB = 5ll * 1024 * 1024 * 1024; + constexpr int fixed_length = 128; + const auto type = fixed_size_binary(fixed_length); + constexpr uint8_t byte_no_match_min = static_cast('A'); + constexpr uint8_t byte_no_match_max = static_cast('y'); + constexpr uint8_t byte_match = static_cast('z'); + const auto value_match = + std::make_shared(std::string(fixed_length, byte_match)); + constexpr int16_t num_rows_per_batch_left = 128; + constexpr int16_t num_rows_per_batch_right = 4096; + const int64_t num_batches_left = 8; + const int64_t num_batches_right = + k5GB / (num_rows_per_batch_right * type->byte_width()); + + // Left side composed of num_batches_left identical batches of num_rows_per_batch_left + // rows of value_match-es. + BatchesWithSchema batches_left; + { + // A column with num_rows_per_batch_left value_match-es. + ASSERT_OK_AND_ASSIGN(auto column, + Constant(value_match)->Generate(num_rows_per_batch_left)); + + // Use the column as both the key and the payload. + ExecBatch batch({column, column}, num_rows_per_batch_left); + batches_left = + BatchesWithSchema{std::vector(num_batches_left, std::move(batch)), + schema({field("l_key", type), field("l_payload", type)})}; + } + + // Right side composed of num_batches_right identical batches of + // num_rows_per_batch_right rows containing only 1 value_match. + BatchesWithSchema batches_right; + { + // A column with (num_rows_per_batch_right - 1) non-value_match-es (possibly null) and + // 1 value_match. + auto non_matches = RandomArrayGenerator(kSeedMax).FixedSizeBinary( + num_rows_per_batch_right - 1, fixed_length, + /*null_probability =*/0.01, /*min_byte=*/byte_no_match_min, + /*max_byte=*/byte_no_match_max); + ASSERT_OK_AND_ASSIGN(auto match, Constant(value_match)->Generate(1)); + ASSERT_OK_AND_ASSIGN(auto column, Concatenate({non_matches, match})); + + // Use the column as both the key and the payload. + ExecBatch batch({column, column}, num_rows_per_batch_right); + batches_right = + BatchesWithSchema{std::vector(num_batches_right, std::move(batch)), + schema({field("r_key", type), field("r_payload", type)})}; + } + + Declaration left{"exec_batch_source", + ExecBatchSourceNodeOptions(std::move(batches_left.schema), + std::move(batches_left.batches))}; + + Declaration right{"exec_batch_source", + ExecBatchSourceNodeOptions(std::move(batches_right.schema), + std::move(batches_right.batches))}; + + HashJoinNodeOptions join_opts(JoinType::INNER, /*left_keys=*/{"l_key"}, + /*right_keys=*/{"r_key"}); + Declaration join{"hashjoin", {std::move(left), std::move(right)}, join_opts}; + + ASSERT_OK_AND_ASSIGN(auto batches_result, DeclarationToExecBatches(std::move(join))); + Declaration result{"exec_batch_source", + ExecBatchSourceNodeOptions(std::move(batches_result.schema), + std::move(batches_result.batches))}; + + // The row count of hash join should be (number of value_match-es in left side) * + // (number of value_match-es in right side). + AssertRowCountEq(result, + num_batches_left * num_rows_per_batch_left * num_batches_right); + + // All rows should be value_match-es. + auto predicate = and_({equal(field_ref("l_key"), literal(value_match)), + equal(field_ref("l_payload"), literal(value_match)), + equal(field_ref("r_key"), literal(value_match)), + equal(field_ref("r_payload"), literal(value_match))}); + Declaration filter{"filter", {result}, FilterNodeOptions{std::move(predicate)}}; + AssertRowCountEq(std::move(filter), + num_batches_left * num_rows_per_batch_left * num_batches_right); +} + +// GH-43495: Test that both the key and the payload of the right side (the build side) are +// var length and larger than 4GB, and the 64-bit offset in the hash table can handle it +// correctly. +TEST(HashJoin, LARGE_MEMORY_TEST(BuildSideOver4GBVarLength)) { + constexpr int64_t k5GB = 5ll * 1024 * 1024 * 1024; + const auto type = utf8(); + constexpr int value_no_match_length_min = 128; + constexpr int value_no_match_length_max = 129; + constexpr int value_match_length = 130; + const auto value_match = + std::make_shared(std::string(value_match_length, 'X')); + constexpr int16_t num_rows_per_batch_left = 128; + constexpr int16_t num_rows_per_batch_right = 4096; + const int64_t num_batches_left = 8; + const int64_t num_batches_right = + k5GB / (num_rows_per_batch_right * value_no_match_length_min); + + // Left side composed of num_batches_left identical batches of num_rows_per_batch_left + // rows of value_match-es. + BatchesWithSchema batches_left; + { + // A column with num_rows_per_batch_left value_match-es. + ASSERT_OK_AND_ASSIGN(auto column, + Constant(value_match)->Generate(num_rows_per_batch_left)); + + // Use the column as both the key and the payload. + ExecBatch batch({column, column}, num_rows_per_batch_left); + batches_left = + BatchesWithSchema{std::vector(num_batches_left, std::move(batch)), + schema({field("l_key", type), field("l_payload", type)})}; + } + + // Right side composed of num_batches_right identical batches of + // num_rows_per_batch_right rows containing only 1 value_match. + BatchesWithSchema batches_right; + { + // A column with (num_rows_per_batch_right - 1) non-value_match-es (possibly null) and + // 1 value_match. + auto non_matches = + RandomArrayGenerator(kSeedMax).String(num_rows_per_batch_right - 1, + /*min_length=*/value_no_match_length_min, + /*max_length=*/value_no_match_length_max, + /*null_probability =*/0.01); + ASSERT_OK_AND_ASSIGN(auto match, Constant(value_match)->Generate(1)); + ASSERT_OK_AND_ASSIGN(auto column, Concatenate({non_matches, match})); + + // Use the column as both the key and the payload. + ExecBatch batch({column, column}, num_rows_per_batch_right); + batches_right = + BatchesWithSchema{std::vector(num_batches_right, std::move(batch)), + schema({field("r_key", type), field("r_payload", type)})}; + } + + Declaration left{"exec_batch_source", + ExecBatchSourceNodeOptions(std::move(batches_left.schema), + std::move(batches_left.batches))}; + + Declaration right{"exec_batch_source", + ExecBatchSourceNodeOptions(std::move(batches_right.schema), + std::move(batches_right.batches))}; + + HashJoinNodeOptions join_opts(JoinType::INNER, /*left_keys=*/{"l_key"}, + /*right_keys=*/{"r_key"}); + Declaration join{"hashjoin", {std::move(left), std::move(right)}, join_opts}; + + ASSERT_OK_AND_ASSIGN(auto batches_result, DeclarationToExecBatches(std::move(join))); + Declaration result{"exec_batch_source", + ExecBatchSourceNodeOptions(std::move(batches_result.schema), + std::move(batches_result.batches))}; + + // The row count of hash join should be (number of value_match-es in left side) * + // (number of value_match-es in right side). + AssertRowCountEq(result, + num_batches_left * num_rows_per_batch_left * num_batches_right); + + // All rows should be value_match-es. + auto predicate = and_({equal(field_ref("l_key"), literal(value_match)), + equal(field_ref("l_payload"), literal(value_match)), + equal(field_ref("r_key"), literal(value_match)), + equal(field_ref("r_payload"), literal(value_match))}); + Declaration filter{"filter", {result}, FilterNodeOptions{std::move(predicate)}}; + AssertRowCountEq(std::move(filter), + num_batches_left * num_rows_per_batch_left * num_batches_right); +} + } // namespace acero } // namespace arrow diff --git a/cpp/src/arrow/acero/swiss_join.cc b/cpp/src/arrow/acero/swiss_join.cc index 732deb72861d6..40a4b5886e4bb 100644 --- a/cpp/src/arrow/acero/swiss_join.cc +++ b/cpp/src/arrow/acero/swiss_join.cc @@ -122,7 +122,7 @@ void RowArrayAccessor::Visit(const RowTableImpl& rows, int column_id, int num_ro if (!is_fixed_length_column) { int varbinary_column_id = VarbinaryColumnId(rows.metadata(), column_id); const uint8_t* row_ptr_base = rows.data(2); - const uint32_t* row_offsets = rows.offsets(); + const RowTableImpl::offset_type* row_offsets = rows.offsets(); uint32_t field_offset_within_row, field_length; if (varbinary_column_id == 0) { @@ -173,7 +173,7 @@ void RowArrayAccessor::Visit(const RowTableImpl& rows, int column_id, int num_ro // Case 4: This is a fixed length column in a varying length row // const uint8_t* row_ptr_base = rows.data(2) + field_offset_within_row; - const uint32_t* row_offsets = rows.offsets(); + const RowTableImpl::offset_type* row_offsets = rows.offsets(); for (int i = 0; i < num_rows; ++i) { uint32_t row_id = row_ids[i]; const uint8_t* row_ptr = row_ptr_base + row_offsets[row_id]; @@ -473,17 +473,10 @@ Status RowArrayMerge::PrepareForMerge(RowArray* target, (*first_target_row_id)[sources.size()] = num_rows; } - if (num_bytes > std::numeric_limits::max()) { - return Status::Invalid( - "There are more than 2^32 bytes of key data. Acero cannot " - "process a join of this magnitude"); - } - // Allocate target memory // target->rows_.Clean(); - RETURN_NOT_OK(target->rows_.AppendEmpty(static_cast(num_rows), - static_cast(num_bytes))); + RETURN_NOT_OK(target->rows_.AppendEmpty(static_cast(num_rows), num_bytes)); // In case of varying length rows, // initialize the first row offset for each range of rows corresponding to a @@ -565,15 +558,15 @@ void RowArrayMerge::CopyVaryingLength(RowTableImpl* target, const RowTableImpl& int64_t first_target_row_offset, const int64_t* source_rows_permutation) { int64_t num_source_rows = source.length(); - uint32_t* target_offsets = target->mutable_offsets(); - const uint32_t* source_offsets = source.offsets(); + RowTableImpl::offset_type* target_offsets = target->mutable_offsets(); + const RowTableImpl::offset_type* source_offsets = source.offsets(); // Permutation of source rows is optional. // if (!source_rows_permutation) { int64_t target_row_offset = first_target_row_offset; for (int64_t i = 0; i < num_source_rows; ++i) { - target_offsets[first_target_row_id + i] = static_cast(target_row_offset); + target_offsets[first_target_row_id + i] = target_row_offset; target_row_offset += source_offsets[i + 1] - source_offsets[i]; } // We purposefully skip outputting of N+1 offset, to allow concurrent @@ -593,7 +586,10 @@ void RowArrayMerge::CopyVaryingLength(RowTableImpl* target, const RowTableImpl& int64_t source_row_id = source_rows_permutation[i]; const uint64_t* source_row_ptr = reinterpret_cast( source.data(2) + source_offsets[source_row_id]); - uint32_t length = source_offsets[source_row_id + 1] - source_offsets[source_row_id]; + int64_t length = source_offsets[source_row_id + 1] - source_offsets[source_row_id]; + // Though the row offset is 64-bit, the length of a single row must be 32-bit as + // required by current row table implementation. + DCHECK_LE(length, std::numeric_limits::max()); // Rows should be 64-bit aligned. // In that case we can copy them using a sequence of 64-bit read/writes. @@ -604,7 +600,7 @@ void RowArrayMerge::CopyVaryingLength(RowTableImpl* target, const RowTableImpl& *target_row_ptr++ = *source_row_ptr++; } - target_offsets[first_target_row_id + i] = static_cast(target_row_offset); + target_offsets[first_target_row_id + i] = target_row_offset; target_row_offset += length; } } diff --git a/cpp/src/arrow/acero/swiss_join_avx2.cc b/cpp/src/arrow/acero/swiss_join_avx2.cc index 0888dd8938455..e42b0b40445bf 100644 --- a/cpp/src/arrow/acero/swiss_join_avx2.cc +++ b/cpp/src/arrow/acero/swiss_join_avx2.cc @@ -23,6 +23,9 @@ namespace arrow { namespace acero { +// TODO(GH-43693): The functions in this file are not wired anywhere. We may consider +// actually utilizing them or removing them. + template int RowArrayAccessor::Visit_avx2(const RowTableImpl& rows, int column_id, int num_rows, const uint32_t* row_ids, @@ -45,48 +48,78 @@ int RowArrayAccessor::Visit_avx2(const RowTableImpl& rows, int column_id, int nu if (!is_fixed_length_column) { int varbinary_column_id = VarbinaryColumnId(rows.metadata(), column_id); const uint8_t* row_ptr_base = rows.data(2); - const uint32_t* row_offsets = rows.offsets(); + const RowTableImpl::offset_type* row_offsets = rows.offsets(); + static_assert( + sizeof(RowTableImpl::offset_type) == sizeof(int64_t), + "RowArrayAccessor::Visit_avx2 only supports 64-bit RowTableImpl::offset_type"); if (varbinary_column_id == 0) { // Case 1: This is the first varbinary column // __m256i field_offset_within_row = _mm256_set1_epi32(rows.metadata().fixed_length); __m256i varbinary_end_array_offset = - _mm256_set1_epi32(rows.metadata().varbinary_end_array_offset); + _mm256_set1_epi64x(rows.metadata().varbinary_end_array_offset); for (int i = 0; i < num_rows / unroll; ++i) { + // Load 8 32-bit row ids. __m256i row_id = _mm256_loadu_si256(reinterpret_cast(row_ids) + i); - __m256i row_offset = _mm256_i32gather_epi32( - reinterpret_cast(row_offsets), row_id, sizeof(uint32_t)); + // Gather the lower/higher 4 64-bit row offsets based on the lower/higher 4 32-bit + // row ids. + __m256i row_offset_lo = + _mm256_i32gather_epi64(row_offsets, _mm256_castsi256_si128(row_id), + sizeof(RowTableImpl::offset_type)); + __m256i row_offset_hi = + _mm256_i32gather_epi64(row_offsets, _mm256_extracti128_si256(row_id, 1), + sizeof(RowTableImpl::offset_type)); + // Gather the lower/higher 4 32-bit field lengths based on the lower/higher 4 + // 64-bit row offsets. + __m128i field_length_lo = _mm256_i64gather_epi32( + reinterpret_cast(row_ptr_base), + _mm256_add_epi64(row_offset_lo, varbinary_end_array_offset), 1); + __m128i field_length_hi = _mm256_i64gather_epi32( + reinterpret_cast(row_ptr_base), + _mm256_add_epi64(row_offset_hi, varbinary_end_array_offset), 1); + // The final 8 32-bit field lengths, subtracting the field offset within row. __m256i field_length = _mm256_sub_epi32( - _mm256_i32gather_epi32( - reinterpret_cast(row_ptr_base), - _mm256_add_epi32(row_offset, varbinary_end_array_offset), 1), - field_offset_within_row); + _mm256_set_m128i(field_length_hi, field_length_lo), field_offset_within_row); process_8_values_fn(i * unroll, row_ptr_base, - _mm256_add_epi32(row_offset, field_offset_within_row), + _mm256_add_epi64(row_offset_lo, field_offset_within_row), + _mm256_add_epi64(row_offset_hi, field_offset_within_row), field_length); } } else { // Case 2: This is second or later varbinary column // __m256i varbinary_end_array_offset = - _mm256_set1_epi32(rows.metadata().varbinary_end_array_offset + - sizeof(uint32_t) * (varbinary_column_id - 1)); + _mm256_set1_epi64x(rows.metadata().varbinary_end_array_offset + + sizeof(uint32_t) * (varbinary_column_id - 1)); auto row_ptr_base_i64 = reinterpret_cast(row_ptr_base); for (int i = 0; i < num_rows / unroll; ++i) { + // Load 8 32-bit row ids. __m256i row_id = _mm256_loadu_si256(reinterpret_cast(row_ids) + i); - __m256i row_offset = _mm256_i32gather_epi32( - reinterpret_cast(row_offsets), row_id, sizeof(uint32_t)); - __m256i end_array_offset = - _mm256_add_epi32(row_offset, varbinary_end_array_offset); - - __m256i field_offset_within_row_A = _mm256_i32gather_epi64( - row_ptr_base_i64, _mm256_castsi256_si128(end_array_offset), 1); - __m256i field_offset_within_row_B = _mm256_i32gather_epi64( - row_ptr_base_i64, _mm256_extracti128_si256(end_array_offset, 1), 1); + // Gather the lower/higher 4 64-bit row offsets based on the lower/higher 4 32-bit + // row ids. + __m256i row_offset_lo = + _mm256_i32gather_epi64(row_offsets, _mm256_castsi256_si128(row_id), + sizeof(RowTableImpl::offset_type)); + // Gather the lower/higher 4 32-bit field lengths based on the lower/higher 4 + // 64-bit row offsets. + __m256i row_offset_hi = + _mm256_i32gather_epi64(row_offsets, _mm256_extracti128_si256(row_id, 1), + sizeof(RowTableImpl::offset_type)); + // Prepare the lower/higher 4 64-bit end array offsets based on the lower/higher 4 + // 64-bit row offsets. + __m256i end_array_offset_lo = + _mm256_add_epi64(row_offset_lo, varbinary_end_array_offset); + __m256i end_array_offset_hi = + _mm256_add_epi64(row_offset_hi, varbinary_end_array_offset); + + __m256i field_offset_within_row_A = + _mm256_i64gather_epi64(row_ptr_base_i64, end_array_offset_lo, 1); + __m256i field_offset_within_row_B = + _mm256_i64gather_epi64(row_ptr_base_i64, end_array_offset_hi, 1); field_offset_within_row_A = _mm256_permutevar8x32_epi32( field_offset_within_row_A, _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7)); field_offset_within_row_B = _mm256_permutevar8x32_epi32( @@ -110,8 +143,14 @@ int RowArrayAccessor::Visit_avx2(const RowTableImpl& rows, int column_id, int nu 0x4e); // Swapping low and high 128-bits field_length = _mm256_sub_epi32(field_length, field_offset_within_row); + field_offset_within_row_A = + _mm256_add_epi32(field_offset_within_row_A, alignment_padding); + field_offset_within_row_B = + _mm256_add_epi32(field_offset_within_row_B, alignment_padding); + process_8_values_fn(i * unroll, row_ptr_base, - _mm256_add_epi32(row_offset, field_offset_within_row), + _mm256_add_epi64(row_offset_lo, field_offset_within_row_A), + _mm256_add_epi64(row_offset_hi, field_offset_within_row_B), field_length); } } @@ -119,7 +158,7 @@ int RowArrayAccessor::Visit_avx2(const RowTableImpl& rows, int column_id, int nu if (is_fixed_length_column) { __m256i field_offset_within_row = - _mm256_set1_epi32(rows.metadata().encoded_field_offset( + _mm256_set1_epi64x(rows.metadata().encoded_field_offset( rows.metadata().pos_after_encoding(column_id))); __m256i field_length = _mm256_set1_epi32(rows.metadata().column_metadatas[column_id].fixed_length); @@ -130,24 +169,51 @@ int RowArrayAccessor::Visit_avx2(const RowTableImpl& rows, int column_id, int nu // const uint8_t* row_ptr_base = rows.data(1); for (int i = 0; i < num_rows / unroll; ++i) { + // Load 8 32-bit row ids. __m256i row_id = _mm256_loadu_si256(reinterpret_cast(row_ids) + i); - __m256i row_offset = _mm256_mullo_epi32(row_id, field_length); - __m256i field_offset = _mm256_add_epi32(row_offset, field_offset_within_row); - process_8_values_fn(i * unroll, row_ptr_base, field_offset, field_length); + // Widen the 32-bit row ids to 64-bit and store the lower/higher 4 of them into 2 + // 256-bit registers. + __m256i row_id_lo = _mm256_cvtepi32_epi64(_mm256_castsi256_si128(row_id)); + __m256i row_id_hi = _mm256_cvtepi32_epi64(_mm256_extracti128_si256(row_id, 1)); + // Calculate the lower/higher 4 64-bit row offsets based on the lower/higher 4 + // 64-bit row ids and the fixed field length. + __m256i row_offset_lo = _mm256_mul_epi32(row_id_lo, field_length); + __m256i row_offset_hi = _mm256_mul_epi32(row_id_hi, field_length); + // Calculate the lower/higher 4 64-bit field offsets based on the lower/higher 4 + // 64-bit row offsets and field offset within row. + __m256i field_offset_lo = + _mm256_add_epi64(row_offset_lo, field_offset_within_row); + __m256i field_offset_hi = + _mm256_add_epi64(row_offset_hi, field_offset_within_row); + process_8_values_fn(i * unroll, row_ptr_base, field_offset_lo, field_offset_hi, + field_length); } } else { // Case 4: This is a fixed length column in varying length row // const uint8_t* row_ptr_base = rows.data(2); - const uint32_t* row_offsets = rows.offsets(); + const RowTableImpl::offset_type* row_offsets = rows.offsets(); for (int i = 0; i < num_rows / unroll; ++i) { + // Load 8 32-bit row ids. __m256i row_id = _mm256_loadu_si256(reinterpret_cast(row_ids) + i); - __m256i row_offset = _mm256_i32gather_epi32( - reinterpret_cast(row_offsets), row_id, sizeof(uint32_t)); - __m256i field_offset = _mm256_add_epi32(row_offset, field_offset_within_row); - process_8_values_fn(i * unroll, row_ptr_base, field_offset, field_length); + // Gather the lower/higher 4 64-bit row offsets based on the lower/higher 4 32-bit + // row ids. + __m256i row_offset_lo = + _mm256_i32gather_epi64(row_offsets, _mm256_castsi256_si128(row_id), + sizeof(RowTableImpl::offset_type)); + __m256i row_offset_hi = + _mm256_i32gather_epi64(row_offsets, _mm256_extracti128_si256(row_id, 1), + sizeof(RowTableImpl::offset_type)); + // Calculate the lower/higher 4 64-bit field offsets based on the lower/higher 4 + // 64-bit row offsets and field offset within row. + __m256i field_offset_lo = + _mm256_add_epi64(row_offset_lo, field_offset_within_row); + __m256i field_offset_hi = + _mm256_add_epi64(row_offset_hi, field_offset_within_row); + process_8_values_fn(i * unroll, row_ptr_base, field_offset_lo, field_offset_hi, + field_length); } } } diff --git a/cpp/src/arrow/compute/row/compare_internal.cc b/cpp/src/arrow/compute/row/compare_internal.cc index 98aea9011266c..5e1a87b795202 100644 --- a/cpp/src/arrow/compute/row/compare_internal.cc +++ b/cpp/src/arrow/compute/row/compare_internal.cc @@ -104,18 +104,21 @@ void KeyCompare::CompareBinaryColumnToRowHelper( const uint8_t* rows_right = rows.data(1); for (uint32_t i = first_row_to_compare; i < num_rows_to_compare; ++i) { uint32_t irow_left = use_selection ? sel_left_maybe_null[i] : i; - uint32_t irow_right = left_to_right_map[irow_left]; - uint32_t offset_right = irow_right * fixed_length + offset_within_row; + // irow_right is used to index into row data so promote to the row offset type. + RowTableImpl::offset_type irow_right = left_to_right_map[irow_left]; + RowTableImpl::offset_type offset_right = + irow_right * fixed_length + offset_within_row; match_bytevector[i] = compare_fn(rows_left, rows_right, irow_left, offset_right); } } else { const uint8_t* rows_left = col.data(1); - const uint32_t* offsets_right = rows.offsets(); + const RowTableImpl::offset_type* offsets_right = rows.offsets(); const uint8_t* rows_right = rows.data(2); for (uint32_t i = first_row_to_compare; i < num_rows_to_compare; ++i) { uint32_t irow_left = use_selection ? sel_left_maybe_null[i] : i; uint32_t irow_right = left_to_right_map[irow_left]; - uint32_t offset_right = offsets_right[irow_right] + offset_within_row; + RowTableImpl::offset_type offset_right = + offsets_right[irow_right] + offset_within_row; match_bytevector[i] = compare_fn(rows_left, rows_right, irow_left, offset_right); } } @@ -145,7 +148,7 @@ void KeyCompare::CompareBinaryColumnToRow(uint32_t offset_within_row, offset_within_row, num_processed, num_rows_to_compare, sel_left_maybe_null, left_to_right_map, ctx, col, rows, match_bytevector, [bit_offset](const uint8_t* left_base, const uint8_t* right_base, - uint32_t irow_left, uint32_t offset_right) { + uint32_t irow_left, RowTableImpl::offset_type offset_right) { uint8_t left = bit_util::GetBit(left_base, irow_left + bit_offset) ? 0xff : 0x00; uint8_t right = right_base[offset_right]; @@ -156,7 +159,7 @@ void KeyCompare::CompareBinaryColumnToRow(uint32_t offset_within_row, offset_within_row, num_processed, num_rows_to_compare, sel_left_maybe_null, left_to_right_map, ctx, col, rows, match_bytevector, [](const uint8_t* left_base, const uint8_t* right_base, uint32_t irow_left, - uint32_t offset_right) { + RowTableImpl::offset_type offset_right) { uint8_t left = left_base[irow_left]; uint8_t right = right_base[offset_right]; return left == right ? 0xff : 0; @@ -166,7 +169,7 @@ void KeyCompare::CompareBinaryColumnToRow(uint32_t offset_within_row, offset_within_row, num_processed, num_rows_to_compare, sel_left_maybe_null, left_to_right_map, ctx, col, rows, match_bytevector, [](const uint8_t* left_base, const uint8_t* right_base, uint32_t irow_left, - uint32_t offset_right) { + RowTableImpl::offset_type offset_right) { util::CheckAlignment(left_base); util::CheckAlignment(right_base + offset_right); uint16_t left = reinterpret_cast(left_base)[irow_left]; @@ -178,7 +181,7 @@ void KeyCompare::CompareBinaryColumnToRow(uint32_t offset_within_row, offset_within_row, num_processed, num_rows_to_compare, sel_left_maybe_null, left_to_right_map, ctx, col, rows, match_bytevector, [](const uint8_t* left_base, const uint8_t* right_base, uint32_t irow_left, - uint32_t offset_right) { + RowTableImpl::offset_type offset_right) { util::CheckAlignment(left_base); util::CheckAlignment(right_base + offset_right); uint32_t left = reinterpret_cast(left_base)[irow_left]; @@ -190,7 +193,7 @@ void KeyCompare::CompareBinaryColumnToRow(uint32_t offset_within_row, offset_within_row, num_processed, num_rows_to_compare, sel_left_maybe_null, left_to_right_map, ctx, col, rows, match_bytevector, [](const uint8_t* left_base, const uint8_t* right_base, uint32_t irow_left, - uint32_t offset_right) { + RowTableImpl::offset_type offset_right) { util::CheckAlignment(left_base); util::CheckAlignment(right_base + offset_right); uint64_t left = reinterpret_cast(left_base)[irow_left]; @@ -202,7 +205,7 @@ void KeyCompare::CompareBinaryColumnToRow(uint32_t offset_within_row, offset_within_row, num_processed, num_rows_to_compare, sel_left_maybe_null, left_to_right_map, ctx, col, rows, match_bytevector, [&col](const uint8_t* left_base, const uint8_t* right_base, uint32_t irow_left, - uint32_t offset_right) { + RowTableImpl::offset_type offset_right) { uint32_t length = col.metadata().fixed_length; // Non-zero length guarantees no underflow @@ -241,7 +244,7 @@ void KeyCompare::CompareVarBinaryColumnToRowHelper( const uint32_t* left_to_right_map, LightContext* ctx, const KeyColumnArray& col, const RowTableImpl& rows, uint8_t* match_bytevector) { const uint32_t* offsets_left = col.offsets(); - const uint32_t* offsets_right = rows.offsets(); + const RowTableImpl::offset_type* offsets_right = rows.offsets(); const uint8_t* rows_left = col.data(2); const uint8_t* rows_right = rows.data(2); for (uint32_t i = first_row_to_compare; i < num_rows_to_compare; ++i) { @@ -249,7 +252,7 @@ void KeyCompare::CompareVarBinaryColumnToRowHelper( uint32_t irow_right = left_to_right_map[irow_left]; uint32_t begin_left = offsets_left[irow_left]; uint32_t length_left = offsets_left[irow_left + 1] - begin_left; - uint32_t begin_right = offsets_right[irow_right]; + RowTableImpl::offset_type begin_right = offsets_right[irow_right]; uint32_t length_right; uint32_t offset_within_row; if (!is_first_varbinary_col) { @@ -334,7 +337,13 @@ void KeyCompare::CompareColumnsToRows( const RowTableImpl& rows, bool are_cols_in_encoding_order, uint8_t* out_match_bitvector_maybe_null) { if (num_rows_to_compare == 0) { - *out_num_rows = 0; + if (out_match_bitvector_maybe_null) { + DCHECK_EQ(out_num_rows, nullptr); + DCHECK_EQ(out_sel_left_maybe_same, nullptr); + bit_util::ClearBitmap(out_match_bitvector_maybe_null, 0, num_rows_to_compare); + } else { + *out_num_rows = 0; + } return; } @@ -440,8 +449,8 @@ void KeyCompare::CompareColumnsToRows( match_bytevector_A, match_bitvector); if (out_match_bitvector_maybe_null) { - ARROW_DCHECK(out_num_rows == nullptr); - ARROW_DCHECK(out_sel_left_maybe_same == nullptr); + DCHECK_EQ(out_num_rows, nullptr); + DCHECK_EQ(out_sel_left_maybe_same, nullptr); memcpy(out_match_bitvector_maybe_null, match_bitvector, bit_util::BytesForBits(num_rows_to_compare)); } else { diff --git a/cpp/src/arrow/compute/row/compare_internal.h b/cpp/src/arrow/compute/row/compare_internal.h index a5a109b0b516a..29d7f859e59ee 100644 --- a/cpp/src/arrow/compute/row/compare_internal.h +++ b/cpp/src/arrow/compute/row/compare_internal.h @@ -42,9 +42,30 @@ class ARROW_EXPORT KeyCompare { /*extra=*/util::MiniBatch::kMiniBatchLength; } - // Returns a single 16-bit selection vector of rows that failed comparison. - // If there is input selection on the left, the resulting selection is a filtered image - // of input selection. + /// \brief Compare a batch of rows in columnar format to the specified rows in row + /// format. + /// + /// The comparison result is populated in either a 16-bit selection vector of rows that + /// failed comparison, or a match bitvector with 1 for matched rows and 0 otherwise. + /// + /// @param num_rows_to_compare The number of rows to compare. + /// @param sel_left_maybe_null Optional input selection vector on the left, the + /// comparison is only performed on the selected rows. Null if all rows in + /// `left_to_right_map` are to be compared. + /// @param left_to_right_map The mapping from the left to the right rows. Left row `i` + /// in `cols` is compared to right row `left_to_right_map[i]` in `row`. + /// @param ctx The light context needed for the comparison. + /// @param out_num_rows The number of rows that failed comparison. Must be null if + /// `out_match_bitvector_maybe_null` is not null. + /// @param out_sel_left_maybe_same The selection vector of rows that failed comparison. + /// Can be the same as `sel_left_maybe_null` for in-place update. Must be null if + /// `out_match_bitvector_maybe_null` is not null. + /// @param cols The left rows in columnar format to compare. + /// @param rows The right rows in row format to compare. + /// @param are_cols_in_encoding_order Whether the columns are in encoding order. + /// @param out_match_bitvector_maybe_null The optional output match bitvector, 1 for + /// matched rows and 0 otherwise. Won't be populated if `out_num_rows` and + /// `out_sel_left_maybe_same` are not null. static void CompareColumnsToRows( uint32_t num_rows_to_compare, const uint16_t* sel_left_maybe_null, const uint32_t* left_to_right_map, LightContext* ctx, uint32_t* out_num_rows, diff --git a/cpp/src/arrow/compute/row/compare_internal_avx2.cc b/cpp/src/arrow/compute/row/compare_internal_avx2.cc index 23238a3691c8a..96eed6fc03a2a 100644 --- a/cpp/src/arrow/compute/row/compare_internal_avx2.cc +++ b/cpp/src/arrow/compute/row/compare_internal_avx2.cc @@ -180,40 +180,6 @@ uint32_t KeyCompare::NullUpdateColumnToRowImp_avx2( } } -namespace { - -// Intrinsics `_mm256_i32gather_epi32/64` treat the `vindex` as signed integer, and we -// are using `uint32_t` to represent the offset, in range of [0, 4G), within the row -// table. When the offset is larger than `0x80000000` (2GB), those intrinsics will treat -// it as negative offset and gather the data from undesired address. To avoid this issue, -// we normalize the addresses by translating `base` `0x80000000` higher, and `offset` -// `0x80000000` lower. This way, the offset is always in range of [-2G, 2G) and those -// intrinsics are safe. - -constexpr uint64_t kTwoGB = 0x80000000ull; - -template -inline __m256i UnsignedOffsetSafeGather32(int const* base, __m256i offset) { - int const* normalized_base = base + kTwoGB / sizeof(int); - __m256i normalized_offset = - _mm256_sub_epi32(offset, _mm256_set1_epi32(static_cast(kTwoGB / kScale))); - return _mm256_i32gather_epi32(normalized_base, normalized_offset, - static_cast(kScale)); -} - -template -inline __m256i UnsignedOffsetSafeGather64(arrow::util::int64_for_gather_t const* base, - __m128i offset) { - arrow::util::int64_for_gather_t const* normalized_base = - base + kTwoGB / sizeof(arrow::util::int64_for_gather_t); - __m128i normalized_offset = - _mm_sub_epi32(offset, _mm_set1_epi32(static_cast(kTwoGB / kScale))); - return _mm256_i32gather_epi64(normalized_base, normalized_offset, - static_cast(kScale)); -} - -} // namespace - template uint32_t KeyCompare::CompareBinaryColumnToRowHelper_avx2( uint32_t offset_within_row, uint32_t num_rows_to_compare, @@ -240,12 +206,26 @@ uint32_t KeyCompare::CompareBinaryColumnToRowHelper_avx2( _mm256_loadu_si256(reinterpret_cast(left_to_right_map) + i); } - __m256i offset_right = - _mm256_mullo_epi32(irow_right, _mm256_set1_epi32(fixed_length)); - offset_right = _mm256_add_epi32(offset_right, _mm256_set1_epi32(offset_within_row)); - - reinterpret_cast(match_bytevector)[i] = - compare8_fn(rows_left, rows_right, i * unroll, irow_left, offset_right); + // Widen the 32-bit row ids to 64-bit and store the first/last 4 of them into 2 + // 256-bit registers. + __m256i irow_right_lo = _mm256_cvtepi32_epi64(_mm256_castsi256_si128(irow_right)); + __m256i irow_right_hi = + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(irow_right, 1)); + // Calculate the lower/higher 4 64-bit row offsets based on the lower/higher 4 + // 64-bit row ids and the fixed length. + __m256i offset_right_lo = + _mm256_mul_epi32(irow_right_lo, _mm256_set1_epi64x(fixed_length)); + __m256i offset_right_hi = + _mm256_mul_epi32(irow_right_hi, _mm256_set1_epi64x(fixed_length)); + // Calculate the lower/higher 4 64-bit field offsets based on the lower/higher 4 + // 64-bit row offsets and field offset within row. + offset_right_lo = + _mm256_add_epi64(offset_right_lo, _mm256_set1_epi64x(offset_within_row)); + offset_right_hi = + _mm256_add_epi64(offset_right_hi, _mm256_set1_epi64x(offset_within_row)); + + reinterpret_cast(match_bytevector)[i] = compare8_fn( + rows_left, rows_right, i * unroll, irow_left, offset_right_lo, offset_right_hi); if (!use_selection) { irow_left = _mm256_add_epi32(irow_left, _mm256_set1_epi32(8)); @@ -254,7 +234,7 @@ uint32_t KeyCompare::CompareBinaryColumnToRowHelper_avx2( return num_rows_to_compare - (num_rows_to_compare % unroll); } else { const uint8_t* rows_left = col.data(1); - const uint32_t* offsets_right = rows.offsets(); + const RowTableImpl::offset_type* offsets_right = rows.offsets(); const uint8_t* rows_right = rows.data(2); constexpr uint32_t unroll = 8; __m256i irow_left = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); @@ -270,12 +250,29 @@ uint32_t KeyCompare::CompareBinaryColumnToRowHelper_avx2( irow_right = _mm256_loadu_si256(reinterpret_cast(left_to_right_map) + i); } - __m256i offset_right = - UnsignedOffsetSafeGather32<4>((int const*)offsets_right, irow_right); - offset_right = _mm256_add_epi32(offset_right, _mm256_set1_epi32(offset_within_row)); - reinterpret_cast(match_bytevector)[i] = - compare8_fn(rows_left, rows_right, i * unroll, irow_left, offset_right); + static_assert(sizeof(RowTableImpl::offset_type) == sizeof(int64_t), + "KeyCompare::CompareBinaryColumnToRowHelper_avx2 only supports " + "64-bit RowTableImpl::offset_type"); + auto offsets_right_i64 = + reinterpret_cast(offsets_right); + // Gather the lower/higher 4 64-bit row offsets based on the lower/higher 4 32-bit + // row ids. + __m256i offset_right_lo = + _mm256_i32gather_epi64(offsets_right_i64, _mm256_castsi256_si128(irow_right), + sizeof(RowTableImpl::offset_type)); + __m256i offset_right_hi = _mm256_i32gather_epi64( + offsets_right_i64, _mm256_extracti128_si256(irow_right, 1), + sizeof(RowTableImpl::offset_type)); + // Calculate the lower/higher 4 64-bit field offsets based on the lower/higher 4 + // 64-bit row offsets and field offset within row. + offset_right_lo = + _mm256_add_epi64(offset_right_lo, _mm256_set1_epi64x(offset_within_row)); + offset_right_hi = + _mm256_add_epi64(offset_right_hi, _mm256_set1_epi64x(offset_within_row)); + + reinterpret_cast(match_bytevector)[i] = compare8_fn( + rows_left, rows_right, i * unroll, irow_left, offset_right_lo, offset_right_hi); if (!use_selection) { irow_left = _mm256_add_epi32(irow_left, _mm256_set1_epi32(8)); @@ -287,8 +284,8 @@ uint32_t KeyCompare::CompareBinaryColumnToRowHelper_avx2( template inline uint64_t CompareSelected8_avx2(const uint8_t* left_base, const uint8_t* right_base, - __m256i irow_left, __m256i offset_right, - int bit_offset = 0) { + __m256i irow_left, __m256i offset_right_lo, + __m256i offset_right_hi, int bit_offset = 0) { __m256i left; switch (column_width) { case 0: { @@ -315,7 +312,9 @@ inline uint64_t CompareSelected8_avx2(const uint8_t* left_base, const uint8_t* r ARROW_DCHECK(false); } - __m256i right = UnsignedOffsetSafeGather32<1>((int const*)right_base, offset_right); + __m128i right_lo = _mm256_i64gather_epi32((int const*)right_base, offset_right_lo, 1); + __m128i right_hi = _mm256_i64gather_epi32((int const*)right_base, offset_right_hi, 1); + __m256i right = _mm256_set_m128i(right_hi, right_lo); if (column_width != sizeof(uint32_t)) { constexpr uint32_t mask = column_width == 0 || column_width == 1 ? 0xff : 0xffff; right = _mm256_and_si256(right, _mm256_set1_epi32(mask)); @@ -333,8 +332,8 @@ inline uint64_t CompareSelected8_avx2(const uint8_t* left_base, const uint8_t* r template inline uint64_t Compare8_avx2(const uint8_t* left_base, const uint8_t* right_base, - uint32_t irow_left_first, __m256i offset_right, - int bit_offset = 0) { + uint32_t irow_left_first, __m256i offset_right_lo, + __m256i offset_right_hi, int bit_offset = 0) { __m256i left; switch (column_width) { case 0: { @@ -364,7 +363,9 @@ inline uint64_t Compare8_avx2(const uint8_t* left_base, const uint8_t* right_bas ARROW_DCHECK(false); } - __m256i right = UnsignedOffsetSafeGather32<1>((int const*)right_base, offset_right); + __m128i right_lo = _mm256_i64gather_epi32((int const*)right_base, offset_right_lo, 1); + __m128i right_hi = _mm256_i64gather_epi32((int const*)right_base, offset_right_hi, 1); + __m256i right = _mm256_set_m128i(right_hi, right_lo); if (column_width != sizeof(uint32_t)) { constexpr uint32_t mask = column_width == 0 || column_width == 1 ? 0xff : 0xffff; right = _mm256_and_si256(right, _mm256_set1_epi32(mask)); @@ -383,7 +384,7 @@ inline uint64_t Compare8_avx2(const uint8_t* left_base, const uint8_t* right_bas template inline uint64_t Compare8_64bit_avx2(const uint8_t* left_base, const uint8_t* right_base, __m256i irow_left, uint32_t irow_left_first, - __m256i offset_right) { + __m256i offset_right_lo, __m256i offset_right_hi) { auto left_base_i64 = reinterpret_cast(left_base); __m256i left_lo, left_hi; @@ -400,10 +401,8 @@ inline uint64_t Compare8_64bit_avx2(const uint8_t* left_base, const uint8_t* rig } auto right_base_i64 = reinterpret_cast(right_base); - __m256i right_lo = - UnsignedOffsetSafeGather64<1>(right_base_i64, _mm256_castsi256_si128(offset_right)); - __m256i right_hi = UnsignedOffsetSafeGather64<1>( - right_base_i64, _mm256_extracti128_si256(offset_right, 1)); + __m256i right_lo = _mm256_i64gather_epi64(right_base_i64, offset_right_lo, 1); + __m256i right_hi = _mm256_i64gather_epi64(right_base_i64, offset_right_hi, 1); uint32_t result_lo = _mm256_movemask_epi8(_mm256_cmpeq_epi64(left_lo, right_lo)); uint32_t result_hi = _mm256_movemask_epi8(_mm256_cmpeq_epi64(left_hi, right_hi)); return result_lo | (static_cast(result_hi) << 32); @@ -412,13 +411,19 @@ inline uint64_t Compare8_64bit_avx2(const uint8_t* left_base, const uint8_t* rig template inline uint64_t Compare8_Binary_avx2(uint32_t length, const uint8_t* left_base, const uint8_t* right_base, __m256i irow_left, - uint32_t irow_left_first, __m256i offset_right) { + uint32_t irow_left_first, __m256i offset_right_lo, + __m256i offset_right_hi) { uint32_t irow_left_array[8]; - uint32_t offset_right_array[8]; + RowTableImpl::offset_type offset_right_array[8]; if (use_selection) { _mm256_storeu_si256(reinterpret_cast<__m256i*>(irow_left_array), irow_left); } - _mm256_storeu_si256(reinterpret_cast<__m256i*>(offset_right_array), offset_right); + static_assert( + sizeof(RowTableImpl::offset_type) * 4 == sizeof(__m256i), + "Unexpected RowTableImpl::offset_type size in KeyCompare::Compare8_Binary_avx2"); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(offset_right_array), offset_right_lo); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(&offset_right_array[4]), + offset_right_hi); // Non-zero length guarantees no underflow int32_t num_loops_less_one = (static_cast(length) + 31) / 32 - 1; @@ -463,13 +468,14 @@ uint32_t KeyCompare::CompareBinaryColumnToRowImp_avx2( offset_within_row, num_rows_to_compare, sel_left_maybe_null, left_to_right_map, ctx, col, rows, match_bytevector, [bit_offset](const uint8_t* left_base, const uint8_t* right_base, - uint32_t irow_left_base, __m256i irow_left, __m256i offset_right) { + uint32_t irow_left_base, __m256i irow_left, __m256i offset_right_lo, + __m256i offset_right_hi) { if (use_selection) { return CompareSelected8_avx2<0>(left_base, right_base, irow_left, - offset_right, bit_offset); + offset_right_lo, offset_right_hi, bit_offset); } else { - return Compare8_avx2<0>(left_base, right_base, irow_left_base, offset_right, - bit_offset); + return Compare8_avx2<0>(left_base, right_base, irow_left_base, + offset_right_lo, offset_right_hi, bit_offset); } }); } else if (col_width == 1) { @@ -477,12 +483,13 @@ uint32_t KeyCompare::CompareBinaryColumnToRowImp_avx2( offset_within_row, num_rows_to_compare, sel_left_maybe_null, left_to_right_map, ctx, col, rows, match_bytevector, [](const uint8_t* left_base, const uint8_t* right_base, uint32_t irow_left_base, - __m256i irow_left, __m256i offset_right) { + __m256i irow_left, __m256i offset_right_lo, __m256i offset_right_hi) { if (use_selection) { return CompareSelected8_avx2<1>(left_base, right_base, irow_left, - offset_right); + offset_right_lo, offset_right_hi); } else { - return Compare8_avx2<1>(left_base, right_base, irow_left_base, offset_right); + return Compare8_avx2<1>(left_base, right_base, irow_left_base, + offset_right_lo, offset_right_hi); } }); } else if (col_width == 2) { @@ -490,12 +497,13 @@ uint32_t KeyCompare::CompareBinaryColumnToRowImp_avx2( offset_within_row, num_rows_to_compare, sel_left_maybe_null, left_to_right_map, ctx, col, rows, match_bytevector, [](const uint8_t* left_base, const uint8_t* right_base, uint32_t irow_left_base, - __m256i irow_left, __m256i offset_right) { + __m256i irow_left, __m256i offset_right_lo, __m256i offset_right_hi) { if (use_selection) { return CompareSelected8_avx2<2>(left_base, right_base, irow_left, - offset_right); + offset_right_lo, offset_right_hi); } else { - return Compare8_avx2<2>(left_base, right_base, irow_left_base, offset_right); + return Compare8_avx2<2>(left_base, right_base, irow_left_base, + offset_right_lo, offset_right_hi); } }); } else if (col_width == 4) { @@ -503,12 +511,13 @@ uint32_t KeyCompare::CompareBinaryColumnToRowImp_avx2( offset_within_row, num_rows_to_compare, sel_left_maybe_null, left_to_right_map, ctx, col, rows, match_bytevector, [](const uint8_t* left_base, const uint8_t* right_base, uint32_t irow_left_base, - __m256i irow_left, __m256i offset_right) { + __m256i irow_left, __m256i offset_right_lo, __m256i offset_right_hi) { if (use_selection) { return CompareSelected8_avx2<4>(left_base, right_base, irow_left, - offset_right); + offset_right_lo, offset_right_hi); } else { - return Compare8_avx2<4>(left_base, right_base, irow_left_base, offset_right); + return Compare8_avx2<4>(left_base, right_base, irow_left_base, + offset_right_lo, offset_right_hi); } }); } else if (col_width == 8) { @@ -516,19 +525,22 @@ uint32_t KeyCompare::CompareBinaryColumnToRowImp_avx2( offset_within_row, num_rows_to_compare, sel_left_maybe_null, left_to_right_map, ctx, col, rows, match_bytevector, [](const uint8_t* left_base, const uint8_t* right_base, uint32_t irow_left_base, - __m256i irow_left, __m256i offset_right) { + __m256i irow_left, __m256i offset_right_lo, __m256i offset_right_hi) { return Compare8_64bit_avx2(left_base, right_base, irow_left, - irow_left_base, offset_right); + irow_left_base, offset_right_lo, + offset_right_hi); }); } else { return CompareBinaryColumnToRowHelper_avx2( offset_within_row, num_rows_to_compare, sel_left_maybe_null, left_to_right_map, ctx, col, rows, match_bytevector, [&col](const uint8_t* left_base, const uint8_t* right_base, - uint32_t irow_left_base, __m256i irow_left, __m256i offset_right) { + uint32_t irow_left_base, __m256i irow_left, __m256i offset_right_lo, + __m256i offset_right_hi) { uint32_t length = col.metadata().fixed_length; - return Compare8_Binary_avx2( - length, left_base, right_base, irow_left, irow_left_base, offset_right); + return Compare8_Binary_avx2(length, left_base, right_base, + irow_left, irow_left_base, + offset_right_lo, offset_right_hi); }); } } @@ -541,7 +553,7 @@ void KeyCompare::CompareVarBinaryColumnToRowImp_avx2( LightContext* ctx, const KeyColumnArray& col, const RowTableImpl& rows, uint8_t* match_bytevector) { const uint32_t* offsets_left = col.offsets(); - const uint32_t* offsets_right = rows.offsets(); + const RowTableImpl::offset_type* offsets_right = rows.offsets(); const uint8_t* rows_left = col.data(2); const uint8_t* rows_right = rows.data(2); for (uint32_t i = 0; i < num_rows_to_compare; ++i) { @@ -549,7 +561,7 @@ void KeyCompare::CompareVarBinaryColumnToRowImp_avx2( uint32_t irow_right = left_to_right_map[irow_left]; uint32_t begin_left = offsets_left[irow_left]; uint32_t length_left = offsets_left[irow_left + 1] - begin_left; - uint32_t begin_right = offsets_right[irow_right]; + RowTableImpl::offset_type begin_right = offsets_right[irow_right]; uint32_t length_right; uint32_t offset_within_row; if (!is_first_varbinary_col) { diff --git a/cpp/src/arrow/compute/row/compare_test.cc b/cpp/src/arrow/compute/row/compare_test.cc index 22af7e067d855..5e8ee7c58a782 100644 --- a/cpp/src/arrow/compute/row/compare_test.cc +++ b/cpp/src/arrow/compute/row/compare_test.cc @@ -27,7 +27,12 @@ namespace arrow { namespace compute { using arrow::bit_util::BytesForBits; +using arrow::bit_util::GetBit; +using arrow::gen::Constant; +using arrow::gen::Random; +using arrow::internal::CountSetBits; using arrow::internal::CpuInfo; +using arrow::random::kSeedMax; using arrow::random::RandomArrayGenerator; using arrow::util::MiniBatch; using arrow::util::TempVectorStack; @@ -106,7 +111,7 @@ TEST(KeyCompare, CompareColumnsToRowsCuriousFSB) { true, match_bitvector.data()); for (int i = 0; i < num_rows; ++i) { SCOPED_TRACE(i); - ASSERT_EQ(arrow::bit_util::GetBit(match_bitvector.data(), i), i != 6); + ASSERT_EQ(GetBit(match_bitvector.data(), i), i != 6); } } } @@ -166,9 +171,111 @@ TEST(KeyCompare, CompareColumnsToRowsTempStackUsage) { } } +namespace { + +Result MakeRowTableFromExecBatch(const ExecBatch& batch) { + RowTableImpl row_table; + + std::vector column_metadatas; + RETURN_NOT_OK(ColumnMetadatasFromExecBatch(batch, &column_metadatas)); + RowTableMetadata table_metadata; + table_metadata.FromColumnMetadataVector(column_metadatas, sizeof(uint64_t), + sizeof(uint64_t)); + RETURN_NOT_OK(row_table.Init(default_memory_pool(), table_metadata)); + std::vector row_ids(batch.length); + std::iota(row_ids.begin(), row_ids.end(), 0); + RowTableEncoder row_encoder; + row_encoder.Init(column_metadatas, sizeof(uint64_t), sizeof(uint64_t)); + std::vector column_arrays; + RETURN_NOT_OK(ColumnArraysFromExecBatch(batch, &column_arrays)); + row_encoder.PrepareEncodeSelected(0, batch.length, column_arrays); + RETURN_NOT_OK(row_encoder.EncodeSelected( + &row_table, static_cast(batch.length), row_ids.data())); + + return row_table; +} + +Result RepeatRowTableUntil(const RowTableImpl& seed, int64_t num_rows) { + RowTableImpl row_table; + + RETURN_NOT_OK(row_table.Init(default_memory_pool(), seed.metadata())); + // Append the seed row table repeatedly to grow the row table to big enough. + while (row_table.length() < num_rows) { + RETURN_NOT_OK(row_table.AppendSelectionFrom(seed, + static_cast(seed.length()), + /*source_row_ids=*/NULLPTR)); + } + + return row_table; +} + +void AssertCompareColumnsToRowsAllMatch(const std::vector& columns, + const RowTableImpl& row_table, + const std::vector& row_ids_to_compare) { + uint32_t num_rows_to_compare = static_cast(row_ids_to_compare.size()); + + TempVectorStack stack; + ASSERT_OK( + stack.Init(default_memory_pool(), + KeyCompare::CompareColumnsToRowsTempStackUsage(num_rows_to_compare))); + LightContext ctx{CpuInfo::GetInstance()->hardware_flags(), &stack}; + + { + // No selection, output no match row ids. + uint32_t num_rows_no_match; + std::vector row_ids_out(num_rows_to_compare); + KeyCompare::CompareColumnsToRows(num_rows_to_compare, /*sel_left_maybe_null=*/NULLPTR, + row_ids_to_compare.data(), &ctx, &num_rows_no_match, + row_ids_out.data(), columns, row_table, + /*are_cols_in_encoding_order=*/true, + /*out_match_bitvector_maybe_null=*/NULLPTR); + ASSERT_EQ(num_rows_no_match, 0); + } + + { + // No selection, output match bit vector. + std::vector match_bitvector(BytesForBits(num_rows_to_compare)); + KeyCompare::CompareColumnsToRows( + num_rows_to_compare, /*sel_left_maybe_null=*/NULLPTR, row_ids_to_compare.data(), + &ctx, + /*out_num_rows=*/NULLPTR, /*out_sel_left_maybe_same=*/NULLPTR, columns, row_table, + /*are_cols_in_encoding_order=*/true, match_bitvector.data()); + ASSERT_EQ(CountSetBits(match_bitvector.data(), 0, num_rows_to_compare), + num_rows_to_compare); + } + + std::vector selection_left(num_rows_to_compare); + std::iota(selection_left.begin(), selection_left.end(), 0); + + { + // With selection, output no match row ids. + uint32_t num_rows_no_match; + std::vector row_ids_out(num_rows_to_compare); + KeyCompare::CompareColumnsToRows(num_rows_to_compare, selection_left.data(), + row_ids_to_compare.data(), &ctx, &num_rows_no_match, + row_ids_out.data(), columns, row_table, + /*are_cols_in_encoding_order=*/true, + /*out_match_bitvector_maybe_null=*/NULLPTR); + ASSERT_EQ(num_rows_no_match, 0); + } + + { + // With selection, output match bit vector. + std::vector match_bitvector(BytesForBits(num_rows_to_compare)); + KeyCompare::CompareColumnsToRows( + num_rows_to_compare, selection_left.data(), row_ids_to_compare.data(), &ctx, + /*out_num_rows=*/NULLPTR, /*out_sel_left_maybe_same=*/NULLPTR, columns, row_table, + /*are_cols_in_encoding_order=*/true, match_bitvector.data()); + ASSERT_EQ(CountSetBits(match_bitvector.data(), 0, num_rows_to_compare), + num_rows_to_compare); + } +} + +} // namespace + // Compare columns to rows at offsets over 2GB within a row table. // Certain AVX2 instructions may behave unexpectedly causing troubles like GH-41813. -TEST(KeyCompare, LARGE_MEMORY_TEST(CompareColumnsToRowsLarge)) { +TEST(KeyCompare, LARGE_MEMORY_TEST(CompareColumnsToRowsOver2GB)) { if constexpr (sizeof(void*) == 4) { GTEST_SKIP() << "Test only works on 64-bit platforms"; } @@ -176,128 +283,194 @@ TEST(KeyCompare, LARGE_MEMORY_TEST(CompareColumnsToRowsLarge)) { // The idea of this case is to create a row table using several fixed length columns and // one var length column (so the row is hence var length and has offset buffer), with // the overall data size exceeding 2GB. Then compare each row with itself. - constexpr int64_t two_gb = 2ll * 1024ll * 1024ll * 1024ll; + constexpr int64_t k2GB = 2ll * 1024ll * 1024ll * 1024ll; // The compare function requires the row id of the left column to be uint16_t, hence the // number of rows. constexpr int64_t num_rows = std::numeric_limits::max() + 1; const std::vector> fixed_length_types{uint64(), uint32()}; // The var length column should be a little smaller than 2GB to workaround the capacity // limitation in the var length builder. - constexpr int32_t var_length = two_gb / num_rows - 1; + constexpr int32_t var_length = k2GB / num_rows - 1; auto row_size = std::accumulate(fixed_length_types.begin(), fixed_length_types.end(), static_cast(var_length), [](int64_t acc, const std::shared_ptr& type) { return acc + type->byte_width(); }); // The overall size should be larger than 2GB. - ASSERT_GT(row_size * num_rows, two_gb); - - MemoryPool* pool = default_memory_pool(); + ASSERT_GT(row_size * num_rows, k2GB); - // The left side columns. - std::vector columns_left; + // The left side batch. ExecBatch batch_left; { std::vector values; // Several fixed length arrays containing random content. for (const auto& type : fixed_length_types) { - ASSERT_OK_AND_ASSIGN(auto value, ::arrow::gen::Random(type)->Generate(num_rows)); + ASSERT_OK_AND_ASSIGN(auto value, Random(type)->Generate(num_rows)); values.push_back(std::move(value)); } // A var length array containing 'X' repeated var_length times. - ASSERT_OK_AND_ASSIGN(auto value_var_length, - ::arrow::gen::Constant( - std::make_shared(std::string(var_length, 'X'))) - ->Generate(num_rows)); + ASSERT_OK_AND_ASSIGN( + auto value_var_length, + Constant(std::make_shared(std::string(var_length, 'X'))) + ->Generate(num_rows)); values.push_back(std::move(value_var_length)); batch_left = ExecBatch(std::move(values), num_rows); - ASSERT_OK(ColumnArraysFromExecBatch(batch_left, &columns_left)); } + // The left side columns. + std::vector columns_left; + ASSERT_OK(ColumnArraysFromExecBatch(batch_left, &columns_left)); + // The right side row table. - RowTableImpl row_table_right; - { - // Encode the row table with the left columns. - std::vector column_metadatas; - ASSERT_OK(ColumnMetadatasFromExecBatch(batch_left, &column_metadatas)); - RowTableMetadata table_metadata; - table_metadata.FromColumnMetadataVector(column_metadatas, sizeof(uint64_t), - sizeof(uint64_t)); - ASSERT_OK(row_table_right.Init(pool, table_metadata)); - std::vector row_ids(num_rows); - std::iota(row_ids.begin(), row_ids.end(), 0); - RowTableEncoder row_encoder; - row_encoder.Init(column_metadatas, sizeof(uint64_t), sizeof(uint64_t)); - row_encoder.PrepareEncodeSelected(0, num_rows, columns_left); - ASSERT_OK(row_encoder.EncodeSelected( - &row_table_right, static_cast(num_rows), row_ids.data())); - - // The row table must contain an offset buffer. - ASSERT_NE(row_table_right.offsets(), NULLPTR); - // The whole point of this test. - ASSERT_GT(row_table_right.offsets()[num_rows - 1], two_gb); - } + ASSERT_OK_AND_ASSIGN(RowTableImpl row_table_right, + MakeRowTableFromExecBatch(batch_left)); + // The row table must contain an offset buffer. + ASSERT_NE(row_table_right.data(2), NULLPTR); + // The whole point of this test. + ASSERT_GT(row_table_right.offsets()[num_rows - 1], k2GB); // The rows to compare. std::vector row_ids_to_compare(num_rows); std::iota(row_ids_to_compare.begin(), row_ids_to_compare.end(), 0); - TempVectorStack stack; - ASSERT_OK(stack.Init(pool, KeyCompare::CompareColumnsToRowsTempStackUsage(num_rows))); - LightContext ctx{CpuInfo::GetInstance()->hardware_flags(), &stack}; + AssertCompareColumnsToRowsAllMatch(columns_left, row_table_right, row_ids_to_compare); +} - { - // No selection, output no match row ids. - uint32_t num_rows_no_match; - std::vector row_ids_out(num_rows); - KeyCompare::CompareColumnsToRows(num_rows, /*sel_left_maybe_null=*/NULLPTR, - row_ids_to_compare.data(), &ctx, &num_rows_no_match, - row_ids_out.data(), columns_left, row_table_right, - /*are_cols_in_encoding_order=*/true, - /*out_match_bitvector_maybe_null=*/NULLPTR); - ASSERT_EQ(num_rows_no_match, 0); +// GH-43495: Compare fixed length columns to rows over 4GB within a row table. +TEST(KeyCompare, LARGE_MEMORY_TEST(CompareColumnsToRowsOver4GBFixedLength)) { + if constexpr (sizeof(void*) == 4) { + GTEST_SKIP() << "Test only works on 64-bit platforms"; } + // The idea of this case is to create a row table using one fixed length column (so the + // row is hence fixed length), with more than 4GB data. Then compare the rows located at + // over 4GB. + + // A small batch to append to the row table repeatedly to grow the row table to big + // enough. + constexpr int64_t num_rows_batch = std::numeric_limits::max(); + constexpr int fixed_length = 256; + + // The size of the row table is one batch larger than 4GB, and we'll compare the last + // num_rows_batch rows. + constexpr int64_t k4GB = 4ll * 1024 * 1024 * 1024; + constexpr int64_t num_rows_row_table = + (k4GB / (fixed_length * num_rows_batch) + 1) * num_rows_batch; + static_assert(num_rows_row_table < std::numeric_limits::max(), + "row table length must be less than uint32 max"); + static_assert(num_rows_row_table * fixed_length > k4GB, + "row table size must be greater than 4GB"); + + // The left side batch with num_rows_batch rows. + ExecBatch batch_left; { - // No selection, output match bit vector. - std::vector match_bitvector(BytesForBits(num_rows)); - KeyCompare::CompareColumnsToRows( - num_rows, /*sel_left_maybe_null=*/NULLPTR, row_ids_to_compare.data(), &ctx, - /*out_num_rows=*/NULLPTR, /*out_sel_left_maybe_same=*/NULLPTR, columns_left, - row_table_right, - /*are_cols_in_encoding_order=*/true, match_bitvector.data()); - ASSERT_EQ(arrow::internal::CountSetBits(match_bitvector.data(), 0, num_rows), - num_rows); + std::vector values; + + // A fixed length array containing random values. + ASSERT_OK_AND_ASSIGN( + auto value_fixed_length, + Random(fixed_size_binary(fixed_length))->Generate(num_rows_batch)); + values.push_back(std::move(value_fixed_length)); + + batch_left = ExecBatch(std::move(values), num_rows_batch); } - std::vector selection_left(num_rows); - std::iota(selection_left.begin(), selection_left.end(), 0); + // The left side columns with num_rows_batch rows. + std::vector columns_left; + ASSERT_OK(ColumnArraysFromExecBatch(batch_left, &columns_left)); + + // The right side row table with num_rows_row_table rows. + ASSERT_OK_AND_ASSIGN( + RowTableImpl row_table_right, + RepeatRowTableUntil(MakeRowTableFromExecBatch(batch_left).ValueUnsafe(), + num_rows_row_table)); + // The row table must not contain a third buffer. + ASSERT_EQ(row_table_right.data(2), NULLPTR); + // The row data must be greater than 4GB. + ASSERT_GT(row_table_right.buffer_size(1), k4GB); + + // The rows to compare: the last num_rows_batch rows in the row table VS. the whole + // batch. + std::vector row_ids_to_compare(num_rows_batch); + std::iota(row_ids_to_compare.begin(), row_ids_to_compare.end(), + static_cast(num_rows_row_table - num_rows_batch)); + + AssertCompareColumnsToRowsAllMatch(columns_left, row_table_right, row_ids_to_compare); +} - { - // With selection, output no match row ids. - uint32_t num_rows_no_match; - std::vector row_ids_out(num_rows); - KeyCompare::CompareColumnsToRows(num_rows, selection_left.data(), - row_ids_to_compare.data(), &ctx, &num_rows_no_match, - row_ids_out.data(), columns_left, row_table_right, - /*are_cols_in_encoding_order=*/true, - /*out_match_bitvector_maybe_null=*/NULLPTR); - ASSERT_EQ(num_rows_no_match, 0); +// GH-43495: Compare var length columns to rows at offset over 4GB within a row table. +TEST(KeyCompare, LARGE_MEMORY_TEST(CompareColumnsToRowsOver4GBVarLength)) { + if constexpr (sizeof(void*) == 4) { + GTEST_SKIP() << "Test only works on 64-bit platforms"; } + // The idea of this case is to create a row table using one fixed length column and one + // var length column (so the row is hence var length and has offset buffer), with more + // than 4GB data. Then compare the rows located at over 4GB. + + // A small batch to append to the row table repeatedly to grow the row table to big + // enough. + constexpr int64_t num_rows_batch = std::numeric_limits::max(); + constexpr int fixed_length = 128; + // Involve some small randomness in the var length column. + constexpr int var_length_min = 128; + constexpr int var_length_max = 129; + constexpr double null_probability = 0.01; + + // The size of the row table is one batch larger than 4GB, and we'll compare the last + // num_rows_batch rows. + constexpr int64_t k4GB = 4ll * 1024 * 1024 * 1024; + constexpr int64_t size_row_min = fixed_length + var_length_min; + constexpr int64_t num_rows_row_table = + (k4GB / (size_row_min * num_rows_batch) + 1) * num_rows_batch; + static_assert(num_rows_row_table < std::numeric_limits::max(), + "row table length must be less than uint32 max"); + static_assert(num_rows_row_table * size_row_min > k4GB, + "row table size must be greater than 4GB"); + + // The left side batch with num_rows_batch rows. + ExecBatch batch_left; { - // With selection, output match bit vector. - std::vector match_bitvector(BytesForBits(num_rows)); - KeyCompare::CompareColumnsToRows( - num_rows, selection_left.data(), row_ids_to_compare.data(), &ctx, - /*out_num_rows=*/NULLPTR, /*out_sel_left_maybe_same=*/NULLPTR, columns_left, - row_table_right, - /*are_cols_in_encoding_order=*/true, match_bitvector.data()); - ASSERT_EQ(arrow::internal::CountSetBits(match_bitvector.data(), 0, num_rows), - num_rows); + std::vector values; + + // A fixed length array containing random values. + ASSERT_OK_AND_ASSIGN( + auto value_fixed_length, + Random(fixed_size_binary(fixed_length))->Generate(num_rows_batch)); + values.push_back(std::move(value_fixed_length)); + + // A var length array containing random binary of 128 or 129 bytes with small portion + // of nulls. + auto value_var_length = RandomArrayGenerator(kSeedMax).String( + num_rows_batch, var_length_min, var_length_max, null_probability); + values.push_back(std::move(value_var_length)); + + batch_left = ExecBatch(std::move(values), num_rows_batch); } + + // The left side columns with num_rows_batch rows. + std::vector columns_left; + ASSERT_OK(ColumnArraysFromExecBatch(batch_left, &columns_left)); + + // The right side row table with num_rows_row_table rows. + ASSERT_OK_AND_ASSIGN( + RowTableImpl row_table_right, + RepeatRowTableUntil(MakeRowTableFromExecBatch(batch_left).ValueUnsafe(), + num_rows_row_table)); + // The row table must contain an offset buffer. + ASSERT_NE(row_table_right.data(2), NULLPTR); + // At least the last row should be located at over 4GB. + ASSERT_GT(row_table_right.offsets()[num_rows_row_table - 1], k4GB); + + // The rows to compare: the last num_rows_batch rows in the row table VS. the whole + // batch. + std::vector row_ids_to_compare(num_rows_batch); + std::iota(row_ids_to_compare.begin(), row_ids_to_compare.end(), + static_cast(num_rows_row_table - num_rows_batch)); + + AssertCompareColumnsToRowsAllMatch(columns_left, row_table_right, row_ids_to_compare); } } // namespace compute diff --git a/cpp/src/arrow/compute/row/encode_internal.cc b/cpp/src/arrow/compute/row/encode_internal.cc index 658e0dffcac68..127d43021d639 100644 --- a/cpp/src/arrow/compute/row/encode_internal.cc +++ b/cpp/src/arrow/compute/row/encode_internal.cc @@ -17,7 +17,6 @@ #include "arrow/compute/row/encode_internal.h" #include "arrow/util/checked_cast.h" -#include "arrow/util/int_util_overflow.h" namespace arrow { namespace compute { @@ -265,7 +264,8 @@ void EncoderInteger::Decode(uint32_t start_row, uint32_t num_rows, num_rows * row_size); } else if (rows.metadata().is_fixed_length) { uint32_t row_size = rows.metadata().fixed_length; - const uint8_t* row_base = rows.data(1) + start_row * row_size; + const uint8_t* row_base = + rows.data(1) + static_cast(start_row) * row_size; row_base += offset_within_row; uint8_t* col_base = col_prep.mutable_data(1); switch (col_prep.metadata().fixed_length) { @@ -296,7 +296,7 @@ void EncoderInteger::Decode(uint32_t start_row, uint32_t num_rows, DCHECK(false); } } else { - const uint32_t* row_offsets = rows.offsets() + start_row; + const RowTableImpl::offset_type* row_offsets = rows.offsets() + start_row; const uint8_t* row_base = rows.data(2); row_base += offset_within_row; uint8_t* col_base = col_prep.mutable_data(1); @@ -362,14 +362,14 @@ void EncoderBinary::EncodeSelectedImp(uint32_t offset_within_row, RowTableImpl* } else { const uint8_t* src_base = col.data(1); uint8_t* dst = rows->mutable_data(2) + offset_within_row; - const uint32_t* offsets = rows->offsets(); + const RowTableImpl::offset_type* offsets = rows->offsets(); for (uint32_t i = 0; i < num_selected; ++i) { copy_fn(dst + offsets[i], src_base, selection[i]); } if (col.data(0)) { const uint8_t* non_null_bits = col.data(0); uint8_t* dst = rows->mutable_data(2) + offset_within_row; - const uint32_t* offsets = rows->offsets(); + const RowTableImpl::offset_type* offsets = rows->offsets(); for (uint32_t i = 0; i < num_selected; ++i) { bool is_null = !bit_util::GetBit(non_null_bits, selection[i] + col.bit_offset(0)); if (is_null) { @@ -585,10 +585,12 @@ void EncoderBinaryPair::DecodeImp(uint32_t num_rows_to_skip, uint32_t start_row, uint8_t* dst_B = col2->mutable_data(1); uint32_t fixed_length = rows.metadata().fixed_length; - const uint32_t* offsets; + const RowTableImpl::offset_type* offsets; const uint8_t* src_base; if (is_row_fixed_length) { - src_base = rows.data(1) + fixed_length * start_row + offset_within_row; + src_base = rows.data(1) + + static_cast(start_row) * fixed_length + + offset_within_row; offsets = nullptr; } else { src_base = rows.data(2) + offset_within_row; @@ -640,7 +642,7 @@ void EncoderOffsets::Decode(uint32_t start_row, uint32_t num_rows, // The Nth element is the sum of all the lengths of varbinary columns data in // that row, up to and including Nth varbinary column. - const uint32_t* row_offsets = rows.offsets() + start_row; + const RowTableImpl::offset_type* row_offsets = rows.offsets() + start_row; // Set the base offset for each column for (size_t col = 0; col < varbinary_cols->size(); ++col) { @@ -658,8 +660,8 @@ void EncoderOffsets::Decode(uint32_t start_row, uint32_t num_rows, // Update the offset of each column uint32_t offset_within_row = rows.metadata().fixed_length; for (size_t col = 0; col < varbinary_cols->size(); ++col) { - offset_within_row += - RowTableMetadata::padding_for_alignment(offset_within_row, string_alignment); + offset_within_row += RowTableMetadata::padding_for_alignment_within_row( + offset_within_row, string_alignment); uint32_t length = varbinary_ends[col] - offset_within_row; offset_within_row = varbinary_ends[col]; uint32_t* col_offsets = (*varbinary_cols)[col].mutable_offsets(); @@ -676,7 +678,7 @@ Status EncoderOffsets::GetRowOffsetsSelected(RowTableImpl* rows, return Status::OK(); } - uint32_t* row_offsets = rows->mutable_offsets(); + RowTableImpl::offset_type* row_offsets = rows->mutable_offsets(); for (uint32_t i = 0; i < num_selected; ++i) { row_offsets[i] = rows->metadata().fixed_length; } @@ -688,7 +690,7 @@ Status EncoderOffsets::GetRowOffsetsSelected(RowTableImpl* rows, for (uint32_t i = 0; i < num_selected; ++i) { uint32_t irow = selection[i]; uint32_t length = col_offsets[irow + 1] - col_offsets[irow]; - row_offsets[i] += RowTableMetadata::padding_for_alignment( + row_offsets[i] += RowTableMetadata::padding_for_alignment_row( row_offsets[i], rows->metadata().string_alignment); row_offsets[i] += length; } @@ -708,20 +710,13 @@ Status EncoderOffsets::GetRowOffsetsSelected(RowTableImpl* rows, } } - uint32_t sum = 0; + int64_t sum = 0; int row_alignment = rows->metadata().row_alignment; for (uint32_t i = 0; i < num_selected; ++i) { - uint32_t length = row_offsets[i]; - length += RowTableMetadata::padding_for_alignment(length, row_alignment); + RowTableImpl::offset_type length = row_offsets[i]; + length += RowTableMetadata::padding_for_alignment_row(length, row_alignment); row_offsets[i] = sum; - uint32_t sum_maybe_overflow = 0; - if (ARROW_PREDICT_FALSE( - arrow::internal::AddWithOverflow(sum, length, &sum_maybe_overflow))) { - return Status::Invalid( - "Offset overflow detected in EncoderOffsets::GetRowOffsetsSelected for row ", i, - " of length ", length, " bytes, current length in total is ", sum, " bytes"); - } - sum = sum_maybe_overflow; + sum += length; } row_offsets[num_selected] = sum; @@ -732,7 +727,7 @@ template void EncoderOffsets::EncodeSelectedImp(uint32_t ivarbinary, RowTableImpl* rows, const std::vector& cols, uint32_t num_selected, const uint16_t* selection) { - const uint32_t* row_offsets = rows->offsets(); + const RowTableImpl::offset_type* row_offsets = rows->offsets(); uint8_t* row_base = rows->mutable_data(2) + rows->metadata().varbinary_end_array_offset + ivarbinary * sizeof(uint32_t); @@ -753,7 +748,7 @@ void EncoderOffsets::EncodeSelectedImp(uint32_t ivarbinary, RowTableImpl* rows, row[0] = rows->metadata().fixed_length + length; } else { row[0] = row[-1] + - RowTableMetadata::padding_for_alignment( + RowTableMetadata::padding_for_alignment_within_row( row[-1], rows->metadata().string_alignment) + length; } @@ -857,7 +852,7 @@ void EncoderNulls::Decode(uint32_t start_row, uint32_t num_rows, const RowTableI void EncoderVarBinary::EncodeSelected(uint32_t ivarbinary, RowTableImpl* rows, const KeyColumnArray& cols, uint32_t num_selected, const uint16_t* selection) { - const uint32_t* row_offsets = rows->offsets(); + const RowTableImpl::offset_type* row_offsets = rows->offsets(); uint8_t* row_base = rows->mutable_data(2); const uint32_t* col_offsets = cols.offsets(); const uint8_t* col_base = cols.data(2); diff --git a/cpp/src/arrow/compute/row/encode_internal.h b/cpp/src/arrow/compute/row/encode_internal.h index 0618ddd8e4b96..37538fcc4b835 100644 --- a/cpp/src/arrow/compute/row/encode_internal.h +++ b/cpp/src/arrow/compute/row/encode_internal.h @@ -173,7 +173,7 @@ class EncoderBinary { copy_fn(dst, src, col_width); } } else { - const uint32_t* row_offsets = rows_const->offsets(); + const RowTableImpl::offset_type* row_offsets = rows_const->offsets(); for (uint32_t i = 0; i < num_rows; ++i) { const uint8_t* src; uint8_t* dst; @@ -267,7 +267,8 @@ class EncoderVarBinary { ARROW_DCHECK(!rows_const->metadata().is_fixed_length && !col_const->metadata().is_fixed_length); - const uint32_t* row_offsets_for_batch = rows_const->offsets() + start_row; + const RowTableImpl::offset_type* row_offsets_for_batch = + rows_const->offsets() + start_row; const uint32_t* col_offsets = col_const->offsets(); uint32_t col_offset_next = col_offsets[0]; @@ -275,7 +276,7 @@ class EncoderVarBinary { uint32_t col_offset = col_offset_next; col_offset_next = col_offsets[i + 1]; - uint32_t row_offset = row_offsets_for_batch[i]; + RowTableImpl::offset_type row_offset = row_offsets_for_batch[i]; const uint8_t* row = rows_const->data(2) + row_offset; uint32_t offset_within_row; diff --git a/cpp/src/arrow/compute/row/encode_internal_avx2.cc b/cpp/src/arrow/compute/row/encode_internal_avx2.cc index 50969c7bd6034..26f8e3a63de0a 100644 --- a/cpp/src/arrow/compute/row/encode_internal_avx2.cc +++ b/cpp/src/arrow/compute/row/encode_internal_avx2.cc @@ -75,10 +75,12 @@ uint32_t EncoderBinaryPair::DecodeImp_avx2(uint32_t start_row, uint32_t num_rows uint8_t* col_vals_B = col2->mutable_data(1); uint32_t fixed_length = rows.metadata().fixed_length; - const uint32_t* offsets; + const RowTableImpl::offset_type* offsets; const uint8_t* src_base; if (is_row_fixed_length) { - src_base = rows.data(1) + fixed_length * start_row + offset_within_row; + src_base = rows.data(1) + + static_cast(fixed_length) * start_row + + offset_within_row; offsets = nullptr; } else { src_base = rows.data(2) + offset_within_row; @@ -99,7 +101,7 @@ uint32_t EncoderBinaryPair::DecodeImp_avx2(uint32_t start_row, uint32_t num_rows src2 = reinterpret_cast(src + fixed_length * 2); src3 = reinterpret_cast(src + fixed_length * 3); } else { - const uint32_t* row_offsets = offsets + i * unroll; + const RowTableImpl::offset_type* row_offsets = offsets + i * unroll; const uint8_t* src = src_base; src0 = reinterpret_cast(src + row_offsets[0]); src1 = reinterpret_cast(src + row_offsets[1]); @@ -140,7 +142,7 @@ uint32_t EncoderBinaryPair::DecodeImp_avx2(uint32_t start_row, uint32_t num_rows } } } else { - const uint32_t* row_offsets = offsets + i * unroll; + const RowTableImpl::offset_type* row_offsets = offsets + i * unroll; const uint8_t* src = src_base; for (int j = 0; j < unroll; ++j) { if (col_width == 1) { diff --git a/cpp/src/arrow/compute/row/row_internal.cc b/cpp/src/arrow/compute/row/row_internal.cc index 746ed950ffa07..aa7e62add45ff 100644 --- a/cpp/src/arrow/compute/row/row_internal.cc +++ b/cpp/src/arrow/compute/row/row_internal.cc @@ -18,7 +18,6 @@ #include "arrow/compute/row/row_internal.h" #include "arrow/compute/util.h" -#include "arrow/util/int_util_overflow.h" namespace arrow { namespace compute { @@ -128,8 +127,8 @@ void RowTableMetadata::FromColumnMetadataVector( const KeyColumnMetadata& col = cols[column_order[i]]; if (col.is_fixed_length && col.fixed_length != 0 && ARROW_POPCOUNT64(col.fixed_length) != 1) { - offset_within_row += RowTableMetadata::padding_for_alignment(offset_within_row, - string_alignment, col); + offset_within_row += RowTableMetadata::padding_for_alignment_within_row( + offset_within_row, string_alignment, col); } column_offsets[i] = offset_within_row; if (!col.is_fixed_length) { @@ -155,7 +154,7 @@ void RowTableMetadata::FromColumnMetadataVector( is_fixed_length = (num_varbinary_cols == 0); fixed_length = offset_within_row + - RowTableMetadata::padding_for_alignment( + RowTableMetadata::padding_for_alignment_within_row( offset_within_row, num_varbinary_cols == 0 ? row_alignment : string_alignment); // We set the number of bytes per row storing null masks of individual key columns @@ -191,7 +190,7 @@ Status RowTableImpl::Init(MemoryPool* pool, const RowTableMetadata& metadata) { auto offsets, AllocateResizableBuffer(size_offsets(kInitialRowsCapacity), pool_)); offsets_ = std::move(offsets); memset(offsets_->mutable_data(), 0, size_offsets(kInitialRowsCapacity)); - reinterpret_cast(offsets_->mutable_data())[0] = 0; + reinterpret_cast(offsets_->mutable_data())[0] = 0; ARROW_ASSIGN_OR_RAISE( auto rows, @@ -226,7 +225,7 @@ void RowTableImpl::Clean() { has_any_nulls_ = false; if (!metadata_.is_fixed_length) { - reinterpret_cast(offsets_->mutable_data())[0] = 0; + reinterpret_cast(offsets_->mutable_data())[0] = 0; } } @@ -235,7 +234,7 @@ int64_t RowTableImpl::size_null_masks(int64_t num_rows) const { } int64_t RowTableImpl::size_offsets(int64_t num_rows) const { - return (num_rows + 1) * sizeof(uint32_t) + kPaddingForVectors; + return (num_rows + 1) * sizeof(offset_type) + kPaddingForVectors; } int64_t RowTableImpl::size_rows_fixed_length(int64_t num_rows) const { @@ -326,23 +325,15 @@ Status RowTableImpl::AppendSelectionFrom(const RowTableImpl& from, if (!metadata_.is_fixed_length) { // Varying-length rows - auto from_offsets = reinterpret_cast(from.offsets_->data()); - auto to_offsets = reinterpret_cast(offsets_->mutable_data()); - uint32_t total_length = to_offsets[num_rows_]; - uint32_t total_length_to_append = 0; + auto from_offsets = reinterpret_cast(from.offsets_->data()); + auto to_offsets = reinterpret_cast(offsets_->mutable_data()); + offset_type total_length = to_offsets[num_rows_]; + int64_t total_length_to_append = 0; for (uint32_t i = 0; i < num_rows_to_append; ++i) { uint16_t row_id = source_row_ids ? source_row_ids[i] : i; - uint32_t length = from_offsets[row_id + 1] - from_offsets[row_id]; + int64_t length = from_offsets[row_id + 1] - from_offsets[row_id]; total_length_to_append += length; - uint32_t to_offset_maybe_overflow = 0; - if (ARROW_PREDICT_FALSE(arrow::internal::AddWithOverflow( - total_length, total_length_to_append, &to_offset_maybe_overflow))) { - return Status::Invalid( - "Offset overflow detected in RowTableImpl::AppendSelectionFrom for row ", - num_rows_ + i, " of length ", length, " bytes, current length in total is ", - to_offsets[num_rows_ + i], " bytes"); - } - to_offsets[num_rows_ + i + 1] = to_offset_maybe_overflow; + to_offsets[num_rows_ + i + 1] = total_length + total_length_to_append; } RETURN_NOT_OK(ResizeOptionalVaryingLengthBuffer(total_length_to_append)); @@ -351,7 +342,8 @@ Status RowTableImpl::AppendSelectionFrom(const RowTableImpl& from, uint8_t* dst = rows_->mutable_data() + total_length; for (uint32_t i = 0; i < num_rows_to_append; ++i) { uint16_t row_id = source_row_ids ? source_row_ids[i] : i; - uint32_t length = from_offsets[row_id + 1] - from_offsets[row_id]; + int64_t length = from_offsets[row_id + 1] - from_offsets[row_id]; + DCHECK_LE(length, std::numeric_limits::max()); auto src64 = reinterpret_cast(src + from_offsets[row_id]); auto dst64 = reinterpret_cast(dst); for (uint32_t j = 0; j < bit_util::CeilDiv(length, 8); ++j) { @@ -397,7 +389,7 @@ Status RowTableImpl::AppendSelectionFrom(const RowTableImpl& from, } Status RowTableImpl::AppendEmpty(uint32_t num_rows_to_append, - uint32_t num_extra_bytes_to_append) { + int64_t num_extra_bytes_to_append) { RETURN_NOT_OK(ResizeFixedLengthBuffers(num_rows_to_append)); if (!metadata_.is_fixed_length) { RETURN_NOT_OK(ResizeOptionalVaryingLengthBuffer(num_extra_bytes_to_append)); diff --git a/cpp/src/arrow/compute/row/row_internal.h b/cpp/src/arrow/compute/row/row_internal.h index 93818fb14d629..094a9c31efe0a 100644 --- a/cpp/src/arrow/compute/row/row_internal.h +++ b/cpp/src/arrow/compute/row/row_internal.h @@ -30,6 +30,8 @@ namespace compute { /// Description of the data stored in a RowTable struct ARROW_EXPORT RowTableMetadata { + using offset_type = int64_t; + /// \brief True if there are no variable length columns in the table bool is_fixed_length; @@ -78,26 +80,35 @@ struct ARROW_EXPORT RowTableMetadata { /// Offsets within a row to fields in their encoding order. std::vector column_offsets; - /// Rounding up offset to the nearest multiple of alignment value. + /// Rounding up offset within row to the nearest multiple of alignment value. /// Alignment must be a power of 2. - static inline uint32_t padding_for_alignment(uint32_t offset, int required_alignment) { + static inline uint32_t padding_for_alignment_within_row(uint32_t offset, + int required_alignment) { ARROW_DCHECK(ARROW_POPCOUNT64(required_alignment) == 1); return static_cast((-static_cast(offset)) & (required_alignment - 1)); } - /// Rounding up offset to the beginning of next column, + /// Rounding up offset within row to the beginning of next column, /// choosing required alignment based on the data type of that column. - static inline uint32_t padding_for_alignment(uint32_t offset, int string_alignment, - const KeyColumnMetadata& col_metadata) { + static inline uint32_t padding_for_alignment_within_row( + uint32_t offset, int string_alignment, const KeyColumnMetadata& col_metadata) { if (!col_metadata.is_fixed_length || ARROW_POPCOUNT64(col_metadata.fixed_length) <= 1) { return 0; } else { - return padding_for_alignment(offset, string_alignment); + return padding_for_alignment_within_row(offset, string_alignment); } } + /// Rounding up row offset to the nearest multiple of alignment value. + /// Alignment must be a power of 2. + static inline offset_type padding_for_alignment_row(offset_type row_offset, + int required_alignment) { + ARROW_DCHECK(ARROW_POPCOUNT64(required_alignment) == 1); + return (-row_offset) & (required_alignment - 1); + } + /// Returns an array of offsets within a row of ends of varbinary fields. inline const uint32_t* varbinary_end_array(const uint8_t* row) const { ARROW_DCHECK(!is_fixed_length); @@ -127,7 +138,7 @@ struct ARROW_EXPORT RowTableMetadata { ARROW_DCHECK(varbinary_id > 0); const uint32_t* varbinary_end = varbinary_end_array(row); uint32_t offset = varbinary_end[varbinary_id - 1]; - offset += padding_for_alignment(offset, string_alignment); + offset += padding_for_alignment_within_row(offset, string_alignment); *out_offset = offset; *out_length = varbinary_end[varbinary_id] - offset; } @@ -161,6 +172,8 @@ struct ARROW_EXPORT RowTableMetadata { /// The row table is not safe class ARROW_EXPORT RowTableImpl { public: + using offset_type = RowTableMetadata::offset_type; + RowTableImpl(); /// \brief Initialize a row array for use /// @@ -175,7 +188,7 @@ class ARROW_EXPORT RowTableImpl { /// \param num_extra_bytes_to_append For tables storing variable-length data this /// should be a guess of how many data bytes will be needed to populate the /// data. This is ignored if there are no variable-length columns - Status AppendEmpty(uint32_t num_rows_to_append, uint32_t num_extra_bytes_to_append); + Status AppendEmpty(uint32_t num_rows_to_append, int64_t num_extra_bytes_to_append); /// \brief Append rows from a source table /// \param from The table to append from /// \param num_rows_to_append The number of rows to append @@ -201,8 +214,12 @@ class ARROW_EXPORT RowTableImpl { } return NULLPTR; } - const uint32_t* offsets() const { return reinterpret_cast(data(1)); } - uint32_t* mutable_offsets() { return reinterpret_cast(mutable_data(1)); } + const offset_type* offsets() const { + return reinterpret_cast(data(1)); + } + offset_type* mutable_offsets() { + return reinterpret_cast(mutable_data(1)); + } const uint8_t* null_masks() const { return null_masks_->data(); } uint8_t* null_masks() { return null_masks_->mutable_data(); } diff --git a/cpp/src/arrow/compute/row/row_test.cc b/cpp/src/arrow/compute/row/row_test.cc index 75f981fb1281d..6aed9e4327812 100644 --- a/cpp/src/arrow/compute/row/row_test.cc +++ b/cpp/src/arrow/compute/row/row_test.cc @@ -123,7 +123,7 @@ TEST(RowTableMemoryConsumption, Encode) { ASSERT_GT(actual_null_mask_size * 2, row_table.buffer_size(0) - padding_for_vectors); - int64_t actual_offset_size = num_rows * sizeof(uint32_t); + int64_t actual_offset_size = num_rows * sizeof(RowTableImpl::offset_type); ASSERT_LE(actual_offset_size, row_table.buffer_size(1) - padding_for_vectors); ASSERT_GT(actual_offset_size * 2, row_table.buffer_size(1) - padding_for_vectors); @@ -134,15 +134,14 @@ TEST(RowTableMemoryConsumption, Encode) { } } -// GH-43202: Ensure that when offset overflow happens in encoding the row table, an -// explicit error is raised instead of a silent wrong result. -TEST(RowTableOffsetOverflow, LARGE_MEMORY_TEST(Encode)) { +// GH-43495: Ensure that we can build a row table with more than 4GB row data. +TEST(RowTableLarge, LARGE_MEMORY_TEST(Encode)) { if constexpr (sizeof(void*) == 4) { GTEST_SKIP() << "Test only works on 64-bit platforms"; } - // Use 8 512MB var-length rows (occupies 4GB+) to overflow the offset in the row table. - constexpr int64_t num_rows = 8; + // Use 9 512MB var-length rows to occupy more than 4GB memory. + constexpr int64_t num_rows = 9; constexpr int64_t length_per_binary = 512 * 1024 * 1024; constexpr int64_t row_alignment = sizeof(uint32_t); constexpr int64_t var_length_alignment = sizeof(uint32_t); @@ -174,39 +173,24 @@ TEST(RowTableOffsetOverflow, LARGE_MEMORY_TEST(Encode)) { // The rows to encode. std::vector row_ids(num_rows, 0); - // Encoding 7 rows should be fine. - { - row_encoder.PrepareEncodeSelected(0, num_rows - 1, columns); - ASSERT_OK(row_encoder.EncodeSelected(&row_table, static_cast(num_rows - 1), - row_ids.data())); - } + // Encode num_rows rows. + row_encoder.PrepareEncodeSelected(0, num_rows, columns); + ASSERT_OK(row_encoder.EncodeSelected(&row_table, static_cast(num_rows), + row_ids.data())); - // Encoding 8 rows should overflow. - { - int64_t length_per_row = table_metadata.fixed_length + length_per_binary; - std::stringstream expected_error_message; - expected_error_message << "Invalid: Offset overflow detected in " - "EncoderOffsets::GetRowOffsetsSelected for row " - << num_rows - 1 << " of length " << length_per_row - << " bytes, current length in total is " - << length_per_row * (num_rows - 1) << " bytes"; - row_encoder.PrepareEncodeSelected(0, num_rows, columns); - ASSERT_RAISES_WITH_MESSAGE( - Invalid, expected_error_message.str(), - row_encoder.EncodeSelected(&row_table, static_cast(num_rows), - row_ids.data())); - } + auto encoded_row_length = table_metadata.fixed_length + length_per_binary; + ASSERT_EQ(row_table.offsets()[num_rows - 1], encoded_row_length * (num_rows - 1)); + ASSERT_EQ(row_table.offsets()[num_rows], encoded_row_length * num_rows); } -// GH-43202: Ensure that when offset overflow happens in appending to the row table, an -// explicit error is raised instead of a silent wrong result. -TEST(RowTableOffsetOverflow, LARGE_MEMORY_TEST(AppendFrom)) { +// GH-43495: Ensure that we can build a row table with more than 4GB row data. +TEST(RowTableLarge, LARGE_MEMORY_TEST(AppendFrom)) { if constexpr (sizeof(void*) == 4) { GTEST_SKIP() << "Test only works on 64-bit platforms"; } - // Use 8 512MB var-length rows (occupies 4GB+) to overflow the offset in the row table. - constexpr int64_t num_rows = 8; + // Use 9 512MB var-length rows to occupy more than 4GB memory. + constexpr int64_t num_rows = 9; constexpr int64_t length_per_binary = 512 * 1024 * 1024; constexpr int64_t num_rows_seed = 1; constexpr int64_t row_alignment = sizeof(uint32_t); @@ -244,23 +228,15 @@ TEST(RowTableOffsetOverflow, LARGE_MEMORY_TEST(AppendFrom)) { RowTableImpl row_table; ASSERT_OK(row_table.Init(pool, table_metadata)); - // Appending the seed 7 times should be fine. - for (int i = 0; i < num_rows - 1; ++i) { + // Append seed num_rows times. + for (int i = 0; i < num_rows; ++i) { ASSERT_OK(row_table.AppendSelectionFrom(row_table_seed, num_rows_seed, /*source_row_ids=*/NULLPTR)); } - // Appending the seed the 8-th time should overflow. - int64_t length_per_row = table_metadata.fixed_length + length_per_binary; - std::stringstream expected_error_message; - expected_error_message - << "Invalid: Offset overflow detected in RowTableImpl::AppendSelectionFrom for row " - << num_rows - 1 << " of length " << length_per_row - << " bytes, current length in total is " << length_per_row * (num_rows - 1) - << " bytes"; - ASSERT_RAISES_WITH_MESSAGE(Invalid, expected_error_message.str(), - row_table.AppendSelectionFrom(row_table_seed, num_rows_seed, - /*source_row_ids=*/NULLPTR)); + auto encoded_row_length = table_metadata.fixed_length + length_per_binary; + ASSERT_EQ(row_table.offsets()[num_rows - 1], encoded_row_length * (num_rows - 1)); + ASSERT_EQ(row_table.offsets()[num_rows], encoded_row_length * num_rows); } } // namespace compute diff --git a/cpp/src/arrow/testing/random.cc b/cpp/src/arrow/testing/random.cc index c317fe7aef44c..59de09fff83c5 100644 --- a/cpp/src/arrow/testing/random.cc +++ b/cpp/src/arrow/testing/random.cc @@ -473,19 +473,16 @@ std::shared_ptr RandomArrayGenerator::StringWithRepeats( return result; } -std::shared_ptr RandomArrayGenerator::FixedSizeBinary(int64_t size, - int32_t byte_width, - double null_probability, - int64_t alignment, - MemoryPool* memory_pool) { +std::shared_ptr RandomArrayGenerator::FixedSizeBinary( + int64_t size, int32_t byte_width, double null_probability, uint8_t min_byte, + uint8_t max_byte, int64_t alignment, MemoryPool* memory_pool) { if (null_probability < 0 || null_probability > 1) { ABORT_NOT_OK(Status::Invalid("null_probability must be between 0 and 1")); } // Visual Studio does not implement uniform_int_distribution for char types. using GenOpt = GenerateOptions>; - GenOpt options(seed(), static_cast('A'), static_cast('z'), - null_probability); + GenOpt options(seed(), min_byte, max_byte, null_probability); int64_t null_count = 0; auto null_bitmap = *AllocateEmptyBitmap(size, alignment, memory_pool); @@ -1087,7 +1084,9 @@ std::shared_ptr RandomArrayGenerator::ArrayOf(const Field& field, int64_t case Type::type::FIXED_SIZE_BINARY: { auto byte_width = internal::checked_pointer_cast(field.type())->byte_width(); - return *FixedSizeBinary(length, byte_width, null_probability, alignment, + return *FixedSizeBinary(length, byte_width, null_probability, + /*min_byte=*/static_cast('A'), + /*min_byte=*/static_cast('z'), alignment, memory_pool) ->View(field.type()); } @@ -1143,7 +1142,9 @@ std::shared_ptr RandomArrayGenerator::ArrayOf(const Field& field, int64_t // type means it's not a (useful) composition of other generators GENERATE_INTEGRAL_CASE_VIEW(Int64Type, DayTimeIntervalType); case Type::type::INTERVAL_MONTH_DAY_NANO: { - return *FixedSizeBinary(length, /*byte_width=*/16, null_probability, alignment, + return *FixedSizeBinary(length, /*byte_width=*/16, null_probability, + /*min_byte=*/static_cast('A'), + /*min_byte=*/static_cast('z'), alignment, memory_pool) ->View(month_day_nano_interval()); } diff --git a/cpp/src/arrow/testing/random.h b/cpp/src/arrow/testing/random.h index 1d97a3ada724a..9c0c5baae0f7c 100644 --- a/cpp/src/arrow/testing/random.h +++ b/cpp/src/arrow/testing/random.h @@ -434,12 +434,18 @@ class ARROW_TESTING_EXPORT RandomArrayGenerator { /// \param[in] size the size of the array to generate /// \param[in] byte_width the byte width of fixed-size binary items /// \param[in] null_probability the probability of a value being null + /// \param[in] min_byte the lower bound of each byte in the binary determined by the + /// uniform distribution + /// \param[in] max_byte the upper bound of each byte in the binary determined by the + /// uniform distribution /// \param[in] alignment alignment for memory allocations (in bytes) /// \param[in] memory_pool memory pool to allocate memory from /// /// \return a generated Array std::shared_ptr FixedSizeBinary(int64_t size, int32_t byte_width, double null_probability = 0, + uint8_t min_byte = static_cast('A'), + uint8_t max_byte = static_cast('z'), int64_t alignment = kDefaultBufferAlignment, MemoryPool* memory_pool = default_memory_pool()); From c599fa0064a627d3b58d4eff821a34391120bcf6 Mon Sep 17 00:00:00 2001 From: Tom Scott-Coombes <62209801+tscottcoombes1@users.noreply.github.com> Date: Mon, 19 Aug 2024 16:13:35 +0100 Subject: [PATCH 09/10] GH-43554: [Go] Handle excluded fields (#43555) ### Rationale for this change We want to be able to handle excluded fields. ### What changes are included in this PR? * we no longer use the value of the field when getting the element type of a list (as the values are invalid for excluded fields) * similarly for map, key value pairs, we don't use the value is there is none * add some tests ### Are these changes tested? yes ### Are there any user-facing changes? no * GitHub Issue: #43554 Lead-authored-by: Tom Scott-Coombes Co-authored-by: Tom Scott-Coombes <62209801+tscottcoombes1@users.noreply.github.com> Co-authored-by: Matt Topol Co-authored-by: tscottcoombes1 <62209801+tscottcoombes1@users.noreply.github.com> Co-authored-by: Joel Lubinitsky <33523178+joellubi@users.noreply.github.com> Signed-off-by: Joel Lubinitsky --- go/arrow/util/messages/types.proto | 46 ++ go/arrow/util/protobuf_reflect.go | 31 +- go/arrow/util/protobuf_reflect_test.go | 421 +++++++++++----- go/arrow/util/util_message/types.pb.go | 654 +++++++++++++++++++++++-- 4 files changed, 996 insertions(+), 156 deletions(-) diff --git a/go/arrow/util/messages/types.proto b/go/arrow/util/messages/types.proto index c085273ca35e0..79b922a22a3be 100644 --- a/go/arrow/util/messages/types.proto +++ b/go/arrow/util/messages/types.proto @@ -54,3 +54,49 @@ message AllTheTypes { OPTION_1 = 1; } } + +message AllTheTypesNoAny { + string str = 1; + int32 int32 = 2; + int64 int64 = 3; + sint32 sint32 = 4; + sint64 sin64 = 5; + uint32 uint32 = 6; + uint64 uint64 = 7; + fixed32 fixed32 = 8; + fixed64 fixed64 = 9; + sfixed32 sfixed32 = 10; + bool bool = 11; + bytes bytes = 12; + double double = 13; + ExampleEnum enum = 14; + ExampleMessage message = 15; + oneof oneof { + string oneofstring = 16; + ExampleMessage oneofmessage = 17; + } + map simple_map = 19; + map complex_map = 20; + repeated string simple_list = 21; + repeated ExampleMessage complex_list = 22; + + enum ExampleEnum { + OPTION_0 = 0; + OPTION_1 = 1; + } +} + +message SimpleNested { + repeated ExampleMessage simple_a = 1; + repeated ExampleMessage simple_b = 2; +} + +message ComplexNested { + repeated AllTheTypesNoAny all_the_types_no_any_a = 1; + repeated AllTheTypesNoAny all_the_types_no_any_b = 2; +} + +message DeepNested { + ComplexNested complex_nested = 1; + SimpleNested simple_nested = 2; +} diff --git a/go/arrow/util/protobuf_reflect.go b/go/arrow/util/protobuf_reflect.go index 03153563b8cb5..c8cda96acf941 100644 --- a/go/arrow/util/protobuf_reflect.go +++ b/go/arrow/util/protobuf_reflect.go @@ -60,6 +60,7 @@ type ProtobufFieldReflection struct { rValue reflect.Value schemaOptions arrow.Field + isListItem bool } func (pfr *ProtobufFieldReflection) isNull() bool { @@ -170,7 +171,7 @@ func (pfr *ProtobufFieldReflection) isEnum() bool { } func (pfr *ProtobufFieldReflection) isStruct() bool { - return pfr.descriptor.Kind() == protoreflect.MessageKind && !pfr.descriptor.IsMap() && pfr.rValue.Kind() != reflect.Slice + return pfr.descriptor.Kind() == protoreflect.MessageKind && !pfr.descriptor.IsMap() && !pfr.isList() } func (pfr *ProtobufFieldReflection) isMap() bool { @@ -178,7 +179,7 @@ func (pfr *ProtobufFieldReflection) isMap() bool { } func (pfr *ProtobufFieldReflection) isList() bool { - return pfr.descriptor.IsList() && pfr.rValue.Kind() == reflect.Slice + return pfr.descriptor.IsList() && !pfr.isListItem } // ProtobufMessageReflection represents the metadata and values of a protobuf message @@ -218,11 +219,7 @@ func (psr ProtobufMessageReflection) getArrowFields() []arrow.Field { var fields []arrow.Field for pfr := range psr.generateStructFields() { - fields = append(fields, arrow.Field{ - Name: pfr.name(), - Type: pfr.getDataType(), - Nullable: true, - }) + fields = append(fields, pfr.arrowField()) } return fields @@ -237,12 +234,10 @@ func (pfr *ProtobufFieldReflection) asList() protobufListReflection { } func (plr protobufListReflection) getDataType() arrow.DataType { - for li := range plr.generateListItems() { - return arrow.ListOf(li.getDataType()) - } pfr := ProtobufFieldReflection{ descriptor: plr.descriptor, schemaOptions: plr.schemaOptions, + isListItem: true, } return arrow.ListOf(pfr.getDataType()) } @@ -401,6 +396,22 @@ func (pmr protobufMapReflection) generateKeyValuePairs() chan protobufMapKeyValu go func() { defer close(out) + if !pmr.rValue.IsValid() { + kvp := protobufMapKeyValuePairReflection{ + k: ProtobufFieldReflection{ + parent: pmr.parent, + descriptor: pmr.descriptor.MapKey(), + schemaOptions: pmr.schemaOptions, + }, + v: ProtobufFieldReflection{ + parent: pmr.parent, + descriptor: pmr.descriptor.MapValue(), + schemaOptions: pmr.schemaOptions, + }, + } + out <- kvp + return + } for _, k := range pmr.rValue.MapKeys() { kvp := protobufMapKeyValuePairReflection{ k: ProtobufFieldReflection{ diff --git a/go/arrow/util/protobuf_reflect_test.go b/go/arrow/util/protobuf_reflect_test.go index 220552df8d89e..7420aa726337d 100644 --- a/go/arrow/util/protobuf_reflect_test.go +++ b/go/arrow/util/protobuf_reflect_test.go @@ -17,9 +17,12 @@ package util import ( - "strings" + "encoding/json" + "fmt" "testing" + "google.golang.org/protobuf/proto" + "github.com/apache/arrow/go/v18/arrow" "github.com/apache/arrow/go/v18/arrow/array" "github.com/apache/arrow/go/v18/arrow/memory" @@ -30,14 +33,52 @@ import ( "google.golang.org/protobuf/types/known/anypb" ) -func SetupTest() util_message.AllTheTypes { - msg := util_message.ExampleMessage{ - Field1: "Example", +type Fixture struct { + msg proto.Message + schema string + jsonStr string +} + +type J map[string]any + +func AllTheTypesFixture() Fixture { + e := J{"field1": "Example"} + + m := J{ + "str": "Hello", + "int32": 10, + "int64": 100, + "sint32": -10, + "sin64": -100, + "uint32": 10, + "uint64": 100, + "fixed32": 10, + "fixed64": 1000, + "sfixed32": 10, + "bool": false, + "bytes": "SGVsbG8sIHdvcmxkIQ==", + "double": 1.1, + "enum": "OPTION_1", + "message": e, + "oneof": []any{0, "World"}, + "any": J{"field1": "Example"}, + "simple_map": []J{{"key": 99, "value": "Hello"}}, + "complex_map": []J{{"key": "complex", "value": e}}, + "simple_list": []any{"Hello", "World"}, + "complex_list": []J{e}, } + jm, err := json.Marshal(m) + if err != nil { + panic(err) + } + jsonString := string(jm) - anyMsg, _ := anypb.New(&msg) + exampleMsg := util_message.ExampleMessage{ + Field1: "Example", + } + anyMsg, _ := anypb.New(&exampleMsg) - return util_message.AllTheTypes{ + msg := util_message.AllTheTypes{ Str: "Hello", Int32: 10, Int64: 100, @@ -52,23 +93,80 @@ func SetupTest() util_message.AllTheTypes { Bytes: []byte("Hello, world!"), Double: 1.1, Enum: util_message.AllTheTypes_OPTION_1, - Message: &msg, + Message: &exampleMsg, Oneof: &util_message.AllTheTypes_Oneofstring{Oneofstring: "World"}, Any: anyMsg, //Breaks the test as the Golang maps have a non-deterministic order //SimpleMap: map[int32]string{99: "Hello", 100: "World", 98: "How", 101: "Are", 1: "You"}, SimpleMap: map[int32]string{99: "Hello"}, - ComplexMap: map[string]*util_message.ExampleMessage{"complex": &msg}, + ComplexMap: map[string]*util_message.ExampleMessage{"complex": &exampleMsg}, SimpleList: []string{"Hello", "World"}, - ComplexList: []*util_message.ExampleMessage{&msg}, + ComplexList: []*util_message.ExampleMessage{&exampleMsg}, + } + + schema := `schema: + fields: 22 + - str: type=utf8, nullable + - int32: type=int32, nullable + - int64: type=int64, nullable + - sint32: type=int32, nullable + - sin64: type=int64, nullable + - uint32: type=uint32, nullable + - uint64: type=uint64, nullable + - fixed32: type=uint32, nullable + - fixed64: type=uint64, nullable + - sfixed32: type=int32, nullable + - bool: type=bool, nullable + - bytes: type=binary, nullable + - double: type=float64, nullable + - enum: type=dictionary, nullable + - message: type=struct, nullable + - oneofstring: type=utf8, nullable + - oneofmessage: type=struct, nullable + - any: type=struct, nullable + - simple_map: type=map, nullable + - complex_map: type=map, items_nullable>, nullable + - simple_list: type=list, nullable + - complex_list: type=list, nullable>, nullable` + + return Fixture{ + msg: &msg, + schema: schema, + jsonStr: jsonString, } } -func TestGetSchema(t *testing.T) { - msg := SetupTest() +func AllTheTypesNoAnyFixture() Fixture { + exampleMsg := util_message.ExampleMessage{ + Field1: "Example", + } - got := NewProtobufMessageReflection(&msg).Schema().String() - want := `schema: + msg := util_message.AllTheTypesNoAny{ + Str: "Hello", + Int32: 10, + Int64: 100, + Sint32: -10, + Sin64: -100, + Uint32: 10, + Uint64: 100, + Fixed32: 10, + Fixed64: 1000, + Sfixed32: 10, + Bool: false, + Bytes: []byte("Hello, world!"), + Double: 1.1, + Enum: util_message.AllTheTypesNoAny_OPTION_1, + Message: &exampleMsg, + Oneof: &util_message.AllTheTypesNoAny_Oneofstring{Oneofstring: "World"}, + //Breaks the test as the Golang maps have a non-deterministic order + //SimpleMap: map[int32]string{99: "Hello", 100: "World", 98: "How", 101: "Are", 1: "You"}, + SimpleMap: map[int32]string{99: "Hello"}, + ComplexMap: map[string]*util_message.ExampleMessage{"complex": &exampleMsg}, + SimpleList: []string{"Hello", "World"}, + ComplexList: []*util_message.ExampleMessage{&exampleMsg}, + } + + schema := `schema: fields: 22 - str: type=utf8, nullable - int32: type=int32, nullable @@ -87,16 +185,62 @@ func TestGetSchema(t *testing.T) { - message: type=struct, nullable - oneofstring: type=utf8, nullable - oneofmessage: type=struct, nullable - - any: type=struct, nullable - simple_map: type=map, nullable - complex_map: type=map, items_nullable>, nullable - simple_list: type=list, nullable - complex_list: type=list, nullable>, nullable` - require.Equal(t, want, got, "got: %s\nwant: %s", got, want) + jsonStr := `{ + "str":"Hello", + "int32":10, + "int64":100, + "sint32":-10, + "sin64":-100, + "uint32":10, + "uint64":100, + "fixed32":10, + "fixed64":1000, + "sfixed32":10, + "bool":false, + "bytes":"SGVsbG8sIHdvcmxkIQ==", + "double":1.1, + "enum":"OPTION_1", + "message":{"field1":"Example"}, + "oneofmessage": { "field1": null }, + "oneofstring": "World", + "simple_map":[{"key":99,"value":"Hello"}], + "complex_map":[{"key":"complex","value":{"field1":"Example"}}], + "simple_list":["Hello","World"], + "complex_list":[{"field1":"Example"}] + }` + + return Fixture{ + msg: &msg, + schema: schema, + jsonStr: jsonStr, + } +} - got = NewProtobufMessageReflection(&msg, WithOneOfHandler(OneOfDenseUnion)).Schema().String() - want = `schema: +func CheckSchema(t *testing.T, pmr *ProtobufMessageReflection, want string) { + got := pmr.Schema().String() + require.Equal(t, got, want, "got: %s\nwant: %s", got, want) +} + +func CheckRecord(t *testing.T, pmr *ProtobufMessageReflection, jsonStr string) { + rec := pmr.Record(nil) + got, err := json.Marshal(rec) + assert.NoError(t, err) + assert.JSONEq(t, jsonStr, string(got), "got: %s\nwant: %s", got, jsonStr) +} + +func TestGetSchema(t *testing.T) { + f := AllTheTypesFixture() + + pmr := NewProtobufMessageReflection(f.msg) + CheckSchema(t, pmr, f.schema) + + pmr = NewProtobufMessageReflection(f.msg, WithOneOfHandler(OneOfDenseUnion)) + want := `schema: fields: 21 - str: type=utf8, nullable - int32: type=int32, nullable @@ -119,14 +263,13 @@ func TestGetSchema(t *testing.T) { - complex_map: type=map, items_nullable>, nullable - simple_list: type=list, nullable - complex_list: type=list, nullable>, nullable` - - require.Equal(t, want, got, "got: %s\nwant: %s", got, want) + CheckSchema(t, pmr, want) excludeComplex := func(pfr *ProtobufFieldReflection) bool { return pfr.isMap() || pfr.isList() || pfr.isStruct() } - got = NewProtobufMessageReflection(&msg, WithExclusionPolicy(excludeComplex)).Schema().String() + pmr = NewProtobufMessageReflection(f.msg, WithExclusionPolicy(excludeComplex)) want = `schema: fields: 15 - str: type=utf8, nullable @@ -144,14 +287,13 @@ func TestGetSchema(t *testing.T) { - double: type=float64, nullable - enum: type=dictionary, nullable - oneofstring: type=utf8, nullable` + CheckSchema(t, pmr, want) - require.Equal(t, want, got, "got: %s\nwant: %s", got, want) - - got = NewProtobufMessageReflection( - &msg, + pmr = NewProtobufMessageReflection( + f.msg, WithExclusionPolicy(excludeComplex), WithFieldNameFormatter(xstrings.ToCamelCase), - ).Schema().String() + ) want = `schema: fields: 15 - Str: type=utf8, nullable @@ -169,123 +311,168 @@ func TestGetSchema(t *testing.T) { - Double: type=float64, nullable - Enum: type=dictionary, nullable - Oneofstring: type=utf8, nullable` - - require.Equal(t, want, got, "got: %s\nwant: %s", got, want) + CheckSchema(t, pmr, want) onlyEnum := func(pfr *ProtobufFieldReflection) bool { return !pfr.isEnum() } - got = NewProtobufMessageReflection( - &msg, + pmr = NewProtobufMessageReflection( + f.msg, WithExclusionPolicy(onlyEnum), WithEnumHandler(EnumNumber), - ).Schema().String() + ) want = `schema: fields: 1 - enum: type=int32, nullable` + CheckSchema(t, pmr, want) - require.Equal(t, want, got, "got: %s\nwant: %s", got, want) - - got = NewProtobufMessageReflection( - &msg, + pmr = NewProtobufMessageReflection( + f.msg, WithExclusionPolicy(onlyEnum), WithEnumHandler(EnumValue), - ).Schema().String() + ) want = `schema: fields: 1 - enum: type=utf8, nullable` - - require.Equal(t, want, got, "got: %s\nwant: %s", got, want) + CheckSchema(t, pmr, want) } func TestRecordFromProtobuf(t *testing.T) { - msg := SetupTest() - - pmr := NewProtobufMessageReflection(&msg, WithOneOfHandler(OneOfDenseUnion)) - schema := pmr.Schema() - got := pmr.Record(nil) - jsonStr := `[ - { - "str":"Hello", - "int32":10, - "int64":100, - "sint32":-10, - "sin64":-100, - "uint32":10, - "uint64":100, - "fixed32":10, - "fixed64":1000, - "sfixed32":10, - "bool":false, - "bytes":"SGVsbG8sIHdvcmxkIQ==", - "double":1.1, - "enum":"OPTION_1", - "message":{"field1":"Example"}, - "oneof": [0, "World"], - "any":{"field1":"Example"}, - "simple_map":[{"key":99,"value":"Hello"}], - "complex_map":[{"key":"complex","value":{"field1":"Example"}}], - "simple_list":["Hello","World"], - "complex_list":[{"field1":"Example"}] - } - ]` - want, _, err := array.RecordFromJSON(memory.NewGoAllocator(), schema, strings.NewReader(jsonStr)) + f := AllTheTypesFixture() - require.NoError(t, err) - require.EqualExportedValues(t, got, want, "got: %s\nwant: %s", got, want) + pmr := NewProtobufMessageReflection(f.msg, WithOneOfHandler(OneOfDenseUnion)) + CheckRecord(t, pmr, fmt.Sprintf(`[%s]`, f.jsonStr)) onlyEnum := func(pfr *ProtobufFieldReflection) bool { return !pfr.isEnum() } - pmr = NewProtobufMessageReflection(&msg, WithExclusionPolicy(onlyEnum), WithEnumHandler(EnumValue)) - got = pmr.Record(nil) - jsonStr = `[ { "enum":"OPTION_1" } ]` - want, _, err = array.RecordFromJSON(memory.NewGoAllocator(), pmr.Schema(), strings.NewReader(jsonStr)) - require.NoError(t, err) - require.True(t, array.RecordEqual(got, want), "got: %s\nwant: %s", got, want) - - pmr = NewProtobufMessageReflection(&msg, WithExclusionPolicy(onlyEnum), WithEnumHandler(EnumNumber)) - got = pmr.Record(nil) - jsonStr = `[ { "enum":"1" } ]` - want, _, err = array.RecordFromJSON(memory.NewGoAllocator(), pmr.Schema(), strings.NewReader(jsonStr)) - require.NoError(t, err) - require.True(t, array.RecordEqual(got, want), "got: %s\nwant: %s", got, want) + pmr = NewProtobufMessageReflection(f.msg, WithExclusionPolicy(onlyEnum), WithEnumHandler(EnumValue)) + jsonStr := `[ { "enum":"OPTION_1" } ]` + CheckRecord(t, pmr, jsonStr) + + pmr = NewProtobufMessageReflection(f.msg, WithExclusionPolicy(onlyEnum), WithEnumHandler(EnumNumber)) + jsonStr = `[ { "enum":1 } ]` + CheckRecord(t, pmr, jsonStr) } func TestNullRecordFromProtobuf(t *testing.T) { pmr := NewProtobufMessageReflection(&util_message.AllTheTypes{}) - schema := pmr.Schema() - got := pmr.Record(nil) - _, _ = got.MarshalJSON() - jsonStr := `[ - { - "str":"", - "int32":0, - "int64":0, - "sint32":0, - "sin64":0, - "uint32":0, - "uint64":0, - "fixed32":0, - "fixed64":0, - "sfixed32":0, - "bool":false, - "bytes":"", - "double":0, - "enum":"OPTION_0", - "message":null, - "oneofmessage":{"field1":""}, - "oneofstring":"", - "any":null, - "simple_map":[], - "complex_map":[], - "simple_list":[], - "complex_list":[] - } - ]` - - want, _, err := array.RecordFromJSON(memory.NewGoAllocator(), schema, strings.NewReader(jsonStr)) - - require.NoError(t, err) - require.EqualExportedValues(t, got, want, "got: %s\nwant: %s", got, want) + CheckRecord(t, pmr, `[{ + "str":"", + "int32":0, + "int64":0, + "sint32":0, + "sin64":0, + "uint32":0, + "uint64":0, + "fixed32":0, + "fixed64":0, + "sfixed32":0, + "bool":false, + "bytes":null, + "double":0, + "enum":"OPTION_0", + "message":null, + "oneofmessage":{"field1":""}, + "oneofstring":"", + "any": null, + "simple_map":[], + "complex_map":[], + "simple_list":[], + "complex_list":[] + }]`) +} + +func TestExcludedNested(t *testing.T) { + msg := util_message.ExampleMessage{ + Field1: "Example", + } + schema := `schema: + fields: 2 + - simple_a: type=list, nullable>, nullable + - simple_b: type=list, nullable>, nullable` + + simpleNested := util_message.SimpleNested{ + SimpleA: []*util_message.ExampleMessage{&msg}, + SimpleB: []*util_message.ExampleMessage{&msg}, + } + pmr := NewProtobufMessageReflection(&simpleNested) + jsonStr := `[{ "simple_a":[{"field1":"Example"}], "simple_b":[{"field1":"Example"}] }]` + CheckSchema(t, pmr, schema) + CheckRecord(t, pmr, jsonStr) + + //exclude one value + simpleNested = util_message.SimpleNested{ + SimpleA: []*util_message.ExampleMessage{&msg}, + } + jsonStr = `[{ "simple_a":[{"field1":"Example"}], "simple_b":[]}]` + CheckSchema(t, pmr, schema) + CheckRecord(t, pmr, jsonStr) + + ////exclude both values + simpleNested = util_message.SimpleNested{} + jsonStr = `[{ "simple_a":[], "simple_b":[] }]` + CheckSchema(t, pmr, schema) + CheckRecord(t, pmr, jsonStr) + + f := AllTheTypesNoAnyFixture() + schema = `schema: + fields: 2 + - all_the_types_no_any_a: type=list, message: struct, oneofstring: utf8, oneofmessage: struct, simple_map: map, complex_map: map, items_nullable>, simple_list: list, complex_list: list, nullable>>, nullable>, nullable + - all_the_types_no_any_b: type=list, message: struct, oneofstring: utf8, oneofmessage: struct, simple_map: map, complex_map: map, items_nullable>, simple_list: list, complex_list: list, nullable>>, nullable>, nullable` + + complexNested := util_message.ComplexNested{ + AllTheTypesNoAnyA: []*util_message.AllTheTypesNoAny{f.msg.(*util_message.AllTheTypesNoAny)}, + AllTheTypesNoAnyB: []*util_message.AllTheTypesNoAny{f.msg.(*util_message.AllTheTypesNoAny)}, + } + jsonStr = fmt.Sprintf(`[{ "all_the_types_no_any_a": [%s], "all_the_types_no_any_b": [%s] }]`, f.jsonStr, f.jsonStr) + pmr = NewProtobufMessageReflection(&complexNested) + CheckSchema(t, pmr, schema) + CheckRecord(t, pmr, jsonStr) + + // exclude one value + complexNested = util_message.ComplexNested{ + AllTheTypesNoAnyB: []*util_message.AllTheTypesNoAny{f.msg.(*util_message.AllTheTypesNoAny)}, + } + jsonStr = fmt.Sprintf(`[{ "all_the_types_no_any_a": [], "all_the_types_no_any_b": [%s] }]`, f.jsonStr) + pmr = NewProtobufMessageReflection(&complexNested) + CheckSchema(t, pmr, schema) + CheckRecord(t, pmr, jsonStr) + + // exclude both values + complexNested = util_message.ComplexNested{} + jsonStr = `[{ "all_the_types_no_any_a": [], "all_the_types_no_any_b": [] }]` + pmr = NewProtobufMessageReflection(&complexNested) + CheckSchema(t, pmr, schema) + CheckRecord(t, pmr, jsonStr) + + schema = `schema: + fields: 2 + - complex_nested: type=struct, message: struct, oneofstring: utf8, oneofmessage: struct, simple_map: map, complex_map: map, items_nullable>, simple_list: list, complex_list: list, nullable>>, nullable>, all_the_types_no_any_b: list, message: struct, oneofstring: utf8, oneofmessage: struct, simple_map: map, complex_map: map, items_nullable>, simple_list: list, complex_list: list, nullable>>, nullable>>, nullable + - simple_nested: type=struct, nullable>, simple_b: list, nullable>>, nullable` + + deepNested := util_message.DeepNested{ + ComplexNested: &complexNested, + SimpleNested: &simpleNested, + } + jsonStr = `[{ "simple_nested": {"simple_a":[], "simple_b":[]}, "complex_nested": {"all_the_types_no_any_a": [], "all_the_types_no_any_b": []} }]` + pmr = NewProtobufMessageReflection(&deepNested) + CheckSchema(t, pmr, schema) + CheckRecord(t, pmr, jsonStr) + + // exclude one value + deepNested = util_message.DeepNested{ + ComplexNested: &complexNested, + } + jsonStr = `[{ "simple_nested": null, "complex_nested": {"all_the_types_no_any_a": [], "all_the_types_no_any_b": []} }]` + pmr = NewProtobufMessageReflection(&deepNested) + CheckSchema(t, pmr, schema) + CheckRecord(t, pmr, jsonStr) + + // exclude both values + deepNested = util_message.DeepNested{} + pmr = NewProtobufMessageReflection(&deepNested) + jsonStr = `[{ "simple_nested": null, "complex_nested": null }]` + CheckSchema(t, pmr, schema) + CheckRecord(t, pmr, jsonStr) } type testProtobufReflection struct { diff --git a/go/arrow/util/util_message/types.pb.go b/go/arrow/util/util_message/types.pb.go index 80e18847c1970..6486b2cc87a09 100644 --- a/go/arrow/util/util_message/types.pb.go +++ b/go/arrow/util/util_message/types.pb.go @@ -23,12 +23,11 @@ package util_message import ( - reflect "reflect" - sync "sync" - protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" anypb "google.golang.org/protobuf/types/known/anypb" + reflect "reflect" + sync "sync" ) const ( @@ -84,6 +83,52 @@ func (AllTheTypes_ExampleEnum) EnumDescriptor() ([]byte, []int) { return file_messages_types_proto_rawDescGZIP(), []int{1, 0} } +type AllTheTypesNoAny_ExampleEnum int32 + +const ( + AllTheTypesNoAny_OPTION_0 AllTheTypesNoAny_ExampleEnum = 0 + AllTheTypesNoAny_OPTION_1 AllTheTypesNoAny_ExampleEnum = 1 +) + +// Enum value maps for AllTheTypesNoAny_ExampleEnum. +var ( + AllTheTypesNoAny_ExampleEnum_name = map[int32]string{ + 0: "OPTION_0", + 1: "OPTION_1", + } + AllTheTypesNoAny_ExampleEnum_value = map[string]int32{ + "OPTION_0": 0, + "OPTION_1": 1, + } +) + +func (x AllTheTypesNoAny_ExampleEnum) Enum() *AllTheTypesNoAny_ExampleEnum { + p := new(AllTheTypesNoAny_ExampleEnum) + *p = x + return p +} + +func (x AllTheTypesNoAny_ExampleEnum) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (AllTheTypesNoAny_ExampleEnum) Descriptor() protoreflect.EnumDescriptor { + return file_messages_types_proto_enumTypes[1].Descriptor() +} + +func (AllTheTypesNoAny_ExampleEnum) Type() protoreflect.EnumType { + return &file_messages_types_proto_enumTypes[1] +} + +func (x AllTheTypesNoAny_ExampleEnum) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use AllTheTypesNoAny_ExampleEnum.Descriptor instead. +func (AllTheTypesNoAny_ExampleEnum) EnumDescriptor() ([]byte, []int) { + return file_messages_types_proto_rawDescGZIP(), []int{2, 0} +} + type ExampleMessage struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -372,6 +417,404 @@ func (*AllTheTypes_Oneofstring) isAllTheTypes_Oneof() {} func (*AllTheTypes_Oneofmessage) isAllTheTypes_Oneof() {} +type AllTheTypesNoAny struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Str string `protobuf:"bytes,1,opt,name=str,proto3" json:"str,omitempty"` + Int32 int32 `protobuf:"varint,2,opt,name=int32,proto3" json:"int32,omitempty"` + Int64 int64 `protobuf:"varint,3,opt,name=int64,proto3" json:"int64,omitempty"` + Sint32 int32 `protobuf:"zigzag32,4,opt,name=sint32,proto3" json:"sint32,omitempty"` + Sin64 int64 `protobuf:"zigzag64,5,opt,name=sin64,proto3" json:"sin64,omitempty"` + Uint32 uint32 `protobuf:"varint,6,opt,name=uint32,proto3" json:"uint32,omitempty"` + Uint64 uint64 `protobuf:"varint,7,opt,name=uint64,proto3" json:"uint64,omitempty"` + Fixed32 uint32 `protobuf:"fixed32,8,opt,name=fixed32,proto3" json:"fixed32,omitempty"` + Fixed64 uint64 `protobuf:"fixed64,9,opt,name=fixed64,proto3" json:"fixed64,omitempty"` + Sfixed32 int32 `protobuf:"fixed32,10,opt,name=sfixed32,proto3" json:"sfixed32,omitempty"` + Bool bool `protobuf:"varint,11,opt,name=bool,proto3" json:"bool,omitempty"` + Bytes []byte `protobuf:"bytes,12,opt,name=bytes,proto3" json:"bytes,omitempty"` + Double float64 `protobuf:"fixed64,13,opt,name=double,proto3" json:"double,omitempty"` + Enum AllTheTypesNoAny_ExampleEnum `protobuf:"varint,14,opt,name=enum,proto3,enum=AllTheTypesNoAny_ExampleEnum" json:"enum,omitempty"` + Message *ExampleMessage `protobuf:"bytes,15,opt,name=message,proto3" json:"message,omitempty"` + // Types that are assignable to Oneof: + // + // *AllTheTypesNoAny_Oneofstring + // *AllTheTypesNoAny_Oneofmessage + Oneof isAllTheTypesNoAny_Oneof `protobuf_oneof:"oneof"` + SimpleMap map[int32]string `protobuf:"bytes,19,rep,name=simple_map,json=simpleMap,proto3" json:"simple_map,omitempty" protobuf_key:"varint,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` + ComplexMap map[string]*ExampleMessage `protobuf:"bytes,20,rep,name=complex_map,json=complexMap,proto3" json:"complex_map,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` + SimpleList []string `protobuf:"bytes,21,rep,name=simple_list,json=simpleList,proto3" json:"simple_list,omitempty"` + ComplexList []*ExampleMessage `protobuf:"bytes,22,rep,name=complex_list,json=complexList,proto3" json:"complex_list,omitempty"` +} + +func (x *AllTheTypesNoAny) Reset() { + *x = AllTheTypesNoAny{} + if protoimpl.UnsafeEnabled { + mi := &file_messages_types_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *AllTheTypesNoAny) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*AllTheTypesNoAny) ProtoMessage() {} + +func (x *AllTheTypesNoAny) ProtoReflect() protoreflect.Message { + mi := &file_messages_types_proto_msgTypes[2] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use AllTheTypesNoAny.ProtoReflect.Descriptor instead. +func (*AllTheTypesNoAny) Descriptor() ([]byte, []int) { + return file_messages_types_proto_rawDescGZIP(), []int{2} +} + +func (x *AllTheTypesNoAny) GetStr() string { + if x != nil { + return x.Str + } + return "" +} + +func (x *AllTheTypesNoAny) GetInt32() int32 { + if x != nil { + return x.Int32 + } + return 0 +} + +func (x *AllTheTypesNoAny) GetInt64() int64 { + if x != nil { + return x.Int64 + } + return 0 +} + +func (x *AllTheTypesNoAny) GetSint32() int32 { + if x != nil { + return x.Sint32 + } + return 0 +} + +func (x *AllTheTypesNoAny) GetSin64() int64 { + if x != nil { + return x.Sin64 + } + return 0 +} + +func (x *AllTheTypesNoAny) GetUint32() uint32 { + if x != nil { + return x.Uint32 + } + return 0 +} + +func (x *AllTheTypesNoAny) GetUint64() uint64 { + if x != nil { + return x.Uint64 + } + return 0 +} + +func (x *AllTheTypesNoAny) GetFixed32() uint32 { + if x != nil { + return x.Fixed32 + } + return 0 +} + +func (x *AllTheTypesNoAny) GetFixed64() uint64 { + if x != nil { + return x.Fixed64 + } + return 0 +} + +func (x *AllTheTypesNoAny) GetSfixed32() int32 { + if x != nil { + return x.Sfixed32 + } + return 0 +} + +func (x *AllTheTypesNoAny) GetBool() bool { + if x != nil { + return x.Bool + } + return false +} + +func (x *AllTheTypesNoAny) GetBytes() []byte { + if x != nil { + return x.Bytes + } + return nil +} + +func (x *AllTheTypesNoAny) GetDouble() float64 { + if x != nil { + return x.Double + } + return 0 +} + +func (x *AllTheTypesNoAny) GetEnum() AllTheTypesNoAny_ExampleEnum { + if x != nil { + return x.Enum + } + return AllTheTypesNoAny_OPTION_0 +} + +func (x *AllTheTypesNoAny) GetMessage() *ExampleMessage { + if x != nil { + return x.Message + } + return nil +} + +func (m *AllTheTypesNoAny) GetOneof() isAllTheTypesNoAny_Oneof { + if m != nil { + return m.Oneof + } + return nil +} + +func (x *AllTheTypesNoAny) GetOneofstring() string { + if x, ok := x.GetOneof().(*AllTheTypesNoAny_Oneofstring); ok { + return x.Oneofstring + } + return "" +} + +func (x *AllTheTypesNoAny) GetOneofmessage() *ExampleMessage { + if x, ok := x.GetOneof().(*AllTheTypesNoAny_Oneofmessage); ok { + return x.Oneofmessage + } + return nil +} + +func (x *AllTheTypesNoAny) GetSimpleMap() map[int32]string { + if x != nil { + return x.SimpleMap + } + return nil +} + +func (x *AllTheTypesNoAny) GetComplexMap() map[string]*ExampleMessage { + if x != nil { + return x.ComplexMap + } + return nil +} + +func (x *AllTheTypesNoAny) GetSimpleList() []string { + if x != nil { + return x.SimpleList + } + return nil +} + +func (x *AllTheTypesNoAny) GetComplexList() []*ExampleMessage { + if x != nil { + return x.ComplexList + } + return nil +} + +type isAllTheTypesNoAny_Oneof interface { + isAllTheTypesNoAny_Oneof() +} + +type AllTheTypesNoAny_Oneofstring struct { + Oneofstring string `protobuf:"bytes,16,opt,name=oneofstring,proto3,oneof"` +} + +type AllTheTypesNoAny_Oneofmessage struct { + Oneofmessage *ExampleMessage `protobuf:"bytes,17,opt,name=oneofmessage,proto3,oneof"` +} + +func (*AllTheTypesNoAny_Oneofstring) isAllTheTypesNoAny_Oneof() {} + +func (*AllTheTypesNoAny_Oneofmessage) isAllTheTypesNoAny_Oneof() {} + +type SimpleNested struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + SimpleA []*ExampleMessage `protobuf:"bytes,1,rep,name=simple_a,json=simpleA,proto3" json:"simple_a,omitempty"` + SimpleB []*ExampleMessage `protobuf:"bytes,2,rep,name=simple_b,json=simpleB,proto3" json:"simple_b,omitempty"` +} + +func (x *SimpleNested) Reset() { + *x = SimpleNested{} + if protoimpl.UnsafeEnabled { + mi := &file_messages_types_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *SimpleNested) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SimpleNested) ProtoMessage() {} + +func (x *SimpleNested) ProtoReflect() protoreflect.Message { + mi := &file_messages_types_proto_msgTypes[3] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SimpleNested.ProtoReflect.Descriptor instead. +func (*SimpleNested) Descriptor() ([]byte, []int) { + return file_messages_types_proto_rawDescGZIP(), []int{3} +} + +func (x *SimpleNested) GetSimpleA() []*ExampleMessage { + if x != nil { + return x.SimpleA + } + return nil +} + +func (x *SimpleNested) GetSimpleB() []*ExampleMessage { + if x != nil { + return x.SimpleB + } + return nil +} + +type ComplexNested struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + AllTheTypesNoAnyA []*AllTheTypesNoAny `protobuf:"bytes,1,rep,name=all_the_types_no_any_a,json=allTheTypesNoAnyA,proto3" json:"all_the_types_no_any_a,omitempty"` + AllTheTypesNoAnyB []*AllTheTypesNoAny `protobuf:"bytes,2,rep,name=all_the_types_no_any_b,json=allTheTypesNoAnyB,proto3" json:"all_the_types_no_any_b,omitempty"` +} + +func (x *ComplexNested) Reset() { + *x = ComplexNested{} + if protoimpl.UnsafeEnabled { + mi := &file_messages_types_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *ComplexNested) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ComplexNested) ProtoMessage() {} + +func (x *ComplexNested) ProtoReflect() protoreflect.Message { + mi := &file_messages_types_proto_msgTypes[4] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ComplexNested.ProtoReflect.Descriptor instead. +func (*ComplexNested) Descriptor() ([]byte, []int) { + return file_messages_types_proto_rawDescGZIP(), []int{4} +} + +func (x *ComplexNested) GetAllTheTypesNoAnyA() []*AllTheTypesNoAny { + if x != nil { + return x.AllTheTypesNoAnyA + } + return nil +} + +func (x *ComplexNested) GetAllTheTypesNoAnyB() []*AllTheTypesNoAny { + if x != nil { + return x.AllTheTypesNoAnyB + } + return nil +} + +type DeepNested struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + ComplexNested *ComplexNested `protobuf:"bytes,1,opt,name=complex_nested,json=complexNested,proto3" json:"complex_nested,omitempty"` + SimpleNested *SimpleNested `protobuf:"bytes,2,opt,name=simple_nested,json=simpleNested,proto3" json:"simple_nested,omitempty"` +} + +func (x *DeepNested) Reset() { + *x = DeepNested{} + if protoimpl.UnsafeEnabled { + mi := &file_messages_types_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *DeepNested) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DeepNested) ProtoMessage() {} + +func (x *DeepNested) ProtoReflect() protoreflect.Message { + mi := &file_messages_types_proto_msgTypes[5] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DeepNested.ProtoReflect.Descriptor instead. +func (*DeepNested) Descriptor() ([]byte, []int) { + return file_messages_types_proto_rawDescGZIP(), []int{5} +} + +func (x *DeepNested) GetComplexNested() *ComplexNested { + if x != nil { + return x.ComplexNested + } + return nil +} + +func (x *DeepNested) GetSimpleNested() *SimpleNested { + if x != nil { + return x.SimpleNested + } + return nil +} + var File_messages_types_proto protoreflect.FileDescriptor var file_messages_types_proto_rawDesc = []byte{ @@ -439,9 +882,90 @@ var file_messages_types_proto_rawDesc = []byte{ 0x02, 0x38, 0x01, 0x22, 0x29, 0x0a, 0x0b, 0x45, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x45, 0x6e, 0x75, 0x6d, 0x12, 0x0c, 0x0a, 0x08, 0x4f, 0x50, 0x54, 0x49, 0x4f, 0x4e, 0x5f, 0x30, 0x10, 0x00, 0x12, 0x0c, 0x0a, 0x08, 0x4f, 0x50, 0x54, 0x49, 0x4f, 0x4e, 0x5f, 0x31, 0x10, 0x01, 0x42, 0x07, - 0x0a, 0x05, 0x6f, 0x6e, 0x65, 0x6f, 0x66, 0x42, 0x11, 0x5a, 0x0f, 0x2e, 0x2e, 0x2f, 0x75, 0x74, - 0x69, 0x6c, 0x5f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x33, + 0x0a, 0x05, 0x6f, 0x6e, 0x65, 0x6f, 0x66, 0x22, 0x95, 0x07, 0x0a, 0x10, 0x41, 0x6c, 0x6c, 0x54, + 0x68, 0x65, 0x54, 0x79, 0x70, 0x65, 0x73, 0x4e, 0x6f, 0x41, 0x6e, 0x79, 0x12, 0x10, 0x0a, 0x03, + 0x73, 0x74, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x73, 0x74, 0x72, 0x12, 0x14, + 0x0a, 0x05, 0x69, 0x6e, 0x74, 0x33, 0x32, 0x18, 0x02, 0x20, 0x01, 0x28, 0x05, 0x52, 0x05, 0x69, + 0x6e, 0x74, 0x33, 0x32, 0x12, 0x14, 0x0a, 0x05, 0x69, 0x6e, 0x74, 0x36, 0x34, 0x18, 0x03, 0x20, + 0x01, 0x28, 0x03, 0x52, 0x05, 0x69, 0x6e, 0x74, 0x36, 0x34, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x69, + 0x6e, 0x74, 0x33, 0x32, 0x18, 0x04, 0x20, 0x01, 0x28, 0x11, 0x52, 0x06, 0x73, 0x69, 0x6e, 0x74, + 0x33, 0x32, 0x12, 0x14, 0x0a, 0x05, 0x73, 0x69, 0x6e, 0x36, 0x34, 0x18, 0x05, 0x20, 0x01, 0x28, + 0x12, 0x52, 0x05, 0x73, 0x69, 0x6e, 0x36, 0x34, 0x12, 0x16, 0x0a, 0x06, 0x75, 0x69, 0x6e, 0x74, + 0x33, 0x32, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x06, 0x75, 0x69, 0x6e, 0x74, 0x33, 0x32, + 0x12, 0x16, 0x0a, 0x06, 0x75, 0x69, 0x6e, 0x74, 0x36, 0x34, 0x18, 0x07, 0x20, 0x01, 0x28, 0x04, + 0x52, 0x06, 0x75, 0x69, 0x6e, 0x74, 0x36, 0x34, 0x12, 0x18, 0x0a, 0x07, 0x66, 0x69, 0x78, 0x65, + 0x64, 0x33, 0x32, 0x18, 0x08, 0x20, 0x01, 0x28, 0x07, 0x52, 0x07, 0x66, 0x69, 0x78, 0x65, 0x64, + 0x33, 0x32, 0x12, 0x18, 0x0a, 0x07, 0x66, 0x69, 0x78, 0x65, 0x64, 0x36, 0x34, 0x18, 0x09, 0x20, + 0x01, 0x28, 0x06, 0x52, 0x07, 0x66, 0x69, 0x78, 0x65, 0x64, 0x36, 0x34, 0x12, 0x1a, 0x0a, 0x08, + 0x73, 0x66, 0x69, 0x78, 0x65, 0x64, 0x33, 0x32, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x0f, 0x52, 0x08, + 0x73, 0x66, 0x69, 0x78, 0x65, 0x64, 0x33, 0x32, 0x12, 0x12, 0x0a, 0x04, 0x62, 0x6f, 0x6f, 0x6c, + 0x18, 0x0b, 0x20, 0x01, 0x28, 0x08, 0x52, 0x04, 0x62, 0x6f, 0x6f, 0x6c, 0x12, 0x14, 0x0a, 0x05, + 0x62, 0x79, 0x74, 0x65, 0x73, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x05, 0x62, 0x79, 0x74, + 0x65, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x6f, 0x75, 0x62, 0x6c, 0x65, 0x18, 0x0d, 0x20, 0x01, + 0x28, 0x01, 0x52, 0x06, 0x64, 0x6f, 0x75, 0x62, 0x6c, 0x65, 0x12, 0x31, 0x0a, 0x04, 0x65, 0x6e, + 0x75, 0x6d, 0x18, 0x0e, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1d, 0x2e, 0x41, 0x6c, 0x6c, 0x54, 0x68, + 0x65, 0x54, 0x79, 0x70, 0x65, 0x73, 0x4e, 0x6f, 0x41, 0x6e, 0x79, 0x2e, 0x45, 0x78, 0x61, 0x6d, + 0x70, 0x6c, 0x65, 0x45, 0x6e, 0x75, 0x6d, 0x52, 0x04, 0x65, 0x6e, 0x75, 0x6d, 0x12, 0x29, 0x0a, + 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x0f, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0f, + 0x2e, 0x45, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x52, + 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x22, 0x0a, 0x0b, 0x6f, 0x6e, 0x65, 0x6f, + 0x66, 0x73, 0x74, 0x72, 0x69, 0x6e, 0x67, 0x18, 0x10, 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, 0x52, + 0x0b, 0x6f, 0x6e, 0x65, 0x6f, 0x66, 0x73, 0x74, 0x72, 0x69, 0x6e, 0x67, 0x12, 0x35, 0x0a, 0x0c, + 0x6f, 0x6e, 0x65, 0x6f, 0x66, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x11, 0x20, 0x01, + 0x28, 0x0b, 0x32, 0x0f, 0x2e, 0x45, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x4d, 0x65, 0x73, 0x73, + 0x61, 0x67, 0x65, 0x48, 0x00, 0x52, 0x0c, 0x6f, 0x6e, 0x65, 0x6f, 0x66, 0x6d, 0x65, 0x73, 0x73, + 0x61, 0x67, 0x65, 0x12, 0x3f, 0x0a, 0x0a, 0x73, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x5f, 0x6d, 0x61, + 0x70, 0x18, 0x13, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x20, 0x2e, 0x41, 0x6c, 0x6c, 0x54, 0x68, 0x65, + 0x54, 0x79, 0x70, 0x65, 0x73, 0x4e, 0x6f, 0x41, 0x6e, 0x79, 0x2e, 0x53, 0x69, 0x6d, 0x70, 0x6c, + 0x65, 0x4d, 0x61, 0x70, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x09, 0x73, 0x69, 0x6d, 0x70, 0x6c, + 0x65, 0x4d, 0x61, 0x70, 0x12, 0x42, 0x0a, 0x0b, 0x63, 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x78, 0x5f, + 0x6d, 0x61, 0x70, 0x18, 0x14, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x21, 0x2e, 0x41, 0x6c, 0x6c, 0x54, + 0x68, 0x65, 0x54, 0x79, 0x70, 0x65, 0x73, 0x4e, 0x6f, 0x41, 0x6e, 0x79, 0x2e, 0x43, 0x6f, 0x6d, + 0x70, 0x6c, 0x65, 0x78, 0x4d, 0x61, 0x70, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x0a, 0x63, 0x6f, + 0x6d, 0x70, 0x6c, 0x65, 0x78, 0x4d, 0x61, 0x70, 0x12, 0x1f, 0x0a, 0x0b, 0x73, 0x69, 0x6d, 0x70, + 0x6c, 0x65, 0x5f, 0x6c, 0x69, 0x73, 0x74, 0x18, 0x15, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0a, 0x73, + 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x4c, 0x69, 0x73, 0x74, 0x12, 0x32, 0x0a, 0x0c, 0x63, 0x6f, 0x6d, + 0x70, 0x6c, 0x65, 0x78, 0x5f, 0x6c, 0x69, 0x73, 0x74, 0x18, 0x16, 0x20, 0x03, 0x28, 0x0b, 0x32, + 0x0f, 0x2e, 0x45, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, + 0x52, 0x0b, 0x63, 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x78, 0x4c, 0x69, 0x73, 0x74, 0x1a, 0x3c, 0x0a, + 0x0e, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x4d, 0x61, 0x70, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, + 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x03, 0x6b, 0x65, + 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x1a, 0x4e, 0x0a, 0x0f, 0x43, + 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x78, 0x4d, 0x61, 0x70, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, + 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, + 0x12, 0x25, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, + 0x0f, 0x2e, 0x45, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, + 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0x29, 0x0a, 0x0b, 0x45, + 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x45, 0x6e, 0x75, 0x6d, 0x12, 0x0c, 0x0a, 0x08, 0x4f, 0x50, + 0x54, 0x49, 0x4f, 0x4e, 0x5f, 0x30, 0x10, 0x00, 0x12, 0x0c, 0x0a, 0x08, 0x4f, 0x50, 0x54, 0x49, + 0x4f, 0x4e, 0x5f, 0x31, 0x10, 0x01, 0x42, 0x07, 0x0a, 0x05, 0x6f, 0x6e, 0x65, 0x6f, 0x66, 0x22, + 0x66, 0x0a, 0x0c, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x4e, 0x65, 0x73, 0x74, 0x65, 0x64, 0x12, + 0x2a, 0x0a, 0x08, 0x73, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x5f, 0x61, 0x18, 0x01, 0x20, 0x03, 0x28, + 0x0b, 0x32, 0x0f, 0x2e, 0x45, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x4d, 0x65, 0x73, 0x73, 0x61, + 0x67, 0x65, 0x52, 0x07, 0x73, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x41, 0x12, 0x2a, 0x0a, 0x08, 0x73, + 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x5f, 0x62, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x0f, 0x2e, + 0x45, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x52, 0x07, + 0x73, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x42, 0x22, 0x9b, 0x01, 0x0a, 0x0d, 0x43, 0x6f, 0x6d, 0x70, + 0x6c, 0x65, 0x78, 0x4e, 0x65, 0x73, 0x74, 0x65, 0x64, 0x12, 0x44, 0x0a, 0x16, 0x61, 0x6c, 0x6c, + 0x5f, 0x74, 0x68, 0x65, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x73, 0x5f, 0x6e, 0x6f, 0x5f, 0x61, 0x6e, + 0x79, 0x5f, 0x61, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x41, 0x6c, 0x6c, 0x54, + 0x68, 0x65, 0x54, 0x79, 0x70, 0x65, 0x73, 0x4e, 0x6f, 0x41, 0x6e, 0x79, 0x52, 0x11, 0x61, 0x6c, + 0x6c, 0x54, 0x68, 0x65, 0x54, 0x79, 0x70, 0x65, 0x73, 0x4e, 0x6f, 0x41, 0x6e, 0x79, 0x41, 0x12, + 0x44, 0x0a, 0x16, 0x61, 0x6c, 0x6c, 0x5f, 0x74, 0x68, 0x65, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x73, + 0x5f, 0x6e, 0x6f, 0x5f, 0x61, 0x6e, 0x79, 0x5f, 0x62, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, + 0x11, 0x2e, 0x41, 0x6c, 0x6c, 0x54, 0x68, 0x65, 0x54, 0x79, 0x70, 0x65, 0x73, 0x4e, 0x6f, 0x41, + 0x6e, 0x79, 0x52, 0x11, 0x61, 0x6c, 0x6c, 0x54, 0x68, 0x65, 0x54, 0x79, 0x70, 0x65, 0x73, 0x4e, + 0x6f, 0x41, 0x6e, 0x79, 0x42, 0x22, 0x77, 0x0a, 0x0a, 0x44, 0x65, 0x65, 0x70, 0x4e, 0x65, 0x73, + 0x74, 0x65, 0x64, 0x12, 0x35, 0x0a, 0x0e, 0x63, 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x78, 0x5f, 0x6e, + 0x65, 0x73, 0x74, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0e, 0x2e, 0x43, 0x6f, + 0x6d, 0x70, 0x6c, 0x65, 0x78, 0x4e, 0x65, 0x73, 0x74, 0x65, 0x64, 0x52, 0x0d, 0x63, 0x6f, 0x6d, + 0x70, 0x6c, 0x65, 0x78, 0x4e, 0x65, 0x73, 0x74, 0x65, 0x64, 0x12, 0x32, 0x0a, 0x0d, 0x73, 0x69, + 0x6d, 0x70, 0x6c, 0x65, 0x5f, 0x6e, 0x65, 0x73, 0x74, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x0b, 0x32, 0x0d, 0x2e, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x4e, 0x65, 0x73, 0x74, 0x65, 0x64, + 0x52, 0x0c, 0x73, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x4e, 0x65, 0x73, 0x74, 0x65, 0x64, 0x42, 0x11, + 0x5a, 0x0f, 0x2e, 0x2e, 0x2f, 0x75, 0x74, 0x69, 0x6c, 0x5f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, + 0x65, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -456,30 +980,50 @@ func file_messages_types_proto_rawDescGZIP() []byte { return file_messages_types_proto_rawDescData } -var file_messages_types_proto_enumTypes = make([]protoimpl.EnumInfo, 1) -var file_messages_types_proto_msgTypes = make([]protoimpl.MessageInfo, 4) +var file_messages_types_proto_enumTypes = make([]protoimpl.EnumInfo, 2) +var file_messages_types_proto_msgTypes = make([]protoimpl.MessageInfo, 10) var file_messages_types_proto_goTypes = []interface{}{ - (AllTheTypes_ExampleEnum)(0), // 0: AllTheTypes.ExampleEnum - (*ExampleMessage)(nil), // 1: ExampleMessage - (*AllTheTypes)(nil), // 2: AllTheTypes - nil, // 3: AllTheTypes.SimpleMapEntry - nil, // 4: AllTheTypes.ComplexMapEntry - (*anypb.Any)(nil), // 5: google.protobuf.Any + (AllTheTypes_ExampleEnum)(0), // 0: AllTheTypes.ExampleEnum + (AllTheTypesNoAny_ExampleEnum)(0), // 1: AllTheTypesNoAny.ExampleEnum + (*ExampleMessage)(nil), // 2: ExampleMessage + (*AllTheTypes)(nil), // 3: AllTheTypes + (*AllTheTypesNoAny)(nil), // 4: AllTheTypesNoAny + (*SimpleNested)(nil), // 5: SimpleNested + (*ComplexNested)(nil), // 6: ComplexNested + (*DeepNested)(nil), // 7: DeepNested + nil, // 8: AllTheTypes.SimpleMapEntry + nil, // 9: AllTheTypes.ComplexMapEntry + nil, // 10: AllTheTypesNoAny.SimpleMapEntry + nil, // 11: AllTheTypesNoAny.ComplexMapEntry + (*anypb.Any)(nil), // 12: google.protobuf.Any } var file_messages_types_proto_depIdxs = []int32{ - 0, // 0: AllTheTypes.enum:type_name -> AllTheTypes.ExampleEnum - 1, // 1: AllTheTypes.message:type_name -> ExampleMessage - 1, // 2: AllTheTypes.oneofmessage:type_name -> ExampleMessage - 5, // 3: AllTheTypes.any:type_name -> google.protobuf.Any - 3, // 4: AllTheTypes.simple_map:type_name -> AllTheTypes.SimpleMapEntry - 4, // 5: AllTheTypes.complex_map:type_name -> AllTheTypes.ComplexMapEntry - 1, // 6: AllTheTypes.complex_list:type_name -> ExampleMessage - 1, // 7: AllTheTypes.ComplexMapEntry.value:type_name -> ExampleMessage - 8, // [8:8] is the sub-list for method output_type - 8, // [8:8] is the sub-list for method input_type - 8, // [8:8] is the sub-list for extension type_name - 8, // [8:8] is the sub-list for extension extendee - 0, // [0:8] is the sub-list for field type_name + 0, // 0: AllTheTypes.enum:type_name -> AllTheTypes.ExampleEnum + 2, // 1: AllTheTypes.message:type_name -> ExampleMessage + 2, // 2: AllTheTypes.oneofmessage:type_name -> ExampleMessage + 12, // 3: AllTheTypes.any:type_name -> google.protobuf.Any + 8, // 4: AllTheTypes.simple_map:type_name -> AllTheTypes.SimpleMapEntry + 9, // 5: AllTheTypes.complex_map:type_name -> AllTheTypes.ComplexMapEntry + 2, // 6: AllTheTypes.complex_list:type_name -> ExampleMessage + 1, // 7: AllTheTypesNoAny.enum:type_name -> AllTheTypesNoAny.ExampleEnum + 2, // 8: AllTheTypesNoAny.message:type_name -> ExampleMessage + 2, // 9: AllTheTypesNoAny.oneofmessage:type_name -> ExampleMessage + 10, // 10: AllTheTypesNoAny.simple_map:type_name -> AllTheTypesNoAny.SimpleMapEntry + 11, // 11: AllTheTypesNoAny.complex_map:type_name -> AllTheTypesNoAny.ComplexMapEntry + 2, // 12: AllTheTypesNoAny.complex_list:type_name -> ExampleMessage + 2, // 13: SimpleNested.simple_a:type_name -> ExampleMessage + 2, // 14: SimpleNested.simple_b:type_name -> ExampleMessage + 4, // 15: ComplexNested.all_the_types_no_any_a:type_name -> AllTheTypesNoAny + 4, // 16: ComplexNested.all_the_types_no_any_b:type_name -> AllTheTypesNoAny + 6, // 17: DeepNested.complex_nested:type_name -> ComplexNested + 5, // 18: DeepNested.simple_nested:type_name -> SimpleNested + 2, // 19: AllTheTypes.ComplexMapEntry.value:type_name -> ExampleMessage + 2, // 20: AllTheTypesNoAny.ComplexMapEntry.value:type_name -> ExampleMessage + 21, // [21:21] is the sub-list for method output_type + 21, // [21:21] is the sub-list for method input_type + 21, // [21:21] is the sub-list for extension type_name + 21, // [21:21] is the sub-list for extension extendee + 0, // [0:21] is the sub-list for field type_name } func init() { file_messages_types_proto_init() } @@ -512,18 +1056,70 @@ func file_messages_types_proto_init() { return nil } } + file_messages_types_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*AllTheTypesNoAny); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_messages_types_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*SimpleNested); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_messages_types_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*ComplexNested); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_messages_types_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*DeepNested); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } } file_messages_types_proto_msgTypes[1].OneofWrappers = []interface{}{ (*AllTheTypes_Oneofstring)(nil), (*AllTheTypes_Oneofmessage)(nil), } + file_messages_types_proto_msgTypes[2].OneofWrappers = []interface{}{ + (*AllTheTypesNoAny_Oneofstring)(nil), + (*AllTheTypesNoAny_Oneofmessage)(nil), + } type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_messages_types_proto_rawDesc, - NumEnums: 1, - NumMessages: 4, + NumEnums: 2, + NumMessages: 10, NumExtensions: 0, NumServices: 0, }, From a380d695a6672f6981d1fe36cd1acc8d68ee9c3e Mon Sep 17 00:00:00 2001 From: mwish Date: Tue, 20 Aug 2024 01:27:45 +0800 Subject: [PATCH 10/10] GH-43733: [C++] Fix Scalar boolean handling in row encoder (#43734) ### Rationale for this change See https://github.com/apache/arrow/issues/43733 ### What changes are included in this PR? Separate Null and Valid handling when BooleanKeyEncoder::Encode meets a Null This patch also does a migration: * row_encoder.cc -> row_encoder_internal.cc * move row_encoder_internal{.cc|.h} from `compute/kernel` to `compute/row` ### Are these changes tested? Yes ### Are there any user-facing changes? No * GitHub Issue: #43733 Authored-by: mwish Signed-off-by: Antoine Pitrou --- cpp/src/arrow/CMakeLists.txt | 2 +- cpp/src/arrow/acero/asof_join_node_test.cc | 2 +- cpp/src/arrow/acero/hash_join.cc | 2 +- cpp/src/arrow/acero/hash_join_benchmark.cc | 2 +- cpp/src/arrow/acero/hash_join_dict.h | 2 +- cpp/src/arrow/acero/hash_join_node_test.cc | 2 +- cpp/src/arrow/acero/swiss_join.cc | 2 +- cpp/src/arrow/acero/swiss_join_internal.h | 2 +- cpp/src/arrow/acero/tpch_node_test.cc | 2 +- cpp/src/arrow/compute/CMakeLists.txt | 1 + .../arrow/compute/kernels/hash_aggregate.cc | 2 +- cpp/src/arrow/compute/row/grouper.cc | 2 +- .../row_encoder_internal.cc} | 41 ++++++----- .../{kernels => row}/row_encoder_internal.h | 14 ++-- .../compute/row/row_encoder_internal_test.cc | 68 +++++++++++++++++++ cpp/src/arrow/compute/row/row_test.cc | 2 +- 16 files changed, 111 insertions(+), 37 deletions(-) rename cpp/src/arrow/compute/{kernels/row_encoder.cc => row/row_encoder_internal.cc} (93%) rename cpp/src/arrow/compute/{kernels => row}/row_encoder_internal.h (96%) create mode 100644 cpp/src/arrow/compute/row/row_encoder_internal_test.cc diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 67d2c19f98a2d..fb785e1e9571b 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -723,7 +723,6 @@ set(ARROW_COMPUTE_SRCS compute/ordering.cc compute/registry.cc compute/kernels/codegen_internal.cc - compute/kernels/row_encoder.cc compute/kernels/ree_util_internal.cc compute/kernels/scalar_cast_boolean.cc compute/kernels/scalar_cast_dictionary.cc @@ -742,6 +741,7 @@ set(ARROW_COMPUTE_SRCS compute/row/encode_internal.cc compute/row/compare_internal.cc compute/row/grouper.cc + compute/row/row_encoder_internal.cc compute/row/row_internal.cc compute/util.cc compute/util_internal.cc) diff --git a/cpp/src/arrow/acero/asof_join_node_test.cc b/cpp/src/arrow/acero/asof_join_node_test.cc index 051e280a4c53c..555f580028fac 100644 --- a/cpp/src/arrow/acero/asof_join_node_test.cc +++ b/cpp/src/arrow/acero/asof_join_node_test.cc @@ -41,8 +41,8 @@ #include "arrow/acero/util.h" #include "arrow/api.h" #include "arrow/compute/api_scalar.h" -#include "arrow/compute/kernels/row_encoder_internal.h" #include "arrow/compute/kernels/test_util.h" +#include "arrow/compute/row/row_encoder_internal.h" #include "arrow/testing/gtest_util.h" #include "arrow/testing/matchers.h" #include "arrow/testing/random.h" diff --git a/cpp/src/arrow/acero/hash_join.cc b/cpp/src/arrow/acero/hash_join.cc index 5aa70a23f7c9e..ddcd2a0995701 100644 --- a/cpp/src/arrow/acero/hash_join.cc +++ b/cpp/src/arrow/acero/hash_join.cc @@ -27,8 +27,8 @@ #include "arrow/acero/hash_join_dict.h" #include "arrow/acero/task_util.h" -#include "arrow/compute/kernels/row_encoder_internal.h" #include "arrow/compute/row/encode_internal.h" +#include "arrow/compute/row/row_encoder_internal.h" #include "arrow/util/tracing_internal.h" namespace arrow { diff --git a/cpp/src/arrow/acero/hash_join_benchmark.cc b/cpp/src/arrow/acero/hash_join_benchmark.cc index 1f8e02e9f0fcf..470960b1c5062 100644 --- a/cpp/src/arrow/acero/hash_join_benchmark.cc +++ b/cpp/src/arrow/acero/hash_join_benchmark.cc @@ -23,7 +23,7 @@ #include "arrow/acero/test_util_internal.h" #include "arrow/acero/util.h" #include "arrow/api.h" -#include "arrow/compute/kernels/row_encoder_internal.h" +#include "arrow/compute/row/row_encoder_internal.h" #include "arrow/testing/random.h" #include "arrow/util/thread_pool.h" diff --git a/cpp/src/arrow/acero/hash_join_dict.h b/cpp/src/arrow/acero/hash_join_dict.h index c7d8d785d079e..02454a7146278 100644 --- a/cpp/src/arrow/acero/hash_join_dict.h +++ b/cpp/src/arrow/acero/hash_join_dict.h @@ -22,7 +22,7 @@ #include "arrow/acero/schema_util.h" #include "arrow/compute/exec.h" -#include "arrow/compute/kernels/row_encoder_internal.h" +#include "arrow/compute/row/row_encoder_internal.h" #include "arrow/result.h" #include "arrow/status.h" #include "arrow/type.h" diff --git a/cpp/src/arrow/acero/hash_join_node_test.cc b/cpp/src/arrow/acero/hash_join_node_test.cc index 88f9a9e71b768..9065e286a2228 100644 --- a/cpp/src/arrow/acero/hash_join_node_test.cc +++ b/cpp/src/arrow/acero/hash_join_node_test.cc @@ -26,9 +26,9 @@ #include "arrow/acero/test_util_internal.h" #include "arrow/acero/util.h" #include "arrow/api.h" -#include "arrow/compute/kernels/row_encoder_internal.h" #include "arrow/compute/kernels/test_util.h" #include "arrow/compute/light_array_internal.h" +#include "arrow/compute/row/row_encoder_internal.h" #include "arrow/testing/extension_type.h" #include "arrow/testing/generator.h" #include "arrow/testing/gtest_util.h" diff --git a/cpp/src/arrow/acero/swiss_join.cc b/cpp/src/arrow/acero/swiss_join.cc index 40a4b5886e4bb..4d0c8187ac6e2 100644 --- a/cpp/src/arrow/acero/swiss_join.cc +++ b/cpp/src/arrow/acero/swiss_join.cc @@ -24,10 +24,10 @@ #include "arrow/acero/swiss_join_internal.h" #include "arrow/acero/util.h" #include "arrow/array/util.h" // MakeArrayFromScalar -#include "arrow/compute/kernels/row_encoder_internal.h" #include "arrow/compute/key_hash_internal.h" #include "arrow/compute/row/compare_internal.h" #include "arrow/compute/row/encode_internal.h" +#include "arrow/compute/row/row_encoder_internal.h" #include "arrow/util/bit_util.h" #include "arrow/util/bitmap_ops.h" #include "arrow/util/tracing_internal.h" diff --git a/cpp/src/arrow/acero/swiss_join_internal.h b/cpp/src/arrow/acero/swiss_join_internal.h index dceb74abe4f1b..4d749c1c529ae 100644 --- a/cpp/src/arrow/acero/swiss_join_internal.h +++ b/cpp/src/arrow/acero/swiss_join_internal.h @@ -22,10 +22,10 @@ #include "arrow/acero/partition_util.h" #include "arrow/acero/schema_util.h" #include "arrow/acero/task_util.h" -#include "arrow/compute/kernels/row_encoder_internal.h" #include "arrow/compute/key_map_internal.h" #include "arrow/compute/light_array_internal.h" #include "arrow/compute/row/encode_internal.h" +#include "arrow/compute/row/row_encoder_internal.h" namespace arrow { diff --git a/cpp/src/arrow/acero/tpch_node_test.cc b/cpp/src/arrow/acero/tpch_node_test.cc index 076bcf634a6ba..17fb43452bc58 100644 --- a/cpp/src/arrow/acero/tpch_node_test.cc +++ b/cpp/src/arrow/acero/tpch_node_test.cc @@ -27,8 +27,8 @@ #include "arrow/acero/test_util_internal.h" #include "arrow/acero/tpch_node.h" #include "arrow/acero/util.h" -#include "arrow/compute/kernels/row_encoder_internal.h" #include "arrow/compute/kernels/test_util.h" +#include "arrow/compute/row/row_encoder_internal.h" #include "arrow/testing/gtest_util.h" #include "arrow/testing/matchers.h" #include "arrow/testing/random.h" diff --git a/cpp/src/arrow/compute/CMakeLists.txt b/cpp/src/arrow/compute/CMakeLists.txt index e20b45897db95..aa2a2d4e9af0b 100644 --- a/cpp/src/arrow/compute/CMakeLists.txt +++ b/cpp/src/arrow/compute/CMakeLists.txt @@ -92,6 +92,7 @@ add_arrow_test(internals_test key_hash_test.cc row/compare_test.cc row/grouper_test.cc + row/row_encoder_internal_test.cc row/row_test.cc util_internal_test.cc) diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate.cc b/cpp/src/arrow/compute/kernels/hash_aggregate.cc index 54cd695421a93..4bf6a6106dfe5 100644 --- a/cpp/src/arrow/compute/kernels/hash_aggregate.cc +++ b/cpp/src/arrow/compute/kernels/hash_aggregate.cc @@ -33,9 +33,9 @@ #include "arrow/compute/kernels/aggregate_internal.h" #include "arrow/compute/kernels/aggregate_var_std_internal.h" #include "arrow/compute/kernels/common_internal.h" -#include "arrow/compute/kernels/row_encoder_internal.h" #include "arrow/compute/kernels/util_internal.h" #include "arrow/compute/row/grouper.h" +#include "arrow/compute/row/row_encoder_internal.h" #include "arrow/record_batch.h" #include "arrow/stl_allocator.h" #include "arrow/type_traits.h" diff --git a/cpp/src/arrow/compute/row/grouper.cc b/cpp/src/arrow/compute/row/grouper.cc index 45b9ad5971e80..5889f94d96c79 100644 --- a/cpp/src/arrow/compute/row/grouper.cc +++ b/cpp/src/arrow/compute/row/grouper.cc @@ -25,12 +25,12 @@ #include "arrow/compute/api_vector.h" #include "arrow/compute/function.h" -#include "arrow/compute/kernels/row_encoder_internal.h" #include "arrow/compute/key_hash_internal.h" #include "arrow/compute/light_array_internal.h" #include "arrow/compute/registry.h" #include "arrow/compute/row/compare_internal.h" #include "arrow/compute/row/grouper_internal.h" +#include "arrow/compute/row/row_encoder_internal.h" #include "arrow/type.h" #include "arrow/type_traits.h" #include "arrow/util/bitmap_ops.h" diff --git a/cpp/src/arrow/compute/kernels/row_encoder.cc b/cpp/src/arrow/compute/row/row_encoder_internal.cc similarity index 93% rename from cpp/src/arrow/compute/kernels/row_encoder.cc rename to cpp/src/arrow/compute/row/row_encoder_internal.cc index 8224eaa6d6315..414cc6793a5a3 100644 --- a/cpp/src/arrow/compute/kernels/row_encoder.cc +++ b/cpp/src/arrow/compute/row/row_encoder_internal.cc @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -#include "arrow/compute/kernels/row_encoder_internal.h" +#include "arrow/compute/row/row_encoder_internal.h" #include "arrow/util/bitmap_writer.h" #include "arrow/util/logging.h" @@ -75,26 +75,31 @@ void BooleanKeyEncoder::AddLengthNull(int32_t* length) { Status BooleanKeyEncoder::Encode(const ExecValue& data, int64_t batch_length, uint8_t** encoded_bytes) { + auto handle_next_valid_value = [&encoded_bytes](bool value) { + auto& encoded_ptr = *encoded_bytes++; + *encoded_ptr++ = kValidByte; + *encoded_ptr++ = value; + }; + auto handle_next_null_value = [&encoded_bytes]() { + auto& encoded_ptr = *encoded_bytes++; + *encoded_ptr++ = kNullByte; + *encoded_ptr++ = 0; + }; + if (data.is_array()) { - VisitArraySpanInline( - data.array, - [&](bool value) { - auto& encoded_ptr = *encoded_bytes++; - *encoded_ptr++ = kValidByte; - *encoded_ptr++ = value; - }, - [&] { - auto& encoded_ptr = *encoded_bytes++; - *encoded_ptr++ = kNullByte; - *encoded_ptr++ = 0; - }); + VisitArraySpanInline(data.array, handle_next_valid_value, + handle_next_null_value); } else { const auto& scalar = data.scalar_as(); - bool value = scalar.is_valid && scalar.value; - for (int64_t i = 0; i < batch_length; i++) { - auto& encoded_ptr = *encoded_bytes++; - *encoded_ptr++ = kValidByte; - *encoded_ptr++ = value; + if (!scalar.is_valid) { + for (int64_t i = 0; i < batch_length; i++) { + handle_next_null_value(); + } + } else { + const bool value = scalar.value; + for (int64_t i = 0; i < batch_length; i++) { + handle_next_valid_value(value); + } } } return Status::OK(); diff --git a/cpp/src/arrow/compute/kernels/row_encoder_internal.h b/cpp/src/arrow/compute/row/row_encoder_internal.h similarity index 96% rename from cpp/src/arrow/compute/kernels/row_encoder_internal.h rename to cpp/src/arrow/compute/row/row_encoder_internal.h index 9bf7c1d1c4fed..60eb14af504f7 100644 --- a/cpp/src/arrow/compute/kernels/row_encoder_internal.h +++ b/cpp/src/arrow/compute/row/row_encoder_internal.h @@ -29,7 +29,7 @@ using internal::checked_cast; namespace compute { namespace internal { -struct KeyEncoder { +struct ARROW_EXPORT KeyEncoder { // the first byte of an encoded key is used to indicate nullity static constexpr bool kExtraByteForNull = true; @@ -60,7 +60,7 @@ struct KeyEncoder { } }; -struct BooleanKeyEncoder : KeyEncoder { +struct ARROW_EXPORT BooleanKeyEncoder : KeyEncoder { static constexpr int kByteWidth = 1; void AddLength(const ExecValue& data, int64_t batch_length, int32_t* lengths) override; @@ -76,7 +76,7 @@ struct BooleanKeyEncoder : KeyEncoder { MemoryPool* pool) override; }; -struct FixedWidthKeyEncoder : KeyEncoder { +struct ARROW_EXPORT FixedWidthKeyEncoder : KeyEncoder { explicit FixedWidthKeyEncoder(std::shared_ptr type) : type_(std::move(type)), byte_width_(checked_cast(*type_).bit_width() / 8) {} @@ -97,7 +97,7 @@ struct FixedWidthKeyEncoder : KeyEncoder { int byte_width_; }; -struct DictionaryKeyEncoder : FixedWidthKeyEncoder { +struct ARROW_EXPORT DictionaryKeyEncoder : FixedWidthKeyEncoder { DictionaryKeyEncoder(std::shared_ptr type, MemoryPool* pool) : FixedWidthKeyEncoder(std::move(type)), pool_(pool) {} @@ -112,7 +112,7 @@ struct DictionaryKeyEncoder : FixedWidthKeyEncoder { }; template -struct VarLengthKeyEncoder : KeyEncoder { +struct ARROW_EXPORT VarLengthKeyEncoder : KeyEncoder { using Offset = typename T::offset_type; void AddLength(const ExecValue& data, int64_t batch_length, int32_t* lengths) override { @@ -232,7 +232,7 @@ struct VarLengthKeyEncoder : KeyEncoder { std::shared_ptr type_; }; -struct NullKeyEncoder : KeyEncoder { +struct ARROW_EXPORT NullKeyEncoder : KeyEncoder { void AddLength(const ExecValue&, int64_t batch_length, int32_t* lengths) override {} void AddLengthNull(int32_t* length) override {} @@ -274,7 +274,7 @@ class ARROW_EXPORT RowEncoder { } private: - ExecContext* ctx_; + ExecContext* ctx_{nullptr}; std::vector> encoders_; std::vector offsets_; std::vector bytes_; diff --git a/cpp/src/arrow/compute/row/row_encoder_internal_test.cc b/cpp/src/arrow/compute/row/row_encoder_internal_test.cc new file mode 100644 index 0000000000000..78839d1ead557 --- /dev/null +++ b/cpp/src/arrow/compute/row/row_encoder_internal_test.cc @@ -0,0 +1,68 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include + +#include "arrow/compute/row/row_encoder_internal.h" + +#include "arrow/array/validate.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/type.h" +#include "arrow/type_fwd.h" + +namespace arrow::compute::internal { + +// GH-43733: Test that the key encoder can handle boolean scalar values well. +TEST(TestKeyEncoder, BooleanScalar) { + for (auto scalar : {BooleanScalar{}, BooleanScalar{true}, BooleanScalar{false}}) { + BooleanKeyEncoder key_encoder; + SCOPED_TRACE("scalar " + scalar.ToString()); + constexpr int64_t kBatchLength = 10; + std::array lengths{}; + key_encoder.AddLength(ExecValue{&scalar}, kBatchLength, lengths.data()); + // Check that the lengths are all 2. + constexpr int32_t kPayloadWidth = + BooleanKeyEncoder::kByteWidth + BooleanKeyEncoder::kExtraByteForNull; + for (int i = 0; i < kBatchLength; ++i) { + ASSERT_EQ(kPayloadWidth, lengths[i]); + } + std::array, kBatchLength> payloads{}; + std::array payload_ptrs{}; + // Reset the payload pointers to point to the beginning of each payload. + // This is necessary because the key encoder may have modified the pointers. + auto reset_payload_ptrs = [&payload_ptrs, &payloads]() { + std::transform(payloads.begin(), payloads.end(), payload_ptrs.begin(), + [](auto& payload) -> uint8_t* { return payload.data(); }); + }; + reset_payload_ptrs(); + ASSERT_OK(key_encoder.Encode(ExecValue{&scalar}, kBatchLength, payload_ptrs.data())); + reset_payload_ptrs(); + ASSERT_OK_AND_ASSIGN(auto array_data, + key_encoder.Decode(payload_ptrs.data(), kBatchLength, + ::arrow::default_memory_pool())); + ASSERT_EQ(kBatchLength, array_data->length); + auto boolean_array = std::make_shared(array_data); + ASSERT_OK(arrow::internal::ValidateArrayFull(*array_data)); + ASSERT_OK_AND_ASSIGN( + auto expected_array, + MakeArrayFromScalar(scalar, kBatchLength, ::arrow::default_memory_pool())); + AssertArraysEqual(*expected_array, *boolean_array); + } +} + +} // namespace arrow::compute::internal diff --git a/cpp/src/arrow/compute/row/row_test.cc b/cpp/src/arrow/compute/row/row_test.cc index 6aed9e4327812..5057ce91b5bea 100644 --- a/cpp/src/arrow/compute/row/row_test.cc +++ b/cpp/src/arrow/compute/row/row_test.cc @@ -155,7 +155,7 @@ TEST(RowTableLarge, LARGE_MEMORY_TEST(Encode)) { auto value, ::arrow::gen::Constant( std::make_shared(std::string(length_per_binary, 'X'))) ->Generate(1)); - values.push_back(std::move(value)); + values.emplace_back(std::move(value)); ExecBatch batch = ExecBatch(std::move(values), 1); ASSERT_OK(ColumnArraysFromExecBatch(batch, &columns));