Skip to content

Commit

Permalink
Merge branch 'main' into fix/predict-newdata
Browse files Browse the repository at this point in the history
  • Loading branch information
be-marc authored Jan 6, 2025
2 parents aff322f + 9c95317 commit c4703bf
Show file tree
Hide file tree
Showing 47 changed files with 166 additions and 145 deletions.
10 changes: 6 additions & 4 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# mlr3 (development version)

* fix: the `$predict_newdata()` method of `Learner` now automatically conducts
type conversions if the input is a `data.frame` (#685)
* BREAKING_CHANGE: Predicting on a `task` with the wrong column information is
now an error and not a warning.
* fix: the `$predict_newdata()` method of `Learner` now automatically conducts type conversions if the input is a `data.frame` (#685)
* BREAKING_CHANGE: Predicting on a `task` with the wrong column information is now an error and not a warning.
* Column names with UTF-8 characters are now allowed by default.
The option `mlr3.allow_utf8_names` is removed.
* BREAKING CHANGE: `Learner$predict_types` is read-only now.
* docs: Clear up behavior of `Learner$predict_type` after training.

# mlr3 0.22.1

Expand Down
4 changes: 2 additions & 2 deletions R/DataBackendRename.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ DataBackendRename = R6Class("DataBackendRename", inherit = DataBackend, cloneabl
assert_character(old, any.missing = FALSE, unique = TRUE)
assert_subset(old, b$colnames)
assert_character(new, any.missing = FALSE, len = length(old))
assert_names(new, if (allow_utf8_names()) "unique" else "strict")
assert_names(new, "unique")

ii = old != new
old = old[ii]
new = new[ii]

if (self$primary_key %in% old) {
if (self$primary_key %chin% old) {
stopf("Renaming the primary key is not supported")
}

Expand Down
4 changes: 2 additions & 2 deletions R/HotstartStack.R
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,9 @@ calculate_cost = function(start_learner, learner, hotstart_id) {
cost = learner$param_set$values[[hotstart_id]] - start_learner$param_set$values[[hotstart_id]]
if (cost == 0) return(-1)

if ("hotstart_backward" %in% learner$properties && "hotstart_forward" %in% learner$properties) {
if ("hotstart_backward" %chin% learner$properties && "hotstart_forward" %chin% learner$properties) {
if (cost < 0) 0 else cost
} else if ("hotstart_backward" %in% learner$properties) {
} else if ("hotstart_backward" %chin% learner$properties) {
if (cost < 0) 0 else NA_real_
} else {
if (cost > 0) cost else NA_real_
Expand Down
25 changes: 15 additions & 10 deletions R/Learner.R
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,10 @@
#' All information about hyperparameters is stored in the slot `param_set` which is a [paradox::ParamSet].
#' The printer gives an overview about the ids of available hyperparameters, their storage type, lower and upper bounds,
#' possible levels (for factors), default values and assigned values.
#' To set hyperparameters, assign a named list to the subslot `values`:
#' To set hyperparameters, call the `set_values()` method on the `param_set`:
#' ```
#' lrn = lrn("classif.rpart")
#' lrn$param_set$values = list(minsplit = 3, cp = 0.01)
#' lrn$param_set$set_values(minsplit = 3, cp = 0.01)
#' ```
#' Note that this operation replaces all previously set hyperparameter values.
#' If you only intend to change one specific hyperparameter value and leave the others as-is, you can use the helper function [mlr3misc::insert_named()]:
Expand Down Expand Up @@ -157,11 +157,6 @@ Learner = R6Class("Learner",
#' @template field_task_type
task_type = NULL,

#' @field predict_types (`character()`)\cr
#' Stores the possible predict types the learner is capable of.
#' A complete list of candidate predict types, grouped by task type, is stored in [`mlr_reflections$learner_predict_types`][mlr_reflections].
predict_types = NULL,

#' @field feature_types (`character()`)\cr
#' Stores the feature types the learner can handle, e.g. `"logical"`, `"numeric"`, or `"factor"`.
#' A complete list of candidate feature types, grouped by task type, is stored in [`mlr_reflections$task_feature_types`][mlr_reflections].
Expand Down Expand Up @@ -214,7 +209,7 @@ Learner = R6Class("Learner",
self$task_type = assert_choice(task_type, mlr_reflections$task_types$type)
private$.param_set = assert_param_set(param_set)
self$feature_types = assert_ordered_set(feature_types, mlr_reflections$task_feature_types, .var.name = "feature_types")
self$predict_types = assert_ordered_set(predict_types, names(mlr_reflections$learner_predict_types[[task_type]]),
private$.predict_types = assert_ordered_set(predict_types, names(mlr_reflections$learner_predict_types[[task_type]]),
empty.ok = FALSE, .var.name = "predict_types")
private$.predict_type = predict_types[1L]
self$properties = sort(assert_subset(properties, mlr_reflections$learner_properties[[task_type]]))
Expand Down Expand Up @@ -638,6 +633,8 @@ Learner = R6Class("Learner",
#' @field predict_type (`character(1)`)\cr
#' Stores the currently active predict type, e.g. `"response"`.
#' Must be an element of `$predict_types`.
#' A few learners already use the predict type during training.
#' So there is no guarantee that changing the predict type after training will have any effect or does not lead to errors.
predict_type = function(rhs) {
if (missing(rhs)) {
return(private$.predict_type)
Expand All @@ -659,8 +656,6 @@ Learner = R6Class("Learner",
private$.param_set
},



#' @field fallback ([Learner])\cr
#' Returns the fallback learner set with `$encapsulate()`.
fallback = function(rhs) {
Expand All @@ -683,13 +678,23 @@ Learner = R6Class("Learner",
}
assert_r6(rhs, "HotstartStack", null.ok = TRUE)
private$.hotstart_stack = rhs
},

#' @field predict_types (`character()`)\cr
#' Stores the possible predict types the learner is capable of.
#' A complete list of candidate predict types, grouped by task type, is stored in [`mlr_reflections$learner_predict_types`][mlr_reflections].
#' This field is read-only.
predict_types = function(rhs) {
assert_ro_binding(rhs)
return(private$.predict_types)
}
),

private = list(
.encapsulation = c(train = "none", predict = "none"),
.fallback = NULL,
.predict_type = NULL,
.predict_types = NULL,
.param_set = NULL,
.hotstart_stack = NULL,

Expand Down
10 changes: 5 additions & 5 deletions R/LearnerClassifDebug.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
#' @export
#' @examples
#' learner = lrn("classif.debug")
#' learner$param_set$values = list(message_train = 1, save_tasks = TRUE)
#' learner$param_set$set_values(message_train = 1, save_tasks = TRUE)
#'
#' # this should signal a message
#' task = tsk("penguins")
Expand Down Expand Up @@ -163,7 +163,7 @@ LearnerClassifDebug = R6Class("LearnerClassifDebug", inherit = LearnerClassif,
pv = self$param_set$get_values(tags = "train")
pv$count_marshaling = pv$count_marshaling %??% FALSE
roll = function(name) {
name %in% names(pv) && pv[[name]] > runif(1L)
name %chin% names(pv) && pv[[name]] > runif(1L)
}

if (!is.null(pv$sleep_train)) {
Expand Down Expand Up @@ -248,7 +248,7 @@ LearnerClassifDebug = R6Class("LearnerClassifDebug", inherit = LearnerClassif,
n = task$nrow
pv = self$param_set$get_values(tags = "predict")
roll = function(name) {
name %in% names(pv) && pv[[name]] > runif(1L)
name %chin% names(pv) && pv[[name]] > runif(1L)
}

if (!is.null(pv$sleep_predict)) {
Expand Down Expand Up @@ -281,7 +281,7 @@ LearnerClassifDebug = R6Class("LearnerClassifDebug", inherit = LearnerClassif,
response = prob = NULL
missing_type = pv$predict_missing_type %??% "na"

if ("response" %in% self$predict_type) {
if ("response" %chin% self$predict_type) {
response = rep.int(unclass(model$response), n)
if (!is.null(pv$predict_missing)) {
ii = sample.int(n, n * pv$predict_missing)
Expand All @@ -292,7 +292,7 @@ LearnerClassifDebug = R6Class("LearnerClassifDebug", inherit = LearnerClassif,
}
}

if ("prob" %in% self$predict_type) {
if ("prob" %chin% self$predict_type) {
cl = task$class_names
prob = matrix(runif(n * length(cl)), nrow = n)
prob = prob / rowSums(prob)
Expand Down
2 changes: 1 addition & 1 deletion R/LearnerClassifFeatureless.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ LearnerClassifFeatureless = R6Class("LearnerClassifFeatureless", inherit = Learn
ps = ps(
method = p_fct(c("mode", "sample", "weighted.sample"), default = "mode", tags = "predict")
)
ps$values = list(method = "mode")
ps$set_values(method = "mode")
super$initialize(
id = "classif.featureless",
feature_types = mlr_reflections$task_feature_types,
Expand Down
8 changes: 4 additions & 4 deletions R/LearnerClassifRpart.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ LearnerClassifRpart = R6Class("LearnerClassifRpart", inherit = LearnerClassif,
usesurrogate = p_int(0L, 2L, default = 2L, tags = "train"),
xval = p_int(0L, default = 10L, tags = "train")
)
ps$values = list(xval = 0L)
ps$set_values(xval = 0L)

super$initialize(
id = "classif.rpart",
Expand Down Expand Up @@ -77,7 +77,7 @@ LearnerClassifRpart = R6Class("LearnerClassifRpart", inherit = LearnerClassif,
.train = function(task) {
pv = self$param_set$get_values(tags = "train")
names(pv) = replace(names(pv), names(pv) == "keep_model", "model")
if ("weights" %in% task$properties) {
if ("weights" %chin% task$properties) {
pv = insert_named(pv, list(weights = task$weights$weight))
}

Expand All @@ -89,11 +89,11 @@ LearnerClassifRpart = R6Class("LearnerClassifRpart", inherit = LearnerClassif,
newdata = task$data(cols = task$feature_names)
response = prob = NULL

if ("response" %in% self$predict_type) {
if ("response" %chin% self$predict_type) {
response = invoke(predict, self$model, newdata = newdata, type = "class",
.opts = allow_partial_matching, .args = pv)
response = unname(response)
} else if ("prob" %in% self$predict_type) {
} else if ("prob" %chin% self$predict_type) {
prob = invoke(predict, self$model, newdata = newdata, type = "prob",
.opts = allow_partial_matching, .args = pv)
rownames(prob) = NULL
Expand Down
2 changes: 1 addition & 1 deletion R/LearnerRegrFeatureless.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ LearnerRegrFeatureless = R6Class("LearnerRegrFeatureless", inherit = LearnerRegr
ps = ps(
robust = p_lgl(default = TRUE, tags = "train")
)
ps$values = list(robust = FALSE)
ps$set_values(robust = FALSE)

super$initialize(
id = "regr.featureless",
Expand Down
4 changes: 2 additions & 2 deletions R/LearnerRegrRpart.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ LearnerRegrRpart = R6Class("LearnerRegrRpart", inherit = LearnerRegr,
usesurrogate = p_int(0L, 2L, default = 2L, tags = "train"),
xval = p_int(0L, default = 10L, tags = "train")
)
ps$values = list(xval = 0L)
ps$set_values(xval = 0L)

super$initialize(
id = "regr.rpart",
Expand Down Expand Up @@ -77,7 +77,7 @@ LearnerRegrRpart = R6Class("LearnerRegrRpart", inherit = LearnerRegr,
.train = function(task) {
pv = self$param_set$get_values(tags = "train")
names(pv) = replace(names(pv), names(pv) == "keep_model", "model")
if ("weights" %in% task$properties) {
if ("weights" %chin% task$properties) {
pv = insert_named(pv, list(weights = task$weights$weight))
}

Expand Down
10 changes: 5 additions & 5 deletions R/Measure.R
Original file line number Diff line number Diff line change
Expand Up @@ -200,19 +200,19 @@ Measure = R6Class("Measure",
# check should be added to assert_measure()
# except when the checks are superfluous for rr$score() and bmr$score()
# these checks should be added bellow
if ("requires_task" %in% self$properties && is.null(task)) {
if ("requires_task" %chin% self$properties && is.null(task)) {
stopf("Measure '%s' requires a task", self$id)
}

if ("requires_learner" %in% self$properties && is.null(learner)) {
if ("requires_learner" %chin% self$properties && is.null(learner)) {
stopf("Measure '%s' requires a learner", self$id)
}

if (!is_scalar_na(self$task_type) && self$task_type != prediction$task_type) {
stopf("Measure '%s' incompatible with task type '%s'", self$id, prediction$task_type)
}

if ("requires_train_set" %in% self$properties && is.null(train_set)) {
if ("requires_train_set" %chin% self$properties && is.null(train_set)) {
stopf("Measure '%s' requires the train_set", self$id)
}

Expand Down Expand Up @@ -258,7 +258,7 @@ Measure = R6Class("Measure",
#' @template field_predict_sets
predict_sets = function(rhs) {
if (!missing(rhs)) {
private$.predict_sets = assert_subset(rhs, mlr_reflections$predict_sets, empty.ok = "requires_no_prediction" %in% self$properties)
private$.predict_sets = assert_subset(rhs, mlr_reflections$predict_sets, empty.ok = "requires_no_prediction" %chin% self$properties)
}
private$.predict_sets
},
Expand Down Expand Up @@ -385,7 +385,7 @@ score_single_measure = function(measure, task, learner, train_set, prediction) {
#' @noRd
score_measures = function(obj, measures, reassemble = TRUE, view = NULL, iters = NULL) {
reassemble_learners = reassemble ||
some(measures, function(m) any(c("requires_learner", "requires_model") %in% m$properties))
some(measures, function(m) any(c("requires_learner", "requires_model") %chin% m$properties))
tab = get_private(obj)$.data$as_data_table(view = view, reassemble_learners = reassemble_learners, convert_predictions = FALSE)

if (!is.null(iters)) {
Expand Down
2 changes: 1 addition & 1 deletion R/MeasureClassifCosts.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ MeasureClassifCosts = R6Class("MeasureClassifCosts",
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function() {
param_set = ps(normalize = p_lgl(tags = "required"))
param_set$values = list(normalize = TRUE)
param_set$set_values(normalize = TRUE)

super$initialize(
id = "classif.costs",
Expand Down
2 changes: 1 addition & 1 deletion R/MeasureDebug.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ MeasureDebugClassif = R6Class("MeasureDebugClassif",
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function() {
param_set = ps(na_ratio = p_dbl(0, 1, tags = "required"))
param_set$values = list(na_ratio = 0)
param_set$set_values(na_ratio = 0)
super$initialize(
id = "debug_classif",
param_set = param_set,
Expand Down
2 changes: 1 addition & 1 deletion R/MeasureSelectedFeatures.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ MeasureSelectedFeatures = R6Class("MeasureSelectedFeatures",
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function() {
param_set = ps(normalize = p_lgl(tags = "required"))
param_set$values = list(normalize = FALSE)
param_set$set_values(normalize = FALSE)

super$initialize(
id = "selected_features",
Expand Down
2 changes: 1 addition & 1 deletion R/PredictionClassif.R
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ PredictionClassif = R6Class("PredictionClassif", inherit = Prediction,
as.data.table.PredictionClassif = function(x, ...) { # nolint
tab = as.data.table(x$data[c("row_ids", "truth", "response")])

if ("prob" %in% x$predict_types) {
if ("prob" %chin% x$predict_types) {
prob = as.data.table(x$data$prob)
setnames(prob, names(prob), paste0("prob.", names(prob)))
tab = rcbind(tab, prob)
Expand Down
4 changes: 2 additions & 2 deletions R/PredictionDataClassif.R
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,11 @@ create_empty_prediction_data.TaskClassif = function(task, learner) {
truth = factor(character(), levels = cn)
)

if ("response" %in% predict_types) {
if ("response" %chin% predict_types) {
pdata$response = pdata$truth
}

if ("prob" %in% predict_types) {
if ("prob" %chin% predict_types) {
pdata$prob = matrix(numeric(), nrow = 0L, ncol = length(cn), dimnames = list(NULL, cn))
}

Expand Down
8 changes: 4 additions & 4 deletions R/PredictionDataRegr.R
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ c.PredictionDataRegr = function(..., keep_duplicates = TRUE) { # nolint
result = as.list(tab)
result$quantiles = quantiles

if ("distr" %in% predict_types[[1L]]) {
if ("distr" %chin% predict_types[[1L]]) {
require_namespaces("distr6", msg = "To predict probability distributions, please install %s")
result$distr = do.call(c, map(dots, "distr"))
}
Expand Down Expand Up @@ -137,15 +137,15 @@ create_empty_prediction_data.TaskRegr = function(task, learner) {
truth = numeric()
)

if ("response" %in% predict_types) {
if ("response" %chin% predict_types) {
pdata$response = pdata$truth
}

if ("se" %in% predict_types) {
if ("se" %chin% predict_types) {
pdata$se = pdata$truth
}

if ("distr" %in% predict_types) {
if ("distr" %chin% predict_types) {
pdata$distr = list()
}

Expand Down
8 changes: 4 additions & 4 deletions R/PredictionRegr.R
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ PredictionRegr = R6Class("PredictionRegr", inherit = Prediction,
self$data = pdata
predict_types = intersect(names(mlr_reflections$learner_predict_types[["regr"]]), names(pdata))
# response is in saved in quantiles matrix
if ("quantiles" %in% predict_types) predict_types = union(predict_types, "response")
if ("quantiles" %chin% predict_types) predict_types = union(predict_types, "response")
self$predict_types = predict_types
if (is.null(pdata$response)) private$.quantile_response = attr(quantiles, "response")
}
Expand Down Expand Up @@ -94,7 +94,7 @@ PredictionRegr = R6Class("PredictionRegr", inherit = Prediction,
#' Access the stored vector distribution.
#' Requires package `distr6`(in repository \url{https://raphaels1.r-universe.dev}) .
distr = function() {
if ("distr" %in% self$predict_types) {
if ("distr" %chin% self$predict_types) {
require_namespaces("distr6", msg = "To predict probability distributions, please install %s")
}
return(self$data$distr)
Expand All @@ -111,12 +111,12 @@ PredictionRegr = R6Class("PredictionRegr", inherit = Prediction,
as.data.table.PredictionRegr = function(x, ...) { # nolint
tab = as.data.table(x$data[c("row_ids", "truth", "response", "se")])

if ("quantiles" %in% x$predict_types) {
if ("quantiles" %chin% x$predict_types) {
tab = rcbind(tab, as.data.table(x$data$quantiles))
set(tab, j = "response", value = x$response)
}

if ("distr" %in% x$predict_types) {
if ("distr" %chin% x$predict_types) {
require_namespaces("distr6", msg = "To predict probability distributions, please install %s")
tab$distr = list(x$distr)
}
Expand Down
2 changes: 1 addition & 1 deletion R/Resampling.R
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
#' r$param_set$values
#'
#' # Do only 3 repeats on 10% of the data
#' r$param_set$values = list(ratio = 0.1, repeats = 3)
#' r$param_set$set_values(ratio = 0.1, repeats = 3)
#' r$param_set$values
#'
#' # Instantiate on penguins task
Expand Down
Loading

0 comments on commit c4703bf

Please sign in to comment.