Skip to content

Commit

Permalink
...
Browse files Browse the repository at this point in the history
  • Loading branch information
be-marc committed Dec 20, 2024
1 parent 92b9ddd commit 41028cd
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 27 deletions.
9 changes: 4 additions & 5 deletions R/ResultData.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,12 @@ ResultData = R6Class("ResultData",
#' An alternative construction method is provided by [as_result_data()].
#'
#' @param data ([data.table::data.table()]) | `NULL`)\cr
#' Do not initialize this object yourself, use [as_result_data()] instead.
#' Do not initialize this object yourself, use [as_result_data()] instead.
#' @param data_extra (`list()`)\cr
#' Additional data to store.
#' This can be used to store additional information for each iteration.
#'
#' Additional data to store.
#' This can be used to store additional information for each iteration.
#' @param store_backends (`logical(1)`)\cr
#' If set to `FALSE`, the backends of the [Task]s provided in `data` are removed.
#' If set to `FALSE`, the backends of the [Task]s provided in `data` are removed.
initialize = function(data = NULL, data_extra = NULL, store_backends = TRUE) {
assert_flag(store_backends)

Expand Down
45 changes: 23 additions & 22 deletions R/worker.R
Original file line number Diff line number Diff line change
Expand Up @@ -270,9 +270,9 @@ workhorse = function(
call_back("on_resample_begin", callbacks, ctx)

if (!is.null(pb)) {
pb(sprintf("%s|%s|i:%i", ctx$task$id, ctx$learner$id, ctx$iteration))
pb(sprintf("%s|%s|i:%i", task$id, learner$id, iteration))
}
if ("internal_valid" %in% ctx$learner$predict_sets && is.null(ctx$task$internal_valid_task) && is.null(get0("validate", ctx$learner))) {
if ("internal_valid" %in% learner$predict_sets && is.null(task$internal_valid_task) && is.null(get0("validate", learner))) {
stopf("Cannot set the predict_type field of learner '%s' to 'internal_valid' if there is no internal validation task configured", learner$id)
}

Expand Down Expand Up @@ -306,74 +306,75 @@ workhorse = function(
}

lg$info("%s learner '%s' on task '%s' (iter %i/%i)",
if (mode == "train") "Applying" else "Hotstarting", ctx$learner$id, ctx$task$id, ctx$iteration, ctx$resampling$iters)
if (mode == "train") "Applying" else "Hotstarting", learner$id, task$id, iteration, resampling$iters)

sets = list(
train = ctx$resampling$train_set(ctx$iteration),
test = ctx$resampling$test_set(ctx$iteration)
train = resampling$train_set(iteration),
test = resampling$test_set(iteration)
)

# train model
ctx$learner = ctx$learner$clone()
# use `learner` reference instead of `ctx$learner` to avoid going through the active binding
ctx$learner = learner = ctx$learner$clone()
if (length(param_values)) {
ctx$learner$param_set$values = list()
ctx$learner$param_set$set_values(.values = param_values)
learner$param_set$values = list()
learner$param_set$set_values(.values = param_values)
}
learner_hash = ctx$learner$hash
learner_hash = learner$hash

validate = get0("validate", ctx$learner)
validate = get0("validate", learner)

test_set = if (identical(validate, "test")) sets$test

call_back("on_resample_before_train", callbacks, ctx)

train_result = learner_train(ctx$learner, ctx$task, sets[["train"]], test_set, mode = mode)
ctx$learner = train_result$learner
train_result = learner_train(learner, task, sets[["train"]], test_set, mode = mode)
ctx$learner = learner = train_result$learner

# process the model so it can be used for prediction (e.g. marshal for callr prediction), but also
# keep a copy of the model in current form in case this is the format that we want to send back to the main process
# and not the format that we need for prediction
model_copy_or_null = process_model_before_predict(
learner = ctx$learner, store_models = store_models, is_sequential = is_sequential, unmarshal = unmarshal
learner = learner, store_models = store_models, is_sequential = is_sequential, unmarshal = unmarshal
)

# predict for each set
predict_sets = ctx$learner$predict_sets
predict_sets = learner$predict_sets

# creates the tasks and row_ids for all selected predict sets
pred_data = prediction_tasks_and_sets(ctx$task, train_result, validate, sets, predict_sets)
pred_data = prediction_tasks_and_sets(task, train_result, validate, sets, predict_sets)
call_back("on_resample_before_predict", callbacks, ctx)

pdatas = Map(function(set, row_ids, task) {
lg$debug("Creating Prediction for predict set '%s'", set)

learner_predict(ctx$learner, task, row_ids)
learner_predict(learner, task, row_ids)
}, set = predict_sets, row_ids = pred_data$sets, task = pred_data$tasks)

if (!length(predict_sets)) {
ctx$learner$state$predict_time = 0L
learner$state$predict_time = 0L
}
ctx$pdatas = discard(pdatas, is.null)

# set the model slot after prediction so it can be sent back to the main process
process_model_after_predict(
learner = ctx$learner, store_models = store_models, is_sequential = is_sequential, model_copy = model_copy_or_null,
learner = learner, store_models = store_models, is_sequential = is_sequential, model_copy = model_copy_or_null,
unmarshal = unmarshal
)

call_back("on_resample_end", callbacks, ctx)

if (!store_models) {
lg$debug("Erasing stored model for learner '%s'", ctx$learner$id)
ctx$learner$state$model = NULL
lg$debug("Erasing stored model for learner '%s'", learner$id)
learner$state$model = NULL
}

learner_state = set_class(ctx$learner$state, c("learner_state", "list"))
learner_state = set_class(learner$state, c("learner_state", "list"))

list(
learner_state = learner_state,
prediction = ctx$pdatas,
param_values = ctx$learner$param_set$values,
param_values = learner$param_set$values,
learner_hash = learner_hash,
data_extra = ctx$data_extra)
}
Expand Down

0 comments on commit 41028cd

Please sign in to comment.