Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-20339: [C++] Add residual filter support to swiss join #39487

Merged
merged 34 commits into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
7d87de6
Sketch basic filter logic
zanmato1984 Jan 1, 2024
cccf9e1
Implement materialize and evaluation for residual filter for swiss join
zanmato1984 Jan 2, 2024
8bfb8d7
Finish impl
zanmato1984 Jan 2, 2024
4322db5
Init residual filter in probe processor
zanmato1984 Jan 2, 2024
f30a382
Add match bitvector update for left joins
zanmato1984 Jan 3, 2024
ec304e4
Bug fix
zanmato1984 Jan 3, 2024
2e496db
Fix
zanmato1984 Jan 3, 2024
cbf1b58
Fix filter bitvector init timing
zanmato1984 Jan 4, 2024
5cf8edf
Refine
zanmato1984 Jan 4, 2024
8fc9f99
Refine
zanmato1984 Jan 4, 2024
50136d8
Fix many bugs
zanmato1984 Jan 6, 2024
80b5765
Revert cmake change
zanmato1984 Jan 6, 2024
d29de58
Revert cmake change
zanmato1984 Jan 6, 2024
1d86852
Remove file
zanmato1984 Jan 6, 2024
c58d125
Add some comments
zanmato1984 Jan 7, 2024
dda56e3
Refine structure
zanmato1984 Jan 7, 2024
d000621
Refine structure and add docs
zanmato1984 Jan 8, 2024
595f96f
WIP
zanmato1984 Jan 11, 2024
a687bb0
Some tests
zanmato1984 Jan 11, 2024
8f9db83
Literal false and null
zanmato1984 Jan 11, 2024
183bf86
Scalar true, false and null
zanmato1984 Jan 11, 2024
a1ade8e
More test
zanmato1984 Jan 12, 2024
9417120
Fix windows build
zanmato1984 Jan 12, 2024
e45be43
Fix build issue
zanmato1984 Jan 12, 2024
8e50c18
Fix comment
zanmato1984 Jan 12, 2024
04f3d19
Fix ubsan
zanmato1984 Jan 12, 2024
ed1dd4b
Fix bug
zanmato1984 Jan 12, 2024
245a6a6
Minor fix
zanmato1984 Jan 23, 2024
d651fd9
Minor fix
zanmato1984 Jan 23, 2024
7f14432
Minor fix
zanmato1984 Jan 23, 2024
7a49012
Add benchmark (#2)
zanmato1984 Jan 24, 2024
533c238
Fix lint
zanmato1984 Jan 24, 2024
629cdd7
Merge remote-tracking branch 'origin/main' into new-swiss-join-filter
zanmato1984 Feb 28, 2024
1070f4d
Address comments
zanmato1984 Feb 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 192 additions & 5 deletions cpp/src/arrow/acero/hash_join_benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ struct BenchmarkSettings {
double null_percentage = 0.0;
double cardinality = 1.0; // Proportion of distinct keys in build side
double selectivity = 1.0; // Probability of a match for a given row
int var_length_min = 2; // Minimal length of any var length types
int var_length_max = 20; // Maximum length of any var length types

Expression residual_filter = literal(true);
};

class JoinBenchmark {
Expand Down Expand Up @@ -79,8 +83,8 @@ class JoinBenchmark {
build_metadata["null_probability"] = std::to_string(settings.null_percentage);
build_metadata["min"] = std::to_string(min_build_value);
build_metadata["max"] = std::to_string(max_build_value);
build_metadata["min_length"] = "2";
build_metadata["max_length"] = "20";
build_metadata["min_length"] = settings.var_length_min;
build_metadata["max_length"] = settings.var_length_max;

std::unordered_map<std::string, std::string> probe_metadata;
probe_metadata["null_probability"] = std::to_string(settings.null_percentage);
Expand Down Expand Up @@ -126,10 +130,9 @@ class JoinBenchmark {
stats_.num_probe_rows = settings.num_probe_batches * settings.batch_size;

schema_mgr_ = std::make_unique<HashJoinSchema>();
Expression filter = literal(true);
DCHECK_OK(schema_mgr_->Init(settings.join_type, *l_batches_with_schema.schema,
left_keys, *r_batches_with_schema.schema, right_keys,
filter, "l_", "r_"));
settings.residual_filter, "l_", "r_"));

if (settings.use_basic_implementation) {
join_ = *HashJoinImpl::MakeBasic();
Expand Down Expand Up @@ -158,7 +161,7 @@ class JoinBenchmark {

DCHECK_OK(join_->Init(
&ctx_, settings.join_type, settings.num_threads, &(schema_mgr_->proj_maps[0]),
&(schema_mgr_->proj_maps[1]), std::move(key_cmp), std::move(filter),
&(schema_mgr_->proj_maps[1]), std::move(key_cmp), settings.residual_filter,
std::move(register_task_group_callback), std::move(start_task_group_callback),
[](int64_t, ExecBatch) { return Status::OK(); },
[](int64_t) { return Status::OK(); }));
Expand Down Expand Up @@ -308,6 +311,60 @@ static void BM_HashJoinBasic_NullPercentage(benchmark::State& st) {

HashJoinBasicBenchmarkImpl(st, settings);
}

template <typename... Args>
static void BM_HashJoinBasic_TrivialResidualFilter(benchmark::State& st,
JoinType join_type,
Expression residual_filter,
Args&&...) {
BenchmarkSettings settings;
settings.join_type = join_type;
settings.build_payload_types = {binary()};
settings.probe_payload_types = {binary()};

settings.use_basic_implementation = st.range(0);

settings.num_build_batches = 1024;
settings.num_probe_batches = 1024;

// Let payload column length from 1 to 100.
settings.var_length_min = 1;
settings.var_length_max = 100;

settings.residual_filter = std::move(residual_filter);

HashJoinBasicBenchmarkImpl(st, settings);
}

template <typename... Args>
static void BM_HashJoinBasic_ComplexResidualFilter(benchmark::State& st,
JoinType join_type, Args&&...) {
BenchmarkSettings settings;
settings.join_type = join_type;
settings.build_payload_types = {binary()};
settings.probe_payload_types = {binary()};

settings.use_basic_implementation = st.range(0);

settings.num_build_batches = 1024;
settings.num_probe_batches = 1024;

// Let payload column length from 1 to 100.
settings.var_length_min = 1;
settings.var_length_max = 100;

// Create filter referring payload columns from both sides.
// binary_length(probe_payload) + binary_length(build_payload) <= 2 * selectivity
settings.selectivity = static_cast<double>(st.range(1)) / 100.0;
using arrow::compute::call;
using arrow::compute::field_ref;
settings.residual_filter =
call("less_equal", {call("plus", {call("binary_length", {field_ref("lp0")}),
call("binary_length", {field_ref("rp0")})}),
literal(2 * settings.selectivity)});

HashJoinBasicBenchmarkImpl(st, settings);
}
#endif

std::vector<int64_t> hashtable_krows = benchmark::CreateRange(1, 4096, 8);
Expand Down Expand Up @@ -435,6 +492,136 @@ BENCHMARK(BM_HashJoinBasic_BuildParallelism)
BENCHMARK(BM_HashJoinBasic_NullPercentage)
->ArgNames({"Null Percentage"})
->DenseRange(0, 100, 10);

const char* use_basic_argname = "Use basic";
std::vector<int64_t> use_basic_arg = benchmark::CreateDenseRange(0, 1, 1);

std::vector<std::string> trivial_residual_filter_argnames = {use_basic_argname};
std::vector<std::vector<int64_t>> trivial_residual_filter_args = {use_basic_arg};

BENCHMARK_CAPTURE(BM_HashJoinBasic_TrivialResidualFilter, "Inner/Literal(true)",
JoinType::INNER, literal(true))
->ArgNames(trivial_residual_filter_argnames)
->ArgsProduct(trivial_residual_filter_args);

BENCHMARK_CAPTURE(BM_HashJoinBasic_TrivialResidualFilter, "Left Semi/Literal(true)",
JoinType::LEFT_SEMI, literal(true))
->ArgNames(trivial_residual_filter_argnames)
->ArgsProduct(trivial_residual_filter_args);

BENCHMARK_CAPTURE(BM_HashJoinBasic_TrivialResidualFilter, "Right Semi/Literal(true)",
JoinType::RIGHT_SEMI, literal(true))
->ArgNames(trivial_residual_filter_argnames)
->ArgsProduct(trivial_residual_filter_args);

BENCHMARK_CAPTURE(BM_HashJoinBasic_TrivialResidualFilter, "Left Anti/Literal(true)",
JoinType::LEFT_ANTI, literal(true))
->ArgNames(trivial_residual_filter_argnames)
->ArgsProduct(trivial_residual_filter_args);

BENCHMARK_CAPTURE(BM_HashJoinBasic_TrivialResidualFilter, "Right Anti/Literal(true)",
JoinType::RIGHT_ANTI, literal(true))
->ArgNames(trivial_residual_filter_argnames)
->ArgsProduct(trivial_residual_filter_args);

BENCHMARK_CAPTURE(BM_HashJoinBasic_TrivialResidualFilter, "Left Outer/Literal(true)",
JoinType::LEFT_OUTER, literal(true))
->ArgNames(trivial_residual_filter_argnames)
->ArgsProduct(trivial_residual_filter_args);

BENCHMARK_CAPTURE(BM_HashJoinBasic_TrivialResidualFilter, "Right Outer/Literal(true)",
JoinType::RIGHT_OUTER, literal(true))
->ArgNames(trivial_residual_filter_argnames)
->ArgsProduct(trivial_residual_filter_args);

BENCHMARK_CAPTURE(BM_HashJoinBasic_TrivialResidualFilter, "Full Outer/Literal(true)",
JoinType::FULL_OUTER, literal(true))
->ArgNames(trivial_residual_filter_argnames)
->ArgsProduct(trivial_residual_filter_args);

BENCHMARK_CAPTURE(BM_HashJoinBasic_TrivialResidualFilter, "Inner/Literal(false)",
JoinType::INNER, literal(false))
->ArgNames(trivial_residual_filter_argnames)
->ArgsProduct(trivial_residual_filter_args);

BENCHMARK_CAPTURE(BM_HashJoinBasic_TrivialResidualFilter, "Left Semi/Literal(false)",
JoinType::LEFT_SEMI, literal(false))
->ArgNames(trivial_residual_filter_argnames)
->ArgsProduct(trivial_residual_filter_args);

BENCHMARK_CAPTURE(BM_HashJoinBasic_TrivialResidualFilter, "Right Semi/Literal(false)",
JoinType::RIGHT_SEMI, literal(false))
->ArgNames(trivial_residual_filter_argnames)
->ArgsProduct(trivial_residual_filter_args);

BENCHMARK_CAPTURE(BM_HashJoinBasic_TrivialResidualFilter, "Left Anti/Literal(false)",
JoinType::LEFT_ANTI, literal(false))
->ArgNames(trivial_residual_filter_argnames)
->ArgsProduct(trivial_residual_filter_args);

BENCHMARK_CAPTURE(BM_HashJoinBasic_TrivialResidualFilter, "Right Anti/Literal(false)",
JoinType::RIGHT_ANTI, literal(false))
->ArgNames(trivial_residual_filter_argnames)
->ArgsProduct(trivial_residual_filter_args);

BENCHMARK_CAPTURE(BM_HashJoinBasic_TrivialResidualFilter, "Left Outer/Literal(false)",
JoinType::LEFT_OUTER, literal(false))
->ArgNames(trivial_residual_filter_argnames)
->ArgsProduct(trivial_residual_filter_args);

BENCHMARK_CAPTURE(BM_HashJoinBasic_TrivialResidualFilter, "Right Outer/Literal(false)",
JoinType::RIGHT_OUTER, literal(false))
->ArgNames(trivial_residual_filter_argnames)
->ArgsProduct(trivial_residual_filter_args);

BENCHMARK_CAPTURE(BM_HashJoinBasic_TrivialResidualFilter, "Full Outer/Literal(false)",
JoinType::FULL_OUTER, literal(false))
->ArgNames(trivial_residual_filter_argnames)
->ArgsProduct(trivial_residual_filter_args);

std::vector<std::string> complex_residual_filter_argnames = {use_basic_argname,
"Selectivity"};
std::vector<std::vector<int64_t>> complex_residual_filter_args = {
use_basic_arg, benchmark::CreateDenseRange(0, 100, 20)};

BENCHMARK_CAPTURE(BM_HashJoinBasic_ComplexResidualFilter, "Inner", JoinType::INNER)
->ArgNames(complex_residual_filter_argnames)
->ArgsProduct(complex_residual_filter_args);

BENCHMARK_CAPTURE(BM_HashJoinBasic_ComplexResidualFilter, "Left Semi",
JoinType::LEFT_SEMI)
->ArgNames(complex_residual_filter_argnames)
->ArgsProduct(complex_residual_filter_args);

BENCHMARK_CAPTURE(BM_HashJoinBasic_ComplexResidualFilter, "Right Semi",
JoinType::RIGHT_SEMI)
->ArgNames(complex_residual_filter_argnames)
->ArgsProduct(complex_residual_filter_args);

BENCHMARK_CAPTURE(BM_HashJoinBasic_ComplexResidualFilter, "Left Anti",
JoinType::LEFT_ANTI)
->ArgNames(complex_residual_filter_argnames)
->ArgsProduct(complex_residual_filter_args);

BENCHMARK_CAPTURE(BM_HashJoinBasic_ComplexResidualFilter, "Right Anti",
JoinType::RIGHT_ANTI)
->ArgNames(complex_residual_filter_argnames)
->ArgsProduct(complex_residual_filter_args);

BENCHMARK_CAPTURE(BM_HashJoinBasic_ComplexResidualFilter, "Left Outer",
JoinType::LEFT_OUTER)
->ArgNames(complex_residual_filter_argnames)
->ArgsProduct(complex_residual_filter_args);

BENCHMARK_CAPTURE(BM_HashJoinBasic_ComplexResidualFilter, "Right Outer",
JoinType::RIGHT_OUTER)
->ArgNames(complex_residual_filter_argnames)
->ArgsProduct(complex_residual_filter_args);

BENCHMARK_CAPTURE(BM_HashJoinBasic_ComplexResidualFilter, "Full Outer",
JoinType::FULL_OUTER)
->ArgNames(complex_residual_filter_argnames)
->ArgsProduct(complex_residual_filter_args);
#else

BENCHMARK_CAPTURE(BM_HashJoinBasic_KeyTypes, "{int32}", {int32()})
Expand Down
6 changes: 2 additions & 4 deletions cpp/src/arrow/acero/hash_join_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -740,13 +740,11 @@ class HashJoinNode : public ExecNode, public TracedNode {
// Create hash join implementation object
// SwissJoin does not support:
// a) 64-bit string offsets
// b) residual predicates
// c) dictionaries
// b) dictionaries
//
bool use_swiss_join;
#if ARROW_LITTLE_ENDIAN
use_swiss_join = (filter == literal(true)) && !schema_mgr->HasDictionaries() &&
!schema_mgr->HasLargeBinary();
use_swiss_join = !schema_mgr->HasDictionaries() && !schema_mgr->HasLargeBinary();
#else
use_swiss_join = false;
#endif
Expand Down
Loading
Loading