diff --git a/cpp/include/cudf/io/types.hpp b/cpp/include/cudf/io/types.hpp index a34881942ce..9e171a62f78 100644 --- a/cpp/include/cudf/io/types.hpp +++ b/cpp/include/cudf/io/types.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2024, NVIDIA CORPORATION. + * Copyright (c) 2019-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -277,13 +277,24 @@ struct column_name_info { struct table_metadata { std::vector schema_info; //!< Detailed name information for the entire output hierarchy - std::vector num_rows_per_source; //!< Number of rows read from each data source. + std::vector num_rows_per_source; //!< Number of rows read from each data source //!< Currently only computed for Parquet readers if no - //!< AST filters being used. Empty vector otherwise. + //!< AST filters being used. Empty vector otherwise std::map user_data; //!< Format-dependent metadata of the first input //!< file as key-values pairs (deprecated) std::vector> per_file_user_data; //!< Per file format-dependent metadata as key-values pairs + + // The following variables are currently only computed for Parquet reader + size_type num_input_row_groups{0}; //!< Total number of input row groups across all data sources + std::optional + num_row_groups_after_stats_filter; //!< Number of remaining row groups after stats filter. + //!< std::nullopt if no filtering done. Currently only + //!< reported by Parquet readers + std::optional + num_row_groups_after_bloom_filter; //!< Number of remaining row groups after bloom filter. + //!< std::nullopt if no filtering done. Currently only + //!< reported by Parquet readers }; /** diff --git a/cpp/src/io/parquet/bloom_filter_reader.cu b/cpp/src/io/parquet/bloom_filter_reader.cu index af524e1f70a..a883981a467 100644 --- a/cpp/src/io/parquet/bloom_filter_reader.cu +++ b/cpp/src/io/parquet/bloom_filter_reader.cu @@ -599,9 +599,11 @@ std::vector aggregate_reader_metadata::get_parquet_types( return parquet_types; } -std::optional>> aggregate_reader_metadata::apply_bloom_filters( +std::pair>>, bool> +aggregate_reader_metadata::apply_bloom_filters( host_span const> sources, host_span const> input_row_group_indices, + size_type total_row_groups, host_span output_dtypes, host_span output_column_schemas, std::reference_wrapper filter, @@ -610,17 +612,6 @@ std::optional>> aggregate_reader_metadata::ap // Number of input table columns auto const num_input_columns = static_cast(output_dtypes.size()); - // Total number of row groups after StatsAST filtration - auto const total_row_groups = std::accumulate( - input_row_group_indices.begin(), - input_row_group_indices.end(), - size_t{0}, - [](size_t sum, auto const& per_file_row_groups) { return sum + per_file_row_groups.size(); }); - - // Check if we have less than 2B total row groups. - CUDF_EXPECTS(total_row_groups <= std::numeric_limits::max(), - "Total number of row groups exceed the size_type's limit"); - // Collect equality literals for each input table column auto const equality_literals = equality_literals_collector{filter.get(), num_input_columns}.get_equality_literals(); @@ -635,7 +626,7 @@ std::optional>> aggregate_reader_metadata::ap [](auto& eq_literals) { return not eq_literals.empty(); }); // Return early if no column with equality predicate(s) - if (equality_col_schemas.empty()) { return std::nullopt; } + if (equality_col_schemas.empty()) { return {std::nullopt, false}; } // Required alignment: // https://github.com/NVIDIA/cuCollections/blob/deab5799f3e4226cb8a49acf2199c03b14941ee4/include/cuco/detail/bloom_filter/bloom_filter_impl.cuh#L55-L67 @@ -654,8 +645,8 @@ std::optional>> aggregate_reader_metadata::ap auto bloom_filter_data = read_bloom_filters( sources, input_row_group_indices, equality_col_schemas, total_row_groups, stream, aligned_mr); - // No bloom filter buffers, return the original row group indices - if (bloom_filter_data.empty()) { return std::nullopt; } + // No bloom filter buffers, return early + if (bloom_filter_data.empty()) { return {std::nullopt, false}; } // Get parquet types for the predicate columns auto const parquet_types = get_parquet_types(input_row_group_indices, equality_col_schemas); @@ -676,8 +667,10 @@ std::optional>> aggregate_reader_metadata::ap h_bloom_filter_spans, stream, cudf::get_current_device_resource_ref()); // Create a bloom filter query table caster - bloom_filter_caster const bloom_filter_col{ - bloom_filter_spans, parquet_types, total_row_groups, equality_col_schemas.size()}; + bloom_filter_caster const bloom_filter_col{bloom_filter_spans, + parquet_types, + static_cast(total_row_groups), + equality_col_schemas.size()}; // Converts bloom filter membership for equality predicate columns to a table // containing a column for each `col[i] == literal` predicate to be evaluated. @@ -714,10 +707,11 @@ std::optional>> aggregate_reader_metadata::ap // Filter bloom filter membership table with the BloomfilterAST expression and collect // filtered row group indices - return collect_filtered_row_group_indices(bloom_filter_membership_table, - bloom_filter_expr.get_bloom_filter_expr(), - input_row_group_indices, - stream); + return {collect_filtered_row_group_indices(bloom_filter_membership_table, + bloom_filter_expr.get_bloom_filter_expr(), + input_row_group_indices, + stream), + true}; } } // namespace cudf::io::parquet::detail diff --git a/cpp/src/io/parquet/predicate_pushdown.cpp b/cpp/src/io/parquet/predicate_pushdown.cpp index 0e307bac097..1508b7eef8b 100644 --- a/cpp/src/io/parquet/predicate_pushdown.cpp +++ b/cpp/src/io/parquet/predicate_pushdown.cpp @@ -388,40 +388,17 @@ class stats_expression_converter : public ast::detail::expression_transformer { }; } // namespace -std::optional>> aggregate_reader_metadata::filter_row_groups( +std::pair>>, surviving_row_group_metrics> +aggregate_reader_metadata::filter_row_groups( host_span const> sources, - host_span const> row_group_indices, + host_span const> input_row_group_indices, + size_type total_row_groups, host_span output_dtypes, host_span output_column_schemas, std::reference_wrapper filter, rmm::cuda_stream_view stream) const { auto mr = cudf::get_current_device_resource_ref(); - // Create row group indices. - std::vector> all_row_group_indices; - host_span const> input_row_group_indices; - if (row_group_indices.empty()) { - std::transform(per_file_metadata.cbegin(), - per_file_metadata.cend(), - std::back_inserter(all_row_group_indices), - [](auto const& file_meta) { - std::vector rg_idx(file_meta.row_groups.size()); - std::iota(rg_idx.begin(), rg_idx.end(), 0); - return rg_idx; - }); - input_row_group_indices = host_span const>(all_row_group_indices); - } else { - input_row_group_indices = row_group_indices; - } - auto const total_row_groups = std::accumulate( - input_row_group_indices.begin(), - input_row_group_indices.end(), - size_t{0}, - [](size_t sum, auto const& per_file_row_groups) { return sum + per_file_row_groups.size(); }); - - // Check if we have less than 2B total row groups. - CUDF_EXPECTS(total_row_groups <= std::numeric_limits::max(), - "Total number of row groups exceed the size_type's limit"); // Converts Column chunk statistics to a table // where min(col[i]) = columns[i*2], max(col[i])=columns[i*2+1] @@ -451,16 +428,22 @@ std::optional>> aggregate_reader_metadata::fi // Converts AST to StatsAST with reference to min, max columns in above `stats_table`. stats_expression_converter const stats_expr{filter.get(), static_cast(output_dtypes.size())}; - auto stats_ast = stats_expr.get_stats_expr(); - auto predicate_col = cudf::detail::compute_column(stats_table, stats_ast.get(), stream, mr); - auto predicate = predicate_col->view(); - CUDF_EXPECTS(predicate.type().id() == cudf::type_id::BOOL8, - "Filter expression must return a boolean column"); // Filter stats table with StatsAST expression and collect filtered row group indices auto const filtered_row_group_indices = collect_filtered_row_group_indices( stats_table, stats_expr.get_stats_expr(), input_row_group_indices, stream); + // Number of surviving row groups after applying stats filter + auto const num_stats_filtered_row_groups = + filtered_row_group_indices.has_value() + ? std::accumulate(filtered_row_group_indices.value().cbegin(), + filtered_row_group_indices.value().cend(), + size_type{0}, + [](auto& sum, auto const& per_file_row_groups) { + return sum + per_file_row_groups.size(); + }) + : total_row_groups; + // Span of row groups to apply bloom filtering on. auto const bloom_filter_input_row_groups = filtered_row_group_indices.has_value() @@ -468,12 +451,32 @@ std::optional>> aggregate_reader_metadata::fi : input_row_group_indices; // Apply bloom filtering on the bloom filter input row groups - auto const bloom_filtered_row_groups = apply_bloom_filters( - sources, bloom_filter_input_row_groups, output_dtypes, output_column_schemas, filter, stream); + auto const [bloom_filtered_row_groups, bloom_filters_exist] = + apply_bloom_filters(sources, + bloom_filter_input_row_groups, + num_stats_filtered_row_groups, + output_dtypes, + output_column_schemas, + filter, + stream); + + // Number of surviving row groups after applying bloom filter + auto const num_bloom_filtered_row_groups = + bloom_filters_exist + ? (bloom_filtered_row_groups.has_value() + ? std::make_optional(std::accumulate(bloom_filtered_row_groups.value().cbegin(), + bloom_filtered_row_groups.value().cend(), + size_type{0}, + [](auto& sum, auto const& per_file_row_groups) { + return sum + per_file_row_groups.size(); + })) + : std::make_optional(num_stats_filtered_row_groups)) + : std::nullopt; // Return bloom filtered row group indices iff collected - return bloom_filtered_row_groups.has_value() ? bloom_filtered_row_groups - : filtered_row_group_indices; + return { + bloom_filtered_row_groups.has_value() ? bloom_filtered_row_groups : filtered_row_group_indices, + {std::make_optional(num_stats_filtered_row_groups), num_bloom_filtered_row_groups}}; } // convert column named expression to column index reference expression diff --git a/cpp/src/io/parquet/reader_impl.cpp b/cpp/src/io/parquet/reader_impl.cpp index 9dd4e19de52..87e358e89f8 100644 --- a/cpp/src/io/parquet/reader_impl.cpp +++ b/cpp/src/io/parquet/reader_impl.cpp @@ -610,6 +610,17 @@ table_with_metadata reader::impl::read_chunk_internal(read_mode mode) auto out_columns = std::vector>{}; out_columns.reserve(_output_buffers.size()); + // Copy number of total input row groups and number of surviving row groups from predicate + // pushdown. + out_metadata.num_input_row_groups = _file_itm_data.num_input_row_groups; + // Copy the number surviving row groups from each predicate pushdown only if the filter has value. + if (_expr_conv.get_converted_expr().has_value()) { + out_metadata.num_row_groups_after_stats_filter = + _file_itm_data.surviving_row_groups.after_stats_filter; + out_metadata.num_row_groups_after_bloom_filter = + _file_itm_data.surviving_row_groups.after_bloom_filter; + } + // no work to do (this can happen on the first pass if we have no rows to read) if (!has_more_work()) { // Check if number of rows per source should be included in output metadata. diff --git a/cpp/src/io/parquet/reader_impl_chunking.hpp b/cpp/src/io/parquet/reader_impl_chunking.hpp index 4a773fbced1..294eaf9ac16 100644 --- a/cpp/src/io/parquet/reader_impl_chunking.hpp +++ b/cpp/src/io/parquet/reader_impl_chunking.hpp @@ -47,6 +47,11 @@ struct file_intermediate_data { // partial sum of the number of rows per data source std::vector exclusive_sum_num_rows_per_source{}; + size_type num_input_row_groups{0}; // total number of input row groups across all data sources + + // struct containing the number of remaining row groups after each predicate pushdown filter + surviving_row_group_metrics surviving_row_groups; + size_t _current_input_pass{0}; // current input pass index size_t _output_chunk_count{0}; // how many output chunks we have produced diff --git a/cpp/src/io/parquet/reader_impl_helpers.cpp b/cpp/src/io/parquet/reader_impl_helpers.cpp index 7d3b6a39d5b..768ca384352 100644 --- a/cpp/src/io/parquet/reader_impl_helpers.cpp +++ b/cpp/src/io/parquet/reader_impl_helpers.cpp @@ -408,10 +408,16 @@ int64_t aggregate_reader_metadata::calc_num_rows() const size_type aggregate_reader_metadata::calc_num_row_groups() const { - return std::accumulate( - per_file_metadata.cbegin(), per_file_metadata.cend(), 0, [](auto& sum, auto& pfm) { + auto const total_row_groups = std::accumulate( + per_file_metadata.cbegin(), per_file_metadata.cend(), size_t{0}, [](size_t& sum, auto& pfm) { return sum + pfm.row_groups.size(); }); + + // Check if we have less than 2B total row groups. + CUDF_EXPECTS(total_row_groups <= std::numeric_limits::max(), + "Total number of row groups exceed the size_type's limit"); + + return static_cast(total_row_groups); } // Copies info from the column and offset indexes into the passed in row_group_info. @@ -1029,7 +1035,12 @@ std::vector aggregate_reader_metadata::get_pandas_index_names() con return names; } -std::tuple, std::vector> +std::tuple, + std::vector, + size_type, + surviving_row_group_metrics> aggregate_reader_metadata::select_row_groups( host_span const> sources, host_span const> row_group_indices, @@ -1040,17 +1051,63 @@ aggregate_reader_metadata::select_row_groups( std::optional> filter, rmm::cuda_stream_view stream) const { + // Compute total number of input row groups + size_type total_row_groups = [&]() { + if (not row_group_indices.empty()) { + size_t const total_row_groups = + std::accumulate(row_group_indices.begin(), + row_group_indices.end(), + size_t{0}, + [](size_t& sum, auto const& pfm) { return sum + pfm.size(); }); + + // Check if we have less than 2B total row groups. + CUDF_EXPECTS(total_row_groups <= std::numeric_limits::max(), + "Total number of row groups exceed the size_type's limit"); + return static_cast(total_row_groups); + } else { + return num_row_groups; + } + }(); + + // Pair to store the number of row groups after stats and bloom filtering respectively. Initialize + // to total_row_groups. + surviving_row_group_metrics num_row_groups_after_filters{}; + std::optional>> filtered_row_group_indices; // if filter is not empty, then gather row groups to read after predicate pushdown if (filter.has_value()) { - filtered_row_group_indices = filter_row_groups( - sources, row_group_indices, output_dtypes, output_column_schemas, filter.value(), stream); + // Span of input row group indices for predicate pushdown + host_span const> input_row_group_indices; + std::vector> all_row_group_indices; + if (row_group_indices.empty()) { + std::transform(per_file_metadata.cbegin(), + per_file_metadata.cend(), + std::back_inserter(all_row_group_indices), + [](auto const& file_meta) { + std::vector rg_idx(file_meta.row_groups.size()); + std::iota(rg_idx.begin(), rg_idx.end(), 0); + return rg_idx; + }); + input_row_group_indices = host_span const>(all_row_group_indices); + } else { + input_row_group_indices = row_group_indices; + } + // Predicate pushdown: Filter row groups using stats and bloom filters + std::tie(filtered_row_group_indices, num_row_groups_after_filters) = + filter_row_groups(sources, + input_row_group_indices, + total_row_groups, + output_dtypes, + output_column_schemas, + filter.value(), + stream); if (filtered_row_group_indices.has_value()) { row_group_indices = host_span const>(filtered_row_group_indices.value()); } } - std::vector selection; + + // Compute the number of rows to read and skip auto [rows_to_skip, rows_to_read] = [&]() { if (not row_group_indices.empty()) { return std::pair{}; } auto const from_opts = cudf::io::detail::skip_rows_num_rows_from_options( @@ -1061,7 +1118,9 @@ aggregate_reader_metadata::select_row_groups( static_cast(from_opts.second)}; }(); - // Get number of rows in each data source + // Vector to hold the `row_group_info` of selected row groups + std::vector selection; + // Number of rows in each data source std::vector num_rows_per_source(per_file_metadata.size(), 0); if (!row_group_indices.empty()) { @@ -1083,6 +1142,10 @@ aggregate_reader_metadata::select_row_groups( } } } else { + // Reset and recompute input row group count to adjust for num_rows and skip_rows. Here, the + // output from predicate pushdown was empty. i.e., no row groups filtered. + total_row_groups = 0; + size_type count = 0; for (size_t src_idx = 0; src_idx < per_file_metadata.size(); ++src_idx) { auto const& fmd = per_file_metadata[src_idx]; @@ -1093,6 +1156,9 @@ aggregate_reader_metadata::select_row_groups( auto const chunk_start_row = count; count += rg.num_rows; if (count > rows_to_skip || count == 0) { + // Keep this row group, increase count + total_row_groups++; + // start row of this row group adjusted with rows_to_skip num_rows_per_source[src_idx] += count; num_rows_per_source[src_idx] -= @@ -1113,9 +1179,24 @@ aggregate_reader_metadata::select_row_groups( } } } + + // If filter had a value and no row groups were filtered, set the number of row groups after + // filters to the number of adjusted input row groups + auto const after_stats_filter = num_row_groups_after_filters.after_stats_filter.has_value() + ? std::make_optional(total_row_groups) + : std::nullopt; + auto const after_bloom_filter = num_row_groups_after_filters.after_bloom_filter.has_value() + ? std::make_optional(total_row_groups) + : std::nullopt; + num_row_groups_after_filters = {after_stats_filter, after_bloom_filter}; } - return {rows_to_skip, rows_to_read, std::move(selection), std::move(num_rows_per_source)}; + return {rows_to_skip, + rows_to_read, + std::move(selection), + std::move(num_rows_per_source), + total_row_groups, + std::move(num_row_groups_after_filters)}; } std::tuple, diff --git a/cpp/src/io/parquet/reader_impl_helpers.hpp b/cpp/src/io/parquet/reader_impl_helpers.hpp index ba5e53e3104..c4372b2c1ff 100644 --- a/cpp/src/io/parquet/reader_impl_helpers.hpp +++ b/cpp/src/io/parquet/reader_impl_helpers.hpp @@ -125,6 +125,14 @@ struct arrow_schema_data_types { data_type type{type_id::EMPTY}; }; +/** + * @brief Struct to store the number of row groups surviving each predicate pushdown filter. + */ +struct surviving_row_group_metrics { + std::optional after_stats_filter; // number of surviving row groups after stats filter + std::optional after_bloom_filter; // number of surviving row groups after bloom filter +}; + class aggregate_reader_metadata { std::vector per_file_metadata; std::vector> keyval_maps; @@ -358,40 +366,47 @@ class aggregate_reader_metadata { * @brief Filters the row groups based on predicate filter * * @param sources Lists of input datasources - * @param row_group_indices Lists of row groups to read, one per source + * @param input_row_group_indices Lists of input row groups, one per source + * @param total_row_groups Total number of row groups in `input_row_group_indices` * @param output_dtypes Datatypes of output columns * @param output_column_schemas schema indices of output columns * @param filter AST expression to filter row groups based on Column chunk statistics * @param stream CUDA stream used for device memory operations and kernel launches - * @return Filtered row group indices, if any is filtered + * @return A pair of a list of filtered row group indices if any are filtered, and a struct + * containing the number of row groups surviving each predicate pushdown filter */ - [[nodiscard]] std::optional>> filter_row_groups( - host_span const> sources, - host_span const> row_group_indices, - host_span output_dtypes, - host_span output_column_schemas, - std::reference_wrapper filter, - rmm::cuda_stream_view stream) const; + [[nodiscard]] std::pair>>, + surviving_row_group_metrics> + filter_row_groups(host_span const> sources, + host_span const> input_row_group_indices, + size_type total_row_groups, + host_span output_dtypes, + host_span output_column_schemas, + std::reference_wrapper filter, + rmm::cuda_stream_view stream) const; /** * @brief Filters the row groups using bloom filters * * @param sources Dataset sources - * @param row_group_indices Lists of input row groups to read, one per source + * @param input_row_group_indices Lists of input row groups, one per source + * @param total_row_groups Total number of row groups in `input_row_group_indices` * @param output_dtypes Datatypes of output columns * @param output_column_schemas schema indices of output columns * @param filter AST expression to filter row groups based on bloom filter membership * @param stream CUDA stream used for device memory operations and kernel launches * - * @return Filtered row group indices, if any is filtered + * @return A pair of filtered row group indices if any is filtered, and a boolean indicating if + * bloom filtering was applied */ - [[nodiscard]] std::optional>> apply_bloom_filters( - host_span const> sources, - host_span const> input_row_group_indices, - host_span output_dtypes, - host_span output_column_schemas, - std::reference_wrapper filter, - rmm::cuda_stream_view stream) const; + [[nodiscard]] std::pair>>, bool> + apply_bloom_filters(host_span const> sources, + host_span const> input_row_group_indices, + size_type total_row_groups, + host_span output_dtypes, + host_span output_column_schemas, + std::reference_wrapper filter, + rmm::cuda_stream_view stream) const; /** * @brief Filters and reduces down to a selection of row groups @@ -408,9 +423,15 @@ class aggregate_reader_metadata { * @param filter Optional AST expression to filter row groups based on Column chunk statistics * @param stream CUDA stream used for device memory operations and kernel launches * @return A tuple of corrected row_start, row_count, list of row group indexes and its - * starting row, and list of number of rows per source + * starting row, list of number of rows per source, number of input row groups, and a + * struct containing the number of row groups surviving each predicate pushdown filter */ - [[nodiscard]] std::tuple, std::vector> + [[nodiscard]] std::tuple, + std::vector, + size_type, + surviving_row_group_metrics> select_row_groups(host_span const> sources, host_span const> row_group_indices, int64_t row_start, diff --git a/cpp/src/io/parquet/reader_impl_preprocess.cu b/cpp/src/io/parquet/reader_impl_preprocess.cu index 3874346e471..b6134947b0c 100644 --- a/cpp/src/io/parquet/reader_impl_preprocess.cu +++ b/cpp/src/io/parquet/reader_impl_preprocess.cu @@ -1285,7 +1285,9 @@ void reader::impl::preprocess_file(read_mode mode) std::tie(_file_itm_data.global_skip_rows, _file_itm_data.global_num_rows, _file_itm_data.row_groups, - _file_itm_data.num_rows_per_source) = + _file_itm_data.num_rows_per_source, + _file_itm_data.num_input_row_groups, + _file_itm_data.surviving_row_groups) = _metadata->select_row_groups(_sources, _options.row_group_indices, _options.skip_rows, diff --git a/cpp/tests/io/parquet_reader_test.cpp b/cpp/tests/io/parquet_reader_test.cpp index 177e6163d4f..b96c423917a 100644 --- a/cpp/tests/io/parquet_reader_test.cpp +++ b/cpp/tests/io/parquet_reader_test.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * Copyright (c) 2023-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -1328,6 +1328,26 @@ TEST_F(ParquetReaderTest, ReorderedReadMultipleFiles) CUDF_TEST_EXPECT_TABLES_EQUAL(sliced[1], swapped2); } +TEST_F(ParquetReaderTest, NoFilter) +{ + srand(31337); + auto expected = create_random_fixed_table(9, 9, false); + + auto filepath = temp_env->get_temp_filepath("FilterSimple.parquet"); + cudf::io::parquet_writer_options args = + cudf::io::parquet_writer_options::builder(cudf::io::sink_info{filepath}, *expected); + cudf::io::write_parquet(args); + + cudf::io::parquet_reader_options read_opts = + cudf::io::parquet_reader_options::builder(cudf::io::source_info{filepath}); + auto result = cudf::io::read_parquet(read_opts); + + CUDF_TEST_EXPECT_TABLES_EQUAL(*result.tbl, *expected); + EXPECT_EQ(result.metadata.num_input_row_groups, 1); + EXPECT_FALSE(result.metadata.num_row_groups_after_stats_filter.has_value()); + EXPECT_FALSE(result.metadata.num_row_groups_after_bloom_filter.has_value()); +} + TEST_F(ParquetReaderTest, FilterSimple) { srand(31337); @@ -2681,52 +2701,107 @@ TYPED_TEST(ParquetReaderPredicatePushdownTest, FilterTyped) auto const [src, filepath] = create_parquet_typed_with_stats("FilterTyped.parquet"); auto const written_table = src.view(); + auto const col_name_0 = cudf::ast::column_name_reference("col0"); + auto const col_ref_0 = cudf::ast::column_reference(0); - // Filtering AST - auto literal_value = []() { - if constexpr (cudf::is_timestamp()) { - // table[0] < 10000 timestamp days/seconds/milliseconds/microseconds/nanoseconds - return cudf::timestamp_scalar(T(typename T::duration(10000))); // i (0-20,000) - } else if constexpr (cudf::is_duration()) { - // table[0] < 10000 day/seconds/milliseconds/microseconds/nanoseconds - return cudf::duration_scalar(T(10000)); // i (0-20,000) - } else if constexpr (std::is_same_v) { - // table[0] < "000010000" - return cudf::string_scalar("000010000"); // i (0-20,000) + auto const test_predicate_pushdown = [&](cudf::ast::operation const& filter_expression, + cudf::ast::operation const& ref_filter, + cudf::size_type expected_total_row_groups, + cudf::size_type expected_stats_filtered_row_groups) { + // Expected result + auto const predicate = cudf::compute_column(written_table, ref_filter); + EXPECT_EQ(predicate->view().type().id(), cudf::type_id::BOOL8) + << "Predicate filter should return a boolean"; + auto const expected = cudf::apply_boolean_mask(written_table, *predicate); + + // Reading with Predicate Pushdown + cudf::io::parquet_reader_options read_opts = + cudf::io::parquet_reader_options::builder(cudf::io::source_info{filepath}) + .filter(filter_expression); + auto const result = cudf::io::read_parquet(read_opts); + auto const result_table = result.tbl->view(); + + // Tests + EXPECT_EQ(static_cast(written_table.column(0).type().id()), + static_cast(result_table.column(0).type().id())) + << "col0 type mismatch"; + + // To make sure AST filters out some elements if row groups must be filtered + if (expected_stats_filtered_row_groups < expected_total_row_groups) { + EXPECT_LT(expected->num_rows(), written_table.num_rows()); } else { - // table[0] < 0 or 100u - return cudf::numeric_scalar((100 - 100 * std::is_signed_v)); // i/100 (-100-100/ 0-200) + EXPECT_LE(expected->num_rows(), written_table.num_rows()); } - }(); - auto literal = cudf::ast::literal(literal_value); - auto col_name_0 = cudf::ast::column_name_reference("col0"); - auto filter_expression = cudf::ast::operation(cudf::ast::ast_operator::LESS, col_name_0, literal); - auto col_ref_0 = cudf::ast::column_reference(0); - auto ref_filter = cudf::ast::operation(cudf::ast::ast_operator::LESS, col_ref_0, literal); - - // Expected result - auto predicate = cudf::compute_column(written_table, ref_filter); - EXPECT_EQ(predicate->view().type().id(), cudf::type_id::BOOL8) - << "Predicate filter should return a boolean"; - auto expected = cudf::apply_boolean_mask(written_table, *predicate); + CUDF_TEST_EXPECT_TABLES_EQUAL(expected->view(), result_table); + EXPECT_EQ(result.metadata.num_input_row_groups, expected_total_row_groups); + EXPECT_TRUE(result.metadata.num_row_groups_after_stats_filter.has_value()); + EXPECT_EQ(result.metadata.num_row_groups_after_stats_filter.value(), + expected_stats_filtered_row_groups); + EXPECT_FALSE(result.metadata.num_row_groups_after_bloom_filter.has_value()); + }; - // Reading with Predicate Pushdown - cudf::io::parquet_reader_options read_opts = - cudf::io::parquet_reader_options::builder(cudf::io::source_info{filepath}) - .filter(filter_expression); - auto result = cudf::io::read_parquet(read_opts); - auto result_table = result.tbl->view(); + // The `literal_value` and stats should filter out 2 out of 4 row groups. + { + auto constexpr expected_total_row_groups = 4; + auto constexpr expected_stats_filtered_row_groups = 2; + + // Filtering AST + auto literal_value = []() { + if constexpr (cudf::is_timestamp()) { + // table[0] < 10000 timestamp days/seconds/milliseconds/microseconds/nanoseconds + return cudf::timestamp_scalar(T(typename T::duration(10000))); // i (0-20,000) + } else if constexpr (cudf::is_duration()) { + // table[0] < 10000 day/seconds/milliseconds/microseconds/nanoseconds + return cudf::duration_scalar(T(10000)); // i (0-20,000) + } else if constexpr (std::is_same_v) { + // table[0] < "000010000" + return cudf::string_scalar("000010000"); // i (0-20,000) + } else { + // table[0] < 0 or 100u + return cudf::numeric_scalar( + (100 - 100 * std::is_signed_v)); // i/100 (-100-100/ 0-200) + } + }(); + + auto const literal = cudf::ast::literal(literal_value); + auto const filter_expression = + cudf::ast::operation(cudf::ast::ast_operator::LESS, col_name_0, literal); + auto const ref_filter = cudf::ast::operation(cudf::ast::ast_operator::LESS, col_ref_0, literal); + test_predicate_pushdown( + filter_expression, ref_filter, expected_total_row_groups, expected_stats_filtered_row_groups); + } - // tests - EXPECT_EQ(int(written_table.column(0).type().id()), int(result_table.column(0).type().id())) - << "col0 type mismatch"; - // To make sure AST filters out some elements - EXPECT_LT(expected->num_rows(), written_table.num_rows()); - EXPECT_EQ(result_table.num_rows(), expected->num_rows()); - EXPECT_EQ(result_table.num_columns(), expected->num_columns()); - CUDF_TEST_EXPECT_TABLES_EQUAL(expected->view(), result_table); + // The `literal_value` and stats should not filter any of the 4 row groups. + { + auto constexpr expected_total_row_groups = 4; + auto constexpr expected_stats_filtered_row_groups = 4; + + // Filtering AST + auto literal_value = []() { + if constexpr (cudf::is_timestamp()) { + return cudf::timestamp_scalar(T(typename T::duration(20000))); + } else if constexpr (cudf::is_duration()) { + return cudf::duration_scalar(T(20000)); + } else if constexpr (std::is_same_v) { + return cudf::string_scalar("000020000"); + } else { + return cudf::numeric_scalar(std::numeric_limits::max()); + } + }(); + + auto const literal = cudf::ast::literal(literal_value); + auto const filter_expression = + cudf::ast::operation(cudf::ast::ast_operator::LESS_EQUAL, col_name_0, literal); + auto const ref_filter = + cudf::ast::operation(cudf::ast::ast_operator::LESS_EQUAL, col_ref_0, literal); + test_predicate_pushdown( + filter_expression, ref_filter, expected_total_row_groups, expected_stats_filtered_row_groups); + } } +////////////////////// +// wide tables tests + // The test below requires several minutes to complete with memcheck, thus it is disabled by // default. TEST_F(ParquetReaderTest, DISABLED_ListsWideTable)