From 8cd42ca38aff422dc9fef2cbe015d8b1963a74bb Mon Sep 17 00:00:00 2001 From: be-marc Date: Wed, 13 Dec 2023 15:52:48 +0100 Subject: [PATCH] feat: reduce blas threads to 1 --- DESCRIPTION | 1 + R/worker.R | 9 +++++++-- R/zzz.R | 1 + 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 1f1248aa7..298b875c4 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -56,6 +56,7 @@ Imports: parallelly, palmerpenguins, paradox (>= 0.10.0), + RhpcBLASctl, uuid Suggests: Matrix, diff --git a/R/worker.R b/R/worker.R index bc45df2b8..324eb7652 100644 --- a/R/worker.R +++ b/R/worker.R @@ -224,8 +224,13 @@ workhorse = function(iteration, task, learner, resampling, param_values = NULL, pb(sprintf("%s|%s|i:%i", task$id, learner$id, iteration)) } - # reduce data.table threads to 1 - if (!is_sequential) setDTthreads(1) + # reduce data.table and blas threads to 1 + if (!is_sequential) { + setDTthreads(1, restore_after_fork = TRUE) + old_blas_threads = blas_get_num_procs() + on.exit(blas_set_num_threads(old_blas_threads), add = TRUE) + blas_set_num_threads(1) + } # restore logger thresholds for (package in names(lgr_threshold)) { diff --git a/R/zzz.R b/R/zzz.R index caef148c6..1b3991935 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -10,6 +10,7 @@ #' @importFrom uuid UUIDgenerate #' @importFrom parallelly availableCores #' @importFrom future nbrOfWorkers plan +#' @importFrom RhpcBLASctl blas_set_num_threads #' #' @section Learn mlr3: #' * Book on mlr3: \url{https://mlr3book.mlr-org.com}