Skip to content

Commit

Permalink
make threshold sampler sample exactly the target amount
Browse files Browse the repository at this point in the history
  • Loading branch information
Dominik Rosch committed Oct 28, 2024
1 parent abf486d commit db5d3e5
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 34 deletions.
40 changes: 16 additions & 24 deletions kaminpar-shm/coarsening/sparsification/ThresholdSampler.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@ namespace kaminpar::shm::sparsification {

template <typename Score> class ThresholdSampler : public ScoreBacedSampler<Score> {
public:
ThresholdSampler(std::unique_ptr<ScoreFunction<Score>> scoreFunction, bool noApprox = false)
: ScoreBacedSampler<Score>(std::move(scoreFunction)),
_noApprox(noApprox) {}
ThresholdSampler(std::unique_ptr<ScoreFunction<Score>> scoreFunction)
: ScoreBacedSampler<Score>(std::move(scoreFunction)) {}

StaticArray<EdgeWeight> sample(const CSRGraph &g, EdgeID target_edge_amount) override {
SCOPED_TIMER("Threshold Sampling");
Expand All @@ -24,35 +23,29 @@ template <typename Score> class ThresholdSampler : public ScoreBacedSampler<Scor
scores = this->_score_function->scores(g);
}

if (_noApprox) {
auto [threshold, numEdgesAtThresholdScoreToInclude] =
find_threshold(scores, target_edge_amount);

utils::parallel_for_upward_edges(g, [&](EdgeID e) {
if (scores[e] > threshold) {
sample[e] = g.edge_weight(e);
} else if (scores[e] == threshold && numEdgesAtThresholdScoreToInclude > 0) {
sample[e] = g.edge_weight(e);
__atomic_add_fetch(&numEdgesAtThresholdScoreToInclude, -1, __ATOMIC_RELAXED);
}
});

KASSERT(
numEdgesAtThresholdScoreToInclude == 0,
"not all nessary edges with threshold score included"
);
} else {
Score threshold;
{
SCOPED_TIMER("Find Threshold with qselect");
threshold =
utils::quickselect_k_smallest<Score>(target_edge_amount, scores.begin(), scores.end());
}
EdgeID edges_less_than_threshold = 0;
utils::parallel_for_upward_edges(g, [&](EdgeID e) {
if (scores[e] <= threshold)
if (scores[e] < threshold) {
sample[e] = g.edge_weight(e);
__atomic_fetch_add(&edges_less_than_threshold, 1, __ATOMIC_RELAXED);
}
});

std::atomic_int64_t edges_at_thresholds_to_include = target_edge_amount / 2 - edges_less_than_threshold;

utils::parallel_for_upward_edges(g, [&](EdgeID e) {
if (scores[e] == threshold) {
if (edges_at_thresholds_to_include-- > 0) {
sample[e] = g.edge_weight(e);
}
}
});
}
return sample;
}

Expand All @@ -72,6 +65,5 @@ template <typename Score> class ThresholdSampler : public ScoreBacedSampler<Scor
EdgeID numEdgesAtThresholdScoreToInclude = indexOfFirstLagerScore - indexOfThreshold / 2;
return std::make_pair(threshold, numEdgesAtThresholdScoreToInclude);
};
bool _noApprox;
};
} // namespace kaminpar::shm::sparsification
16 changes: 6 additions & 10 deletions kaminpar-shm/factories.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ std::unique_ptr<sparsification::Sampler> create_sampler(const Context &ctx) {
return std::make_unique<sparsification::UnbiasedThesholdSampler>();
case SparsificationAlgorithm::WEIGHT_THRESHOLD:
return std::make_unique<sparsification::ThresholdSampler<EdgeWeight>>(
std::make_unique<WeightFunction>(), ctx.sparsification.no_approx
std::make_unique<WeightFunction>()

);
case SparsificationAlgorithm::EFFECTIVE_RESISTANCE:
Expand Down Expand Up @@ -238,31 +238,27 @@ std::unique_ptr<sparsification::Sampler> create_sampler(const Context &ctx) {
std::make_unique<
sparsification::NetworKitScoreAdapter<NetworKit::ForestFireScore, double>>(
[](const NetworKit::Graph &g) { return NetworKit::ForestFireScore(g, 0.95, 5); }
),
ctx.sparsification.no_approx
)
);
case ScoreFunctionSection::NETWORKIT_WEIGHTED_FOREST_FIRE:
return std::make_unique<sparsification::ThresholdSampler<double>>(
std::make_unique<sparsification::NetworKitScoreAdapter<
sparsification::NetworKitWeightedForestFireScore,
double>>([](const NetworKit::Graph &g) {
return sparsification::NetworKitWeightedForestFireScore(g, 0.95, 5);
}),
ctx.sparsification.no_approx
})
);
case ScoreFunctionSection::WEIGHTED_FOREST_FIRE:
return std::make_unique<sparsification::ThresholdSampler<EdgeID>>(
std::make_unique<sparsification::WeightedForestFireScore>(0.95, 5),
ctx.sparsification.no_approx
std::make_unique<sparsification::WeightedForestFireScore>(0.95, 5)
);
case ScoreFunctionSection::EFFECTIVE_RESISTANCE:
return std::make_unique<sparsification::ThresholdSampler<double>>(
std::make_unique<sparsification::EffectiveResistanceScore>(4),
ctx.sparsification.no_approx
std::make_unique<sparsification::EffectiveResistanceScore>(4)
);
case ScoreFunctionSection::WEIGHT:
return std::make_unique<sparsification::ThresholdSampler<EdgeWeight>>(
std::make_unique<WeightFunction>(), ctx.sparsification.no_approx
std::make_unique<WeightFunction>()

);
}
Expand Down

0 comments on commit db5d3e5

Please sign in to comment.