Skip to content

Commit

Permalink
Adopt some of the improvements in subpar::parallelize.
Browse files Browse the repository at this point in the history
This includes the use of exception_ptr instead of strings, as well as
improved worksharing with more precise jobs-to-worker calculations.
  • Loading branch information
LTLA committed Aug 26, 2024
1 parent b0dbc58 commit 85d4d91
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 25 deletions.
56 changes: 31 additions & 25 deletions include/tatami_r/parallelize.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,67 +40,73 @@ inline manticore::Executor& executor() {

/**
* @tparam Function_ Function to be executed.
* @tparam Index_ Integer type for the job indices.
*
* @param njobs Number of jobs to be executed.
* @param fun Function to run in each thread.
* This is a lambda that should accept three arguments:
* - Integer containing the thread ID
* - Integer containing the thread ID.
* - Integer specifying the index of the first job to be executed in a thread.
* - Integer specifying the number of jobs to be executed in a thread.
* @param njobs Number of jobs to be executed.
* @param nthreads Number of threads to parallelize over.
*
* This function is a drop-in replacement for `tatami::parallelize()`.
* The series of integers from 0 to `njobs - 1` is split into `nthreads` contiguous ranges.
* Each range is used as input to `fun` within the corresponding thread.
* It is assumed that the execution of any given job is independent of the next.
*
* This function is only available if `TATAMI_R_PARALLELIZE_UNKNOWN` is defined.
*/
template<class Function_>
void parallelize(Function_ fun, size_t njobs, size_t nthreads) {
if (nthreads == 1 || njobs == 1) {
template<class Function_, class Index_>
void parallelize(Function_ fun, Index_ njobs, int nthreads) {
if (njobs == 0) {
return;
}

if (nthreads <= 1 || njobs == 1) {
fun(0, 0, njobs);
return;
}

Index_ jobs_per_worker = njobs / nthreads;
int remainder = njobs % nthreads;
if (jobs_per_worker == 0) {
jobs_per_worker = 1;
remainder = 0;
nthreads = njobs;
}

auto& mexec = executor();
mexec.initialize(nthreads, "failed to execute R command");

size_t jobs_per_worker = (njobs / nthreads) + (njobs % nthreads > 0);
size_t start = 0;

std::vector<std::thread> runners;
runners.reserve(nthreads);
std::vector<std::string> errors(nthreads);
std::vector<std::exception_ptr> errors(nthreads);

for (size_t w = 0; w < nthreads; ++w) {
if (start == njobs) {
mexec.finish_thread(false);
continue;
}
size_t end = start + std::min(njobs - start, jobs_per_worker);
Index_ start = 0;
for (int w = 0; w < nthreads; ++w) {
Index_ length = jobs_per_worker + (w < remainder);

runners.emplace_back([&](size_t id, size_t s, size_t l) -> void {
runners.emplace_back([&](int id, Index_ s, Index_ l) {
try {
fun(id, s, l);
} catch (std::exception& x) {
// No throw here, we need to make sure we mark the
// thread as being completed so that the main loop can quit.
errors[id] = x.what();
} catch (...) {
errors[id] = std::current_exception();
}
mexec.finish_thread();
}, w, start, end - start);
}, w, start, length);

start = end;
start += length;
}

mexec.listen();
for (auto& x : runners) {
x.join();
}

for (auto err : errors) {
if (!err.empty()) {
throw std::runtime_error(err);
for (const auto& err : errors) {
if (err) {
std::rethrow_exception(err);
}
}
}
Expand Down
1 change: 1 addition & 0 deletions tests/src/Makevars
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ PKG_CPPFLAGS=-I../../include \
-I../../build/_deps/tatami-src/include/ \
-I../../build/_deps/tatami_chunked-src/include/ \
-I../../build/_deps/manticore-src/include/ \
-I../../build/_deps/subpar-src/include/ \
-DTEST_CUSTOM_PARALLEL \
-fstack-protector-strong \
-Wformat \
Expand Down
13 changes: 13 additions & 0 deletions tests/tests/testthat/test-miscellaneous.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,19 @@ test_that("works correctly with the default cache size", {
expect_identical(raticate.tests::num_columns(z), 20L)
})

# parallelization handles edge cases correctly
{
library(Matrix)
y <- Matrix(runif(1), 1, 1) # only one job in either dimension.
parallel_test_suite(y, 0.1)

y <- Matrix(runif(0), 0, 0) # no jobs in either dimension.
parallel_test_suite(y, 0.1)

y <- Matrix(runif(4), 2, 2) # fewer jobs than threads (for workers = 3).
parallel_test_suite(y, 0.1)
}

test_that("Behaves correctly with R-side errors", {
setClass("MyFailMatrix", contains="dgeMatrix")
setMethod("extract_array", "MyFailMatrix", function(x, index) stop("HEY!"))
Expand Down

0 comments on commit 85d4d91

Please sign in to comment.