Skip to content

Commit

Permalink
More overflow safe swiss table.
Browse files Browse the repository at this point in the history
  • Loading branch information
zanmato1984 committed Feb 12, 2025
1 parent 18e8f50 commit af1c470
Show file tree
Hide file tree
Showing 5 changed files with 211 additions and 205 deletions.
52 changes: 25 additions & 27 deletions cpp/src/arrow/acero/swiss_join.cc
Original file line number Diff line number Diff line change
Expand Up @@ -643,37 +643,36 @@ void SwissTableMerge::MergePartition(SwissTable* target, const SwissTable* sourc
//
int source_group_id_bits =
SwissTable::num_groupid_bits_from_log_blocks(source->log_blocks());
uint64_t source_group_id_mask = ~0ULL >> (64 - source_group_id_bits);
int64_t source_block_bytes = source_group_id_bits + 8;
int source_block_bytes =
SwissTable::num_block_bytes_from_num_groupid_bits(source_group_id_bits);
ARROW_DCHECK(source_block_bytes % sizeof(uint64_t) == 0);

// Compute index of the last block in target that corresponds to the given
// partition.
//
ARROW_DCHECK(num_partition_bits <= target->log_blocks());
int64_t target_max_block_id =
uint32_t target_max_block_id =
((partition_id + 1) << (target->log_blocks() - num_partition_bits)) - 1;

overflow_group_ids->clear();
overflow_hashes->clear();

// For each source block...
int64_t source_blocks = 1LL << source->log_blocks();
for (int64_t block_id = 0; block_id < source_blocks; ++block_id) {
uint8_t* block_bytes = source->blocks() + block_id * source_block_bytes;
uint32_t source_blocks = 1 << source->log_blocks();
for (uint32_t block_id = 0; block_id < source_blocks; ++block_id) {
const uint8_t* block_bytes = source->block_data(block_id, source_block_bytes);
uint64_t block = *reinterpret_cast<const uint64_t*>(block_bytes);

// For each non-empty source slot...
constexpr uint64_t kHighBitOfEachByte = 0x8080808080808080ULL;
constexpr int kSlotsPerBlock = 8;
int num_full_slots =
kSlotsPerBlock - static_cast<int>(ARROW_POPCOUNT64(block & kHighBitOfEachByte));
int num_full_slots = SwissTable::kSlotsPerBlock -
static_cast<int>(ARROW_POPCOUNT64(block & kHighBitOfEachByte));
for (int local_slot_id = 0; local_slot_id < num_full_slots; ++local_slot_id) {
// Read group id and hash for this slot.
//
uint64_t group_id =
source->extract_group_id(block_bytes, local_slot_id, source_group_id_mask);
int64_t global_slot_id = block_id * kSlotsPerBlock + local_slot_id;
uint32_t group_id =
source->extract_group_id(block_bytes, local_slot_id, source_group_id_bits);
uint32_t global_slot_id = SwissTable::global_slot_id(block_id, local_slot_id);
uint32_t hash = source->hashes()[global_slot_id];
// Insert partition id into the highest bits of hash, shifting the
// remaining hash bits right.
Expand All @@ -696,17 +695,18 @@ void SwissTableMerge::MergePartition(SwissTable* target, const SwissTable* sourc
}
}

inline bool SwissTableMerge::InsertNewGroup(SwissTable* target, uint64_t group_id,
uint32_t hash, int64_t max_block_id) {
inline bool SwissTableMerge::InsertNewGroup(SwissTable* target, uint32_t group_id,
uint32_t hash, uint32_t max_block_id) {
// Load the first block to visit for this hash
//
int64_t block_id = hash >> (SwissTable::bits_hash_ - target->log_blocks());
int64_t block_id_mask = ((1LL << target->log_blocks()) - 1);
uint32_t block_id = SwissTable::block_id_from_hash(hash, target->log_blocks());
uint32_t block_id_mask = (1 << target->log_blocks()) - 1;
int num_group_id_bits =
SwissTable::num_groupid_bits_from_log_blocks(target->log_blocks());
int64_t num_block_bytes = num_group_id_bits + sizeof(uint64_t);
int num_block_bytes =
SwissTable::num_block_bytes_from_num_groupid_bits(num_group_id_bits);
ARROW_DCHECK(num_block_bytes % sizeof(uint64_t) == 0);
uint8_t* block_bytes = target->blocks() + block_id * num_block_bytes;
const uint8_t* block_bytes = target->block_data(block_id, num_block_bytes);
uint64_t block = *reinterpret_cast<const uint64_t*>(block_bytes);

// Search for the first block with empty slots.
Expand All @@ -715,25 +715,23 @@ inline bool SwissTableMerge::InsertNewGroup(SwissTable* target, uint64_t group_i
constexpr uint64_t kHighBitOfEachByte = 0x8080808080808080ULL;
while ((block & kHighBitOfEachByte) == 0 && block_id < max_block_id) {
block_id = (block_id + 1) & block_id_mask;
block_bytes = target->blocks() + block_id * num_block_bytes;
block_bytes = target->block_data(block_id, num_block_bytes);
block = *reinterpret_cast<const uint64_t*>(block_bytes);
}
if ((block & kHighBitOfEachByte) == 0) {
return false;
}
constexpr int kSlotsPerBlock = 8;
int local_slot_id =
kSlotsPerBlock - static_cast<int>(ARROW_POPCOUNT64(block & kHighBitOfEachByte));
int64_t global_slot_id = block_id * kSlotsPerBlock + local_slot_id;
target->insert_into_empty_slot(static_cast<uint32_t>(global_slot_id), hash,
static_cast<uint32_t>(group_id));
int local_slot_id = SwissTable::kSlotsPerBlock -
static_cast<int>(ARROW_POPCOUNT64(block & kHighBitOfEachByte));
uint32_t global_slot_id = SwissTable::global_slot_id(block_id, local_slot_id);
target->insert_into_empty_slot(global_slot_id, hash, group_id);
return true;
}

void SwissTableMerge::InsertNewGroups(SwissTable* target,
const std::vector<uint32_t>& group_ids,
const std::vector<uint32_t>& hashes) {
int64_t num_blocks = 1LL << target->log_blocks();
uint32_t num_blocks = 1 << target->log_blocks();
for (size_t i = 0; i < group_ids.size(); ++i) {
std::ignore = InsertNewGroup(target, group_ids[i], hashes[i], num_blocks);
}
Expand Down Expand Up @@ -1191,7 +1189,7 @@ Status SwissTableForJoinBuild::PushNextBatch(int64_t thread_id,
// We want each partition to correspond to a range of block indices,
// so we also partition on the highest bits of the hash.
//
return locals.batch_hashes[i] >> (31 - log_num_prtns_) >> 1;
return locals.batch_hashes[i] >> (SwissTable::bits_hash_ - log_num_prtns_);
},
[&locals](int64_t i, int pos) {
locals.batch_prtn_row_ids[pos] = static_cast<uint16_t>(i);
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/arrow/acero/swiss_join_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,8 @@ class SwissTableMerge {
// Max block id value greater or equal to the number of blocks guarantees that
// the search will not be stopped.
//
static inline bool InsertNewGroup(SwissTable* target, uint64_t group_id, uint32_t hash,
int64_t max_block_id);
static inline bool InsertNewGroup(SwissTable* target, uint32_t group_id, uint32_t hash,
uint32_t max_block_id);
};

struct SwissTableWithKeys {
Expand Down
Loading

0 comments on commit af1c470

Please sign in to comment.