Skip to content

Commit

Permalink
refactor: add option to store prototype
Browse files Browse the repository at this point in the history
  • Loading branch information
be-marc committed Dec 7, 2023
1 parent 2c77828 commit 2164d4d
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 9 deletions.
2 changes: 1 addition & 1 deletion R/Learner.R
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ Learner = R6Class("Learner",
train_row_ids = if (!is.null(row_ids)) row_ids else task$row_roles$use
test_row_ids = task$row_roles$test

learner_train(learner, task, train_row_ids = train_row_ids, test_row_ids = test_row_ids, mode = mode)
learner_train(learner, task, train_row_ids = train_row_ids, test_row_ids = test_row_ids, mode = mode, store_prototype = TRUE)

# store the task w/o the data
self$state$train_task = task_rm_backend(task$clone(deep = TRUE))
Expand Down
15 changes: 7 additions & 8 deletions R/worker.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
learner_train = function(learner, task, train_row_ids = NULL, test_row_ids = NULL, mode = "train") {
learner_train = function(learner, task, train_row_ids = NULL, test_row_ids = NULL, mode = "train", store_prototype = FALSE) {
# This wrapper calls learner$train, and additionally performs some basic
# checks that the training was successful.
# Exceptions here are possibly encapsulated, so that they get captured
Expand Down Expand Up @@ -68,18 +68,21 @@ learner_train = function(learner, task, train_row_ids = NULL, test_row_ids = NUL
log = append_log(NULL, "train", result$log$class, result$log$msg)
train_time = result$elapsed

proto = task$data(rows = integer())
learner$state = insert_named(learner$state, list(
model = result$result,
log = log,
train_time = train_time,
param_vals = learner$param_set$values,
task_hash = task$hash,
data_prototype = proto,
task_prototype = proto,
mlr3_version = mlr_reflections$package_version
))

if (store_prototype) {
proto = task$data(rows = integer()))
learner$state$data_prototype = proto
learner$state$task_prototype = proto
}

if (is.null(result$result)) {
lg$debug("Learner '%s' on task '%s' failed to %s a model",
learner$id, task$id, mode, learner = learner$clone(), messages = result$log$msg)
Expand Down Expand Up @@ -249,10 +252,6 @@ workhorse = function(iteration, task, learner, resampling, param_values = NULL,
learner_hash = learner$hash
learner = learner_train(learner, task, sets[["train"]], sets[["test"]], mode = mode)

# repeated saving of the prototype leads to large ResultData objects if the task contains many columns, factor levels or attributes
learner$state$data_prototype = NULL
learner$state$task_prototype = NULL

# predict for each set
sets = sets[learner$predict_sets]
pdatas = Map(function(set, row_ids) {
Expand Down

0 comments on commit 2164d4d

Please sign in to comment.