From 2164d4dae9671e3d55bddc8e045a3375cf0870dc Mon Sep 17 00:00:00 2001 From: be-marc Date: Thu, 7 Dec 2023 16:47:59 +0100 Subject: [PATCH] refactor: add option to store prototype --- R/Learner.R | 2 +- R/worker.R | 15 +++++++-------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/R/Learner.R b/R/Learner.R index 8ca69a4b5..8318dfa6f 100644 --- a/R/Learner.R +++ b/R/Learner.R @@ -242,7 +242,7 @@ Learner = R6Class("Learner", train_row_ids = if (!is.null(row_ids)) row_ids else task$row_roles$use test_row_ids = task$row_roles$test - learner_train(learner, task, train_row_ids = train_row_ids, test_row_ids = test_row_ids, mode = mode) + learner_train(learner, task, train_row_ids = train_row_ids, test_row_ids = test_row_ids, mode = mode, store_prototype = TRUE) # store the task w/o the data self$state$train_task = task_rm_backend(task$clone(deep = TRUE)) diff --git a/R/worker.R b/R/worker.R index 5beb0f568..848516d76 100644 --- a/R/worker.R +++ b/R/worker.R @@ -1,4 +1,4 @@ -learner_train = function(learner, task, train_row_ids = NULL, test_row_ids = NULL, mode = "train") { +learner_train = function(learner, task, train_row_ids = NULL, test_row_ids = NULL, mode = "train", store_prototype = FALSE) { # This wrapper calls learner$train, and additionally performs some basic # checks that the training was successful. # Exceptions here are possibly encapsulated, so that they get captured @@ -68,18 +68,21 @@ learner_train = function(learner, task, train_row_ids = NULL, test_row_ids = NUL log = append_log(NULL, "train", result$log$class, result$log$msg) train_time = result$elapsed - proto = task$data(rows = integer()) learner$state = insert_named(learner$state, list( model = result$result, log = log, train_time = train_time, param_vals = learner$param_set$values, task_hash = task$hash, - data_prototype = proto, - task_prototype = proto, mlr3_version = mlr_reflections$package_version )) + if (store_prototype) { + proto = task$data(rows = integer())) + learner$state$data_prototype = proto + learner$state$task_prototype = proto + } + if (is.null(result$result)) { lg$debug("Learner '%s' on task '%s' failed to %s a model", learner$id, task$id, mode, learner = learner$clone(), messages = result$log$msg) @@ -249,10 +252,6 @@ workhorse = function(iteration, task, learner, resampling, param_values = NULL, learner_hash = learner$hash learner = learner_train(learner, task, sets[["train"]], sets[["test"]], mode = mode) - # repeated saving of the prototype leads to large ResultData objects if the task contains many columns, factor levels or attributes - learner$state$data_prototype = NULL - learner$state$task_prototype = NULL - # predict for each set sets = sets[learner$predict_sets] pdatas = Map(function(set, row_ids) {