Skip to content

Commit

Permalink
Multi-threaded implementation of HyFD algorithm
Browse files Browse the repository at this point in the history
The HyFD algorithm was originally single-threaded,
which goes against its intended implementation.
If ThreadNumber option is specified in the algorithm configuration step,
the algorithm will utilize the specified number of threads.
  • Loading branch information
cone-forest authored and chernishev committed Jan 15, 2025
1 parent 31b0216 commit 00910d6
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 10 deletions.
14 changes: 11 additions & 3 deletions src/core/algorithms/fd/hyfd/hyfd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,22 @@

#include "algorithms/fd/hycommon/preprocessor.h"
#include "algorithms/fd/hycommon/util/pli_util.h"
#include "config/names.h"
#include "config/thread_number/option.h"
#include "inductor.h"
#include "sampler.h"
#include "validator.h"

namespace algos::hyfd {

HyFD::HyFD(std::optional<ColumnLayoutRelationDataManager> relation_manager)
: PliBasedFDAlgorithm({}, relation_manager) {}
: PliBasedFDAlgorithm({}, relation_manager) {
RegisterOption(config::kThreadNumberOpt(&threads_num_));
}

void HyFD::MakeExecuteOptsAvailable() {
MakeOptionsAvailable({config::names::kThreads});
}

unsigned long long HyFD::ExecuteInternal() {
using namespace hy;
Expand All @@ -30,12 +38,12 @@ unsigned long long HyFD::ExecuteInternal() {
auto const plis_shared = std::make_shared<PLIs>(std::move(plis));
auto const pli_records_shared = std::make_shared<Rows>(std::move(pli_records));

Sampler sampler(plis_shared, pli_records_shared);
Sampler sampler(plis_shared, pli_records_shared, threads_num_);

auto const positive_cover_tree =
std::make_shared<fd_tree::FDTree>(GetRelation().GetNumColumns());
Inductor inductor(positive_cover_tree);
Validator validator(positive_cover_tree, plis_shared, pli_records_shared);
Validator validator(positive_cover_tree, plis_shared, pli_records_shared, threads_num_);

IdPairs comparison_suggestions;

Expand Down
5 changes: 5 additions & 0 deletions src/core/algorithms/fd/hyfd/hyfd.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "algorithms/fd/hycommon/types.h"
#include "algorithms/fd/pli_based_fd_algorithm.h"
#include "algorithms/fd/raw_fd.h"
#include "config/thread_number/type.h"
#include "model/table/position_list_index.h"

namespace algos::hyfd {
Expand Down Expand Up @@ -43,6 +44,10 @@ class HyFD : public PliBasedFDAlgorithm {

void RegisterFDs(std::vector<RawFD>&& fds, std::vector<algos::hy::ClusterId> const& og_mapping);

void MakeExecuteOptsAvailable() override;

config::ThreadNumType threads_num_ = 1;

public:
HyFD(std::optional<ColumnLayoutRelationDataManager> relation_manager = std::nullopt);
};
Expand Down
4 changes: 2 additions & 2 deletions src/core/algorithms/fd/hyfd/sampler.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ class Sampler {
hy::Sampler sampler_;

public:
Sampler(hy::PLIsPtr plis, hy::RowsPtr pli_records)
: sampler_(std::move(plis), std::move(pli_records)) {}
Sampler(hy::PLIsPtr plis, hy::RowsPtr pli_records, config::ThreadNumType threads_num = 1)
: sampler_(std::move(plis), std::move(pli_records), threads_num) {}

NonFDList GetNonFDs(hy::IdPairs const& comparison_suggestions) {
return sampler_.GetAgreeSets(comparison_suggestions);
Expand Down
35 changes: 32 additions & 3 deletions src/core/algorithms/fd/hyfd/validator.cpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
#include "validator.h"

#include <algorithm>
#include <string>
#include <tuple>
#include <future>
#include <utility>
#include <vector>

#include <boost/asio/post.hpp>
#include <boost/asio/thread_pool.hpp>
#include <boost/dynamic_bitset.hpp>
#include <easylogging++.h>

Expand Down Expand Up @@ -249,6 +250,29 @@ Validator::FDValidations Validator::ValidateAndExtendSeq(std::vector<LhsPair> co
return result;
}

Validator::FDValidations Validator::ValidateAndExtendPar(std::vector<LhsPair> const& vertices) {
FDValidations result;
boost::asio::thread_pool pool(threads_num_);
std::vector<std::future<FDValidations>> validation_futures;
validation_futures.reserve(vertices.size());

for (auto const& vertex : vertices) {
std::packaged_task<FDValidations()> task(
[this, &vertex]() { return GetValidations(vertex); });
validation_futures.push_back(task.get_future());
boost::asio::post(pool, std::move(task));
}

pool.join();

for (auto&& future : validation_futures) {
assert(future.valid());
result.Add(future.get());
}

return result;
}

algos::hy::IdPairs Validator::ValidateAndExtendCandidates() {
size_t const num_attributes = plis_->size();

Expand All @@ -263,7 +287,12 @@ algos::hy::IdPairs Validator::ValidateAndExtendCandidates() {
size_t previous_num_invalid_fds = 0;
algos::hy::IdPairs comparison_suggestions;
while (!cur_level_vertices.empty()) {
auto const result = ValidateAndExtendSeq(cur_level_vertices);
FDValidations result;
if (threads_num_ > 1) {
result = ValidateAndExtendPar(cur_level_vertices);
} else {
result = ValidateAndExtendSeq(cur_level_vertices);
}

comparison_suggestions.insert(comparison_suggestions.end(),
result.ComparisonSuggestions().begin(),
Expand Down
10 changes: 8 additions & 2 deletions src/core/algorithms/fd/hyfd/validator.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "algorithms/fd/hycommon/primitive_validations.h"
#include "algorithms/fd/hyfd/model/fd_tree.h"
#include "algorithms/fd/raw_fd.h"
#include "config/thread_number/type.h"
#include "model/table/position_list_index.h"
#include "types.h"

Expand All @@ -33,16 +34,21 @@ class Validator {

FDValidations ValidateAndExtendSeq(std::vector<LhsPair> const& vertices);

FDValidations ValidateAndExtendPar(std::vector<LhsPair> const& vertices);

[[nodiscard]] unsigned GetLevelNum() const {
return current_level_number_;
}

config::ThreadNumType threads_num_ = 1;

public:
Validator(std::shared_ptr<fd_tree::FDTree> fds, hy::PLIsPtr plis,
hy::RowsPtr compressed_records) noexcept
hy::RowsPtr compressed_records, config::ThreadNumType threads_num) noexcept
: fds_(std::move(fds)),
plis_(std::move(plis)),
compressed_records_(std::move(compressed_records)) {}
compressed_records_(std::move(compressed_records)),
threads_num_(threads_num) {}

hy::IdPairs ValidateAndExtendCandidates();
};
Expand Down

0 comments on commit 00910d6

Please sign in to comment.