From 2fe8746f71aaa55e12fff49f7934c96cb0d29ee8 Mon Sep 17 00:00:00 2001 From: Marc Becker <33069354+be-marc@users.noreply.github.com> Date: Fri, 20 Dec 2024 16:52:13 +0100 Subject: [PATCH 01/10] refactor: predict_type and predict_types (#1233) * refactor: predict_type and predict_types * ... * ... --- NEWS.md | 3 +++ R/Learner.R | 21 +++++++++++++-------- man/Learner.Rd | 13 ++++++++----- 3 files changed, 24 insertions(+), 13 deletions(-) diff --git a/NEWS.md b/NEWS.md index 83d6dfbb0..6992b98cd 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,8 @@ # mlr3 (development version) +* BREAKING CHANGE: `Learner$predict_types` is read-only now. +* docs: Clear up behavior of `Learner$predict_type` after training. + # mlr3 0.22.1 * fix: Extend `assert_measure()` with checks for trained models in `assert_scorable()`. diff --git a/R/Learner.R b/R/Learner.R index d1b7e7863..461dd150f 100644 --- a/R/Learner.R +++ b/R/Learner.R @@ -157,11 +157,6 @@ Learner = R6Class("Learner", #' @template field_task_type task_type = NULL, - #' @field predict_types (`character()`)\cr - #' Stores the possible predict types the learner is capable of. - #' A complete list of candidate predict types, grouped by task type, is stored in [`mlr_reflections$learner_predict_types`][mlr_reflections]. - predict_types = NULL, - #' @field feature_types (`character()`)\cr #' Stores the feature types the learner can handle, e.g. `"logical"`, `"numeric"`, or `"factor"`. #' A complete list of candidate feature types, grouped by task type, is stored in [`mlr_reflections$task_feature_types`][mlr_reflections]. @@ -214,7 +209,7 @@ Learner = R6Class("Learner", self$task_type = assert_choice(task_type, mlr_reflections$task_types$type) private$.param_set = assert_param_set(param_set) self$feature_types = assert_ordered_set(feature_types, mlr_reflections$task_feature_types, .var.name = "feature_types") - self$predict_types = assert_ordered_set(predict_types, names(mlr_reflections$learner_predict_types[[task_type]]), + private$.predict_types = assert_ordered_set(predict_types, names(mlr_reflections$learner_predict_types[[task_type]]), empty.ok = FALSE, .var.name = "predict_types") private$.predict_type = predict_types[1L] self$properties = sort(assert_subset(properties, mlr_reflections$learner_properties[[task_type]])) @@ -627,6 +622,8 @@ Learner = R6Class("Learner", #' @field predict_type (`character(1)`)\cr #' Stores the currently active predict type, e.g. `"response"`. #' Must be an element of `$predict_types`. + #' A few learners already use the predict type during training. + #' So there is no guarantee that changing the predict type after training will have any effect or does not lead to errors. predict_type = function(rhs) { if (missing(rhs)) { return(private$.predict_type) @@ -648,8 +645,6 @@ Learner = R6Class("Learner", private$.param_set }, - - #' @field fallback ([Learner])\cr #' Returns the fallback learner set with `$encapsulate()`. fallback = function(rhs) { @@ -672,6 +667,15 @@ Learner = R6Class("Learner", } assert_r6(rhs, "HotstartStack", null.ok = TRUE) private$.hotstart_stack = rhs + }, + + #' @field predict_types (`character()`)\cr + #' Stores the possible predict types the learner is capable of. + #' A complete list of candidate predict types, grouped by task type, is stored in [`mlr_reflections$learner_predict_types`][mlr_reflections]. + #' This field is read-only. + predict_types = function(rhs) { + assert_ro_binding(rhs) + return(private$.predict_types) } ), @@ -679,6 +683,7 @@ Learner = R6Class("Learner", .encapsulation = c(train = "none", predict = "none"), .fallback = NULL, .predict_type = NULL, + .predict_types = NULL, .param_set = NULL, .hotstart_stack = NULL, diff --git a/man/Learner.Rd b/man/Learner.Rd index f5ede1944..70d7cf8f9 100644 --- a/man/Learner.Rd +++ b/man/Learner.Rd @@ -195,10 +195,6 @@ Task type, e.g. \code{"classif"} or \code{"regr"}. For a complete list of possible task types (depending on the loaded packages), see \code{\link[=mlr_reflections]{mlr_reflections$task_types$type}}.} -\item{\code{predict_types}}{(\code{character()})\cr -Stores the possible predict types the learner is capable of. -A complete list of candidate predict types, grouped by task type, is stored in \code{\link[=mlr_reflections]{mlr_reflections$learner_predict_types}}.} - \item{\code{feature_types}}{(\code{character()})\cr Stores the feature types the learner can handle, e.g. \code{"logical"}, \code{"numeric"}, or \code{"factor"}. A complete list of candidate feature types, grouped by task type, is stored in \code{\link[=mlr_reflections]{mlr_reflections$task_feature_types}}.} @@ -289,7 +285,9 @@ Hash (unique identifier) for this partial object, excluding some components whic \item{\code{predict_type}}{(\code{character(1)})\cr Stores the currently active predict type, e.g. \code{"response"}. -Must be an element of \verb{$predict_types}.} +Must be an element of \verb{$predict_types}. +A few learners already use the predict type during training. +So there is no guarantee that changing the predict type after training will have any effect or does not lead to errors.} \item{\code{param_set}}{(\link[paradox:ParamSet]{paradox::ParamSet})\cr Set of hyperparameters.} @@ -302,6 +300,11 @@ Returns the encapsulation settings set with \verb{$encapsulate()}.} \item{\code{hotstart_stack}}{(\link{HotstartStack})\cr. Stores \code{HotstartStack}.} + +\item{\code{predict_types}}{(\code{character()})\cr +Stores the possible predict types the learner is capable of. +A complete list of candidate predict types, grouped by task type, is stored in \code{\link[=mlr_reflections]{mlr_reflections$learner_predict_types}}. +This field is read-only.} } \if{html}{\out{}} } From 89544ffafa789eba982ef42197c45f1beedb476f Mon Sep 17 00:00:00 2001 From: Marc Becker <33069354+be-marc@users.noreply.github.com> Date: Fri, 20 Dec 2024 16:53:15 +0100 Subject: [PATCH 02/10] refactor: allow uft8 in column names by default (#1234) --- NEWS.md | 2 ++ R/DataBackendRename.R | 2 +- R/Task.R | 10 +++------- R/helper.R | 5 ----- R/zzz.R | 4 ---- inst/testthat/helper_autotest.R | 2 -- man/mlr3-package.Rd | 4 ---- tests/testthat/test_Task.R | 9 --------- 8 files changed, 6 insertions(+), 32 deletions(-) diff --git a/NEWS.md b/NEWS.md index 6992b98cd..69e7292de 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,7 @@ # mlr3 (development version) +* Column names with UTF-8 characters are now allowed by default. +The option `mlr3.allow_utf8_names` is removed. * BREAKING CHANGE: `Learner$predict_types` is read-only now. * docs: Clear up behavior of `Learner$predict_type` after training. diff --git a/R/DataBackendRename.R b/R/DataBackendRename.R index 81617a283..d4838d5e1 100644 --- a/R/DataBackendRename.R +++ b/R/DataBackendRename.R @@ -9,7 +9,7 @@ DataBackendRename = R6Class("DataBackendRename", inherit = DataBackend, cloneabl assert_character(old, any.missing = FALSE, unique = TRUE) assert_subset(old, b$colnames) assert_character(new, any.missing = FALSE, len = length(old)) - assert_names(new, if (allow_utf8_names()) "unique" else "strict") + assert_names(new, "unique") ii = old != new old = old[ii] diff --git a/R/Task.R b/R/Task.R index a081f15ca..0f077c477 100644 --- a/R/Task.R +++ b/R/Task.R @@ -128,13 +128,9 @@ Task = R6Class("Task", cn = self$backend$colnames rn = self$backend$rownames - if (allow_utf8_names()) { - assert_names(cn, "unique", .var.name = "column names") - if (any(grepl("%", cn, fixed = TRUE))) { - stopf("Column names may not contain special character '%%'") - } - } else { - assert_names(cn, "strict", .var.name = "column names") + assert_names(cn, "unique", .var.name = "column names") + if (any(grepl("%", cn, fixed = TRUE))) { + stopf("Column names may not contain special character '%%'") } self$col_info = col_info(self$backend) diff --git a/R/helper.R b/R/helper.R index b27ecb192..39460eb08 100644 --- a/R/helper.R +++ b/R/helper.R @@ -4,11 +4,6 @@ translate_types = function(x) { factor(map_values(x, r_types, p_types), levels = p_types) } - -allow_utf8_names = function() { - isTRUE(getOption("mlr3.allow_utf8_names")) -} - get_featureless_learner = function(task_type) { if (!is.na(task_type)) { id = paste0(task_type, ".featureless") diff --git a/R/zzz.R b/R/zzz.R index b34c219ba..84f0077b3 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -50,10 +50,6 @@ #' * `"mlr3.debug"`: If set to `TRUE`, parallelization via \CRANpkg{future} is disabled to simplify #' debugging and provide more concise tracebacks. #' Note that results computed in debug mode use a different seeding mechanism and are **not reproducible**. -#' * `"mlr3.allow_utf8_names"`: If set to `TRUE`, checks on the feature names are relaxed, allowing -#' non-ascii characters in column names. This is an experimental and temporal option to -#' pave the way for text analysis, and will likely be removed in a future version of the package. -#' analysis. #' * `"mlr3.warn_version_mismatch"`: Set to `FALSE` to silence warnings raised during predict if a learner has been #' trained with a different version version of mlr3. #' diff --git a/inst/testthat/helper_autotest.R b/inst/testthat/helper_autotest.R index 377e7d152..f1f647b34 100644 --- a/inst/testthat/helper_autotest.R +++ b/inst/testthat/helper_autotest.R @@ -81,8 +81,6 @@ generate_generic_tasks = function(learner, proto) { # task with non-ascii feature names if (p > 0L) { - opts = options(mlr3.allow_utf8_names = TRUE) - on.exit(options(opts)) sel = proto$feature_types[list(learner$feature_types), "id", on = "type", with = FALSE, nomatch = NULL][[1L]] tasks$utf8_feature_names = proto$clone(deep = TRUE)$select(sel) old = sel[1L] diff --git a/man/mlr3-package.Rd b/man/mlr3-package.Rd index fb5185e56..9c4f5f689 100644 --- a/man/mlr3-package.Rd +++ b/man/mlr3-package.Rd @@ -63,10 +63,6 @@ parallelization with \CRANpkg{future}. Defaults to 1. \item \code{"mlr3.debug"}: If set to \code{TRUE}, parallelization via \CRANpkg{future} is disabled to simplify debugging and provide more concise tracebacks. Note that results computed in debug mode use a different seeding mechanism and are \strong{not reproducible}. -\item \code{"mlr3.allow_utf8_names"}: If set to \code{TRUE}, checks on the feature names are relaxed, allowing -non-ascii characters in column names. This is an experimental and temporal option to -pave the way for text analysis, and will likely be removed in a future version of the package. -analysis. \item \code{"mlr3.warn_version_mismatch"}: Set to \code{FALSE} to silence warnings raised during predict if a learner has been trained with a different version version of mlr3. } diff --git a/tests/testthat/test_Task.R b/tests/testthat/test_Task.R index b1005c81d..880b2726b 100644 --- a/tests/testthat/test_Task.R +++ b/tests/testthat/test_Task.R @@ -552,15 +552,6 @@ test_that("set_levels", { }) test_that("special chars in feature names (#697)", { - prev = options(mlr3.allow_utf8_names = FALSE) - on.exit(options(prev)) - - expect_error( - TaskRegr$new("test", data.table(`%^` = 1:3, t = 3:1), target = "t"), - "comply" - ) - options(mlr3.allow_utf8_names = TRUE) - expect_error( TaskRegr$new("test", data.table(`%asd` = 1:3, t = 3:1), target = "t") , From 54ec9b16379fef0e3bd5fca0c1dc05ae266a6c01 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20M=C3=BCcke?= Date: Mon, 23 Dec 2024 09:38:14 +0100 Subject: [PATCH 03/10] perf: use %chin% (#1224) --- R/DataBackendRename.R | 2 +- R/HotstartStack.R | 4 ++-- R/LearnerClassifDebug.R | 8 ++++---- R/LearnerClassifRpart.R | 6 +++--- R/LearnerRegrRpart.R | 2 +- R/Measure.R | 10 +++++----- R/PredictionClassif.R | 2 +- R/PredictionDataClassif.R | 4 ++-- R/PredictionDataRegr.R | 8 ++++---- R/PredictionRegr.R | 8 ++++---- R/Task.R | 18 +++++++++--------- R/assertions.R | 6 +++--- R/benchmark.R | 6 +++--- R/resample.R | 6 +++--- R/set_validate.R | 2 +- R/worker.R | 4 ++-- tests/testthat/test_HotstartStack.R | 12 ++++++------ tests/testthat/test_PredictionRegr.R | 4 ++-- tests/testthat/test_Task.R | 16 ++++++++-------- tests/testthat/test_benchmark.R | 10 +++++----- tests/testthat/test_convert_task.R | 4 ++-- tests/testthat/test_resample.R | 2 +- 22 files changed, 72 insertions(+), 72 deletions(-) diff --git a/R/DataBackendRename.R b/R/DataBackendRename.R index d4838d5e1..5024682a9 100644 --- a/R/DataBackendRename.R +++ b/R/DataBackendRename.R @@ -15,7 +15,7 @@ DataBackendRename = R6Class("DataBackendRename", inherit = DataBackend, cloneabl old = old[ii] new = new[ii] - if (self$primary_key %in% old) { + if (self$primary_key %chin% old) { stopf("Renaming the primary key is not supported") } diff --git a/R/HotstartStack.R b/R/HotstartStack.R index 90096abff..208ca3a68 100644 --- a/R/HotstartStack.R +++ b/R/HotstartStack.R @@ -202,9 +202,9 @@ calculate_cost = function(start_learner, learner, hotstart_id) { cost = learner$param_set$values[[hotstart_id]] - start_learner$param_set$values[[hotstart_id]] if (cost == 0) return(-1) - if ("hotstart_backward" %in% learner$properties && "hotstart_forward" %in% learner$properties) { + if ("hotstart_backward" %chin% learner$properties && "hotstart_forward" %chin% learner$properties) { if (cost < 0) 0 else cost - } else if ("hotstart_backward" %in% learner$properties) { + } else if ("hotstart_backward" %chin% learner$properties) { if (cost < 0) 0 else NA_real_ } else { if (cost > 0) cost else NA_real_ diff --git a/R/LearnerClassifDebug.R b/R/LearnerClassifDebug.R index 908bb7965..4f6984f1d 100644 --- a/R/LearnerClassifDebug.R +++ b/R/LearnerClassifDebug.R @@ -163,7 +163,7 @@ LearnerClassifDebug = R6Class("LearnerClassifDebug", inherit = LearnerClassif, pv = self$param_set$get_values(tags = "train") pv$count_marshaling = pv$count_marshaling %??% FALSE roll = function(name) { - name %in% names(pv) && pv[[name]] > runif(1L) + name %chin% names(pv) && pv[[name]] > runif(1L) } if (!is.null(pv$sleep_train)) { @@ -248,7 +248,7 @@ LearnerClassifDebug = R6Class("LearnerClassifDebug", inherit = LearnerClassif, n = task$nrow pv = self$param_set$get_values(tags = "predict") roll = function(name) { - name %in% names(pv) && pv[[name]] > runif(1L) + name %chin% names(pv) && pv[[name]] > runif(1L) } if (!is.null(pv$sleep_predict)) { @@ -281,7 +281,7 @@ LearnerClassifDebug = R6Class("LearnerClassifDebug", inherit = LearnerClassif, response = prob = NULL missing_type = pv$predict_missing_type %??% "na" - if ("response" %in% self$predict_type) { + if ("response" %chin% self$predict_type) { response = rep.int(unclass(model$response), n) if (!is.null(pv$predict_missing)) { ii = sample.int(n, n * pv$predict_missing) @@ -292,7 +292,7 @@ LearnerClassifDebug = R6Class("LearnerClassifDebug", inherit = LearnerClassif, } } - if ("prob" %in% self$predict_type) { + if ("prob" %chin% self$predict_type) { cl = task$class_names prob = matrix(runif(n * length(cl)), nrow = n) prob = prob / rowSums(prob) diff --git a/R/LearnerClassifRpart.R b/R/LearnerClassifRpart.R index 02070150f..05342335c 100644 --- a/R/LearnerClassifRpart.R +++ b/R/LearnerClassifRpart.R @@ -77,7 +77,7 @@ LearnerClassifRpart = R6Class("LearnerClassifRpart", inherit = LearnerClassif, .train = function(task) { pv = self$param_set$get_values(tags = "train") names(pv) = replace(names(pv), names(pv) == "keep_model", "model") - if ("weights" %in% task$properties) { + if ("weights" %chin% task$properties) { pv = insert_named(pv, list(weights = task$weights$weight)) } @@ -89,11 +89,11 @@ LearnerClassifRpart = R6Class("LearnerClassifRpart", inherit = LearnerClassif, newdata = task$data(cols = task$feature_names) response = prob = NULL - if ("response" %in% self$predict_type) { + if ("response" %chin% self$predict_type) { response = invoke(predict, self$model, newdata = newdata, type = "class", .opts = allow_partial_matching, .args = pv) response = unname(response) - } else if ("prob" %in% self$predict_type) { + } else if ("prob" %chin% self$predict_type) { prob = invoke(predict, self$model, newdata = newdata, type = "prob", .opts = allow_partial_matching, .args = pv) rownames(prob) = NULL diff --git a/R/LearnerRegrRpart.R b/R/LearnerRegrRpart.R index 5910fbcd0..f78785799 100644 --- a/R/LearnerRegrRpart.R +++ b/R/LearnerRegrRpart.R @@ -77,7 +77,7 @@ LearnerRegrRpart = R6Class("LearnerRegrRpart", inherit = LearnerRegr, .train = function(task) { pv = self$param_set$get_values(tags = "train") names(pv) = replace(names(pv), names(pv) == "keep_model", "model") - if ("weights" %in% task$properties) { + if ("weights" %chin% task$properties) { pv = insert_named(pv, list(weights = task$weights$weight)) } diff --git a/R/Measure.R b/R/Measure.R index 5c58b85e8..a4cd9fcf1 100644 --- a/R/Measure.R +++ b/R/Measure.R @@ -200,11 +200,11 @@ Measure = R6Class("Measure", # check should be added to assert_measure() # except when the checks are superfluous for rr$score() and bmr$score() # these checks should be added bellow - if ("requires_task" %in% self$properties && is.null(task)) { + if ("requires_task" %chin% self$properties && is.null(task)) { stopf("Measure '%s' requires a task", self$id) } - if ("requires_learner" %in% self$properties && is.null(learner)) { + if ("requires_learner" %chin% self$properties && is.null(learner)) { stopf("Measure '%s' requires a learner", self$id) } @@ -212,7 +212,7 @@ Measure = R6Class("Measure", stopf("Measure '%s' incompatible with task type '%s'", self$id, prediction$task_type) } - if ("requires_train_set" %in% self$properties && is.null(train_set)) { + if ("requires_train_set" %chin% self$properties && is.null(train_set)) { stopf("Measure '%s' requires the train_set", self$id) } @@ -258,7 +258,7 @@ Measure = R6Class("Measure", #' @template field_predict_sets predict_sets = function(rhs) { if (!missing(rhs)) { - private$.predict_sets = assert_subset(rhs, mlr_reflections$predict_sets, empty.ok = "requires_no_prediction" %in% self$properties) + private$.predict_sets = assert_subset(rhs, mlr_reflections$predict_sets, empty.ok = "requires_no_prediction" %chin% self$properties) } private$.predict_sets }, @@ -385,7 +385,7 @@ score_single_measure = function(measure, task, learner, train_set, prediction) { #' @noRd score_measures = function(obj, measures, reassemble = TRUE, view = NULL, iters = NULL) { reassemble_learners = reassemble || - some(measures, function(m) any(c("requires_learner", "requires_model") %in% m$properties)) + some(measures, function(m) any(c("requires_learner", "requires_model") %chin% m$properties)) tab = get_private(obj)$.data$as_data_table(view = view, reassemble_learners = reassemble_learners, convert_predictions = FALSE) if (!is.null(iters)) { diff --git a/R/PredictionClassif.R b/R/PredictionClassif.R index bc8aa9335..2b452a663 100644 --- a/R/PredictionClassif.R +++ b/R/PredictionClassif.R @@ -175,7 +175,7 @@ PredictionClassif = R6Class("PredictionClassif", inherit = Prediction, as.data.table.PredictionClassif = function(x, ...) { # nolint tab = as.data.table(x$data[c("row_ids", "truth", "response")]) - if ("prob" %in% x$predict_types) { + if ("prob" %chin% x$predict_types) { prob = as.data.table(x$data$prob) setnames(prob, names(prob), paste0("prob.", names(prob))) tab = rcbind(tab, prob) diff --git a/R/PredictionDataClassif.R b/R/PredictionDataClassif.R index 94206354e..06fb0e554 100644 --- a/R/PredictionDataClassif.R +++ b/R/PredictionDataClassif.R @@ -132,11 +132,11 @@ create_empty_prediction_data.TaskClassif = function(task, learner) { truth = factor(character(), levels = cn) ) - if ("response" %in% predict_types) { + if ("response" %chin% predict_types) { pdata$response = pdata$truth } - if ("prob" %in% predict_types) { + if ("prob" %chin% predict_types) { pdata$prob = matrix(numeric(), nrow = 0L, ncol = length(cn), dimnames = list(NULL, cn)) } diff --git a/R/PredictionDataRegr.R b/R/PredictionDataRegr.R index 080dd549b..bfe49c28a 100644 --- a/R/PredictionDataRegr.R +++ b/R/PredictionDataRegr.R @@ -99,7 +99,7 @@ c.PredictionDataRegr = function(..., keep_duplicates = TRUE) { # nolint result = as.list(tab) result$quantiles = quantiles - if ("distr" %in% predict_types[[1L]]) { + if ("distr" %chin% predict_types[[1L]]) { require_namespaces("distr6", msg = "To predict probability distributions, please install %s") result$distr = do.call(c, map(dots, "distr")) } @@ -137,15 +137,15 @@ create_empty_prediction_data.TaskRegr = function(task, learner) { truth = numeric() ) - if ("response" %in% predict_types) { + if ("response" %chin% predict_types) { pdata$response = pdata$truth } - if ("se" %in% predict_types) { + if ("se" %chin% predict_types) { pdata$se = pdata$truth } - if ("distr" %in% predict_types) { + if ("distr" %chin% predict_types) { pdata$distr = list() } diff --git a/R/PredictionRegr.R b/R/PredictionRegr.R index ddf07d774..63c73fa68 100644 --- a/R/PredictionRegr.R +++ b/R/PredictionRegr.R @@ -61,7 +61,7 @@ PredictionRegr = R6Class("PredictionRegr", inherit = Prediction, self$data = pdata predict_types = intersect(names(mlr_reflections$learner_predict_types[["regr"]]), names(pdata)) # response is in saved in quantiles matrix - if ("quantiles" %in% predict_types) predict_types = union(predict_types, "response") + if ("quantiles" %chin% predict_types) predict_types = union(predict_types, "response") self$predict_types = predict_types if (is.null(pdata$response)) private$.quantile_response = attr(quantiles, "response") } @@ -94,7 +94,7 @@ PredictionRegr = R6Class("PredictionRegr", inherit = Prediction, #' Access the stored vector distribution. #' Requires package `distr6`(in repository \url{https://raphaels1.r-universe.dev}) . distr = function() { - if ("distr" %in% self$predict_types) { + if ("distr" %chin% self$predict_types) { require_namespaces("distr6", msg = "To predict probability distributions, please install %s") } return(self$data$distr) @@ -111,12 +111,12 @@ PredictionRegr = R6Class("PredictionRegr", inherit = Prediction, as.data.table.PredictionRegr = function(x, ...) { # nolint tab = as.data.table(x$data[c("row_ids", "truth", "response", "se")]) - if ("quantiles" %in% x$predict_types) { + if ("quantiles" %chin% x$predict_types) { tab = rcbind(tab, as.data.table(x$data$quantiles)) set(tab, j = "response", value = x$response) } - if ("distr" %in% x$predict_types) { + if ("distr" %chin% x$predict_types) { require_namespaces("distr6", msg = "To predict probability distributions, please install %s") tab$distr = list(x$distr) } diff --git a/R/Task.R b/R/Task.R index 0f077c477..2bfbc34f3 100644 --- a/R/Task.R +++ b/R/Task.R @@ -222,7 +222,7 @@ Task = R6Class("Task", # print additional columns as specified in reflections before = mlr_reflections$task_print_col_roles$before - iwalk(before[before %in% names(roles)], function(role, str) { + iwalk(before[before %chin% names(roles)], function(role, str) { catn(str_indent(sprintf("* %s:", str), roles[[role]])) }) @@ -242,7 +242,7 @@ Task = R6Class("Task", # print additional columns are specified in reflections after = mlr_reflections$task_print_col_roles$after - iwalk(after[after %in% names(roles)], function(role, str) { + iwalk(after[after %chin% names(roles)], function(role, str) { catn(str_indent(sprintf("* %s:", str), roles[[role]])) }) @@ -367,7 +367,7 @@ Task = R6Class("Task", levels = function(cols = NULL) { if (is.null(cols)) { cols = unlist(private$.col_roles[c("target", "feature")], use.names = FALSE) - cols = self$col_info[get("id") %in% cols & get("type") %in% c("factor", "ordered"), "id", with = FALSE][[1L]] + cols = self$col_info[get("id") %chin% cols & get("type") %chin% c("factor", "ordered"), "id", with = FALSE][[1L]] } else { assert_subset(cols, self$col_info$id) } @@ -465,7 +465,7 @@ Task = R6Class("Task", type_check = TRUE if (is.data.frame(data)) { - pk_in_backend = pk %in% names(data) + pk_in_backend = pk %chin% names(data) type_check = FALSE # done by auto-converter keep_cols = intersect(names(data), self$col_info$id) @@ -517,7 +517,7 @@ Task = R6Class("Task", } # merge factor levels - ii = tab[type %in% c("factor", "ordered"), which = TRUE] + ii = tab[type %chin% c("factor", "ordered"), which = TRUE] for (i in ii) { x = tab[["levels"]][[i]] y = tab[["levels_y"]][[i]] @@ -726,7 +726,7 @@ Task = R6Class("Task", #' @return Modified `self`. droplevels = function(cols = NULL) { assert_has_backend(self) - tab = self$col_info[get("type") %in% c("factor", "ordered"), c("id", "levels", "fix_factor_levels"), with = FALSE] + tab = self$col_info[get("type") %chin% c("factor", "ordered"), c("id", "levels", "fix_factor_levels"), with = FALSE] if (!is.null(cols)) { tab = tab[list(cols), on = "id", nomatch = NULL] } @@ -926,7 +926,7 @@ Task = R6Class("Task", assert_has_backend(self) assert_list(rhs, .var.name = "row_roles") - if ("test" %in% names(rhs) || "holdout" %in% names(rhs)) { + if ("test" %chin% names(rhs) || "holdout" %chin% names(rhs)) { stopf("Setting row roles 'test'/'holdout' is no longer possible.") } assert_names(names(rhs), "unique", permutation.of = mlr_reflections$task_row_roles, .var.name = "names of row_roles") @@ -1333,7 +1333,7 @@ col_info = function(x, ...) { #' @export col_info.data.table = function(x, primary_key = character(), ...) { # nolint types = map_chr(x, function(x) class(x)[1L]) - discrete = setdiff(names(types)[types %in% c("factor", "ordered")], primary_key) + discrete = setdiff(names(types)[types %chin% c("factor", "ordered")], primary_key) levels = insert_named(named_list(names(types)), lapply(x[, discrete, with = FALSE], distinct_values, drop = FALSE)) data.table(id = names(types), type = unname(types), levels = levels, key = "id") } @@ -1342,7 +1342,7 @@ col_info.data.table = function(x, primary_key = character(), ...) { # nolint #' @export col_info.DataBackend = function(x, ...) { # nolint types = map_chr(x$head(1L), function(x) class(x)[1L]) - discrete = setdiff(names(types)[types %in% c("factor", "ordered")], x$primary_key) + discrete = setdiff(names(types)[types %chin% c("factor", "ordered")], x$primary_key) levels = insert_named(named_list(names(types)), x$distinct(rows = NULL, cols = discrete)) data.table(id = names(types), type = unname(types), levels = levels, key = "id") } diff --git a/R/assertions.R b/R/assertions.R index fa6618dbd..d86b5ce43 100644 --- a/R/assertions.R +++ b/R/assertions.R @@ -181,7 +181,7 @@ assert_predictable = function(task, learner) { stopf("Learner '%s' has received tasks with different columns in train and predict.", learner$id) } - ids = train_task$col_info[get("id") %in% cols_train, "id"]$id + ids = train_task$col_info[get("id") %chin% cols_train, "id"]$id ci_predict = task$col_info[list(ids), c("id", "type", "levels"), on = "id"] ci_train = train_task$col_info[list(ids), c("id", "type", "levels"), on = "id"] @@ -260,11 +260,11 @@ assert_measure = function(measure, task = NULL, learner = NULL, prediction = NUL #' @param prediction ([Prediction]). #' @rdname mlr_assertions assert_scorable = function(measure, task, learner, prediction = NULL, .var.name = vname(measure)) { - if ("requires_model" %in% measure$properties && is.null(learner$model)) { + if ("requires_model" %chin% measure$properties && is.null(learner$model)) { stopf("Measure '%s' requires the trained model", measure$id) } - if ("requires_model" %in% measure$properties && is_marshaled_model(learner$model)) { + if ("requires_model" %chin% measure$properties && is_marshaled_model(learner$model)) { stopf("Measure '%s' requires the trained model, but model is in marshaled form", measure$id) } diff --git a/R/benchmark.R b/R/benchmark.R index 5f26c38a6..63b41f180 100644 --- a/R/benchmark.R +++ b/R/benchmark.R @@ -110,13 +110,13 @@ benchmark = function(design, store_models = FALSE, store_backends = TRUE, encaps setDT(design) task = learner = resampling = NULL - if ("task" %in% clone) { + if ("task" %chin% clone) { design[, "task" := list(list(task[[1L]]$clone())), by = list(hashes(task))] } - if ("learner" %in% clone) { + if ("learner" %chin% clone) { design[, "learner" := list(list(learner[[1L]]$clone())), by = list(hashes(learner))] } - if ("resampling" %in% clone) { + if ("resampling" %chin% clone) { design[, "resampling" := list(list(resampling[[1L]]$clone())), by = list(hashes(resampling))] } diff --git a/R/resample.R b/R/resample.R index cc1bb88f2..c82ffe4ea 100644 --- a/R/resample.R +++ b/R/resample.R @@ -57,9 +57,9 @@ #' print(bmr1$combine(bmr2)) resample = function(task, learner, resampling, store_models = FALSE, store_backends = TRUE, encapsulate = NA_character_, allow_hotstart = FALSE, clone = c("task", "learner", "resampling"), unmarshal = TRUE) { assert_subset(clone, c("task", "learner", "resampling")) - task = assert_task(as_task(task, clone = "task" %in% clone)) - learner = assert_learner(as_learner(learner, clone = "learner" %in% clone, discard_state = TRUE)) - resampling = assert_resampling(as_resampling(resampling, clone = "resampling" %in% clone)) + task = assert_task(as_task(task, clone = "task" %chin% clone)) + learner = assert_learner(as_learner(learner, clone = "learner" %chin% clone, discard_state = TRUE)) + resampling = assert_resampling(as_resampling(resampling, clone = "resampling" %chin% clone)) assert_flag(store_models) assert_flag(store_backends) # this does not check the internal validation task as it might not be set yet diff --git a/R/set_validate.R b/R/set_validate.R index 5e8825b26..c89a61e00 100644 --- a/R/set_validate.R +++ b/R/set_validate.R @@ -26,7 +26,7 @@ set_validate = function(learner, validate, ...) { #' @export set_validate.Learner = function(learner, validate, ...) { - if (!"validation" %in% learner$properties) { + if (!"validation" %chin% learner$properties) { stopf("Learner '%s' does not support validation.", learner$id) } learner$validate = validate diff --git a/R/worker.R b/R/worker.R index e52daa23f..b5a794329 100644 --- a/R/worker.R +++ b/R/worker.R @@ -255,7 +255,7 @@ workhorse = function(iteration, task, learner, resampling, param_values = NULL, if (!is.null(pb)) { pb(sprintf("%s|%s|i:%i", task$id, learner$id, iteration)) } - if ("internal_valid" %in% learner$predict_sets && is.null(task$internal_valid_task) && is.null(get0("validate", learner))) { + if ("internal_valid" %chin% learner$predict_sets && is.null(task$internal_valid_task) && is.null(get0("validate", learner))) { stopf("Cannot set the predict_type field of learner '%s' to 'internal_valid' if there is no internal validation task configured", learner$id) } @@ -351,7 +351,7 @@ prediction_tasks_and_sets = function(task, train_result, validate, sets, predict return(list(tasks = tasks[predict_sets], sets = sets[predict_sets])) } - if ("internal_valid" %in% predict_sets) { + if ("internal_valid" %chin% predict_sets) { if (is.numeric(validate) || identical(validate, "test")) { # in this scenario, the internal_valid_task was created during learner_train, which means that it used the # primary task. The selected ids are returned via the train result diff --git a/tests/testthat/test_HotstartStack.R b/tests/testthat/test_HotstartStack.R index 55148c8c3..a1e1dc3c3 100644 --- a/tests/testthat/test_HotstartStack.R +++ b/tests/testthat/test_HotstartStack.R @@ -134,7 +134,7 @@ test_that("HotstartStack works with backward target learner and decreased hotsta learner_1$train(task) learner = lrn("classif.debug", iter = 1) - learner$properties[learner$properties %in% "hotstart_forward"] = "hotstart_backward" + learner$properties[learner$properties %chin% "hotstart_forward"] = "hotstart_backward" hot = HotstartStack$new(list(learner_1)) expect_equal(hot$start_cost(learner, task$hash), 0) @@ -151,7 +151,7 @@ test_that("HotstartStack works with backward target learner when cost of hotstar learner_2$train(task) learner = lrn("classif.debug", iter = 3) - learner$properties[learner$properties %in% "hotstart_forward"] = "hotstart_backward" + learner$properties[learner$properties %chin% "hotstart_forward"] = "hotstart_backward" hot = HotstartStack$new(list(learner_1, learner_2)) expect_equal(hot$start_cost(learner, task$hash), c(0, 0)) @@ -166,7 +166,7 @@ test_that("HotstartStack works when hotstart values of hotstart learners are low learner_2$train(task) learner = lrn("classif.debug", iter = 2) - learner$properties[learner$properties %in% "hotstart_forward"] = "hotstart_backward" + learner$properties[learner$properties %chin% "hotstart_forward"] = "hotstart_backward" hot = HotstartStack$new(list(learner_1, learner_2)) expect_equal(hot$start_cost(learner, task$hash), c(0, NA_real_)) @@ -181,7 +181,7 @@ test_that("HotstartStack works when backward hotstart and target learner are equ learner_1$train(task) learner = lrn("classif.debug", iter = 1) - learner$properties[learner$properties %in% "hotstart_forward"] = "hotstart_backward" + learner$properties[learner$properties %chin% "hotstart_forward"] = "hotstart_backward" hot = HotstartStack$new(list(learner_1)) expect_equal(hot$start_cost(learner, task$hash), -1) @@ -197,7 +197,7 @@ test_that("HotstartStack works with backward target learner when hotstart values learner_1$train(task) learner = lrn("classif.debug", iter = 2) - learner$properties[learner$properties %in% "hotstart_forward"] = "hotstart_backward" + learner$properties[learner$properties %chin% "hotstart_forward"] = "hotstart_backward" hot = HotstartStack$new(list(learner_1)) expect_equal(hot$start_cost(learner, task$hash), NA_real_) @@ -218,7 +218,7 @@ test_that("HotstartStack works with backward target learner when hotstart learne learner_4$train(task) learner = lrn("classif.debug", iter = 2) - learner$properties[learner$properties %in% "hotstart_forward"] = "hotstart_backward" + learner$properties[learner$properties %chin% "hotstart_forward"] = "hotstart_backward" hot = HotstartStack$new(list(learner_1, learner_2, learner_3, learner_4)) expect_equal(hot$start_cost(learner, task$hash), c(NA_real_, -1, 0, NA_real_)) diff --git a/tests/testthat/test_PredictionRegr.R b/tests/testthat/test_PredictionRegr.R index 0005e3a72..90a47b738 100644 --- a/tests/testthat/test_PredictionRegr.R +++ b/tests/testthat/test_PredictionRegr.R @@ -45,9 +45,9 @@ test_that("c drops se (#250)", { pred = do.call(c, rr$predictions()) expect_null(pred$data$se) - expect_false("se" %in% pred$predict_types) + expect_false("se" %chin% pred$predict_types) expect_true(allMissing(pred$se)) - expect_false("se" %in% names(as.data.table(pred))) + expect_false("se" %chin% names(as.data.table(pred))) }) test_that("distr", { diff --git a/tests/testthat/test_Task.R b/tests/testthat/test_Task.R index 880b2726b..fbfba18eb 100644 --- a/tests/testthat/test_Task.R +++ b/tests/testthat/test_Task.R @@ -237,11 +237,11 @@ test_that("rename works", { test_that("stratify works", { task = tsk("iris") - expect_false("strata" %in% task$properties) + expect_false("strata" %chin% task$properties) expect_null(task$strata) task$col_roles$stratum = task$target_names - expect_true("strata" %in% task$properties) + expect_true("strata" %chin% task$properties) tab = task$strata expect_data_table(tab, ncols = 2, nrows = 3) expect_list(tab$row_id, "integer") @@ -252,8 +252,8 @@ test_that("groups/weights work", { task = TaskRegr$new("test", b, target = "y") task$set_row_roles(16:20, character()) - expect_false("groups" %in% task$properties) - expect_false("weights" %in% task$properties) + expect_false("groups" %chin% task$properties) + expect_false("weights" %chin% task$properties) expect_null(task$groups) expect_null(task$weights) @@ -434,7 +434,7 @@ test_that("col roles getters/setters", { }) task$col_roles$feature = setdiff(task$col_roles$feature, "Sepal.Length") - expect_false("Sepal.Length" %in% task$feature_names) + expect_false("Sepal.Length" %chin% task$feature_names) }) test_that("Task$row_names", { @@ -471,7 +471,7 @@ test_that("Task$set_col_roles", { task$set_col_roles("mass", add_to = "feature") expect_equal(task$n_features, 8L) - expect_true("mass" %in% task$feature_names) + expect_true("mass" %chin% task$feature_names) task$set_col_roles("age", roles = "weight") expect_equal(task$n_features, 7L) @@ -480,7 +480,7 @@ test_that("Task$set_col_roles", { task$set_col_roles("age", add_to = "feature", remove_from = "weight") expect_equal(task$n_features, 8L) - expect_true("age" %in% task$feature_names) + expect_true("age" %chin% task$feature_names) expect_null(task$weights) }) @@ -650,7 +650,7 @@ test_that("cbind supports non-standard primary key (#961)", { b = as_data_backend(tbl, primary_key = "myid") task = as_task_regr(b, target = "y") task$cbind(data.table(x1 = 10:1)) - expect_true("x1" %in% task$feature_names) + expect_true("x1" %chin% task$feature_names) }) test_that("$select changes hash", { diff --git a/tests/testthat/test_benchmark.R b/tests/testthat/test_benchmark.R index d62a4928d..8a14a6015 100644 --- a/tests/testthat/test_benchmark.R +++ b/tests/testthat/test_benchmark.R @@ -85,9 +85,9 @@ test_that("bmr$combine()", { expect_data_table(get_private(bmr_new)$.data$data$fact, nrows = 6L) expect_data_table(get_private(bmr_combined)$.data$data$fact, nrows = 24L) - expect_false("pima" %in% bmr$tasks$task_id) - expect_true("pima" %in% bmr_new$tasks$task_id) - expect_true("pima" %in% bmr_combined$tasks$task_id) + expect_false("pima" %chin% bmr$tasks$task_id) + expect_true("pima" %chin% bmr_new$tasks$task_id) + expect_true("pima" %chin% bmr_combined$tasks$task_id) } rr = resample(tsk("zoo"), lrn("classif.rpart"), rsmp("holdout")) @@ -212,7 +212,7 @@ test_that("extract params", { aggr = bmr$aggregate(params = TRUE) expect_list(aggr$params[[1]], names = "unique", len = 0L) - expect_true(all(c("warnings", "errors") %in% names(bmr$score(conditions = TRUE)))) + expect_true(all(c("warnings", "errors") %chin% names(bmr$score(conditions = TRUE)))) }) test_that("benchmark_grid", { @@ -480,7 +480,7 @@ test_that("param_values in benchmark", { x } trained = bmr$learners$learner - ii = which(map_lgl(trained, function(x) "cp" %in% names(x$param_set$values))) # find learner with cp + ii = which(map_lgl(trained, function(x) "cp" %chin% names(x$param_set$values))) # find learner with cp expect_count(ii) expect_equal(sortnames(bmr$learners$learner[-ii][[1]]$param_set$values), list(minbucket = 2, minsplit = 12, xval = 0)) diff --git a/tests/testthat/test_convert_task.R b/tests/testthat/test_convert_task.R index 9758f2dad..da87eebfc 100644 --- a/tests/testthat/test_convert_task.R +++ b/tests/testthat/test_convert_task.R @@ -109,7 +109,7 @@ test_that("convert_task reconstructs task", { tsk2 = convert_task(task2) expect_equal(task2$nrow, tsk2$nrow) expect_equal(task2$ncol, tsk2$ncol) - expect_true("twoclass" %in% tsk2$properties) + expect_true("twoclass" %chin% tsk2$properties) task3 = task2 task3$row_roles$use = 1:150 @@ -117,7 +117,7 @@ test_that("convert_task reconstructs task", { tsk3$man = "mlr3::mlr_tasks_iris" expect_equal(task3$nrow, tsk3$nrow) expect_equal(task3$ncol, tsk3$ncol) - expect_true("multiclass" %in% tsk3$properties) + expect_true("multiclass" %chin% tsk3$properties) expect_equal(task, tsk3, ignore_attr = TRUE) }) diff --git a/tests/testthat/test_resample.R b/tests/testthat/test_resample.R index 5668734a9..a7cc0102c 100644 --- a/tests/testthat/test_resample.R +++ b/tests/testthat/test_resample.R @@ -94,7 +94,7 @@ test_that("empty train/predict sets", { }) test_that("conditions are returned", { - expect_true(all(c("warnings", "errors") %in% names(rr$score(conditions = TRUE)))) + expect_true(all(c("warnings", "errors") %chin% names(rr$score(conditions = TRUE)))) }) test_that("save/load roundtrip", { From 52deba9d9ab3d5124b4e9a3adec44b0e79608886 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20M=C3=BCcke?= Date: Mon, 6 Jan 2025 16:57:09 +0100 Subject: [PATCH 04/10] feat(task): add missing ordered in properties (#1238) --- R/Task.R | 4 +++- man/Task.Rd | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/R/Task.R b/R/Task.R index 2bfbc34f3..66b7cafc5 100644 --- a/R/Task.R +++ b/R/Task.R @@ -896,6 +896,7 @@ Task = R6Class("Task", #' * `"strata"`: The task is resampled using one or more stratification variables (role `"stratum"`). #' * `"groups"`: The task comes with grouping/blocking information (role `"group"`). #' * `"weights"`: The task comes with observation weights (role `"weight"`). + #' * `"ordered"`: The task has columns which define the row order (role `"order"`). #' #' Note that above listed properties are calculated from the `$col_roles` and may not be set explicitly. properties = function(rhs) { @@ -905,7 +906,8 @@ Task = R6Class("Task", private$.properties, if (length(col_roles$group)) "groups" else NULL, if (length(col_roles$stratum)) "strata" else NULL, - if (length(col_roles$weight)) "weights" else NULL + if (length(col_roles$weight)) "weights" else NULL, + if (length(col_roles$order)) "ordered" else NULL ) } else { private$.properties = assert_set(rhs, .var.name = "properties") diff --git a/man/Task.Rd b/man/Task.Rd index 4e925c1ad..53a7afe97 100644 --- a/man/Task.Rd +++ b/man/Task.Rd @@ -198,6 +198,7 @@ The following properties are currently standardized and understood by tasks in \ \item \code{"strata"}: The task is resampled using one or more stratification variables (role \code{"stratum"}). \item \code{"groups"}: The task comes with grouping/blocking information (role \code{"group"}). \item \code{"weights"}: The task comes with observation weights (role \code{"weight"}). +\item \code{"ordered"}: The task has columns which define the row order (role \code{"order"}). } Note that above listed properties are calculated from the \verb{$col_roles} and may not be set explicitly.} From 47a3a51d75ab3f92a1bc23b5072b23b20ac17706 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20M=C3=BCcke?= Date: Mon, 6 Jan 2025 17:01:01 +0100 Subject: [PATCH 05/10] feat(task): allow date feature (#1237) --- R/auto_convert.R | 29 +++++++++++++++++++++++++++++ R/mlr_reflections.R | 2 +- inst/testthat/helper_autotest.R | 3 ++- tests/testthat/test_auto_convert.R | 6 ++++++ 4 files changed, 38 insertions(+), 2 deletions(-) diff --git a/R/auto_convert.R b/R/auto_convert.R index 999625e15..5c10a9a90 100644 --- a/R/auto_convert.R +++ b/R/auto_convert.R @@ -28,6 +28,10 @@ ee[["logical___POSIXct"]] = function(value, type, levels) { if (allMissing(value)) .POSIXct(value, tz = "") else value } +ee[["logical___Date"]] = + function(value, type, levels) { + if (allMissing(value)) as.Date(value) else value + } ## from: integer ee[["integer___logical"]] = @@ -44,6 +48,8 @@ ee[["integer___ordered"]] = ee[["logical___ordered"]] ee[["integer___POSIXct"]] = ee[["logical___POSIXct"]] +ee[["integer___Date"]] = + ee[["logical___Date"]] ## from: numeric ee[["numeric___logical"]] = @@ -60,6 +66,8 @@ ee[["numeric___ordered"]] = ee[["logical___ordered"]] ee[["numeric___POSIXct"]] = ee[["logical___POSIXct"]] +ee[["numeric___Date"]] = + ee[["logical___Date"]] ## from: character ee[["character___logical"]] = @@ -83,6 +91,11 @@ ee[["character___POSIXct"]] = x = try(as.POSIXct(value, ""), silent = TRUE) if (inherits(x, "try-error")) value else x } +ee[["character___Date"]] = + function(value, type, levels) { + x = try(as.Date(value), silent = TRUE) + if (inherits(x, "try-error")) value else x + } ## from: factor ee[["factor___logical"]] = @@ -99,6 +112,8 @@ ee[["factor___ordered"]] = } ee[["factor___POSIXct"]] = ee[["character___POSIXct"]] +ee[["factor___Date"]] = + ee[["character___Date"]] ## from: ordered ee[["ordered___character"]] = @@ -109,6 +124,20 @@ ee[["ordered___ordered"]] = ee[["ordered___ordered"]] ee[["ordered___POSIXct"]] = ee[["character___POSIXct"]] +ee[["ordered___Date"]] = + ee[["character___Date"]] + +## from: POSIXct +ee[["POSIXct___Date"]] = + function(value, type, levels) { + as.Date(value) + } + +## from: Date +ee[["Date___POSIXct"]] = + function(value, type, levels) { + as.POSIXct(value) + } rm(ee) # nolint end diff --git a/R/mlr_reflections.R b/R/mlr_reflections.R index 283562658..4ac532741 100644 --- a/R/mlr_reflections.R +++ b/R/mlr_reflections.R @@ -87,7 +87,7 @@ local({ ) mlr_reflections$task_feature_types = c( - lgl = "logical", int = "integer", dbl = "numeric", chr = "character", fct = "factor", ord = "ordered", pxc = "POSIXct" + lgl = "logical", int = "integer", dbl = "numeric", chr = "character", fct = "factor", ord = "ordered", pxc = "POSIXct", dte = "Date" ) mlr_reflections$task_row_roles = c( diff --git a/inst/testthat/helper_autotest.R b/inst/testthat/helper_autotest.R index f1f647b34..fc3fcc41b 100644 --- a/inst/testthat/helper_autotest.R +++ b/inst/testthat/helper_autotest.R @@ -118,7 +118,8 @@ generate_data = function(learner, N) { character = sample(rep_len(letters[1:2], N)), factor = sample(factor(rep_len(c("f1", "f2"), N), levels = c("f1", "f2"))), ordered = sample(ordered(rep_len(c("o1", "o2"), N), levels = c("o1", "o2"))), - POSIXct = Sys.time() - runif(N, min = 0, max = 10 * 365 * 24 * 60 * 60) + POSIXct = Sys.time() - runif(N, min = 0, max = 10 * 365 * 24 * 60 * 60), + Date = Sys.Date() - runif(N, min = 0, max = 10 * 365) ) } types = unique(learner$feature_types) diff --git a/tests/testthat/test_auto_convert.R b/tests/testthat/test_auto_convert.R index d17a9ac24..c7129793a 100644 --- a/tests/testthat/test_auto_convert.R +++ b/tests/testthat/test_auto_convert.R @@ -243,4 +243,10 @@ test_that("POSIXct", { auto_convert("2020-01-20 10:00:00", "x", "POSIXct", character()), as.POSIXct("2020-01-20 10:00:00", "") ) + + expect_date(auto_convert(Sys.time(), "x", "Date", character())) +}) + +test_that("Date", { + expect_date(auto_convert("2021-04-21", "x", "Date", character())) }) From 9c953178059216032e8d4868e06296614a54063b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20M=C3=BCcke?= Date: Mon, 6 Jan 2025 17:01:29 +0100 Subject: [PATCH 06/10] refactor: use more set_values() (#1239) * refactor: use more set_values() * docs: adjust text to reference set_values for changing param set --- R/Learner.R | 4 ++-- R/LearnerClassifDebug.R | 2 +- R/LearnerClassifFeatureless.R | 2 +- R/LearnerClassifRpart.R | 2 +- R/LearnerRegrFeatureless.R | 2 +- R/LearnerRegrRpart.R | 2 +- R/MeasureClassifCosts.R | 2 +- R/MeasureDebug.R | 2 +- R/MeasureSelectedFeatures.R | 2 +- R/Resampling.R | 2 +- R/ResamplingBootstrap.R | 2 +- R/ResamplingCV.R | 2 +- R/ResamplingHoldout.R | 2 +- R/ResamplingRepeatedCV.R | 2 +- R/ResamplingSubsampling.R | 2 +- R/TaskGeneratorMoons.R | 2 +- man/Learner.Rd | 4 ++-- man/Resampling.Rd | 2 +- man/mlr_learners_classif.debug.Rd | 2 +- 19 files changed, 21 insertions(+), 21 deletions(-) diff --git a/R/Learner.R b/R/Learner.R index 461dd150f..84fd7a075 100644 --- a/R/Learner.R +++ b/R/Learner.R @@ -72,10 +72,10 @@ #' All information about hyperparameters is stored in the slot `param_set` which is a [paradox::ParamSet]. #' The printer gives an overview about the ids of available hyperparameters, their storage type, lower and upper bounds, #' possible levels (for factors), default values and assigned values. -#' To set hyperparameters, assign a named list to the subslot `values`: +#' To set hyperparameters, call the `set_values()` method on the `param_set`: #' ``` #' lrn = lrn("classif.rpart") -#' lrn$param_set$values = list(minsplit = 3, cp = 0.01) +#' lrn$param_set$set_values(minsplit = 3, cp = 0.01) #' ``` #' Note that this operation replaces all previously set hyperparameter values. #' If you only intend to change one specific hyperparameter value and leave the others as-is, you can use the helper function [mlr3misc::insert_named()]: diff --git a/R/LearnerClassifDebug.R b/R/LearnerClassifDebug.R index 4f6984f1d..ed7ebfc92 100644 --- a/R/LearnerClassifDebug.R +++ b/R/LearnerClassifDebug.R @@ -37,7 +37,7 @@ #' @export #' @examples #' learner = lrn("classif.debug") -#' learner$param_set$values = list(message_train = 1, save_tasks = TRUE) +#' learner$param_set$set_values(message_train = 1, save_tasks = TRUE) #' #' # this should signal a message #' task = tsk("penguins") diff --git a/R/LearnerClassifFeatureless.R b/R/LearnerClassifFeatureless.R index 35c3aeb17..995839fb0 100644 --- a/R/LearnerClassifFeatureless.R +++ b/R/LearnerClassifFeatureless.R @@ -34,7 +34,7 @@ LearnerClassifFeatureless = R6Class("LearnerClassifFeatureless", inherit = Learn ps = ps( method = p_fct(c("mode", "sample", "weighted.sample"), default = "mode", tags = "predict") ) - ps$values = list(method = "mode") + ps$set_values(method = "mode") super$initialize( id = "classif.featureless", feature_types = mlr_reflections$task_feature_types, diff --git a/R/LearnerClassifRpart.R b/R/LearnerClassifRpart.R index 05342335c..c3fc6896d 100644 --- a/R/LearnerClassifRpart.R +++ b/R/LearnerClassifRpart.R @@ -37,7 +37,7 @@ LearnerClassifRpart = R6Class("LearnerClassifRpart", inherit = LearnerClassif, usesurrogate = p_int(0L, 2L, default = 2L, tags = "train"), xval = p_int(0L, default = 10L, tags = "train") ) - ps$values = list(xval = 0L) + ps$set_values(xval = 0L) super$initialize( id = "classif.rpart", diff --git a/R/LearnerRegrFeatureless.R b/R/LearnerRegrFeatureless.R index 501c48535..ec349d295 100644 --- a/R/LearnerRegrFeatureless.R +++ b/R/LearnerRegrFeatureless.R @@ -23,7 +23,7 @@ LearnerRegrFeatureless = R6Class("LearnerRegrFeatureless", inherit = LearnerRegr ps = ps( robust = p_lgl(default = TRUE, tags = "train") ) - ps$values = list(robust = FALSE) + ps$set_values(robust = FALSE) super$initialize( id = "regr.featureless", diff --git a/R/LearnerRegrRpart.R b/R/LearnerRegrRpart.R index f78785799..35e2f2f0e 100644 --- a/R/LearnerRegrRpart.R +++ b/R/LearnerRegrRpart.R @@ -37,7 +37,7 @@ LearnerRegrRpart = R6Class("LearnerRegrRpart", inherit = LearnerRegr, usesurrogate = p_int(0L, 2L, default = 2L, tags = "train"), xval = p_int(0L, default = 10L, tags = "train") ) - ps$values = list(xval = 0L) + ps$set_values(xval = 0L) super$initialize( id = "regr.rpart", diff --git a/R/MeasureClassifCosts.R b/R/MeasureClassifCosts.R index 63456d213..063780a85 100644 --- a/R/MeasureClassifCosts.R +++ b/R/MeasureClassifCosts.R @@ -45,7 +45,7 @@ MeasureClassifCosts = R6Class("MeasureClassifCosts", #' Creates a new instance of this [R6][R6::R6Class] class. initialize = function() { param_set = ps(normalize = p_lgl(tags = "required")) - param_set$values = list(normalize = TRUE) + param_set$set_values(normalize = TRUE) super$initialize( id = "classif.costs", diff --git a/R/MeasureDebug.R b/R/MeasureDebug.R index 1efb64329..8e02c0a72 100644 --- a/R/MeasureDebug.R +++ b/R/MeasureDebug.R @@ -27,7 +27,7 @@ MeasureDebugClassif = R6Class("MeasureDebugClassif", #' Creates a new instance of this [R6][R6::R6Class] class. initialize = function() { param_set = ps(na_ratio = p_dbl(0, 1, tags = "required")) - param_set$values = list(na_ratio = 0) + param_set$set_values(na_ratio = 0) super$initialize( id = "debug_classif", param_set = param_set, diff --git a/R/MeasureSelectedFeatures.R b/R/MeasureSelectedFeatures.R index db4748608..d53ead467 100644 --- a/R/MeasureSelectedFeatures.R +++ b/R/MeasureSelectedFeatures.R @@ -31,7 +31,7 @@ MeasureSelectedFeatures = R6Class("MeasureSelectedFeatures", #' Creates a new instance of this [R6][R6::R6Class] class. initialize = function() { param_set = ps(normalize = p_lgl(tags = "required")) - param_set$values = list(normalize = FALSE) + param_set$set_values(normalize = FALSE) super$initialize( id = "selected_features", diff --git a/R/Resampling.R b/R/Resampling.R index 89aa3e89b..a33544d54 100644 --- a/R/Resampling.R +++ b/R/Resampling.R @@ -56,7 +56,7 @@ #' r$param_set$values #' #' # Do only 3 repeats on 10% of the data -#' r$param_set$values = list(ratio = 0.1, repeats = 3) +#' r$param_set$set_values(ratio = 0.1, repeats = 3) #' r$param_set$values #' #' # Instantiate on penguins task diff --git a/R/ResamplingBootstrap.R b/R/ResamplingBootstrap.R index b4e75e942..54667059d 100644 --- a/R/ResamplingBootstrap.R +++ b/R/ResamplingBootstrap.R @@ -49,7 +49,7 @@ ResamplingBootstrap = R6Class("ResamplingBootstrap", inherit = Resampling, ratio = p_dbl(0, upper = 1, tags = "required"), repeats = p_int(1L, tags = "required") ) - ps$values = list(ratio = 1, repeats = 30L) + ps$set_values(ratio = 1, repeats = 30L) super$initialize(id = "bootstrap", param_set = ps, duplicated_ids = TRUE, label = "Bootstrap", man = "mlr3::mlr_resamplings_bootstrap") diff --git a/R/ResamplingCV.R b/R/ResamplingCV.R index 4b36a71b0..34a0e60fc 100644 --- a/R/ResamplingCV.R +++ b/R/ResamplingCV.R @@ -44,7 +44,7 @@ ResamplingCV = R6Class("ResamplingCV", inherit = Resampling, ps = ps( folds = p_int(2L, tags = "required") ) - ps$values = list(folds = 10L) + ps$set_values(folds = 10L) super$initialize(id = "cv", param_set = ps, label = "Cross-Validation", man = "mlr3::mlr_resamplings_cv") diff --git a/R/ResamplingHoldout.R b/R/ResamplingHoldout.R index 0a5d7ad8d..0c47a12e9 100644 --- a/R/ResamplingHoldout.R +++ b/R/ResamplingHoldout.R @@ -45,7 +45,7 @@ ResamplingHoldout = R6Class("ResamplingHoldout", inherit = Resampling, ps = ps( ratio = p_dbl(0, 1, tags = "required") ) - ps$values = list(ratio = 2 / 3) + ps$set_values(ratio = 2 / 3) super$initialize(id = "holdout", param_set = ps, label = "Holdout", man = "mlr3::mlr_resamplings_holdout") diff --git a/R/ResamplingRepeatedCV.R b/R/ResamplingRepeatedCV.R index 212106eb8..bbf3b492c 100644 --- a/R/ResamplingRepeatedCV.R +++ b/R/ResamplingRepeatedCV.R @@ -56,7 +56,7 @@ ResamplingRepeatedCV = R6Class("ResamplingRepeatedCV", inherit = Resampling, folds = p_int(2L, tags = "required"), repeats = p_int(1L) ) - ps$values = list(repeats = 10L, folds = 10L) + ps$set_values(repeats = 10L, folds = 10L) super$initialize(id = "repeated_cv", param_set = ps, label = "Repeated Cross-Validation", man = "mlr3::mlr_resamplings_repeated_cv") }, diff --git a/R/ResamplingSubsampling.R b/R/ResamplingSubsampling.R index 99a447f3c..beeae7f1b 100644 --- a/R/ResamplingSubsampling.R +++ b/R/ResamplingSubsampling.R @@ -48,7 +48,7 @@ ResamplingSubsampling = R6Class("ResamplingSubsampling", inherit = Resampling, ratio = p_dbl(0, 1, tags = "required"), repeats = p_int(1, tags = "required") ) - ps$values = list(repeats = 30L, ratio = 2 / 3) + ps$set_values(repeats = 30L, ratio = 2 / 3) super$initialize(id = "subsampling", param_set = ps, label = "Subsampling", man = "mlr3::mlr_resamplings_subsampling") diff --git a/R/TaskGeneratorMoons.R b/R/TaskGeneratorMoons.R index e2036d2b6..7a1b7be7a 100644 --- a/R/TaskGeneratorMoons.R +++ b/R/TaskGeneratorMoons.R @@ -27,7 +27,7 @@ TaskGeneratorMoons = R6Class("TaskGeneratorMoons", ps = ps( sigma = p_dbl(0, tags = "required") ) - ps$values = list(sigma = 1) + ps$set_values(sigma = 1) super$initialize(id = "moons", task_type = "classif", param_set = ps, label = "Moons Classification", man = "mlr3::mlr_task_generators_moons") diff --git a/man/Learner.Rd b/man/Learner.Rd index 70d7cf8f9..817222e0f 100644 --- a/man/Learner.Rd +++ b/man/Learner.Rd @@ -61,10 +61,10 @@ If the learner is not trained yet, this returns \code{NULL}. All information about hyperparameters is stored in the slot \code{param_set} which is a \link[paradox:ParamSet]{paradox::ParamSet}. The printer gives an overview about the ids of available hyperparameters, their storage type, lower and upper bounds, possible levels (for factors), default values and assigned values. -To set hyperparameters, assign a named list to the subslot \code{values}: +To set hyperparameters, call the \code{set_values()} method on the \code{param_set}: \if{html}{\out{
}}\preformatted{lrn = lrn("classif.rpart") -lrn$param_set$values = list(minsplit = 3, cp = 0.01) +lrn$param_set$set_values(minsplit = 3, cp = 0.01) }\if{html}{\out{
}} Note that this operation replaces all previously set hyperparameter values. diff --git a/man/Resampling.Rd b/man/Resampling.Rd index 26692dd46..c2e62bb05 100644 --- a/man/Resampling.Rd +++ b/man/Resampling.Rd @@ -53,7 +53,7 @@ r = rsmp("subsampling") r$param_set$values # Do only 3 repeats on 10\% of the data -r$param_set$values = list(ratio = 0.1, repeats = 3) +r$param_set$set_values(ratio = 0.1, repeats = 3) r$param_set$values # Instantiate on penguins task diff --git a/man/mlr_learners_classif.debug.Rd b/man/mlr_learners_classif.debug.Rd index a6e5ca749..e8c89c46a 100644 --- a/man/mlr_learners_classif.debug.Rd +++ b/man/mlr_learners_classif.debug.Rd @@ -77,7 +77,7 @@ lrn("classif.debug") \examples{ learner = lrn("classif.debug") -learner$param_set$values = list(message_train = 1, save_tasks = TRUE) +learner$param_set$set_values(message_train = 1, save_tasks = TRUE) # this should signal a message task = tsk("penguins") From 5c24ba8ac1990b7e336cc6a0c378fc54001700d9 Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Mon, 6 Jan 2025 17:04:59 +0100 Subject: [PATCH 07/10] Fix/predict newdata (#1240) * ci: fail on note * fix(predict): type conversion when predicting on new data * ... --------- Co-authored-by: Marc Becker <33069354+be-marc@users.noreply.github.com> --- NEWS.md | 4 +++- R/Learner.R | 11 +++++++++++ R/assertions.R | 2 +- man/Learner.Rd | 4 +++- tests/testthat/test_Learner.R | 27 +++++++++++++++++++++------ 5 files changed, 39 insertions(+), 9 deletions(-) diff --git a/NEWS.md b/NEWS.md index 69e7292de..441f1976e 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,7 +1,9 @@ # mlr3 (development version) +* fix: the `$predict_newdata()` method of `Learner` now automatically conducts type conversions if the input is a `data.frame` (#685) +* BREAKING_CHANGE: Predicting on a `task` with the wrong column information is now an error and not a warning. * Column names with UTF-8 characters are now allowed by default. -The option `mlr3.allow_utf8_names` is removed. + The option `mlr3.allow_utf8_names` is removed. * BREAKING CHANGE: `Learner$predict_types` is read-only now. * docs: Clear up behavior of `Learner$predict_type` after training. diff --git a/R/Learner.R b/R/Learner.R index 84fd7a075..de23f51e2 100644 --- a/R/Learner.R +++ b/R/Learner.R @@ -377,6 +377,8 @@ Learner = R6Class("Learner", #' `data.frame()` or [DataBackend]. #' If a [DataBackend] is provided as `newdata`, the row ids are preserved, #' otherwise they are set to to the sequence `1:nrow(newdata)`. + #' If the input is a `data.frame`, [`auto_convert`] is used for type-conversions to ensure compatability + #' of features between `$train()` and `$predict()`. #' #' @param task ([Task]). #' @@ -393,6 +395,14 @@ Learner = R6Class("Learner", task = task_rm_backend(task) } + if (is.data.frame(newdata)) { + keep_cols = intersect(names(newdata), task$col_info$id) + ci = task$col_info[list(keep_cols), on = "id"] + newdata = do.call(data.table, Map(auto_convert, + value = as.list(newdata)[ci$id], + id = ci$id, type = ci$type, levels = ci$levels)) + } + newdata = as_data_backend(newdata) assert_names(newdata$colnames, must.include = task$feature_names) @@ -409,6 +419,7 @@ Learner = R6Class("Learner", # do some type conversions if necessary task$backend = newdata + task$col_info = col_info(task$backend) task$row_roles$use = task$backend$rownames self$predict(task) }, diff --git a/R/assertions.R b/R/assertions.R index d86b5ce43..e14494bf4 100644 --- a/R/assertions.R +++ b/R/assertions.R @@ -189,7 +189,7 @@ assert_predictable = function(task, learner) { all(pmap_lgl(list(x = ci_train$levels, y = ci_predict$levels), identical)) if (!ok) { - lg$warn("Learner '%s' received task with different column info (feature type or level ordering) during train and predict.", learner$id) + stopf("Learner '%s' received task with different column info (feature type or factor level ordering) during train and predict.", learner$id) } } diff --git a/man/Learner.Rd b/man/Learner.Rd index 817222e0f..c57dcbf14 100644 --- a/man/Learner.Rd +++ b/man/Learner.Rd @@ -518,7 +518,9 @@ New data to predict on. All data formats convertible by \code{\link[=as_data_backend]{as_data_backend()}} are supported, e.g. \code{data.frame()} or \link{DataBackend}. If a \link{DataBackend} is provided as \code{newdata}, the row ids are preserved, -otherwise they are set to to the sequence \code{1:nrow(newdata)}.} +otherwise they are set to to the sequence \code{1:nrow(newdata)}. +If the input is a \code{data.frame}, \code{\link{auto_convert}} is used for type-conversions to ensure compatability +of features between \verb{$train()} and \verb{$predict()}.} \item{\code{task}}{(\link{Task}).} } diff --git a/tests/testthat/test_Learner.R b/tests/testthat/test_Learner.R index 8917f4452..901c6c271 100644 --- a/tests/testthat/test_Learner.R +++ b/tests/testthat/test_Learner.R @@ -255,13 +255,15 @@ test_that("learner cannot be trained with TuneToken present", { test_that("integer<->numeric conversion in newdata (#533)", { data = data.table(y = runif(10), x = 1:10) - newdata = data.table(y = runif(10), x = 1:10 + 0.1) + newdata1 = data.table(y = runif(10), x = as.double(1:10)) + newdata2 = data.table(y = runif(10), x = 1:10 + 0.1) task = TaskRegr$new("test", data, "y") learner = lrn("regr.featureless") learner$train(task) expect_prediction(learner$predict_newdata(data)) - expect_prediction(learner$predict_newdata(newdata)) + expect_prediction(learner$predict_newdata(newdata1)) + expect_error(learner$predict_newdata(newdata2), "failed to convert from class 'numeric'") }) test_that("weights", { @@ -575,10 +577,7 @@ test_that("column info is compared during predict", { task_other = as_task_classif(dother, target = "y") l = lrn("classif.rpart") l$train(task) - old_threshold = lg$threshold - lg$set_threshold("warn") - expect_output(l$predict(task_flip), "task with different column info") - lg$set_threshold(old_threshold) + expect_error(l$predict(task_flip), "task with different column info") expect_error(l$predict(task_other), "with different columns") }) @@ -663,3 +662,19 @@ test_that("configure method works", { expect_equal(learner$param_set$values$xval, 10) expect_equal(learner$predict_sets, "train") }) + +test_that("predict_newdata auto conversion (#685)", { + l = lrn("classif.debug", save_tasks = TRUE)$train(tsk("iris")$select(c("Sepal.Length", "Sepal.Width"))) + expect_error(l$predict_newdata(data.table(Sepal.Length = 1, Sepal.Width = "abc")), + "Incompatible types during auto-converting column 'Sepal.Width'", fixed = TRUE) + expect_error(l$predict_newdata(data.table(Sepal.Length = 1L)), + "but is missing elements") + + # New test for integerish value conversion to double + p1 = l$predict_newdata(data.table(Sepal.Length = 1, Sepal.Width = 2)) + p2 = l$predict_newdata(data.table(Sepal.Length = 1L, Sepal.Width = 2)) + expect_equal(l$model$task_predict$col_info[list("Sepal.Length")]$type, "numeric") + expect_double(l$model$task_predict$data(cols = "Sepal.Length")[[1]]) + + expect_equal(p1, p2) +}) From e21782a288f2dfbe38eb7034bdcf7321c133ef2d Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Mon, 6 Jan 2025 17:12:03 +0100 Subject: [PATCH 08/10] feat: improve docs for converters and better checks (#1231) * feat: improve docs for converters and better checks * ... * fix failing tests --- NAMESPACE | 2 ++ R/BenchmarkResult.R | 18 +++++++++++++++--- R/Prediction.R | 12 ++++++++++-- R/ResampleResult.R | 18 +++++++++++++++--- R/as_learner.R | 1 + R/as_measure.R | 2 ++ R/as_resampling.R | 3 ++- R/as_task.R | 8 ++++++++ R/as_task_classif.R | 2 +- R/as_task_regr.R | 2 +- R/assertions.R | 26 ++++++++++++++++++++++++++ man/as_resampling.Rd | 1 + man/as_task.Rd | 2 ++ man/as_task_classif.Rd | 2 +- man/as_task_regr.Rd | 2 +- man/assert_empty_ellipsis.Rd | 20 ++++++++++++++++++++ tests/testthat/test_Learner.R | 3 +-- tests/testthat/test_as_learner.R | 4 ++++ tests/testthat/test_as_measure.R | 4 ++++ tests/testthat/test_as_resampling.R | 4 ++++ tests/testthat/test_as_task.R | 4 ++++ tests/testthat/test_assertions.R | 8 ++++++++ 22 files changed, 133 insertions(+), 15 deletions(-) create mode 100644 man/assert_empty_ellipsis.Rd create mode 100644 tests/testthat/test_assertions.R diff --git a/NAMESPACE b/NAMESPACE index 8946cdd5e..d6625ea4e 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -42,6 +42,7 @@ S3method(as_resampling,Resampling) S3method(as_resamplings,default) S3method(as_resamplings,list) S3method(as_task,Task) +S3method(as_task,default) S3method(as_task_classif,DataBackend) S3method(as_task_classif,Matrix) S3method(as_task_classif,TaskClassif) @@ -200,6 +201,7 @@ export(as_tasks) export(as_tasks_unsupervised) export(assert_backend) export(assert_benchmark_result) +export(assert_empty_ellipsis) export(assert_learnable) export(assert_learner) export(assert_learners) diff --git a/R/BenchmarkResult.R b/R/BenchmarkResult.R index 9ee35c3ce..d3abea9bc 100644 --- a/R/BenchmarkResult.R +++ b/R/BenchmarkResult.R @@ -175,7 +175,11 @@ BenchmarkResult = R6Class("BenchmarkResult", #' #' @return [data.table::data.table()]. score = function(measures = NULL, ids = TRUE, conditions = FALSE, predictions = TRUE) { - measures = as_measures(measures, task_type = self$task_type) + measures = if (is.null(measures)) { + default_measures(self$task_type) + } else { + assert_measures(as_measures(measures)) + } assert_flag(ids) assert_flag(conditions) assert_flag(predictions) @@ -230,7 +234,11 @@ BenchmarkResult = R6Class("BenchmarkResult", #' @param predict_sets (`character()`)\cr #' The predict sets. obs_loss = function(measures = NULL, predict_sets = "test") { - measures = as_measures(measures, task_type = private$.data$task_type) + measures = if (is.null(measures)) { + default_measures(self$task_type) + } else { + assert_measures(as_measures(measures)) + } map_dtr(self$resample_results$resample_result, function(rr) { rr$obs_loss(measures, predict_sets) @@ -276,7 +284,11 @@ BenchmarkResult = R6Class("BenchmarkResult", #' #' @return [data.table::data.table()]. aggregate = function(measures = NULL, ids = TRUE, uhashes = FALSE, params = FALSE, conditions = FALSE) { - measures = assert_measures(as_measures(measures, task_type = self$task_type)) + measures = if (is.null(measures)) { + default_measures(self$task_type) + } else { + assert_measures(as_measures(measures)) + } assert_flag(ids) assert_flag(uhashes) assert_flag(params) diff --git a/R/Prediction.R b/R/Prediction.R index 243397ea6..ad7a0c8ce 100644 --- a/R/Prediction.R +++ b/R/Prediction.R @@ -90,7 +90,11 @@ Prediction = R6Class("Prediction", #' #' @return [Prediction]. score = function(measures = NULL, task = NULL, learner = NULL, train_set = NULL) { - measures = as_measures(measures, task_type = self$task_type) + measures = if (is.null(measures)) { + default_measures(self$task_type) + } else { + assert_measures(as_measures(measures)) + } scores = map_dbl(measures, function(m) m$score(prediction = self, task = task, learner = learner, train_set = train_set)) set_names(scores, ids(measures)) }, @@ -105,7 +109,11 @@ Prediction = R6Class("Prediction", #' Note that some measures such as RMSE, do have an `$obs_loss`, but they require an #' additional transformation after aggregation, in this example taking the square-root. obs_loss = function(measures = NULL) { - measures = as_measures(measures, task_type = self$task_type) + measures = if (is.null(measures)) { + default_measures(self$task_type) + } else { + assert_measures(as_measures(measures)) + } get_obs_loss(as.data.table(self), measures) }, diff --git a/R/ResampleResult.R b/R/ResampleResult.R index 83b79ada3..895696d54 100644 --- a/R/ResampleResult.R +++ b/R/ResampleResult.R @@ -143,7 +143,11 @@ ResampleResult = R6Class("ResampleResult", #' #' @return [data.table::data.table()]. score = function(measures = NULL, ids = TRUE, conditions = FALSE, predictions = TRUE) { - measures = as_measures(measures, task_type = private$.data$task_type) + measures = if (is.null(measures)) { + default_measures(self$task_type) + } else { + assert_measures(as_measures(measures)) + } assert_flag(ids) assert_flag(conditions) assert_flag(predictions) @@ -196,7 +200,11 @@ ResampleResult = R6Class("ResampleResult", #' @param predict_sets (`character()`)\cr #' The predict sets. obs_loss = function(measures = NULL, predict_sets = "test") { - measures = as_measures(measures, task_type = self$task_type) + measures = if (is.null(measures)) { + default_measures(self$task_type) + } else { + assert_measures(as_measures(measures)) + } tab = map_dtr(self$predictions(predict_sets), as.data.table, .idcol = "iteration") get_obs_loss(tab, measures) }, @@ -208,7 +216,11 @@ ResampleResult = R6Class("ResampleResult", #' #' @return Named `numeric()`. aggregate = function(measures = NULL) { - measures = as_measures(measures, task_type = private$.data$task_type) + measures = if (is.null(measures)) { + default_measures(self$task_type) + } else { + assert_measures(as_measures(measures)) + } resample_result_aggregate(self, measures) }, diff --git a/R/as_learner.R b/R/as_learner.R index 4c303f511..3ec806845 100644 --- a/R/as_learner.R +++ b/R/as_learner.R @@ -16,6 +16,7 @@ as_learner = function(x, ...) { # nolint #' Whether to discard the state. #' @rdname as_learner as_learner.Learner = function(x, clone = FALSE, discard_state = FALSE, ...) { # nolint + assert_empty_ellipsis(...) if (isTRUE(clone) && isTRUE(discard_state)) { clone_without(x, "state") } else if (isTRUE(clone)) { diff --git a/R/as_measure.R b/R/as_measure.R index d266d8f27..97f2798ef 100644 --- a/R/as_measure.R +++ b/R/as_measure.R @@ -17,12 +17,14 @@ as_measure = function(x, ...) { # nolint #' @export #' @rdname as_measure as_measure.NULL = function(x, task_type = NULL, ...) { # nolint + assert_empty_ellipsis(...) default_measures(task_type)[[1L]] } #' @export #' @rdname as_measure as_measure.Measure = function(x, clone = FALSE, ...) { # nolint + assert_empty_ellipsis(...) if (isTRUE(clone)) x$clone() else x } diff --git a/R/as_resampling.R b/R/as_resampling.R index 03c4e6cba..60e2453e6 100644 --- a/R/as_resampling.R +++ b/R/as_resampling.R @@ -2,7 +2,7 @@ #' #' @description #' Convert object to a [Resampling] or a list of [Resampling]. -#' +#' This method e.g. allows to convert an [`mlr3oml::OMLTask`] to a [`Resampling`]. #' @inheritParams as_task #' @export as_resampling = function(x, ...) { # nolint @@ -12,6 +12,7 @@ as_resampling = function(x, ...) { # nolint #' @export #' @rdname as_resampling as_resampling.Resampling = function(x, clone = FALSE, ...) { # nolint + assert_empty_ellipsis(...) if (isTRUE(clone)) x$clone() else x } diff --git a/R/as_task.R b/R/as_task.R index a6841dc39..1f7017f70 100644 --- a/R/as_task.R +++ b/R/as_task.R @@ -2,6 +2,8 @@ #' #' @description #' Convert object to a [Task] or a list of [Task]. +#' This method e.g. allows to convert an [`mlr3oml::OMLTask`] to a [`Task`] and additionally supports cloning. +#' In order to construct a [Task] from a `data.frame`, use task-specific converters such as [`as_task_classif()`] or [`as_task_regr()`]. #' #' @param x (any)\cr #' Object to convert. @@ -12,11 +14,17 @@ as_task = function(x, ...) { UseMethod("as_task") } +#' @export +as_task.default = function(x, ...) { + stopf("No method for class '%s'. To create a task from a `data.frame`, use dedicated converters such as `as_task_classif()` or `as_task_regr()`.", class(x)[1L]) +} + #' @rdname as_task #' @param clone (`logical(1)`)\cr #' If `TRUE`, ensures that the returned object is not the same as the input `x`. #' @export as_task.Task = function(x, clone = FALSE, ...) { # nolint + assert_empty_ellipsis(...) if (isTRUE(clone)) x$clone(deep = TRUE) else x } diff --git a/R/as_task_classif.R b/R/as_task_classif.R index cf75a0396..4f1c39bbd 100644 --- a/R/as_task_classif.R +++ b/R/as_task_classif.R @@ -4,7 +4,7 @@ #' Convert object to a [TaskClassif]. #' This is a S3 generic. mlr3 ships with methods for the following objects: #' -#' 1. [TaskClassif]: ensure the identity +#' 1. [TaskClassif]: returns the object as-is, possibly cloned. #' 2. [`formula`], [data.frame()], [matrix()], [Matrix::Matrix()] and [DataBackend]: provides an alternative to the constructor of [TaskClassif]. #' 3. [TaskRegr]: Calls [convert_task()]. #' diff --git a/R/as_task_regr.R b/R/as_task_regr.R index ce4f90d1a..ad1e682e9 100644 --- a/R/as_task_regr.R +++ b/R/as_task_regr.R @@ -4,7 +4,7 @@ #' Convert object to a [TaskRegr]. #' This is a S3 generic. mlr3 ships with methods for the following objects: #' -#' 1. [TaskRegr]: ensure the identity +#' 1. [TaskRegr]: returns the object as-is, possibly cloned. #' 2. [`formula`], [data.frame()], [matrix()], [Matrix::Matrix()] and [DataBackend]: provides an alternative to the constructor of [TaskRegr]. #' 3. [TaskClassif]: Calls [convert_task()]. #' diff --git a/R/assertions.R b/R/assertions.R index e14494bf4..bd1a529ba 100644 --- a/R/assertions.R +++ b/R/assertions.R @@ -405,3 +405,29 @@ assert_param_values = function(x, n_learners = NULL, .var.name = vname(x)) { } invisible(x) } + +#' @title Assert Empty Ellipsis +#' @description +#' Assert that `...` arguments are empty. +#' Use this function in S3-methods to ensure that misspelling of arguments does not go unnoticed. +#' @param ... (any)\cr +#' Ellipsis arguments to check. +#' @keywords internal +#' @return `NULL` +#' @export +assert_empty_ellipsis = function(...) { + if (...length()) { + names = ...names() + if (is.null(names)) { + stopf("Received %i unnamed argument that was not used.", ...length()) + } else { + names2 = names[names != ""] + if (length(names2) == length(names)) { + stopf("Received the following named arguments that were unused: %s.", paste0(names2, collapse = ", ")) + } else { + stopf("Received unused arguments: %i unnamed, as well as named arguments %s.", length(names) - length(names2), paste0(names2, collapse = ", ")) + } + } + } + NULL +} diff --git a/man/as_resampling.Rd b/man/as_resampling.Rd index 846ea9ac3..02784a7ec 100644 --- a/man/as_resampling.Rd +++ b/man/as_resampling.Rd @@ -30,4 +30,5 @@ If \code{TRUE}, ensures that the returned object is not the same as the input \c } \description{ Convert object to a \link{Resampling} or a list of \link{Resampling}. +This method e.g. allows to convert an \code{\link[mlr3oml:oml_task]{mlr3oml::OMLTask}} to a \code{\link{Resampling}}. } diff --git a/man/as_task.Rd b/man/as_task.Rd index eba52ac8a..8aff6071e 100644 --- a/man/as_task.Rd +++ b/man/as_task.Rd @@ -30,4 +30,6 @@ If \code{TRUE}, ensures that the returned object is not the same as the input \c } \description{ Convert object to a \link{Task} or a list of \link{Task}. +This method e.g. allows to convert an \code{\link[mlr3oml:oml_task]{mlr3oml::OMLTask}} to a \code{\link{Task}} and additionally supports cloning. +In order to construct a \link{Task} from a \code{data.frame}, use task-specific converters such as \code{\link[=as_task_classif]{as_task_classif()}} or \code{\link[=as_task_regr]{as_task_regr()}}. } diff --git a/man/as_task_classif.Rd b/man/as_task_classif.Rd index 823cffaa8..4d633963c 100644 --- a/man/as_task_classif.Rd +++ b/man/as_task_classif.Rd @@ -106,7 +106,7 @@ Data frame containing all columns referenced in formula \code{x}.} Convert object to a \link{TaskClassif}. This is a S3 generic. mlr3 ships with methods for the following objects: \enumerate{ -\item \link{TaskClassif}: ensure the identity +\item \link{TaskClassif}: returns the object as-is, possibly cloned. \item \code{\link{formula}}, \code{\link[=data.frame]{data.frame()}}, \code{\link[=matrix]{matrix()}}, \code{\link[Matrix:Matrix]{Matrix::Matrix()}} and \link{DataBackend}: provides an alternative to the constructor of \link{TaskClassif}. \item \link{TaskRegr}: Calls \code{\link[=convert_task]{convert_task()}}. } diff --git a/man/as_task_regr.Rd b/man/as_task_regr.Rd index 35e57e84d..f2b77330c 100644 --- a/man/as_task_regr.Rd +++ b/man/as_task_regr.Rd @@ -100,7 +100,7 @@ Data frame containing all columns referenced in formula \code{x}.} Convert object to a \link{TaskRegr}. This is a S3 generic. mlr3 ships with methods for the following objects: \enumerate{ -\item \link{TaskRegr}: ensure the identity +\item \link{TaskRegr}: returns the object as-is, possibly cloned. \item \code{\link{formula}}, \code{\link[=data.frame]{data.frame()}}, \code{\link[=matrix]{matrix()}}, \code{\link[Matrix:Matrix]{Matrix::Matrix()}} and \link{DataBackend}: provides an alternative to the constructor of \link{TaskRegr}. \item \link{TaskClassif}: Calls \code{\link[=convert_task]{convert_task()}}. } diff --git a/man/assert_empty_ellipsis.Rd b/man/assert_empty_ellipsis.Rd new file mode 100644 index 000000000..21568d8b8 --- /dev/null +++ b/man/assert_empty_ellipsis.Rd @@ -0,0 +1,20 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/assertions.R +\name{assert_empty_ellipsis} +\alias{assert_empty_ellipsis} +\title{Assert Empty Ellipsis} +\usage{ +assert_empty_ellipsis(...) +} +\arguments{ +\item{...}{(any)\cr +Ellipsis arguments to check.} +} +\value{ +\code{NULL} +} +\description{ +Assert that \code{...} arguments are empty. +Use this function in S3-methods to ensure that misspelling of arguments does not go unnoticed. +} +\keyword{internal} diff --git a/tests/testthat/test_Learner.R b/tests/testthat/test_Learner.R index 901c6c271..5ed9d803a 100644 --- a/tests/testthat/test_Learner.R +++ b/tests/testthat/test_Learner.R @@ -551,8 +551,7 @@ test_that("learner state contains internal valid task information", { test_that("validation task with 0 observations", { learner = lrn("classif.debug", validate = "predefined") task = tsk("iris") - task$internal_valid_task = integer(0) - expect_error({learner$train(task)}, "has 0 observations") + expect_warning({task$internal_valid_task = integer(0)}) }) test_that("column info is compared during predict", { diff --git a/tests/testthat/test_as_learner.R b/tests/testthat/test_as_learner.R index 3d34685a1..59cdc865a 100644 --- a/tests/testthat/test_as_learner.R +++ b/tests/testthat/test_as_learner.R @@ -21,3 +21,7 @@ test_that("discard_state", { as_learner(learner3, clone = FALSE, discard_state = TRUE) expect_null(learner3$state) }) + +test_that("error when arguments are misspelled", { + expect_error(as_learner(lrn("classif.rpart"), clone2 = TRUE), "Received the following") +}) diff --git a/tests/testthat/test_as_measure.R b/tests/testthat/test_as_measure.R index 647774fed..584962d66 100644 --- a/tests/testthat/test_as_measure.R +++ b/tests/testthat/test_as_measure.R @@ -14,3 +14,7 @@ test_that("as_measure conversion", { default = as_measures(NULL, task_type = "classif") expect_list(default, types = "Measure") }) + +test_that("error when arguments are misspelled", { + expect_error(as_measure(msr("classif.acc"), clone2 = TRUE), "Received the following") +}) diff --git a/tests/testthat/test_as_resampling.R b/tests/testthat/test_as_resampling.R index 40143a0c4..0718a7ff0 100644 --- a/tests/testthat/test_as_resampling.R +++ b/tests/testthat/test_as_resampling.R @@ -10,3 +10,7 @@ test_that("as_resampling conversion", { expect_list(as_resamplings(resampling), types = "Resampling") expect_list(as_resamplings(list(resampling)), types = "Resampling") }) + +test_that("error when arguments are misspelled", { + expect_error(as_resampling(rsmp("holdout"), clone2 = TRUE), "Received the following") +}) diff --git a/tests/testthat/test_as_task.R b/tests/testthat/test_as_task.R index 13b00ed82..9d763f221 100644 --- a/tests/testthat/test_as_task.R +++ b/tests/testthat/test_as_task.R @@ -22,3 +22,7 @@ test_that("as_task_xx error messages (#944)", { "subset of" ) }) + +test_that("error when arguments are misspelled", { + expect_error(as_task(tsk("iris"), clone2 = TRUE), "Received the following") +}) diff --git a/tests/testthat/test_assertions.R b/tests/testthat/test_assertions.R new file mode 100644 index 000000000..1c77d92b9 --- /dev/null +++ b/tests/testthat/test_assertions.R @@ -0,0 +1,8 @@ +test_that("assert_empty_ellipsis works", { + expect_error(assert_empty_ellipsis(1), "Received 1 unnamed argument") + expect_error(assert_empty_ellipsis(1, 2), "Received 2 unnamed argument") + expect_error(assert_empty_ellipsis(a = 1), "that were unused: a") + expect_error(assert_empty_ellipsis(a = 1, b = 2), "that were unused: a, b") + expect_error(assert_empty_ellipsis(a = 1, b = 1, 2), "1 unnamed, as well as named arguments a, b") + expect_null(assert_empty_ellipsis()) +}) From 694e21c4cf97c22e2f2ddc7ae16441917ba32f78 Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Thu, 9 Jan 2025 09:51:29 +0100 Subject: [PATCH 09/10] fix: as_measures (#1242) --- R/BenchmarkResult.R | 12 ++---------- R/Prediction.R | 12 ++---------- R/ResampleResult.R | 18 +++--------------- R/as_measure.R | 14 +++++++------- man/as_measure.Rd | 16 ++++++++-------- man/mlr_learners_classif.featureless.Rd | 2 +- man/mlr_learners_regr.featureless.Rd | 2 +- 7 files changed, 24 insertions(+), 52 deletions(-) diff --git a/R/BenchmarkResult.R b/R/BenchmarkResult.R index d3abea9bc..1ebd980a0 100644 --- a/R/BenchmarkResult.R +++ b/R/BenchmarkResult.R @@ -175,11 +175,7 @@ BenchmarkResult = R6Class("BenchmarkResult", #' #' @return [data.table::data.table()]. score = function(measures = NULL, ids = TRUE, conditions = FALSE, predictions = TRUE) { - measures = if (is.null(measures)) { - default_measures(self$task_type) - } else { - assert_measures(as_measures(measures)) - } + measures = assert_measures(as_measures(measures, task_type = self$task_type)) assert_flag(ids) assert_flag(conditions) assert_flag(predictions) @@ -234,11 +230,7 @@ BenchmarkResult = R6Class("BenchmarkResult", #' @param predict_sets (`character()`)\cr #' The predict sets. obs_loss = function(measures = NULL, predict_sets = "test") { - measures = if (is.null(measures)) { - default_measures(self$task_type) - } else { - assert_measures(as_measures(measures)) - } + measures = assert_measures(as_measures(measures, task_type = self$task_type)) map_dtr(self$resample_results$resample_result, function(rr) { rr$obs_loss(measures, predict_sets) diff --git a/R/Prediction.R b/R/Prediction.R index ad7a0c8ce..c2000e5ec 100644 --- a/R/Prediction.R +++ b/R/Prediction.R @@ -90,11 +90,7 @@ Prediction = R6Class("Prediction", #' #' @return [Prediction]. score = function(measures = NULL, task = NULL, learner = NULL, train_set = NULL) { - measures = if (is.null(measures)) { - default_measures(self$task_type) - } else { - assert_measures(as_measures(measures)) - } + measures = assert_measures(as_measures(measures, task_type = self$task_type)) scores = map_dbl(measures, function(m) m$score(prediction = self, task = task, learner = learner, train_set = train_set)) set_names(scores, ids(measures)) }, @@ -109,11 +105,7 @@ Prediction = R6Class("Prediction", #' Note that some measures such as RMSE, do have an `$obs_loss`, but they require an #' additional transformation after aggregation, in this example taking the square-root. obs_loss = function(measures = NULL) { - measures = if (is.null(measures)) { - default_measures(self$task_type) - } else { - assert_measures(as_measures(measures)) - } + measures = assert_measures(as_measures(measures, task_type = self$task_type)) get_obs_loss(as.data.table(self), measures) }, diff --git a/R/ResampleResult.R b/R/ResampleResult.R index 895696d54..53da71cdc 100644 --- a/R/ResampleResult.R +++ b/R/ResampleResult.R @@ -143,11 +143,7 @@ ResampleResult = R6Class("ResampleResult", #' #' @return [data.table::data.table()]. score = function(measures = NULL, ids = TRUE, conditions = FALSE, predictions = TRUE) { - measures = if (is.null(measures)) { - default_measures(self$task_type) - } else { - assert_measures(as_measures(measures)) - } + measures = assert_measures(as_measures(measures, task_type = self$task_type)) assert_flag(ids) assert_flag(conditions) assert_flag(predictions) @@ -200,11 +196,7 @@ ResampleResult = R6Class("ResampleResult", #' @param predict_sets (`character()`)\cr #' The predict sets. obs_loss = function(measures = NULL, predict_sets = "test") { - measures = if (is.null(measures)) { - default_measures(self$task_type) - } else { - assert_measures(as_measures(measures)) - } + measures = assert_measures(as_measures(measures, task_type = self$task_type)) tab = map_dtr(self$predictions(predict_sets), as.data.table, .idcol = "iteration") get_obs_loss(tab, measures) }, @@ -216,11 +208,7 @@ ResampleResult = R6Class("ResampleResult", #' #' @return Named `numeric()`. aggregate = function(measures = NULL) { - measures = if (is.null(measures)) { - default_measures(self$task_type) - } else { - assert_measures(as_measures(measures)) - } + measures = assert_measures(as_measures(measures, task_type = self$task_type)) resample_result_aggregate(self, measures) }, diff --git a/R/as_measure.R b/R/as_measure.R index 97f2798ef..bc63c32ab 100644 --- a/R/as_measure.R +++ b/R/as_measure.R @@ -10,7 +10,7 @@ #' #' @return [Measure]. #' @export -as_measure = function(x, ...) { # nolint +as_measure = function(x, task_type = NULL, ...) { # nolint UseMethod("as_measure") } @@ -23,21 +23,21 @@ as_measure.NULL = function(x, task_type = NULL, ...) { # nolint #' @export #' @rdname as_measure -as_measure.Measure = function(x, clone = FALSE, ...) { # nolint +as_measure.Measure = function(x, task_type = NULL, clone = FALSE, ...) { # nolint assert_empty_ellipsis(...) if (isTRUE(clone)) x$clone() else x } #' @export #' @rdname as_measure -as_measures = function(x, ...) { # nolint +as_measures = function(x, task_type = NULL, ...) { # nolint UseMethod("as_measures") } #' @export #' @rdname as_measure -as_measures.default = function(x, ...) { # nolint - list(as_measure(x, ...)) +as_measures.default = function(x, task_type = NULL, ...) { # nolint + list(as_measure(x, task_type = task_type, ...)) } #' @export @@ -48,6 +48,6 @@ as_measures.NULL = function(x, task_type = NULL, ...) { # nolint #' @export #' @rdname as_measure -as_measures.list = function(x, ...) { # nolint - lapply(x, as_measure, ...) +as_measures.list = function(x, task_type = NULL, ...) { # nolint + lapply(x, as_measure, task_type = NULL, ...) } diff --git a/man/as_measure.Rd b/man/as_measure.Rd index 79d52a565..ad17bc4a8 100644 --- a/man/as_measure.Rd +++ b/man/as_measure.Rd @@ -10,31 +10,31 @@ \alias{as_measures.list} \title{Convert to a Measure} \usage{ -as_measure(x, ...) +as_measure(x, task_type = NULL, ...) \method{as_measure}{`NULL`}(x, task_type = NULL, ...) -\method{as_measure}{Measure}(x, clone = FALSE, ...) +\method{as_measure}{Measure}(x, task_type = NULL, clone = FALSE, ...) -as_measures(x, ...) +as_measures(x, task_type = NULL, ...) -\method{as_measures}{default}(x, ...) +\method{as_measures}{default}(x, task_type = NULL, ...) \method{as_measures}{`NULL`}(x, task_type = NULL, ...) -\method{as_measures}{list}(x, ...) +\method{as_measures}{list}(x, task_type = NULL, ...) } \arguments{ \item{x}{(any)\cr Object to convert.} -\item{...}{(any)\cr -Additional arguments.} - \item{task_type}{(\code{character(1)})\cr Used if \code{x} is \code{NULL} to construct a default measure for the respective task type. The default measures are stored in \code{\link[=mlr_reflections]{mlr_reflections$default_measures}}.} +\item{...}{(any)\cr +Additional arguments.} + \item{clone}{(\code{logical(1)})\cr If \code{TRUE}, ensures that the returned object is not the same as the input \code{x}.} } diff --git a/man/mlr_learners_classif.featureless.Rd b/man/mlr_learners_classif.featureless.Rd index cf9724c9a..e701c7369 100644 --- a/man/mlr_learners_classif.featureless.Rd +++ b/man/mlr_learners_classif.featureless.Rd @@ -36,7 +36,7 @@ lrn("classif.featureless") \itemize{ \item Task type: \dQuote{classif} \item Predict Types: \dQuote{response}, \dQuote{prob} -\item Feature Types: \dQuote{logical}, \dQuote{integer}, \dQuote{numeric}, \dQuote{character}, \dQuote{factor}, \dQuote{ordered}, \dQuote{POSIXct} +\item Feature Types: \dQuote{logical}, \dQuote{integer}, \dQuote{numeric}, \dQuote{character}, \dQuote{factor}, \dQuote{ordered}, \dQuote{POSIXct}, \dQuote{Date} \item Required Packages: \CRANpkg{mlr3} } } diff --git a/man/mlr_learners_regr.featureless.Rd b/man/mlr_learners_regr.featureless.Rd index 4cd8de85e..64a69f4cb 100644 --- a/man/mlr_learners_regr.featureless.Rd +++ b/man/mlr_learners_regr.featureless.Rd @@ -25,7 +25,7 @@ lrn("regr.featureless") \itemize{ \item Task type: \dQuote{regr} \item Predict Types: \dQuote{response}, \dQuote{se}, \dQuote{quantiles} -\item Feature Types: \dQuote{logical}, \dQuote{integer}, \dQuote{numeric}, \dQuote{character}, \dQuote{factor}, \dQuote{ordered}, \dQuote{POSIXct} +\item Feature Types: \dQuote{logical}, \dQuote{integer}, \dQuote{numeric}, \dQuote{character}, \dQuote{factor}, \dQuote{ordered}, \dQuote{POSIXct}, \dQuote{Date} \item Required Packages: \CRANpkg{mlr3}, 'stats' } } From a54679e30ce533c5f5bf6af23c3333d59eca6196 Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Fri, 10 Jan 2025 10:14:27 +0100 Subject: [PATCH 10/10] fix(learner): column info and type conversion in predict_newdata (#1243) --- NEWS.md | 2 +- R/Learner.R | 33 +++++++++++++++++------------ man/Learner.Rd | 6 +++--- tests/testthat/test_Learner.R | 39 +++++++++++++++++++++++++++++++++++ 4 files changed, 63 insertions(+), 17 deletions(-) diff --git a/NEWS.md b/NEWS.md index 441f1976e..89892f152 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,6 +1,6 @@ # mlr3 (development version) -* fix: the `$predict_newdata()` method of `Learner` now automatically conducts type conversions if the input is a `data.frame` (#685) +* fix: the `$predict_newdata()` method of `Learner` now automatically conducts type conversions (#685) * BREAKING_CHANGE: Predicting on a `task` with the wrong column information is now an error and not a warning. * Column names with UTF-8 characters are now allowed by default. The option `mlr3.allow_utf8_names` is removed. diff --git a/R/Learner.R b/R/Learner.R index de23f51e2..7e9657085 100644 --- a/R/Learner.R +++ b/R/Learner.R @@ -370,6 +370,8 @@ Learner = R6Class("Learner", #' of the training task stored in the learner. #' If the learner has been fitted via [resample()] or [benchmark()], you need to pass the corresponding task stored #' in the [ResampleResult] or [BenchmarkResult], respectively. + #' Further, [`auto_convert`] is used for type-conversions to ensure compatability + #' of features between `$train()` and `$predict()`. #' #' @param newdata (any object supported by [as_data_backend()])\cr #' New data to predict on. @@ -377,8 +379,6 @@ Learner = R6Class("Learner", #' `data.frame()` or [DataBackend]. #' If a [DataBackend] is provided as `newdata`, the row ids are preserved, #' otherwise they are set to to the sequence `1:nrow(newdata)`. - #' If the input is a `data.frame`, [`auto_convert`] is used for type-conversions to ensure compatability - #' of features between `$train()` and `$predict()`. #' #' @param task ([Task]). #' @@ -395,31 +395,38 @@ Learner = R6Class("Learner", task = task_rm_backend(task) } - if (is.data.frame(newdata)) { - keep_cols = intersect(names(newdata), task$col_info$id) - ci = task$col_info[list(keep_cols), on = "id"] - newdata = do.call(data.table, Map(auto_convert, - value = as.list(newdata)[ci$id], - id = ci$id, type = ci$type, levels = ci$levels)) - } - newdata = as_data_backend(newdata) assert_names(newdata$colnames, must.include = task$feature_names) # the following columns are automatically set to NA if missing impute = unlist(task$col_roles[c("target", "name", "order", "stratum", "group", "weight")], use.names = FALSE) impute = setdiff(impute, newdata$colnames) - if (length(impute)) { + tab1 = if (length(impute)) { # create list with correct NA types and cbind it to the backend ci = insert_named(task$col_info[list(impute), c("id", "type", "levels"), on = "id", with = FALSE], list(value = NA)) na_cols = set_names(pmap(ci, function(..., nrow) rep(auto_convert(...), nrow), nrow = newdata$nrow), ci$id) - tab = invoke(data.table, .args = insert_named(na_cols, set_names(list(newdata$rownames), newdata$primary_key))) + invoke(data.table, .args = insert_named(na_cols, set_names(list(newdata$rownames), newdata$primary_key))) + } + + # Perform type conversion where necessary + keep_cols = intersect(newdata$colnames, task$col_info$id) + ci = task$col_info[list(keep_cols), ][ + get("type") != col_info(newdata)[list(keep_cols), on = "id"]$type] + tab2 = do.call(data.table, Map(auto_convert, + value = as.list(newdata$data(rows = newdata$rownames, cols = ci$id)), + id = ci$id, type = ci$type, levels = ci$levels)) + + tab = cbind(tab1, tab2) + if (ncol(tab)) { + tab[[newdata$primary_key]] = newdata$rownames newdata = DataBackendCbind$new(newdata, DataBackendDataTable$new(tab, primary_key = newdata$primary_key)) } - # do some type conversions if necessary + prevci = task$col_info task$backend = newdata task$col_info = col_info(task$backend) + task$col_info[, c("label", "fix_factor_levels")] = prevci[list(task$col_info$id), on = "id", c("label", "fix_factor_levels")] + task$col_info$fix_factor_levels[is.na(task$col_info$fix_factor_levels)] = FALSE task$row_roles$use = task$backend$rownames self$predict(task) }, diff --git a/man/Learner.Rd b/man/Learner.Rd index c57dcbf14..ca54c1a11 100644 --- a/man/Learner.Rd +++ b/man/Learner.Rd @@ -506,6 +506,8 @@ If the learner's \verb{$train()} method has been called, there is a (size reduce of the training task stored in the learner. If the learner has been fitted via \code{\link[=resample]{resample()}} or \code{\link[=benchmark]{benchmark()}}, you need to pass the corresponding task stored in the \link{ResampleResult} or \link{BenchmarkResult}, respectively. +Further, \code{\link{auto_convert}} is used for type-conversions to ensure compatability +of features between \verb{$train()} and \verb{$predict()}. \subsection{Usage}{ \if{html}{\out{
}}\preformatted{Learner$predict_newdata(newdata, task = NULL)}\if{html}{\out{
}} } @@ -518,9 +520,7 @@ New data to predict on. All data formats convertible by \code{\link[=as_data_backend]{as_data_backend()}} are supported, e.g. \code{data.frame()} or \link{DataBackend}. If a \link{DataBackend} is provided as \code{newdata}, the row ids are preserved, -otherwise they are set to to the sequence \code{1:nrow(newdata)}. -If the input is a \code{data.frame}, \code{\link{auto_convert}} is used for type-conversions to ensure compatability -of features between \verb{$train()} and \verb{$predict()}.} +otherwise they are set to to the sequence \code{1:nrow(newdata)}.} \item{\code{task}}{(\link{Task}).} } diff --git a/tests/testthat/test_Learner.R b/tests/testthat/test_Learner.R index 5ed9d803a..086a06298 100644 --- a/tests/testthat/test_Learner.R +++ b/tests/testthat/test_Learner.R @@ -677,3 +677,42 @@ test_that("predict_newdata auto conversion (#685)", { expect_equal(p1, p2) }) + +test_that("predict_newdata creates column info correctly", { + + learner = lrn("classif.debug", save_tasks = TRUE) + task = tsk("iris") + task$col_info$label = letters[1:6] + task$col_info$fix_factor_levels = c(TRUE, TRUE, FALSE, TRUE, FALSE, TRUE) + learner$train(task) + + ## data.frame is passed without task + p1 = learner$predict_newdata(iris[10:11, ]) + expect_equal(learner$model$task_predict$col_info, task$col_info) + expect_equal(p1$row_ids, 1:2) + + ## backend is passed without task + d = iris[10:11, ] + d$..row_id = 10:11 + b = as_data_backend(d, primary_key = "..row_id") + p2 = learner$predict_newdata(b) + expect_equal(p2$row_ids, 10:11) + expect_equal(learner$model$task_predict$col_info, task$col_info) + + ## data.frame is passed with task + task2 = tsk("iris") + learner$predict_newdata(iris[10:11, ], task2) + expect_equal(learner$model$task_predict$col_info, task2$col_info) + + ## backend is passed with task + learner$predict_newdata(b, task2) + expect_equal(learner$model$task_predict$col_info, task2$col_info) + + ## backend with different name for primary key + d2 = iris[10:11, ] + d2$row_id = 10:11 + b2 = as_data_backend(d2, primary_key = "row_id") + p3 = learner$predict_newdata(b2, task2) + expect_equal(p3$row_ids, 10:11) + expect_true("row_id" %in% learner$model$task_predict$col_info$id) +})