Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
  • Loading branch information
Dominik Rosch committed Jan 31, 2025
2 parents 9bf12bc + 567e19f commit 16005c7
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 50 deletions.
17 changes: 4 additions & 13 deletions kaminpar-shm/coarsening/sparsification/ThresholdSampler.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,29 +23,20 @@ template <typename Score> class ThresholdSampler : public ScoreBacedSampler<Scor
scores = this->_score_function->scores(g);
}

Score threshold;
utils::K_SmallestInfo<Score> threshold;
{
SCOPED_TIMER("Find Threshold with qselect");
threshold =
utils::quickselect_k_smallest<Score>(target_edge_amount, scores.begin(), scores.end());
}
EdgeID edges_less_than_threshold = 0;

double inclusion_probaility_if_equal = (target_edge_amount / 2 - threshold.number_of_elements_smaller) / threshold.number_of_elemtns_equal;
utils::parallel_for_upward_edges(g, [&](EdgeID e) {
if (scores[e] < threshold) {
if (scores[e] < threshold.value || (scores[e] == threshold.value && Random::instance().random_bool(inclusion_probaility_if_equal))) {
sample[e] = g.edge_weight(e);
__atomic_fetch_add(&edges_less_than_threshold, 1, __ATOMIC_RELAXED);
}
});

std::atomic_int64_t edges_at_thresholds_to_include = target_edge_amount / 2 - edges_less_than_threshold;

utils::parallel_for_upward_edges(g, [&](EdgeID e) {
if (scores[e] == threshold) {
if (edges_at_thresholds_to_include-- > 0) {
sample[e] = g.edge_weight(e);
}
}
});
return sample;
}

Expand Down
117 changes: 81 additions & 36 deletions kaminpar-shm/coarsening/sparsification/sparsification_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
#include <functional>

#include <oneapi/tbb/concurrent_vector.h>
#include <oneapi/tbb/enumerable_thread_specific.h>
#include <oneapi/tbb/parallel_sort.h>

#include "kaminpar-shm/datastructures/csr_graph.h"
#include "kaminpar-shm/kaminpar.h"

#include "kaminpar-common/datastructures/static_array.h"
#include "kaminpar-common/parallel/algorithm.h"
#include "kaminpar-common/random.h"

namespace kaminpar::shm::sparsification::utils {
Expand Down Expand Up @@ -44,7 +46,6 @@ inline void parallel_for_downward_edges(const CSRGraph &g, Lambda function) {
});
}

template <typename T, typename Iterator> T medians_of_medians(Iterator begin, Iterator end);

template <typename T, typename Iterator>
T sortselect_k_smallest(size_t k, Iterator begin, Iterator end) {
Expand All @@ -57,40 +58,15 @@ T sortselect_k_smallest(size_t k, Iterator begin, Iterator end) {
return sorted[k - 1];
}

template <typename T, typename Iterator>
T quickselect_k_smallest(size_t k, Iterator begin, Iterator end) {
size_t size = std::distance(begin, end);
if (size <= 20)
return sortselect_k_smallest<T, Iterator>(k, begin, end);
T pivot = medians_of_medians<T, Iterator>(begin, end);
tbb::concurrent_vector<T> less = {}, greater = {};
size_t number_equal_to_pivot = 0;
tbb::parallel_for(0ul, size, [&](size_t i) {
T x = begin[i];
if (x < pivot)
less.push_back(x);
else if (x > pivot)
greater.push_back(x);
else // equal
__atomic_add_fetch(&number_equal_to_pivot, 1, __ATOMIC_RELAXED);
});

KASSERT(
less.size() <= 0.7 * size + 2 && greater.size() <= 0.7 * size + 2,
"median of medians privot guarantee does not hold",
assert::always
);
if (k <= less.size())
return quickselect_k_smallest<T, typename tbb::concurrent_vector<T>::iterator>(
k, less.begin(), less.end()
);
else if (less.size() + number_equal_to_pivot < k)
return quickselect_k_smallest<T, typename tbb::concurrent_vector<T>::iterator>(
k - number_equal_to_pivot - less.size(), greater.begin(), greater.end()
);
else
return pivot;
}
template <typename T>
struct K_SmallestInfo {
T value;
size_t number_of_elements_smaller;
size_t number_of_elemtns_equal;
};
template <typename T, typename Iterator>
K_SmallestInfo<T> quickselect_k_smallest(size_t k, Iterator begin, Iterator end) ;

template <typename T, typename Iterator> T median(Iterator begin, Iterator end) {
size_t size = std::distance(begin, end);
Expand All @@ -106,7 +82,7 @@ template <typename T, typename Iterator> T median(Iterator begin, Iterator end)
}
}

template <typename T, typename Iterator> T medians_of_medians(Iterator begin, Iterator end) {
template <typename T, typename Iterator> T median_of_medians(Iterator begin, Iterator end) {
size_t size = std::distance(begin, end);
if (size <= 10)
return median<T, Iterator>(begin, end);
Expand All @@ -119,7 +95,76 @@ template <typename T, typename Iterator> T medians_of_medians(Iterator begin, It

return quickselect_k_smallest<T, typename StaticArray<T>::iterator>(
(number_of_sections + 1) / 2, medians.begin(), medians.end()
);
).value;
}

template <typename T, typename Iterator>
K_SmallestInfo<T> quickselect_k_smallest(size_t k, Iterator begin, Iterator end) {
size_t size = std::distance(begin, end);
if (size <= 20) {
T k_smallest = sortselect_k_smallest<T, Iterator>(k, begin, end);
size_t number_equal = 0; size_t number_less;
for (auto x = begin; x != end; x++) {
if (*x == k_smallest)
number_equal++;
else if (*x < k_smallest) {
number_less++;
}
}
return {k_smallest, number_less, number_equal};
}
T pivot = median_of_medians<T,Iterator>(begin,end);

StaticArray<size_t> less(size);
StaticArray<size_t> greater(size);
tbb::enumerable_thread_specific<size_t> thread_specific_number_equal;
tbb::enumerable_thread_specific<size_t> thread_specific_number_less;
tbb::parallel_for(0ul, size, [&](size_t i) {
if (begin[i] < pivot) {
less[i] = 1;
thread_specific_number_less.local()++;
} else if (begin[i] > pivot) {
greater[i] = 1;
} else {
thread_specific_number_equal.local()++;
}
});

auto add = [](size_t a, size_t b) {
return a + b;
};
size_t number_equal = thread_specific_number_equal.combine(add);
size_t number_less = thread_specific_number_less.combine(add);

if (k <= number_less) {
parallel::prefix_sum(less.begin(), less.end(), less.begin());
KASSERT(less[size-1] == number_less, "prefix sum does not work", assert::always);

StaticArray<T> elements_less(number_less);
tbb::parallel_for(0ul, size, [&](auto i) {
if (begin[i] < pivot) {
elements_less[less[i] - 1] = begin[i];
}
});

return quickselect_k_smallest<T>(k, elements_less.begin(), elements_less.end());
} else if (k > number_less + number_equal) {
parallel::prefix_sum(greater.begin(), greater.end(), greater.begin());
KASSERT(greater[size - 1] == size-number_equal-number_less, "prefix sum does not work", assert::always);

StaticArray<T> elements_greater(size - number_equal - number_less);
tbb::parallel_for(0ul, size, [&](auto i) {
if (begin[i] > pivot) {
elements_greater[greater[i] - 1] = begin[i];
}
});

return quickselect_k_smallest<T>(
k - number_equal - number_less, elements_greater.begin(), elements_greater.end()
);
} else {
return {pivot, number_less, number_equal};
}
}

} // namespace kaminpar::shm::sparsification::utils
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ TEST(SparsificationUtils, MedianOfMedians) {
42 // median 42
};
ASSERT_EQ(
sparsification::utils::medians_of_medians<int>(
sparsification::utils::median_of_medians<int>(
numbers_with_mom_2.begin(), numbers_with_mom_2.end()
),
2
Expand Down

0 comments on commit 16005c7

Please sign in to comment.