From 9c953178059216032e8d4868e06296614a54063b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20M=C3=BCcke?= Date: Mon, 6 Jan 2025 17:01:29 +0100 Subject: [PATCH] refactor: use more set_values() (#1239) * refactor: use more set_values() * docs: adjust text to reference set_values for changing param set --- R/Learner.R | 4 ++-- R/LearnerClassifDebug.R | 2 +- R/LearnerClassifFeatureless.R | 2 +- R/LearnerClassifRpart.R | 2 +- R/LearnerRegrFeatureless.R | 2 +- R/LearnerRegrRpart.R | 2 +- R/MeasureClassifCosts.R | 2 +- R/MeasureDebug.R | 2 +- R/MeasureSelectedFeatures.R | 2 +- R/Resampling.R | 2 +- R/ResamplingBootstrap.R | 2 +- R/ResamplingCV.R | 2 +- R/ResamplingHoldout.R | 2 +- R/ResamplingRepeatedCV.R | 2 +- R/ResamplingSubsampling.R | 2 +- R/TaskGeneratorMoons.R | 2 +- man/Learner.Rd | 4 ++-- man/Resampling.Rd | 2 +- man/mlr_learners_classif.debug.Rd | 2 +- 19 files changed, 21 insertions(+), 21 deletions(-) diff --git a/R/Learner.R b/R/Learner.R index 461dd150f..84fd7a075 100644 --- a/R/Learner.R +++ b/R/Learner.R @@ -72,10 +72,10 @@ #' All information about hyperparameters is stored in the slot `param_set` which is a [paradox::ParamSet]. #' The printer gives an overview about the ids of available hyperparameters, their storage type, lower and upper bounds, #' possible levels (for factors), default values and assigned values. -#' To set hyperparameters, assign a named list to the subslot `values`: +#' To set hyperparameters, call the `set_values()` method on the `param_set`: #' ``` #' lrn = lrn("classif.rpart") -#' lrn$param_set$values = list(minsplit = 3, cp = 0.01) +#' lrn$param_set$set_values(minsplit = 3, cp = 0.01) #' ``` #' Note that this operation replaces all previously set hyperparameter values. #' If you only intend to change one specific hyperparameter value and leave the others as-is, you can use the helper function [mlr3misc::insert_named()]: diff --git a/R/LearnerClassifDebug.R b/R/LearnerClassifDebug.R index 4f6984f1d..ed7ebfc92 100644 --- a/R/LearnerClassifDebug.R +++ b/R/LearnerClassifDebug.R @@ -37,7 +37,7 @@ #' @export #' @examples #' learner = lrn("classif.debug") -#' learner$param_set$values = list(message_train = 1, save_tasks = TRUE) +#' learner$param_set$set_values(message_train = 1, save_tasks = TRUE) #' #' # this should signal a message #' task = tsk("penguins") diff --git a/R/LearnerClassifFeatureless.R b/R/LearnerClassifFeatureless.R index 35c3aeb17..995839fb0 100644 --- a/R/LearnerClassifFeatureless.R +++ b/R/LearnerClassifFeatureless.R @@ -34,7 +34,7 @@ LearnerClassifFeatureless = R6Class("LearnerClassifFeatureless", inherit = Learn ps = ps( method = p_fct(c("mode", "sample", "weighted.sample"), default = "mode", tags = "predict") ) - ps$values = list(method = "mode") + ps$set_values(method = "mode") super$initialize( id = "classif.featureless", feature_types = mlr_reflections$task_feature_types, diff --git a/R/LearnerClassifRpart.R b/R/LearnerClassifRpart.R index 05342335c..c3fc6896d 100644 --- a/R/LearnerClassifRpart.R +++ b/R/LearnerClassifRpart.R @@ -37,7 +37,7 @@ LearnerClassifRpart = R6Class("LearnerClassifRpart", inherit = LearnerClassif, usesurrogate = p_int(0L, 2L, default = 2L, tags = "train"), xval = p_int(0L, default = 10L, tags = "train") ) - ps$values = list(xval = 0L) + ps$set_values(xval = 0L) super$initialize( id = "classif.rpart", diff --git a/R/LearnerRegrFeatureless.R b/R/LearnerRegrFeatureless.R index 501c48535..ec349d295 100644 --- a/R/LearnerRegrFeatureless.R +++ b/R/LearnerRegrFeatureless.R @@ -23,7 +23,7 @@ LearnerRegrFeatureless = R6Class("LearnerRegrFeatureless", inherit = LearnerRegr ps = ps( robust = p_lgl(default = TRUE, tags = "train") ) - ps$values = list(robust = FALSE) + ps$set_values(robust = FALSE) super$initialize( id = "regr.featureless", diff --git a/R/LearnerRegrRpart.R b/R/LearnerRegrRpart.R index f78785799..35e2f2f0e 100644 --- a/R/LearnerRegrRpart.R +++ b/R/LearnerRegrRpart.R @@ -37,7 +37,7 @@ LearnerRegrRpart = R6Class("LearnerRegrRpart", inherit = LearnerRegr, usesurrogate = p_int(0L, 2L, default = 2L, tags = "train"), xval = p_int(0L, default = 10L, tags = "train") ) - ps$values = list(xval = 0L) + ps$set_values(xval = 0L) super$initialize( id = "regr.rpart", diff --git a/R/MeasureClassifCosts.R b/R/MeasureClassifCosts.R index 63456d213..063780a85 100644 --- a/R/MeasureClassifCosts.R +++ b/R/MeasureClassifCosts.R @@ -45,7 +45,7 @@ MeasureClassifCosts = R6Class("MeasureClassifCosts", #' Creates a new instance of this [R6][R6::R6Class] class. initialize = function() { param_set = ps(normalize = p_lgl(tags = "required")) - param_set$values = list(normalize = TRUE) + param_set$set_values(normalize = TRUE) super$initialize( id = "classif.costs", diff --git a/R/MeasureDebug.R b/R/MeasureDebug.R index 1efb64329..8e02c0a72 100644 --- a/R/MeasureDebug.R +++ b/R/MeasureDebug.R @@ -27,7 +27,7 @@ MeasureDebugClassif = R6Class("MeasureDebugClassif", #' Creates a new instance of this [R6][R6::R6Class] class. initialize = function() { param_set = ps(na_ratio = p_dbl(0, 1, tags = "required")) - param_set$values = list(na_ratio = 0) + param_set$set_values(na_ratio = 0) super$initialize( id = "debug_classif", param_set = param_set, diff --git a/R/MeasureSelectedFeatures.R b/R/MeasureSelectedFeatures.R index db4748608..d53ead467 100644 --- a/R/MeasureSelectedFeatures.R +++ b/R/MeasureSelectedFeatures.R @@ -31,7 +31,7 @@ MeasureSelectedFeatures = R6Class("MeasureSelectedFeatures", #' Creates a new instance of this [R6][R6::R6Class] class. initialize = function() { param_set = ps(normalize = p_lgl(tags = "required")) - param_set$values = list(normalize = FALSE) + param_set$set_values(normalize = FALSE) super$initialize( id = "selected_features", diff --git a/R/Resampling.R b/R/Resampling.R index 89aa3e89b..a33544d54 100644 --- a/R/Resampling.R +++ b/R/Resampling.R @@ -56,7 +56,7 @@ #' r$param_set$values #' #' # Do only 3 repeats on 10% of the data -#' r$param_set$values = list(ratio = 0.1, repeats = 3) +#' r$param_set$set_values(ratio = 0.1, repeats = 3) #' r$param_set$values #' #' # Instantiate on penguins task diff --git a/R/ResamplingBootstrap.R b/R/ResamplingBootstrap.R index b4e75e942..54667059d 100644 --- a/R/ResamplingBootstrap.R +++ b/R/ResamplingBootstrap.R @@ -49,7 +49,7 @@ ResamplingBootstrap = R6Class("ResamplingBootstrap", inherit = Resampling, ratio = p_dbl(0, upper = 1, tags = "required"), repeats = p_int(1L, tags = "required") ) - ps$values = list(ratio = 1, repeats = 30L) + ps$set_values(ratio = 1, repeats = 30L) super$initialize(id = "bootstrap", param_set = ps, duplicated_ids = TRUE, label = "Bootstrap", man = "mlr3::mlr_resamplings_bootstrap") diff --git a/R/ResamplingCV.R b/R/ResamplingCV.R index 4b36a71b0..34a0e60fc 100644 --- a/R/ResamplingCV.R +++ b/R/ResamplingCV.R @@ -44,7 +44,7 @@ ResamplingCV = R6Class("ResamplingCV", inherit = Resampling, ps = ps( folds = p_int(2L, tags = "required") ) - ps$values = list(folds = 10L) + ps$set_values(folds = 10L) super$initialize(id = "cv", param_set = ps, label = "Cross-Validation", man = "mlr3::mlr_resamplings_cv") diff --git a/R/ResamplingHoldout.R b/R/ResamplingHoldout.R index 0a5d7ad8d..0c47a12e9 100644 --- a/R/ResamplingHoldout.R +++ b/R/ResamplingHoldout.R @@ -45,7 +45,7 @@ ResamplingHoldout = R6Class("ResamplingHoldout", inherit = Resampling, ps = ps( ratio = p_dbl(0, 1, tags = "required") ) - ps$values = list(ratio = 2 / 3) + ps$set_values(ratio = 2 / 3) super$initialize(id = "holdout", param_set = ps, label = "Holdout", man = "mlr3::mlr_resamplings_holdout") diff --git a/R/ResamplingRepeatedCV.R b/R/ResamplingRepeatedCV.R index 212106eb8..bbf3b492c 100644 --- a/R/ResamplingRepeatedCV.R +++ b/R/ResamplingRepeatedCV.R @@ -56,7 +56,7 @@ ResamplingRepeatedCV = R6Class("ResamplingRepeatedCV", inherit = Resampling, folds = p_int(2L, tags = "required"), repeats = p_int(1L) ) - ps$values = list(repeats = 10L, folds = 10L) + ps$set_values(repeats = 10L, folds = 10L) super$initialize(id = "repeated_cv", param_set = ps, label = "Repeated Cross-Validation", man = "mlr3::mlr_resamplings_repeated_cv") }, diff --git a/R/ResamplingSubsampling.R b/R/ResamplingSubsampling.R index 99a447f3c..beeae7f1b 100644 --- a/R/ResamplingSubsampling.R +++ b/R/ResamplingSubsampling.R @@ -48,7 +48,7 @@ ResamplingSubsampling = R6Class("ResamplingSubsampling", inherit = Resampling, ratio = p_dbl(0, 1, tags = "required"), repeats = p_int(1, tags = "required") ) - ps$values = list(repeats = 30L, ratio = 2 / 3) + ps$set_values(repeats = 30L, ratio = 2 / 3) super$initialize(id = "subsampling", param_set = ps, label = "Subsampling", man = "mlr3::mlr_resamplings_subsampling") diff --git a/R/TaskGeneratorMoons.R b/R/TaskGeneratorMoons.R index e2036d2b6..7a1b7be7a 100644 --- a/R/TaskGeneratorMoons.R +++ b/R/TaskGeneratorMoons.R @@ -27,7 +27,7 @@ TaskGeneratorMoons = R6Class("TaskGeneratorMoons", ps = ps( sigma = p_dbl(0, tags = "required") ) - ps$values = list(sigma = 1) + ps$set_values(sigma = 1) super$initialize(id = "moons", task_type = "classif", param_set = ps, label = "Moons Classification", man = "mlr3::mlr_task_generators_moons") diff --git a/man/Learner.Rd b/man/Learner.Rd index 70d7cf8f9..817222e0f 100644 --- a/man/Learner.Rd +++ b/man/Learner.Rd @@ -61,10 +61,10 @@ If the learner is not trained yet, this returns \code{NULL}. All information about hyperparameters is stored in the slot \code{param_set} which is a \link[paradox:ParamSet]{paradox::ParamSet}. The printer gives an overview about the ids of available hyperparameters, their storage type, lower and upper bounds, possible levels (for factors), default values and assigned values. -To set hyperparameters, assign a named list to the subslot \code{values}: +To set hyperparameters, call the \code{set_values()} method on the \code{param_set}: \if{html}{\out{
}}\preformatted{lrn = lrn("classif.rpart") -lrn$param_set$values = list(minsplit = 3, cp = 0.01) +lrn$param_set$set_values(minsplit = 3, cp = 0.01) }\if{html}{\out{
}} Note that this operation replaces all previously set hyperparameter values. diff --git a/man/Resampling.Rd b/man/Resampling.Rd index 26692dd46..c2e62bb05 100644 --- a/man/Resampling.Rd +++ b/man/Resampling.Rd @@ -53,7 +53,7 @@ r = rsmp("subsampling") r$param_set$values # Do only 3 repeats on 10\% of the data -r$param_set$values = list(ratio = 0.1, repeats = 3) +r$param_set$set_values(ratio = 0.1, repeats = 3) r$param_set$values # Instantiate on penguins task diff --git a/man/mlr_learners_classif.debug.Rd b/man/mlr_learners_classif.debug.Rd index a6e5ca749..e8c89c46a 100644 --- a/man/mlr_learners_classif.debug.Rd +++ b/man/mlr_learners_classif.debug.Rd @@ -77,7 +77,7 @@ lrn("classif.debug") \examples{ learner = lrn("classif.debug") -learner$param_set$values = list(message_train = 1, save_tasks = TRUE) +learner$param_set$set_values(message_train = 1, save_tasks = TRUE) # this should signal a message task = tsk("penguins")