diff --git a/cpp/cmake_modules/ThirdpartyToolchain.cmake b/cpp/cmake_modules/ThirdpartyToolchain.cmake index b8e765f08587a..ad7344b09dd4e 100644 --- a/cpp/cmake_modules/ThirdpartyToolchain.cmake +++ b/cpp/cmake_modules/ThirdpartyToolchain.cmake @@ -1812,7 +1812,9 @@ if(ARROW_WITH_PROTOBUF) else() set(ARROW_PROTOBUF_REQUIRED_VERSION "2.6.1") endif() - if(ARROW_ORC OR ARROW_WITH_OPENTELEMETRY) + if(ARROW_ORC + OR ARROW_SUBSTRAIT + OR ARROW_WITH_OPENTELEMETRY) set(ARROW_PROTOBUF_ARROW_CMAKE_PACKAGE_NAME "Arrow") set(ARROW_PROTOBUF_ARROW_PC_PACKAGE_NAME "arrow") elseif(ARROW_FLIGHT) diff --git a/cpp/src/arrow/filesystem/s3fs.cc b/cpp/src/arrow/filesystem/s3fs.cc index 2ba64ee22f54f..640888e1c4fa5 100644 --- a/cpp/src/arrow/filesystem/s3fs.cc +++ b/cpp/src/arrow/filesystem/s3fs.cc @@ -99,12 +99,21 @@ #define ARROW_S3_HAS_CRT #endif +#if ARROW_AWS_SDK_VERSION_CHECK(1, 10, 0) +#define ARROW_S3_HAS_S3CLIENT_CONFIGURATION +#endif + #ifdef ARROW_S3_HAS_CRT #include #include #include #endif +#ifdef ARROW_S3_HAS_S3CLIENT_CONFIGURATION +#include +#include +#endif + #include "arrow/util/windows_fixup.h" #include "arrow/buffer.h" @@ -128,19 +137,17 @@ #include "arrow/util/task_group.h" #include "arrow/util/thread_pool.h" -namespace arrow { - -using internal::TaskGroup; -using internal::ToChars; -using io::internal::SubmitIO; -using util::Uri; - -namespace fs { +namespace arrow::fs { using ::Aws::Client::AWSError; using ::Aws::S3::S3Errors; namespace S3Model = Aws::S3::Model; +using ::arrow::internal::TaskGroup; +using ::arrow::internal::ToChars; +using ::arrow::io::internal::SubmitIO; +using ::arrow::util::Uri; + using internal::ConnectRetryStrategy; using internal::DetectS3Backend; using internal::ErrorToStatus; @@ -913,6 +920,134 @@ Result> GetClientHolder( // ----------------------------------------------------------------------- // S3 client factory: build S3Client from S3Options +#ifdef ARROW_S3_HAS_S3CLIENT_CONFIGURATION + +// GH-40279: standard initialization of S3Client creates a new `S3EndpointProvider` +// every time. Its construction takes 1ms, which makes instantiating every S3Client +// very costly (see upstream bug report +// at https://github.com/aws/aws-sdk-cpp/issues/2880). +// To work around this, we build and cache `S3EndpointProvider` instances +// for each distinct endpoint configuration, and reuse them whenever possible. +// Since most applications tend to use a single endpoint configuration, this +// makes the 1ms setup cost a once-per-process overhead, making it much more +// bearable - if not ideal. + +struct EndpointConfigKey { + explicit EndpointConfigKey(const Aws::S3::S3ClientConfiguration& config) + : region(config.region), + scheme(config.scheme), + endpoint_override(config.endpointOverride), + use_virtual_addressing(config.useVirtualAddressing) {} + + Aws::String region; + Aws::Http::Scheme scheme; + Aws::String endpoint_override; + bool use_virtual_addressing; + + bool operator==(const EndpointConfigKey& other) const noexcept { + return region == other.region && scheme == other.scheme && + endpoint_override == other.endpoint_override && + use_virtual_addressing == other.use_virtual_addressing; + } +}; + +} // namespace +} // namespace arrow::fs + +template <> +struct std::hash { + std::size_t operator()(const arrow::fs::EndpointConfigKey& key) const noexcept { + // A crude hash is sufficient since we expect the cache to remain very small. + auto h = std::hash{}; + return h(key.region) ^ h(key.endpoint_override); + } +}; + +namespace arrow::fs { +namespace { + +// EndpointProvider configuration happens in a non-thread-safe way, even +// when the updates are idempotent. This is a problem when trying to reuse +// a single EndpointProvider from several clients. +// To work around this, this class ensures reconfiguration of an existing +// EndpointProvider is a no-op. +class InitOnceEndpointProvider : public Aws::S3::S3EndpointProviderBase { + public: + explicit InitOnceEndpointProvider( + std::shared_ptr wrapped) + : wrapped_(std::move(wrapped)) {} + + void InitBuiltInParameters(const Aws::S3::S3ClientConfiguration& config) override {} + + void OverrideEndpoint(const Aws::String& endpoint) override { + ARROW_LOG(ERROR) << "unexpected call to InitOnceEndpointProvider::OverrideEndpoint"; + } + Aws::S3::Endpoint::S3ClientContextParameters& AccessClientContextParameters() override { + ARROW_LOG(ERROR) + << "unexpected call to InitOnceEndpointProvider::AccessClientContextParameters"; + // Need to return a reference to something... + return wrapped_->AccessClientContextParameters(); + } + + const Aws::S3::Endpoint::S3ClientContextParameters& GetClientContextParameters() + const override { + return wrapped_->GetClientContextParameters(); + } + Aws::Endpoint::ResolveEndpointOutcome ResolveEndpoint( + const Aws::Endpoint::EndpointParameters& params) const override { + return wrapped_->ResolveEndpoint(params); + } + + protected: + std::shared_ptr wrapped_; +}; + +// A class that instantiates a single EndpointProvider per distinct endpoint +// configuration and initializes it in a thread-safe way. See earlier comments +// for rationale. +class EndpointProviderCache { + public: + std::shared_ptr Lookup( + const Aws::S3::S3ClientConfiguration& config) { + auto key = EndpointConfigKey(config); + CacheValue* value; + { + std::unique_lock lock(mutex_); + value = &cache_[std::move(key)]; + } + std::call_once(value->once, [&]() { + auto endpoint_provider = std::make_shared(); + endpoint_provider->InitBuiltInParameters(config); + value->endpoint_provider = + std::make_shared(std::move(endpoint_provider)); + }); + return value->endpoint_provider; + } + + void Reset() { + std::unique_lock lock(mutex_); + cache_.clear(); + } + + static EndpointProviderCache* Instance() { + static EndpointProviderCache instance; + return &instance; + } + + private: + EndpointProviderCache() = default; + + struct CacheValue { + std::once_flag once; + std::shared_ptr endpoint_provider; + }; + + std::mutex mutex_; + std::unordered_map cache_; +}; + +#endif // ARROW_S3_HAS_S3CLIENT_CONFIGURATION + class ClientBuilder { public: explicit ClientBuilder(S3Options options) : options_(std::move(options)) {} @@ -958,9 +1093,6 @@ class ClientBuilder { client_config_.caPath = ToAwsString(internal::global_options.tls_ca_dir_path); } - const bool use_virtual_addressing = - options_.endpoint_override.empty() || options_.force_virtual_addressing; - // Set proxy options if provided if (!options_.proxy_options.scheme.empty()) { if (options_.proxy_options.scheme == "http") { @@ -990,10 +1122,20 @@ class ClientBuilder { client_config_.maxConnections = std::max(io_context->executor()->GetCapacity(), 25); } + const bool use_virtual_addressing = + options_.endpoint_override.empty() || options_.force_virtual_addressing; + +#ifdef ARROW_S3_HAS_S3CLIENT_CONFIGURATION + client_config_.useVirtualAddressing = use_virtual_addressing; + auto endpoint_provider = EndpointProviderCache::Instance()->Lookup(client_config_); + auto client = std::make_shared(credentials_provider_, endpoint_provider, + client_config_); +#else auto client = std::make_shared( credentials_provider_, client_config_, Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy::Never, use_virtual_addressing); +#endif client->s3_retry_strategy_ = options_.retry_strategy; return GetClientHolder(std::move(client)); } @@ -1002,7 +1144,11 @@ class ClientBuilder { protected: S3Options options_; +#ifdef ARROW_S3_HAS_S3CLIENT_CONFIGURATION + Aws::S3::S3ClientConfiguration client_config_; +#else Aws::Client::ClientConfiguration client_config_; +#endif std::shared_ptr credentials_provider_; }; @@ -2949,6 +3095,9 @@ struct AwsInstance { "This could lead to a segmentation fault at exit"; } GetClientFinalizer()->Finalize(); +#ifdef ARROW_S3_HAS_S3CLIENT_CONFIGURATION + EndpointProviderCache::Instance()->Reset(); +#endif Aws::ShutdownAPI(aws_options_); } } @@ -3090,5 +3239,4 @@ Result ResolveS3BucketRegion(const std::string& bucket) { return resolver->ResolveRegion(bucket); } -} // namespace fs -} // namespace arrow +} // namespace arrow::fs diff --git a/cpp/src/arrow/memory_pool.cc b/cpp/src/arrow/memory_pool.cc index d58c203d2ae27..2f8ce3a6fa8c7 100644 --- a/cpp/src/arrow/memory_pool.cc +++ b/cpp/src/arrow/memory_pool.cc @@ -472,7 +472,7 @@ class BaseMemoryPoolImpl : public MemoryPool { } #endif - stats_.UpdateAllocatedBytes(size); + stats_.DidAllocateBytes(size); return Status::OK(); } @@ -494,7 +494,7 @@ class BaseMemoryPoolImpl : public MemoryPool { } #endif - stats_.UpdateAllocatedBytes(new_size - old_size); + stats_.DidReallocateBytes(old_size, new_size); return Status::OK(); } @@ -509,7 +509,7 @@ class BaseMemoryPoolImpl : public MemoryPool { #endif Allocator::DeallocateAligned(buffer, size, alignment); - stats_.UpdateAllocatedBytes(-size, /*is_free*/ true); + stats_.DidFreeBytes(size); } void ReleaseUnused() override { Allocator::ReleaseUnused(); } @@ -761,20 +761,20 @@ class ProxyMemoryPool::ProxyMemoryPoolImpl { Status Allocate(int64_t size, int64_t alignment, uint8_t** out) { RETURN_NOT_OK(pool_->Allocate(size, alignment, out)); - stats_.UpdateAllocatedBytes(size); + stats_.DidAllocateBytes(size); return Status::OK(); } Status Reallocate(int64_t old_size, int64_t new_size, int64_t alignment, uint8_t** ptr) { RETURN_NOT_OK(pool_->Reallocate(old_size, new_size, alignment, ptr)); - stats_.UpdateAllocatedBytes(new_size - old_size); + stats_.DidReallocateBytes(old_size, new_size); return Status::OK(); } void Free(uint8_t* buffer, int64_t size, int64_t alignment) { pool_->Free(buffer, size, alignment); - stats_.UpdateAllocatedBytes(-size, /*is_free=*/true); + stats_.DidFreeBytes(size); } int64_t bytes_allocated() const { return stats_.bytes_allocated(); } diff --git a/cpp/src/arrow/memory_pool.h b/cpp/src/arrow/memory_pool.h index 712a828041c76..98c6dc3e211b8 100644 --- a/cpp/src/arrow/memory_pool.h +++ b/cpp/src/arrow/memory_pool.h @@ -35,44 +35,68 @@ namespace internal { /////////////////////////////////////////////////////////////////////// // Helper tracking memory statistics -class MemoryPoolStats { - public: - MemoryPoolStats() : bytes_allocated_(0), max_memory_(0) {} - - int64_t max_memory() const { return max_memory_.load(); } - - int64_t bytes_allocated() const { return bytes_allocated_.load(); } +/// \brief Memory pool statistics +/// +/// 64-byte aligned so that all atomic values are on the same cache line. +class alignas(64) MemoryPoolStats { + private: + // All atomics are updated according to Acquire-Release ordering. + // https://en.cppreference.com/w/cpp/atomic/memory_order#Release-Acquire_ordering + // + // max_memory_, total_allocated_bytes_, and num_allocs_ only go up (they are + // monotonically increasing) which can allow some optimizations. + std::atomic max_memory_{0}; + std::atomic bytes_allocated_{0}; + std::atomic total_allocated_bytes_{0}; + std::atomic num_allocs_{0}; - int64_t total_bytes_allocated() const { return total_allocated_bytes_.load(); } + public: + int64_t max_memory() const { return max_memory_.load(std::memory_order_acquire); } - int64_t num_allocations() const { return num_allocs_.load(); } + int64_t bytes_allocated() const { + return bytes_allocated_.load(std::memory_order_acquire); + } - inline void UpdateAllocatedBytes(int64_t diff, bool is_free = false) { - auto allocated = bytes_allocated_.fetch_add(diff) + diff; - // "maximum" allocated memory is ill-defined in multi-threaded code, - // so don't try to be too rigorous here - if (diff > 0 && allocated > max_memory_) { - max_memory_ = allocated; - } + int64_t total_bytes_allocated() const { + return total_allocated_bytes_.load(std::memory_order_acquire); + } - // Reallocations might just expand/contract the allocation in place or might - // copy to a new location. We can't really know, so we just represent the - // optimistic case. - if (diff > 0) { - total_allocated_bytes_ += diff; + int64_t num_allocations() const { return num_allocs_.load(std::memory_order_acquire); } + + inline void DidAllocateBytes(int64_t size) { + // Issue the load before everything else. max_memory_ is monotonically increasing, + // so we can use a relaxed load before the read-modify-write. + auto max_memory = max_memory_.load(std::memory_order_relaxed); + const auto old_bytes_allocated = + bytes_allocated_.fetch_add(size, std::memory_order_acq_rel); + // Issue store operations on values that we don't depend on to proceed + // with execution. When done, max_memory and old_bytes_allocated have + // a higher chance of being available on CPU registers. This also has the + // nice side-effect of putting 3 atomic stores close to each other in the + // instruction stream. + total_allocated_bytes_.fetch_add(size, std::memory_order_acq_rel); + num_allocs_.fetch_add(1, std::memory_order_acq_rel); + + // If other threads are updating max_memory_ concurrently we leave the loop without + // updating knowing that it already reached a value even higher than ours. + const auto allocated = old_bytes_allocated + size; + while (max_memory < allocated && !max_memory_.compare_exchange_weak( + /*expected=*/max_memory, /*desired=*/allocated, + std::memory_order_acq_rel)) { } + } - // We count any reallocation as a allocation. - if (!is_free) { - num_allocs_ += 1; + inline void DidReallocateBytes(int64_t old_size, int64_t new_size) { + if (new_size > old_size) { + DidAllocateBytes(new_size - old_size); + } else { + DidFreeBytes(old_size - new_size); } } - protected: - std::atomic bytes_allocated_ = 0; - std::atomic max_memory_ = 0; - std::atomic total_allocated_bytes_ = 0; - std::atomic num_allocs_ = 0; + inline void DidFreeBytes(int64_t size) { + bytes_allocated_.fetch_sub(size, std::memory_order_acq_rel); + } }; } // namespace internal diff --git a/cpp/src/arrow/memory_pool_benchmark.cc b/cpp/src/arrow/memory_pool_benchmark.cc index fe7a3dd2f8ee0..c2e55314b56f9 100644 --- a/cpp/src/arrow/memory_pool_benchmark.cc +++ b/cpp/src/arrow/memory_pool_benchmark.cc @@ -114,8 +114,12 @@ static void AllocateTouchDeallocate( state.SetBytesProcessed(state.iterations() * nbytes); } -#define BENCHMARK_ALLOCATE_ARGS \ - ->RangeMultiplier(16)->Range(4096, 16 * 1024 * 1024)->ArgName("size")->UseRealTime() +#define BENCHMARK_ALLOCATE_ARGS \ + ->RangeMultiplier(16) \ + ->Range(4096, 16 * 1024 * 1024) \ + ->ArgName("size") \ + ->UseRealTime() \ + ->ThreadRange(1, 32) #define BENCHMARK_ALLOCATE(benchmark_func, template_param) \ BENCHMARK_TEMPLATE(benchmark_func, template_param) BENCHMARK_ALLOCATE_ARGS diff --git a/cpp/src/arrow/memory_pool_test.cc b/cpp/src/arrow/memory_pool_test.cc index 81d9d69ba346d..3f0a852876718 100644 --- a/cpp/src/arrow/memory_pool_test.cc +++ b/cpp/src/arrow/memory_pool_test.cc @@ -106,11 +106,6 @@ TEST(DefaultMemoryPool, Identity) { specific_pools.end()); } -// Death tests and valgrind are known to not play well 100% of the time. See -// googletest documentation -#if !(defined(ARROW_VALGRIND) || defined(ADDRESS_SANITIZER)) - -// TODO: is this still a death test? TEST(DefaultMemoryPoolDeathTest, Statistics) { MemoryPool* pool = default_memory_pool(); uint8_t* data1; @@ -137,18 +132,16 @@ TEST(DefaultMemoryPoolDeathTest, Statistics) { ASSERT_EQ(150, pool->max_memory()); ASSERT_EQ(200, pool->total_bytes_allocated()); ASSERT_EQ(50, pool->bytes_allocated()); - ASSERT_EQ(4, pool->num_allocations()); + ASSERT_EQ(3, pool->num_allocations()); pool->Free(data1, 50); ASSERT_EQ(150, pool->max_memory()); ASSERT_EQ(200, pool->total_bytes_allocated()); ASSERT_EQ(0, pool->bytes_allocated()); - ASSERT_EQ(4, pool->num_allocations()); + ASSERT_EQ(3, pool->num_allocations()); } -#endif // ARROW_VALGRIND - TEST(LoggingMemoryPool, Logging) { auto pool = MemoryPool::CreateDefault(); diff --git a/cpp/src/arrow/memory_pool_test.h b/cpp/src/arrow/memory_pool_test.h index e4a07099f830f..32f1cc5d1d310 100644 --- a/cpp/src/arrow/memory_pool_test.h +++ b/cpp/src/arrow/memory_pool_test.h @@ -38,19 +38,20 @@ class TestMemoryPoolBase : public ::testing::Test { auto pool = memory_pool(); uint8_t* data; + const auto old_bytes_allocated = pool->bytes_allocated(); ASSERT_OK(pool->Allocate(100, &data)); EXPECT_EQ(static_cast(0), reinterpret_cast(data) % 64); - ASSERT_EQ(100, pool->bytes_allocated()); + ASSERT_EQ(old_bytes_allocated + 100, pool->bytes_allocated()); uint8_t* data2; ASSERT_OK(pool->Allocate(27, &data2)); EXPECT_EQ(static_cast(0), reinterpret_cast(data2) % 64); - ASSERT_EQ(127, pool->bytes_allocated()); + ASSERT_EQ(old_bytes_allocated + 127, pool->bytes_allocated()); pool->Free(data, 100); - ASSERT_EQ(27, pool->bytes_allocated()); + ASSERT_EQ(old_bytes_allocated + 27, pool->bytes_allocated()); pool->Free(data2, 27); - ASSERT_EQ(0, pool->bytes_allocated()); + ASSERT_EQ(old_bytes_allocated, pool->bytes_allocated()); } void TestOOM() { diff --git a/cpp/src/arrow/stl_allocator.h b/cpp/src/arrow/stl_allocator.h index a1f4ae9feb82b..82e6aaa8772b9 100644 --- a/cpp/src/arrow/stl_allocator.h +++ b/cpp/src/arrow/stl_allocator.h @@ -110,7 +110,7 @@ class STLMemoryPool : public MemoryPool { } catch (std::bad_alloc& e) { return Status::OutOfMemory(e.what()); } - stats_.UpdateAllocatedBytes(size); + stats_.DidAllocateBytes(size); return Status::OK(); } @@ -124,13 +124,13 @@ class STLMemoryPool : public MemoryPool { } memcpy(*ptr, old_ptr, std::min(old_size, new_size)); alloc_.deallocate(old_ptr, old_size); - stats_.UpdateAllocatedBytes(new_size - old_size); + stats_.DidReallocateBytes(old_size, new_size); return Status::OK(); } void Free(uint8_t* buffer, int64_t size, int64_t /*alignment*/) override { alloc_.deallocate(buffer, size); - stats_.UpdateAllocatedBytes(-size, /*is_free=*/true); + stats_.DidFreeBytes(size); } int64_t bytes_allocated() const override { return stats_.bytes_allocated(); } diff --git a/cpp/src/arrow/util/bit_util_benchmark.cc b/cpp/src/arrow/util/bit_util_benchmark.cc index 3bcb4ceea6303..0bf2c26f12486 100644 --- a/cpp/src/arrow/util/bit_util_benchmark.cc +++ b/cpp/src/arrow/util/bit_util_benchmark.cc @@ -449,7 +449,7 @@ static void CopyBitmap(benchmark::State& state) { // NOLINT non-const reference const uint8_t* src = buffer->data(); const int64_t length = bits_size - OffsetSrc; - auto copy = *AllocateEmptyBitmap(length); + auto copy = *AllocateEmptyBitmap(length + OffsetDest); for (auto _ : state) { internal::CopyBitmap(src, OffsetSrc, length, copy->mutable_data(), OffsetDest); diff --git a/cpp/src/gandiva/decimal_type_util.cc b/cpp/src/gandiva/decimal_type_util.cc index 2abc5a21eaa88..cce4292f3bf15 100644 --- a/cpp/src/gandiva/decimal_type_util.cc +++ b/cpp/src/gandiva/decimal_type_util.cc @@ -30,7 +30,8 @@ constexpr int32_t DecimalTypeUtil::kMinAdjustedScale; // Implementation of decimal rules. Status DecimalTypeUtil::GetResultType(Op op, const Decimal128TypeVector& in_types, - Decimal128TypePtr* out_type) { + Decimal128TypePtr* out_type, + bool use_compute_rules) { DCHECK_EQ(in_types.size(), 2); *out_type = nullptr; @@ -59,7 +60,9 @@ Status DecimalTypeUtil::GetResultType(Op op, const Decimal128TypeVector& in_type break; case kOpDivide: - result_scale = std::max(kMinAdjustedScale, s1 + p2 + 1); + result_scale = use_compute_rules + ? std::max(kMinComputeAdjustedScale, s1 + p2 - s2 + 1) + : std::max(kMinAdjustedScale, s1 + p2 + 1); result_precision = p1 - s1 + s2 + result_scale; break; @@ -68,7 +71,17 @@ Status DecimalTypeUtil::GetResultType(Op op, const Decimal128TypeVector& in_type result_precision = std::min(p1 - s1, p2 - s2) + result_scale; break; } - *out_type = MakeAdjustedType(result_precision, result_scale); + + if (use_compute_rules) { + if (result_precision < kMinPrecision || result_precision > kMaxPrecision) { + return Status::Invalid("Decimal precision out of range [", int32_t(kMinPrecision), + ", ", int32_t(kMaxPrecision), "]: ", result_precision); + } + *out_type = MakeType(result_precision, result_scale); + } else { + *out_type = MakeAdjustedType(result_precision, result_scale); + } + return Status::OK(); } diff --git a/cpp/src/gandiva/decimal_type_util.h b/cpp/src/gandiva/decimal_type_util.h index 2b496f6cbf5bd..16ce544717e46 100644 --- a/cpp/src/gandiva/decimal_type_util.h +++ b/cpp/src/gandiva/decimal_type_util.h @@ -45,6 +45,9 @@ class GANDIVA_EXPORT DecimalTypeUtil { /// The maximum precision representable by a 8-byte decimal static constexpr int32_t kMaxDecimal64Precision = 18; + /// The minimum precision representable by a 16-byte decimal + static constexpr int32_t kMinPrecision = 1; + /// The maximum precision representable by a 16-byte decimal static constexpr int32_t kMaxPrecision = 38; @@ -57,10 +60,19 @@ class GANDIVA_EXPORT DecimalTypeUtil { // * There is no strong reason for 6, but both SQLServer and Impala use 6 too. static constexpr int32_t kMinAdjustedScale = 6; + // The same function with kMinAdjustedScale, just for compatibility with + // compute module's decimal promotion rules. + static constexpr int32_t kMinComputeAdjustedScale = 4; + // For specified operation and input scale/precision, determine the output // scale/precision. + // + // The 'use_compute_rules' is for compatibility with compute module's + // decimal promotion rules: + // https://arrow.apache.org/docs/cpp/compute.html#arithmetic-functions static Status GetResultType(Op op, const Decimal128TypeVector& in_types, - Decimal128TypePtr* out_type); + Decimal128TypePtr* out_type, + bool use_compute_rules = false); static Decimal128TypePtr MakeType(int32_t precision, int32_t scale) { return std::dynamic_pointer_cast( diff --git a/cpp/src/gandiva/tests/decimal_single_test.cc b/cpp/src/gandiva/tests/decimal_single_test.cc index 666ee4a68d5de..57c281a4551ef 100644 --- a/cpp/src/gandiva/tests/decimal_single_test.cc +++ b/cpp/src/gandiva/tests/decimal_single_test.cc @@ -49,7 +49,8 @@ class TestDecimalOps : public ::testing::Test { ArrayPtr MakeDecimalVector(const DecimalScalar128& in); void Verify(DecimalTypeUtil::Op, const std::string& function, const DecimalScalar128& x, - const DecimalScalar128& y, const DecimalScalar128& expected); + const DecimalScalar128& y, const DecimalScalar128& expected, + bool use_compute_rules = false, bool verify_failed = false); void AddAndVerify(const DecimalScalar128& x, const DecimalScalar128& y, const DecimalScalar128& expected) { @@ -67,8 +68,10 @@ class TestDecimalOps : public ::testing::Test { } void DivideAndVerify(const DecimalScalar128& x, const DecimalScalar128& y, - const DecimalScalar128& expected) { - Verify(DecimalTypeUtil::kOpDivide, "divide", x, y, expected); + const DecimalScalar128& expected, bool use_compute_rules = false, + bool verify_failed = false) { + Verify(DecimalTypeUtil::kOpDivide, "divide", x, y, expected, use_compute_rules, + verify_failed); } void ModAndVerify(const DecimalScalar128& x, const DecimalScalar128& y, @@ -91,7 +94,8 @@ ArrayPtr TestDecimalOps::MakeDecimalVector(const DecimalScalar128& in) { void TestDecimalOps::Verify(DecimalTypeUtil::Op op, const std::string& function, const DecimalScalar128& x, const DecimalScalar128& y, - const DecimalScalar128& expected) { + const DecimalScalar128& expected, bool use_compute_rules, + bool verify_failed) { auto x_type = std::make_shared(x.precision(), x.scale()); auto y_type = std::make_shared(y.precision(), y.scale()); auto field_x = field("x", x_type); @@ -99,8 +103,14 @@ void TestDecimalOps::Verify(DecimalTypeUtil::Op op, const std::string& function, auto schema = arrow::schema({field_x, field_y}); Decimal128TypePtr output_type; - auto status = DecimalTypeUtil::GetResultType(op, {x_type, y_type}, &output_type); - ARROW_EXPECT_OK(status); + auto status = DecimalTypeUtil::GetResultType(op, {x_type, y_type}, &output_type, + use_compute_rules); + if (verify_failed) { + ASSERT_NOT_OK(status); + return; + } else { + ARROW_EXPECT_OK(status); + } // output fields auto res = field("res", output_type); @@ -283,13 +293,31 @@ TEST_F(TestDecimalOps, TestMultiply) { } TEST_F(TestDecimalOps, TestDivide) { + // fast-path + // + // origin Gandiva's rules DivideAndVerify(decimal_literal("201", 10, 3), // x decimal_literal("301", 10, 2), // y decimal_literal("6677740863787", 23, 14)); // expected + // compute module's rules + DivideAndVerify(decimal_literal("201", 10, 3), // x + decimal_literal("301", 10, 2), // y + decimal_literal("66777408638", 21, 12), // expected + /*use_compute_rules=*/true); + + // max precision beyond 38 + // + // normally under origin Gandiva rules DivideAndVerify(DecimalScalar128(std::string(38, '9'), 38, 20), // x DecimalScalar128(std::string(35, '9'), 38, 20), // x DecimalScalar128("1000000000", 38, 6)); + + // invalid under compute module's rules + DivideAndVerify(DecimalScalar128(std::string(38, '9'), 38, 20), // x + DecimalScalar128(std::string(35, '9'), 38, 20), // x + DecimalScalar128(std::string(35, '9'), 0, 0), // useless expected + /*use_compute_rules=*/true, /*verify_failed=*/true); } TEST_F(TestDecimalOps, TestMod) { diff --git a/csharp/src/Apache.Arrow.Flight/FlightRecordBatchStreamReader.cs b/csharp/src/Apache.Arrow.Flight/FlightRecordBatchStreamReader.cs index d21fb25f5c946..7400ec15e54d6 100644 --- a/csharp/src/Apache.Arrow.Flight/FlightRecordBatchStreamReader.cs +++ b/csharp/src/Apache.Arrow.Flight/FlightRecordBatchStreamReader.cs @@ -45,12 +45,12 @@ private protected FlightRecordBatchStreamReader(IAsyncStreamReader Schema => _arrowReaderImplementation.ReadSchema(); + public ValueTask Schema => _arrowReaderImplementation.GetSchemaAsync(); internal ValueTask GetFlightDescriptor() { return _arrowReaderImplementation.ReadFlightDescriptor(); - } + } /// /// Get the application metadata from the latest received record batch diff --git a/csharp/src/Apache.Arrow.Flight/Internal/RecordBatchReaderImplementation.cs b/csharp/src/Apache.Arrow.Flight/Internal/RecordBatchReaderImplementation.cs index be844ea58e404..99876bf769dc7 100644 --- a/csharp/src/Apache.Arrow.Flight/Internal/RecordBatchReaderImplementation.cs +++ b/csharp/src/Apache.Arrow.Flight/Internal/RecordBatchReaderImplementation.cs @@ -48,19 +48,33 @@ public async ValueTask ReadFlightDescriptor() { if (!HasReadSchema) { - await ReadSchema().ConfigureAwait(false); + await ReadSchemaAsync(CancellationToken.None).ConfigureAwait(false); } return _flightDescriptor; } - public async ValueTask ReadSchema() + public async ValueTask GetSchemaAsync() + { + if (!HasReadSchema) + { + await ReadSchemaAsync(CancellationToken.None).ConfigureAwait(false); + } + return _schema; + } + + public override void ReadSchema() + { + ReadSchemaAsync(CancellationToken.None).AsTask().Wait(); + } + + public override async ValueTask ReadSchemaAsync(CancellationToken cancellationToken) { if (HasReadSchema) { - return Schema; + return; } - var moveNextResult = await _flightDataStream.MoveNext().ConfigureAwait(false); + var moveNextResult = await _flightDataStream.MoveNext(cancellationToken).ConfigureAwait(false); if (!moveNextResult) { @@ -87,12 +101,11 @@ public async ValueTask ReadSchema() switch (message.HeaderType) { case MessageHeader.Schema: - Schema = FlightMessageSerializer.DecodeSchema(message.ByteBuffer); + _schema = FlightMessageSerializer.DecodeSchema(message.ByteBuffer); break; default: throw new Exception($"Expected schema as the first message, but got: {message.HeaderType.ToString()}"); } - return Schema; } public override async ValueTask ReadNextRecordBatchAsync(CancellationToken cancellationToken) @@ -101,7 +114,7 @@ public override async ValueTask ReadNextRecordBatchAsync(Cancellati if (!HasReadSchema) { - await ReadSchema().ConfigureAwait(false); + await ReadSchemaAsync(cancellationToken).ConfigureAwait(false); } var moveNextResult = await _flightDataStream.MoveNext().ConfigureAwait(false); if (moveNextResult) diff --git a/csharp/src/Apache.Arrow/Arrays/MapArray.cs b/csharp/src/Apache.Arrow/Arrays/MapArray.cs index a6676b134e34a..dad50981ea54d 100644 --- a/csharp/src/Apache.Arrow/Arrays/MapArray.cs +++ b/csharp/src/Apache.Arrow/Arrays/MapArray.cs @@ -135,6 +135,19 @@ private MapArray(ArrayData data, IArrowArray structs) : base(data, structs, Arro { } + public override void Accept(IArrowArrayVisitor visitor) + { + switch (visitor) + { + case IArrowArrayVisitor typedVisitor: + typedVisitor.Visit(this); + break; + default: + base.Accept(visitor); + break; + } + } + public IEnumerable> GetTuples(int index, Func getKey, Func getValue) where TKeyArray : Array where TValueArray : Array { diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowFileReaderImplementation.cs b/csharp/src/Apache.Arrow/Ipc/ArrowFileReaderImplementation.cs index 02f36b079349b..4b7c5f914c402 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowFileReaderImplementation.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowFileReaderImplementation.cs @@ -52,7 +52,7 @@ public async ValueTask RecordBatchCountAsync(CancellationToken cancellation return _footer.RecordBatchCount; } - protected override async ValueTask ReadSchemaAsync(CancellationToken cancellationToken = default) + public override async ValueTask ReadSchemaAsync(CancellationToken cancellationToken = default) { if (HasReadSchema) { @@ -85,7 +85,7 @@ protected override async ValueTask ReadSchemaAsync(CancellationToken cancellatio } } - protected override void ReadSchema() + public override void ReadSchema() { if (HasReadSchema) { @@ -139,7 +139,7 @@ private void ReadSchema(Memory buffer) // Deserialize the footer from the footer flatbuffer _footer = new ArrowFooter(Flatbuf.Footer.GetRootAsFooter(CreateByteBuffer(buffer)), ref _dictionaryMemo); - Schema = _footer.Schema; + _schema = _footer.Schema; } public async ValueTask ReadRecordBatchAsync(int index, CancellationToken cancellationToken) diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs b/csharp/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs index 6e2336a591bf1..842c56823d07f 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs @@ -33,6 +33,13 @@ public ArrowMemoryReaderImplementation(ReadOnlyMemory buffer, ICompression _buffer = buffer; } + public override ValueTask ReadSchemaAsync(CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); + ReadSchema(); + return default; + } + public override ValueTask ReadNextRecordBatchAsync(CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); @@ -93,7 +100,7 @@ public override RecordBatch ReadNextRecordBatch() return batch; } - private void ReadSchema() + public override void ReadSchema() { if (HasReadSchema) { @@ -117,7 +124,7 @@ private void ReadSchema() } ByteBuffer schemaBuffer = CreateByteBuffer(_buffer.Slice(_bufferPosition)); - Schema = MessageSerializer.GetSchema(ReadMessage(schemaBuffer), ref _dictionaryMemo); + _schema = MessageSerializer.GetSchema(ReadMessage(schemaBuffer), ref _dictionaryMemo); _bufferPosition += schemaMessageLength; } } diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs b/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs index eb7349a570786..4e273dbde5690 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs @@ -30,13 +30,25 @@ namespace Apache.Arrow.Ipc { internal abstract class ArrowReaderImplementation : IDisposable { - public Schema Schema { get; protected set; } - protected bool HasReadSchema => Schema != null; + public Schema Schema + { + get + { + if (!HasReadSchema) + { + ReadSchema(); + } + return _schema; + } + } + + protected internal bool HasReadSchema => _schema != null; private protected DictionaryMemo _dictionaryMemo; private protected DictionaryMemo DictionaryMemo => _dictionaryMemo ??= new DictionaryMemo(); private protected readonly MemoryAllocator _allocator; private readonly ICompressionCodecFactory _compressionCodecFactory; + private protected Schema _schema; private protected ArrowReaderImplementation() : this(null, null) { } @@ -57,6 +69,9 @@ protected virtual void Dispose(bool disposing) { } + public abstract ValueTask ReadSchemaAsync(CancellationToken cancellationToken); + public abstract void ReadSchema(); + public abstract ValueTask ReadNextRecordBatchAsync(CancellationToken cancellationToken); public abstract RecordBatch ReadNextRecordBatch(); diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowStreamReader.cs b/csharp/src/Apache.Arrow/Ipc/ArrowStreamReader.cs index cdcfe7875da22..e129da399d59a 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowStreamReader.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowStreamReader.cs @@ -28,6 +28,9 @@ public class ArrowStreamReader : IArrowReader, IArrowArrayStream, IDisposable { private protected readonly ArrowReaderImplementation _implementation; + /// + /// May block if the schema hasn't yet been read. To avoid blocking, use GetSchemaAsync. + /// public Schema Schema => _implementation.Schema; public ArrowStreamReader(Stream stream) @@ -97,6 +100,15 @@ protected virtual void Dispose(bool disposing) } } + public async ValueTask GetSchema(CancellationToken cancellationToken = default) + { + if (!_implementation.HasReadSchema) + { + await _implementation.ReadSchemaAsync(cancellationToken); + } + return _implementation.Schema; + } + public ValueTask ReadNextRecordBatchAsync(CancellationToken cancellationToken = default) { return _implementation.ReadNextRecordBatchAsync(cancellationToken); diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs b/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs index 5428c88c27bbc..5583a58487bf5 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs @@ -146,7 +146,7 @@ protected ReadResult ReadMessage() return new ReadResult(messageLength, result); } - protected virtual async ValueTask ReadSchemaAsync(CancellationToken cancellationToken = default) + public override async ValueTask ReadSchemaAsync(CancellationToken cancellationToken = default) { if (HasReadSchema) { @@ -164,11 +164,11 @@ protected virtual async ValueTask ReadSchemaAsync(CancellationToken cancellation EnsureFullRead(buff, bytesRead); Google.FlatBuffers.ByteBuffer schemabb = CreateByteBuffer(buff); - Schema = MessageSerializer.GetSchema(ReadMessage(schemabb), ref _dictionaryMemo); + _schema = MessageSerializer.GetSchema(ReadMessage(schemabb), ref _dictionaryMemo); } } - protected virtual void ReadSchema() + public override void ReadSchema() { if (HasReadSchema) { @@ -184,7 +184,7 @@ protected virtual void ReadSchema() EnsureFullRead(buff, bytesRead); Google.FlatBuffers.ByteBuffer schemabb = CreateByteBuffer(buff); - Schema = MessageSerializer.GetSchema(ReadMessage(schemabb), ref _dictionaryMemo); + _schema = MessageSerializer.GetSchema(ReadMessage(schemabb), ref _dictionaryMemo); } } diff --git a/csharp/test/Apache.Arrow.Tests/ArrowReaderVerifier.cs b/csharp/test/Apache.Arrow.Tests/ArrowReaderVerifier.cs index 10315ff287c0b..2e7488092c2cf 100644 --- a/csharp/test/Apache.Arrow.Tests/ArrowReaderVerifier.cs +++ b/csharp/test/Apache.Arrow.Tests/ArrowReaderVerifier.cs @@ -38,6 +38,9 @@ public static void VerifyReader(ArrowStreamReader reader, RecordBatch originalBa public static async Task VerifyReaderAsync(ArrowStreamReader reader, RecordBatch originalBatch) { + Schema schema = await reader.GetSchema(); + Assert.NotNull(schema); + RecordBatch readBatch = await reader.ReadNextRecordBatchAsync(); CompareBatches(originalBatch, readBatch); diff --git a/csharp/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs b/csharp/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs index ed030cc6ace11..b9e4664fdcd45 100644 --- a/csharp/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs +++ b/csharp/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs @@ -94,6 +94,8 @@ public async Task ReadRecordBatch_Memory(bool writeEnd) { await TestReaderFromMemory((reader, originalBatch) => { + Assert.NotNull(reader.Schema); + ArrowReaderVerifier.VerifyReader(reader, originalBatch); return Task.CompletedTask; }, writeEnd); diff --git a/csharp/test/Apache.Arrow.Tests/MapArrayTests.cs b/csharp/test/Apache.Arrow.Tests/MapArrayTests.cs index 7f35f104267dc..21decdacc0588 100644 --- a/csharp/test/Apache.Arrow.Tests/MapArrayTests.cs +++ b/csharp/test/Apache.Arrow.Tests/MapArrayTests.cs @@ -85,8 +85,118 @@ public void MapArray_Should_GetKeyValuePairs() Assert.Equal(new KeyValuePair[] { kv1, kv2 }, array.GetKeyValuePairs(2, GetKey, GetValue).ToArray()); } + [Fact] + public void MapArray_Should_AcceptMapVisitor() + { + var mapArray = BuildMapArray(); + var visitor = new MapOnlyVisitor(); + mapArray.Accept(visitor); + + Assert.True(visitor.MapVisited); + Assert.False(visitor.BaseVisited); + } + + [Fact] + public void MapArray_Should_AcceptListVisitor() + { + var mapArray = BuildMapArray(); + var visitor = new ListOnlyVisitor(); + mapArray.Accept(visitor); + + Assert.True(visitor.ListVisited); + Assert.False(visitor.BaseVisited); + } + + [Fact] + public void MapArray_Should_AcceptListAndMapVisitor() + { + var mapArray = BuildMapArray(); + var visitor = new MapAndListVisitor(); + mapArray.Accept(visitor); + + Assert.True(visitor.MapVisited); + Assert.False(visitor.ListVisited); + Assert.False(visitor.BaseVisited); + } + + private static MapArray BuildMapArray() + { + MapType type = new MapType(StringType.Default, Int64Type.Default); + MapArray.Builder builder = new MapArray.Builder(type); + var keyBuilder = builder.KeyBuilder as StringArray.Builder; + var valueBuilder = builder.ValueBuilder as Int64Array.Builder; + + builder.Append(); + keyBuilder.Append("test"); + valueBuilder.Append(1); + + builder.AppendNull(); + + builder.Append(); + keyBuilder.Append("other"); + valueBuilder.Append(123); + keyBuilder.Append("kv"); + valueBuilder.AppendNull(); + + return builder.Build(); + } + private static string GetKey(StringArray array, int index) => array.GetString(index); private static int? GetValue(Int32Array array, int index) => array.GetValue(index); private static long? GetValue(Int64Array array, int index) => array.GetValue(index); + + private sealed class MapOnlyVisitor : IArrowArrayVisitor + { + public bool MapVisited = false; + public bool BaseVisited = false; + + public void Visit(MapArray array) + { + MapVisited = true; + } + + public void Visit(IArrowArray array) + { + BaseVisited = true; + } + } + + private sealed class ListOnlyVisitor : IArrowArrayVisitor + { + public bool ListVisited = false; + public bool BaseVisited = false; + + public void Visit(ListArray array) + { + ListVisited = true; + } + + public void Visit(IArrowArray array) + { + BaseVisited = true; + } + } + + private sealed class MapAndListVisitor : IArrowArrayVisitor, IArrowArrayVisitor + { + public bool MapVisited = false; + public bool ListVisited = false; + public bool BaseVisited = false; + + public void Visit(MapArray array) + { + MapVisited = true; + } + + public void Visit(ListArray array) + { + ListVisited = true; + } + + public void Visit(IArrowArray array) + { + BaseVisited = true; + } + } } } diff --git a/docs/source/conf.py b/docs/source/conf.py index 7915e2c2c485a..ad8fa798d6aac 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -208,8 +208,14 @@ # source_suffix = { - '.md': 'markdown', + # We need to keep "'.rst': 'restructuredtext'" as the first item. + # This is a workaround of + # https://github.com/sphinx-doc/sphinx/issues/12147 . + # + # We can sort these items in alphabetical order with Sphinx 7.3.0 + # or later that will include the fix of this problem. '.rst': 'restructuredtext', + '.md': 'markdown', } autosummary_generate = True diff --git a/docs/source/format/FlightSql.rst b/docs/source/format/FlightSql.rst index 6bb917271366c..5573c0040761f 100644 --- a/docs/source/format/FlightSql.rst +++ b/docs/source/format/FlightSql.rst @@ -141,6 +141,21 @@ the ``type`` should be ``ClosePreparedStatement``). Execute a previously created prepared statement and get the results. When used with DoPut: binds parameter values to the prepared statement. + The server may optionally provide an updated handle in the response. + Updating the handle allows the client to supply all state required to + execute the query in an ActionPreparedStatementExecute message. + For example, stateless servers can encode the bound parameter values into + the new handle, and the client will send that new handle with parameters + back to the server. + + Note that a handle returned from a DoPut call with + CommandPreparedStatementQuery can itself be passed to a subsequent DoPut + call with CommandPreparedStatementQuery to bind a new set of parameters. + The subsequent call itself may return an updated handle which again should + be used for subsequent requests. + + The server is responsible for detecting the case where the client does not + use the updated handle and should return an error. When used with GetFlightInfo: execute the prepared statement. The prepared statement can be reused after fetching results. diff --git a/docs/source/format/FlightSql/CommandPreparedStatementQuery.mmd b/docs/source/format/FlightSql/CommandPreparedStatementQuery.mmd index cb50522eb5a32..cbd1eb6014bca 100644 --- a/docs/source/format/FlightSql/CommandPreparedStatementQuery.mmd +++ b/docs/source/format/FlightSql/CommandPreparedStatementQuery.mmd @@ -28,6 +28,8 @@ Server->>Client: ActionCreatePreparedStatementResult{handle} loop for each invocation of the prepared statement Client->>Server: DoPut(CommandPreparedStatementQuery) Client->>Server: stream of FlightData +Server-->>Client: DoPutPreparedStatementResult{handle} +Note over Client,Server: optional response with updated handle Client->>Server: GetFlightInfo(CommandPreparedStatementQuery) Server->>Client: FlightInfo{endpoints: [FlightEndpoint{…}, …]} loop for each endpoint in FlightInfo.endpoints diff --git a/docs/source/format/FlightSql/CommandPreparedStatementQuery.mmd.svg b/docs/source/format/FlightSql/CommandPreparedStatementQuery.mmd.svg index 96a5bc3688297..cbf6a78e9a5ce 100644 --- a/docs/source/format/FlightSql/CommandPreparedStatementQuery.mmd.svg +++ b/docs/source/format/FlightSql/CommandPreparedStatementQuery.mmd.svg @@ -1 +1 @@ -ClientServerDoAction(ActionCreatePreparedStatementRequest)1ActionCreatePreparedStatementResult{handle}2DoPut(CommandPreparedStatementQuery)3stream of FlightData4GetFlightInfo(CommandPreparedStatementQuery)5FlightInfo{endpoints: [FlightEndpoint{…}, …]}6DoGet(endpoint.ticket)7stream of FlightData8loop[for each endpoint in FlightInfo.endpoints]loop[for each invocation of the prepared statement]DoAction(ActionClosePreparedStatementRequest)9ActionClosePreparedStatementRequest{}10ClientServer \ No newline at end of file +ServerClientServerClientoptional response with updated handleloop[for each endpoint in FlightInfo.endpoints]loop[for each invocation of the prepared statement]DoAction(ActionCreatePreparedStatementRequest)1ActionCreatePreparedStatementResult{handle}2DoPut(CommandPreparedStatementQuery)3stream of FlightData4DoPutPreparedStatementResult{handle}5GetFlightInfo(CommandPreparedStatementQuery)6FlightInfo{endpoints: [FlightEndpoint{…}, …]}7DoGet(endpoint.ticket)8stream of FlightData9DoAction(ActionClosePreparedStatementRequest)10ActionClosePreparedStatementRequest{}11 \ No newline at end of file diff --git a/format/FlightSql.proto b/format/FlightSql.proto index 581cf1f76d57c..3282ee4f47304 100644 --- a/format/FlightSql.proto +++ b/format/FlightSql.proto @@ -1797,6 +1797,26 @@ message DoPutUpdateResult { int64 record_count = 1; } +/* An *optional* response returned when `DoPut` is called with `CommandPreparedStatementQuery`. + * + * *Note on legacy behavior*: previous versions of the protocol did not return any result for + * this command, and that behavior should still be supported by clients. In that case, the client + * can continue as though the fields in this message were not provided or set to sensible default values. + */ +message DoPutPreparedStatementResult { + option (experimental) = true; + + // Represents a (potentially updated) opaque handle for the prepared statement on the server. + // Because the handle could potentially be updated, any previous handles for this prepared + // statement should be considered invalid, and all subsequent requests for this prepared + // statement must use this new handle. + // The updated handle allows implementing query parameters with stateless services. + // + // When an updated handle is not provided by the server, clients should contiue + // using the previous handle provided by `ActionCreatePreparedStatementResonse`. + optional bytes prepared_statement_handle = 1; +} + /* * Request message for the "CancelQuery" action. * diff --git a/go/arrow/array/decimal128.go b/go/arrow/array/decimal128.go index 0dca320cda959..dc5f5d761618e 100644 --- a/go/arrow/array/decimal128.go +++ b/go/arrow/array/decimal128.go @@ -19,7 +19,6 @@ package array import ( "bytes" "fmt" - "math" "math/big" "reflect" "strings" @@ -86,15 +85,19 @@ func (a *Decimal128) setData(data *Data) { a.values = a.values[beg:end] } } - func (a *Decimal128) GetOneForMarshal(i int) interface{} { if a.IsNull(i) { return nil } - typ := a.DataType().(*arrow.Decimal128Type) - f := (&big.Float{}).SetInt(a.Value(i).BigInt()) - f.Quo(f, big.NewFloat(math.Pow10(int(typ.Scale)))) + n := a.Value(i) + scale := typ.Scale + f := (&big.Float{}).SetInt(n.BigInt()) + if scale < 0 { + f.SetPrec(128).Mul(f, (&big.Float{}).SetInt(decimal128.GetScaleMultiplier(int(-scale)).BigInt())) + } else { + f.SetPrec(128).Quo(f, (&big.Float{}).SetInt(decimal128.GetScaleMultiplier(int(scale)).BigInt())) + } return f.Text('g', int(typ.Precision)) } diff --git a/go/arrow/array/decimal128_test.go b/go/arrow/array/decimal128_test.go index 836a6987df69f..31c6a6f8cadd6 100644 --- a/go/arrow/array/decimal128_test.go +++ b/go/arrow/array/decimal128_test.go @@ -204,7 +204,17 @@ func TestDecimal128StringRoundTrip(t *testing.T) { decimal128.FromI64(9), decimal128.FromI64(10), } - valid := []bool{true, true, true, false, true, true, false, true, true, true} + val1, err := decimal128.FromString("0.99", dt.Precision, dt.Scale) + if err != nil { + t.Fatal(err) + } + val2, err := decimal128.FromString("1234567890.12345", dt.Precision, dt.Scale) + if err != nil { + t.Fatal(err) + } + values = append(values, val1, val2) + + valid := []bool{true, true, true, false, true, true, false, true, true, true, true, true} b.AppendValues(values, valid) @@ -224,3 +234,50 @@ func TestDecimal128StringRoundTrip(t *testing.T) { assert.True(t, array.Equal(arr, arr1)) } + +func TestDecimal128GetOneForMarshal(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) + defer mem.AssertSize(t, 0) + + dtype := &arrow.Decimal128Type{Precision: 38, Scale: 20} + + b := array.NewDecimal128Builder(mem, dtype) + defer b.Release() + + cases := []struct { + give any + want any + }{ + {"1", "1"}, + {"1.25", "1.25"}, + {"0.99", "0.99"}, + {"1234567890.123456789", "1234567890.123456789"}, + {nil, nil}, + {"-0.99", "-0.99"}, + {"-1234567890.123456789", "-1234567890.123456789"}, + {"0.0000000000000000001", "1e-19"}, + } + for _, v := range cases { + if v.give == nil { + b.AppendNull() + continue + } + + dt, err := decimal128.FromString(v.give.(string), dtype.Precision, dtype.Scale) + if err != nil { + t.Fatal(err) + } + b.Append(dt) + } + + arr := b.NewDecimal128Array() + defer arr.Release() + + if got, want := arr.Len(), len(cases); got != want { + t.Fatalf("invalid array length: got=%d, want=%d", got, want) + } + + for i := range cases { + assert.Equalf(t, cases[i].want, arr.GetOneForMarshal(i), "unexpected value at index %d", i) + } +} diff --git a/go/arrow/array/decimal256.go b/go/arrow/array/decimal256.go index 452ac96625bc8..f9c666300fa61 100644 --- a/go/arrow/array/decimal256.go +++ b/go/arrow/array/decimal256.go @@ -19,7 +19,6 @@ package array import ( "bytes" "fmt" - "math" "math/big" "reflect" "strings" @@ -91,10 +90,15 @@ func (a *Decimal256) GetOneForMarshal(i int) interface{} { if a.IsNull(i) { return nil } - typ := a.DataType().(*arrow.Decimal256Type) - f := (&big.Float{}).SetInt(a.Value(i).BigInt()) - f.Quo(f, big.NewFloat(math.Pow10(int(typ.Scale)))) + n := a.Value(i) + scale := typ.Scale + f := (&big.Float{}).SetInt(n.BigInt()) + if scale < 0 { + f.SetPrec(256).Mul(f, (&big.Float{}).SetInt(decimal256.GetScaleMultiplier(int(-scale)).BigInt())) + } else { + f.SetPrec(256).Quo(f, (&big.Float{}).SetInt(decimal256.GetScaleMultiplier(int(scale)).BigInt())) + } return f.Text('g', int(typ.Precision)) } diff --git a/go/arrow/array/decimal256_test.go b/go/arrow/array/decimal256_test.go index 4f0c441210643..c78bd5243a66a 100644 --- a/go/arrow/array/decimal256_test.go +++ b/go/arrow/array/decimal256_test.go @@ -205,7 +205,17 @@ func TestDecimal256StringRoundTrip(t *testing.T) { decimal256.FromI64(9), decimal256.FromI64(10), } - valid := []bool{true, true, true, false, true, true, false, true, true, true} + val1, err := decimal256.FromString("0.99", dt.Precision, dt.Scale) + if err != nil { + t.Fatal(err) + } + val2, err := decimal256.FromString("1234567890.123456789", dt.Precision, dt.Scale) + if err != nil { + t.Fatal(err) + } + values = append(values, val1, val2) + + valid := []bool{true, true, true, false, true, true, false, true, true, true, true, true} b.AppendValues(values, valid) @@ -217,11 +227,67 @@ func TestDecimal256StringRoundTrip(t *testing.T) { defer b1.Release() for i := 0; i < arr.Len(); i++ { - assert.NoError(t, b1.AppendValueFromString(arr.ValueStr(i))) + v := arr.ValueStr(i) + assert.NoError(t, b1.AppendValueFromString(v)) } arr1 := b1.NewArray().(*array.Decimal256) defer arr1.Release() + for i := 0; i < arr.Len(); i++ { + if arr.IsNull(i) && arr1.IsNull(i) { + continue + } + if arr.Value(i) != arr1.Value(i) { + t.Fatalf("unexpected value at index %d: got=%v, want=%v", i, arr1.Value(i), arr.Value(i)) + } + } assert.True(t, array.Equal(arr, arr1)) } + +func TestDecimal256GetOneForMarshal(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) + defer mem.AssertSize(t, 0) + + dtype := &arrow.Decimal256Type{Precision: 38, Scale: 20} + + b := array.NewDecimal256Builder(mem, dtype) + defer b.Release() + + cases := []struct { + give any + want any + }{ + {"1", "1"}, + {"1.25", "1.25"}, + {"0.99", "0.99"}, + {"1234567890.123456789", "1234567890.123456789"}, + {nil, nil}, + {"-0.99", "-0.99"}, + {"-1234567890.123456789", "-1234567890.123456789"}, + {"0.0000000000000000001", "1e-19"}, + } + for _, v := range cases { + if v.give == nil { + b.AppendNull() + continue + } + + dt, err := decimal256.FromString(v.give.(string), dtype.Precision, dtype.Scale) + if err != nil { + t.Fatal(err) + } + b.Append(dt) + } + + arr := b.NewDecimal256Array() + defer arr.Release() + + if got, want := arr.Len(), len(cases); got != want { + t.Fatalf("invalid array length: got=%d, want=%d", got, want) + } + + for i := range cases { + assert.Equalf(t, cases[i].want, arr.GetOneForMarshal(i), "unexpected value at index %d", i) + } +} diff --git a/go/parquet/file/file_writer.go b/go/parquet/file/file_writer.go index a2cf397cbc80b..57344b25cf05c 100644 --- a/go/parquet/file/file_writer.go +++ b/go/parquet/file/file_writer.go @@ -32,6 +32,7 @@ import ( type Writer struct { sink utils.WriteCloserTell open bool + footerFlushed bool props *parquet.WriterProperties rowGroups int nrows int @@ -125,6 +126,7 @@ func (fw *Writer) appendRowGroup(buffered bool) *rowGroupWriter { fw.rowGroupWriter.Close() } fw.rowGroups++ + fw.footerFlushed = false rgMeta := fw.metadata.AppendRowGroup() fw.rowGroupWriter = newRowGroupWriter(fw.sink, rgMeta, int16(fw.rowGroups)-1, fw.props, buffered, fw.fileEncryptor) return fw.rowGroupWriter @@ -172,12 +174,9 @@ func (fw *Writer) Close() (err error) { // if any functions here panic, we set open to be false so // that this doesn't get called again fw.open = false - if fw.rowGroupWriter != nil { - fw.nrows += fw.rowGroupWriter.nrows - fw.rowGroupWriter.Close() - } - fw.rowGroupWriter = nil + defer func() { + fw.closeEncryptor() ierr := fw.sink.Close() if err != nil { if ierr != nil { @@ -189,30 +188,48 @@ func (fw *Writer) Close() (err error) { err = ierr }() + err = fw.FlushWithFooter() + fw.metadata.Clear() + } + return nil +} + +// FlushWithFooter closes any open row group writer and writes the file footer, leaving +// the writer open for additional row groups. Additional footers written by later +// calls to FlushWithFooter or Close will be cumulative, so that only the last footer +// written need ever be read by a reader. +func (fw *Writer) FlushWithFooter() error { + if !fw.footerFlushed { + if fw.rowGroupWriter != nil { + fw.nrows += fw.rowGroupWriter.nrows + fw.rowGroupWriter.Close() + } + fw.rowGroupWriter = nil + + fileMetadata, err := fw.metadata.Snapshot() + if err != nil { + return err + } + fileEncryptProps := fw.props.FileEncryptionProperties() if fileEncryptProps == nil { // non encrypted file - fileMetadata, err := fw.metadata.Finish() - if err != nil { + if _, err = writeFileMetadata(fileMetadata, fw.sink); err != nil { + return err + } + } else { + if err := fw.flushEncryptedFile(fileMetadata, fileEncryptProps); err != nil { return err } - - _, err = writeFileMetadata(fileMetadata, fw.sink) - return err } - return fw.closeEncryptedFile(fileEncryptProps) + fw.footerFlushed = true } return nil } -func (fw *Writer) closeEncryptedFile(props *parquet.FileEncryptionProperties) error { +func (fw *Writer) flushEncryptedFile(fileMetadata *metadata.FileMetaData, props *parquet.FileEncryptionProperties) error { // encrypted file with encrypted footer if props.EncryptedFooter() { - fileMetadata, err := fw.metadata.Finish() - if err != nil { - return err - } - footerLen := int64(0) cryptoMetadata := fw.metadata.GetFileCryptoMetaData() @@ -236,19 +253,18 @@ func (fw *Writer) closeEncryptedFile(props *parquet.FileEncryptionProperties) er return err } } else { - fileMetadata, err := fw.metadata.Finish() - if err != nil { - return err - } footerSigningEncryptor := fw.fileEncryptor.GetFooterSigningEncryptor() - if _, err = writeEncryptedFileMetadata(fileMetadata, fw.sink, footerSigningEncryptor, false); err != nil { + if _, err := writeEncryptedFileMetadata(fileMetadata, fw.sink, footerSigningEncryptor, false); err != nil { return err } } + return nil +} + +func (fw *Writer) closeEncryptor() { if fw.fileEncryptor != nil { fw.fileEncryptor.WipeOutEncryptionKeys() } - return nil } func writeFileMetadata(fileMetadata *metadata.FileMetaData, w io.Writer) (n int64, err error) { diff --git a/go/parquet/file/file_writer_test.go b/go/parquet/file/file_writer_test.go index 434c9852c5823..3687fc8778202 100644 --- a/go/parquet/file/file_writer_test.go +++ b/go/parquet/file/file_writer_test.go @@ -64,6 +64,20 @@ func (t *SerializeTestSuite) fileSerializeTest(codec compress.Compression, expec writer := file.NewParquetWriter(sink, t.Schema.Root(), file.WithWriterProps(props)) t.GenerateData(int64(t.rowsPerRG)) + + t.serializeGeneratedData(writer) + writer.FlushWithFooter() + + t.validateSerializedData(writer, sink, expected) + + t.serializeGeneratedData(writer) + writer.Close() + + t.numRowGroups *= 2 + t.validateSerializedData(writer, sink, expected) +} + +func (t *SerializeTestSuite) serializeGeneratedData(writer *file.Writer) { for rg := 0; rg < t.numRowGroups/2; rg++ { rgw := writer.AppendRowGroup() for col := 0; col < t.numCols; col++ { @@ -94,8 +108,9 @@ func (t *SerializeTestSuite) fileSerializeTest(codec compress.Compression, expec } rgw.Close() } - writer.Close() +} +func (t *SerializeTestSuite) validateSerializedData(writer *file.Writer, sink *encoding.BufferWriter, expected compress.Compression) { nrows := t.numRowGroups * t.rowsPerRG t.EqualValues(nrows, writer.NumRows()) diff --git a/go/parquet/metadata/file.go b/go/parquet/metadata/file.go index f40081f172a75..fc376383165b1 100644 --- a/go/parquet/metadata/file.go +++ b/go/parquet/metadata/file.go @@ -104,6 +104,15 @@ func (f *FileMetaDataBuilder) AppendKeyValueMetadata(key string, value string) e // version etc. This will clear out this filemetadatabuilder so it can // be re-used func (f *FileMetaDataBuilder) Finish() (*FileMetaData, error) { + out, err := f.Snapshot() + f.Clear() + return out, err +} + +// Snapshot returns finalized metadata of the number of rows, row groups, version etc. +// The snapshot must be used (e.g., serialized) before any additional (meta)data is +// written, as it refers to builder datastructures that will continue to mutate. +func (f *FileMetaDataBuilder) Snapshot() (*FileMetaData, error) { totalRows := int64(0) for _, rg := range f.rowGroups { totalRows += rg.NumRows @@ -161,9 +170,13 @@ func (f *FileMetaDataBuilder) Finish() (*FileMetaData, error) { } out.initColumnOrders() + return out, nil +} + +// Clears out this filemetadatabuilder so it can be re-used +func (f *FileMetaDataBuilder) Clear() { f.metadata = format.NewFileMetaData() f.rowGroups = nil - return out, nil } // KeyValueMetadata is an alias for a slice of thrift keyvalue pairs. diff --git a/java/dataset/src/main/cpp/jni_util.cc b/java/dataset/src/main/cpp/jni_util.cc index f1b5a7f7c650e..8e899527f6a99 100644 --- a/java/dataset/src/main/cpp/jni_util.cc +++ b/java/dataset/src/main/cpp/jni_util.cc @@ -97,7 +97,11 @@ class ReservationListenableMemoryPool::Impl { int64_t Reserve(int64_t diff) { std::lock_guard lock(mutex_); - stats_.UpdateAllocatedBytes(diff); + if (diff > 0) { + stats_.DidAllocateBytes(diff); + } else if (diff < 0) { + stats_.DidFreeBytes(-diff); + } int64_t new_block_count; int64_t bytes_reserved = stats_.bytes_allocated(); if (bytes_reserved == 0) { diff --git a/js/src/util/bn.ts b/js/src/util/bn.ts index af546be5436a2..b4db9cf2b4afe 100644 --- a/js/src/util/bn.ts +++ b/js/src/util/bn.ts @@ -36,7 +36,7 @@ function BigNum(this: any, x: any, ...xs: any) { BigNum.prototype[isArrowBigNumSymbol] = true; BigNum.prototype.toJSON = function >(this: T) { return `"${bigNumToString(this)}"`; }; -BigNum.prototype.valueOf = function >(this: T) { return bigNumToNumber(this); }; +BigNum.prototype.valueOf = function >(this: T, scale?: number) { return bigNumToNumber(this, scale); }; BigNum.prototype.toString = function >(this: T) { return bigNumToString(this); }; BigNum.prototype[Symbol.toPrimitive] = function >(this: T, hint: 'string' | 'number' | 'default' = 'default') { switch (hint) { @@ -68,24 +68,36 @@ Object.assign(SignedBigNum.prototype, BigNum.prototype, { 'constructor': SignedB Object.assign(UnsignedBigNum.prototype, BigNum.prototype, { 'constructor': UnsignedBigNum, 'signed': false, 'TypedArray': Uint32Array, 'BigIntArray': BigUint64Array }); Object.assign(DecimalBigNum.prototype, BigNum.prototype, { 'constructor': DecimalBigNum, 'signed': true, 'TypedArray': Uint32Array, 'BigIntArray': BigUint64Array }); +//FOR ES2020 COMPATIBILITY +const TWO_TO_THE_64 = BigInt(4294967296) * BigInt(4294967296); // 2^64 = 0x10000000000000000n +const TWO_TO_THE_64_MINUS_1 = TWO_TO_THE_64 - BigInt(1); // (2^32 * 2^32) - 1 = 0xFFFFFFFFFFFFFFFFn + /** @ignore */ -function bigNumToNumber>(bn: T) { - const { buffer, byteOffset, length, 'signed': signed } = bn; - const words = new BigUint64Array(buffer, byteOffset, length); +export function bigNumToNumber>(bn: T, scale?: number) { + const { buffer, byteOffset, byteLength, 'signed': signed } = bn; + const words = new BigUint64Array(buffer, byteOffset, byteLength / 8); const negative = signed && words.at(-1)! & (BigInt(1) << BigInt(63)); - let number = negative ? BigInt(1) : BigInt(0); - let i = BigInt(0); + let number = BigInt(0); + let i = 0; if (!negative) { for (const word of words) { - number += word * (BigInt(1) << (BigInt(32) * i++)); + number |= word * (BigInt(1) << BigInt(64 * i++)); } } else { for (const word of words) { - number += ~word * (BigInt(1) << (BigInt(32) * i++)); + number |= (word ^ TWO_TO_THE_64_MINUS_1) * (BigInt(1) << BigInt(64 * i++)); } number *= BigInt(-1); + number -= BigInt(1); + } + if (typeof scale === 'number') { + const denominator = BigInt(Math.pow(10, scale)); + const quotient = number / denominator; + const remainder = number % denominator; + const n = Number(quotient) + (Number(remainder) / Number(denominator)); + return n; } - return number; + return Number(number); } /** @ignore */ @@ -217,7 +229,7 @@ export interface BN extends TypedArrayLike { * arithmetic operators, like `+`. Easy (and unsafe) way to convert BN to * number via `+bn_inst` */ - valueOf(): number; + valueOf(scale?: number): number; /** * Return the JSON representation of the bytes. Must be wrapped in double-quotes, * so it's compatible with JSON.stringify(). diff --git a/js/test/unit/bn-tests.ts b/js/test/unit/bn-tests.ts index c9606baf85942..dbda02198ea2e 100644 --- a/js/test/unit/bn-tests.ts +++ b/js/test/unit/bn-tests.ts @@ -83,4 +83,19 @@ describe(`BN`, () => { const d4 = toDecimal(new Uint32Array([0x9D91E773, 0x4BB90CED, 0xAB2354CC, 0x54278E9B])); expect(d4.toString()).toBe('111860543658909349380118287427608635251'); }); + + test(`valueOf for decimal numbers`, () => { + const n1 = new BN(new Uint32Array([0x00000001, 0x00000000, 0x00000000, 0x00000000]), false); + expect(n1.valueOf()).toBe(1); + const n2 = new BN(new Uint32Array([0xFFFFFFFE, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF]), true); + expect(n2.valueOf()).toBe(-2); + const n3 = new BN(new Uint32Array([0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF]), true); + expect(n3.valueOf()).toBe(-1); + const n4 = new BN(new Uint32Array([0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF]), true); + expect(n4.valueOf(1)).toBe(-0.1); + const n5 = new BN(new Uint32Array([0x00000000, 0x00000000, 0x00000000, 0x80000000]), false); + expect(n5.valueOf()).toBe(1.7014118346046923e+38); + // const n6 = new BN(new Uint32Array([0x00000000, 0x00000000, 0x00000000, 0x80000000]), false); + // expect(n6.valueOf(1)).toBe(1.7014118346046923e+37); + }); }); diff --git a/r/DESCRIPTION b/r/DESCRIPTION index 81b5d79258255..6062a8c4f4689 100644 --- a/r/DESCRIPTION +++ b/r/DESCRIPTION @@ -4,10 +4,10 @@ Version: 15.0.2.9000 Authors@R: c( person("Neal", "Richardson", email = "neal.p.richardson@gmail.com", role = c("aut")), person("Ian", "Cook", email = "ianmcook@gmail.com", role = c("aut")), - person("Nic", "Crane", email = "thisisnic@gmail.com", role = c("aut", "cre")), + person("Nic", "Crane", email = "thisisnic@gmail.com", role = c("aut")), person("Dewey", "Dunnington", role = c("aut"), email = "dewey@fishandwhistle.net", comment = c(ORCID = "0000-0002-9415-4582")), person("Romain", "Fran\u00e7ois", role = c("aut"), comment = c(ORCID = "0000-0002-2444-4226")), - person("Jonathan", "Keane", email = "jkeane@gmail.com", role = c("aut")), + person("Jonathan", "Keane", email = "jkeane@gmail.com", role = c("aut", "cre")), person("Drago\u0219", "Moldovan-Gr\u00fcnfeld", email = "dragos.mold@gmail.com", role = c("aut")), person("Jeroen", "Ooms", email = "jeroen@berkeley.edu", role = c("aut")), person("Jacob", "Wujciak-Jens", email = "jacob@wujciak.de", role = c("aut")),