Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: improve docs for converters and better checks #1231

Merged
merged 3 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ S3method(as_resampling,Resampling)
S3method(as_resamplings,default)
S3method(as_resamplings,list)
S3method(as_task,Task)
S3method(as_task,default)
S3method(as_task_classif,DataBackend)
S3method(as_task_classif,Matrix)
S3method(as_task_classif,TaskClassif)
Expand Down Expand Up @@ -200,6 +201,7 @@ export(as_tasks)
export(as_tasks_unsupervised)
export(assert_backend)
export(assert_benchmark_result)
export(assert_empty_ellipsis)
export(assert_learnable)
export(assert_learner)
export(assert_learners)
Expand Down
18 changes: 15 additions & 3 deletions R/BenchmarkResult.R
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,11 @@ BenchmarkResult = R6Class("BenchmarkResult",
#'
#' @return [data.table::data.table()].
score = function(measures = NULL, ids = TRUE, conditions = FALSE, predictions = TRUE) {
measures = as_measures(measures, task_type = self$task_type)
measures = if (is.null(measures)) {
default_measures(self$task_type)
} else {
assert_measures(as_measures(measures))
}
assert_flag(ids)
assert_flag(conditions)
assert_flag(predictions)
Expand Down Expand Up @@ -230,7 +234,11 @@ BenchmarkResult = R6Class("BenchmarkResult",
#' @param predict_sets (`character()`)\cr
#' The predict sets.
obs_loss = function(measures = NULL, predict_sets = "test") {
measures = as_measures(measures, task_type = private$.data$task_type)
measures = if (is.null(measures)) {
default_measures(self$task_type)
} else {
assert_measures(as_measures(measures))
}
map_dtr(self$resample_results$resample_result,
function(rr) {
rr$obs_loss(measures, predict_sets)
Expand Down Expand Up @@ -276,7 +284,11 @@ BenchmarkResult = R6Class("BenchmarkResult",
#'
#' @return [data.table::data.table()].
aggregate = function(measures = NULL, ids = TRUE, uhashes = FALSE, params = FALSE, conditions = FALSE) {
measures = assert_measures(as_measures(measures, task_type = self$task_type))
measures = if (is.null(measures)) {
default_measures(self$task_type)
} else {
assert_measures(as_measures(measures))
}
assert_flag(ids)
assert_flag(uhashes)
assert_flag(params)
Expand Down
12 changes: 10 additions & 2 deletions R/Prediction.R
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,11 @@ Prediction = R6Class("Prediction",
#'
#' @return [Prediction].
score = function(measures = NULL, task = NULL, learner = NULL, train_set = NULL) {
measures = as_measures(measures, task_type = self$task_type)
measures = if (is.null(measures)) {
default_measures(self$task_type)
} else {
assert_measures(as_measures(measures))
}
scores = map_dbl(measures, function(m) m$score(prediction = self, task = task, learner = learner, train_set = train_set))
set_names(scores, ids(measures))
},
Expand All @@ -105,7 +109,11 @@ Prediction = R6Class("Prediction",
#' Note that some measures such as RMSE, do have an `$obs_loss`, but they require an
#' additional transformation after aggregation, in this example taking the square-root.
obs_loss = function(measures = NULL) {
measures = as_measures(measures, task_type = self$task_type)
measures = if (is.null(measures)) {
default_measures(self$task_type)
} else {
assert_measures(as_measures(measures))
}
get_obs_loss(as.data.table(self), measures)
},

Expand Down
18 changes: 15 additions & 3 deletions R/ResampleResult.R
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,11 @@ ResampleResult = R6Class("ResampleResult",
#'
#' @return [data.table::data.table()].
score = function(measures = NULL, ids = TRUE, conditions = FALSE, predictions = TRUE) {
measures = as_measures(measures, task_type = private$.data$task_type)
measures = if (is.null(measures)) {
default_measures(self$task_type)
} else {
assert_measures(as_measures(measures))
}
assert_flag(ids)
assert_flag(conditions)
assert_flag(predictions)
Expand Down Expand Up @@ -196,7 +200,11 @@ ResampleResult = R6Class("ResampleResult",
#' @param predict_sets (`character()`)\cr
#' The predict sets.
obs_loss = function(measures = NULL, predict_sets = "test") {
measures = as_measures(measures, task_type = self$task_type)
measures = if (is.null(measures)) {
default_measures(self$task_type)
} else {
assert_measures(as_measures(measures))
}
tab = map_dtr(self$predictions(predict_sets), as.data.table, .idcol = "iteration")
get_obs_loss(tab, measures)
},
Expand All @@ -208,7 +216,11 @@ ResampleResult = R6Class("ResampleResult",
#'
#' @return Named `numeric()`.
aggregate = function(measures = NULL) {
measures = as_measures(measures, task_type = private$.data$task_type)
measures = if (is.null(measures)) {
default_measures(self$task_type)
} else {
assert_measures(as_measures(measures))
}
resample_result_aggregate(self, measures)
},

Expand Down
1 change: 1 addition & 0 deletions R/as_learner.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ as_learner = function(x, ...) { # nolint
#' Whether to discard the state.
#' @rdname as_learner
as_learner.Learner = function(x, clone = FALSE, discard_state = FALSE, ...) { # nolint
assert_empty_ellipsis(...)
if (isTRUE(clone) && isTRUE(discard_state)) {
clone_without(x, "state")
} else if (isTRUE(clone)) {
Expand Down
2 changes: 2 additions & 0 deletions R/as_measure.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@ as_measure = function(x, ...) { # nolint
#' @export
#' @rdname as_measure
as_measure.NULL = function(x, task_type = NULL, ...) { # nolint
assert_empty_ellipsis(...)
default_measures(task_type)[[1L]]
}

#' @export
#' @rdname as_measure
as_measure.Measure = function(x, clone = FALSE, ...) { # nolint
assert_empty_ellipsis(...)
if (isTRUE(clone)) x$clone() else x
}

Expand Down
3 changes: 2 additions & 1 deletion R/as_resampling.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#'
#' @description
#' Convert object to a [Resampling] or a list of [Resampling].
#'
#' This method e.g. allows to convert an [`mlr3oml::OMLTask`] to a [`Resampling`].
#' @inheritParams as_task
#' @export
as_resampling = function(x, ...) { # nolint
Expand All @@ -12,6 +12,7 @@ as_resampling = function(x, ...) { # nolint
#' @export
#' @rdname as_resampling
as_resampling.Resampling = function(x, clone = FALSE, ...) { # nolint
assert_empty_ellipsis(...)
if (isTRUE(clone)) x$clone() else x
}

Expand Down
8 changes: 8 additions & 0 deletions R/as_task.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#'
#' @description
#' Convert object to a [Task] or a list of [Task].
#' This method e.g. allows to convert an [`mlr3oml::OMLTask`] to a [`Task`] and additionally supports cloning.
#' In order to construct a [Task] from a `data.frame`, use task-specific converters such as [`as_task_classif()`] or [`as_task_regr()`].
#'
#' @param x (any)\cr
#' Object to convert.
Expand All @@ -12,11 +14,17 @@ as_task = function(x, ...) {
UseMethod("as_task")
}

#' @export
as_task.default = function(x, ...) {
stopf("No method for class '%s'. To create a task from a `data.frame`, use dedicated converters such as `as_task_classif()` or `as_task_regr()`.", class(x)[1L])
}

#' @rdname as_task
#' @param clone (`logical(1)`)\cr
#' If `TRUE`, ensures that the returned object is not the same as the input `x`.
#' @export
as_task.Task = function(x, clone = FALSE, ...) { # nolint
assert_empty_ellipsis(...)
if (isTRUE(clone)) x$clone(deep = TRUE) else x
}

Expand Down
2 changes: 1 addition & 1 deletion R/as_task_classif.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#' Convert object to a [TaskClassif].
#' This is a S3 generic. mlr3 ships with methods for the following objects:
#'
#' 1. [TaskClassif]: ensure the identity
#' 1. [TaskClassif]: returns the object as-is, possibly cloned.
#' 2. [`formula`], [data.frame()], [matrix()], [Matrix::Matrix()] and [DataBackend]: provides an alternative to the constructor of [TaskClassif].
#' 3. [TaskRegr]: Calls [convert_task()].
#'
Expand Down
2 changes: 1 addition & 1 deletion R/as_task_regr.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#' Convert object to a [TaskRegr].
#' This is a S3 generic. mlr3 ships with methods for the following objects:
#'
#' 1. [TaskRegr]: ensure the identity
#' 1. [TaskRegr]: returns the object as-is, possibly cloned.
#' 2. [`formula`], [data.frame()], [matrix()], [Matrix::Matrix()] and [DataBackend]: provides an alternative to the constructor of [TaskRegr].
#' 3. [TaskClassif]: Calls [convert_task()].
#'
Expand Down
26 changes: 26 additions & 0 deletions R/assertions.R
Original file line number Diff line number Diff line change
Expand Up @@ -405,3 +405,29 @@ assert_param_values = function(x, n_learners = NULL, .var.name = vname(x)) {
}
invisible(x)
}

#' @title Assert Empty Ellipsis
#' @description
#' Assert that `...` arguments are empty.
#' Use this function in S3-methods to ensure that misspelling of arguments does not go unnoticed.
#' @param ... (any)\cr
#' Ellipsis arguments to check.
#' @keywords internal
#' @return `NULL`
#' @export
assert_empty_ellipsis = function(...) {
if (...length()) {
names = ...names()
if (is.null(names)) {
stopf("Received %i unnamed argument that was not used.", ...length())
} else {
names2 = names[names != ""]
if (length(names2) == length(names)) {
stopf("Received the following named arguments that were unused: %s.", paste0(names2, collapse = ", "))
} else {
stopf("Received unused arguments: %i unnamed, as well as named arguments %s.", length(names) - length(names2), paste0(names2, collapse = ", "))
}
}
}
NULL
}
1 change: 1 addition & 0 deletions man/as_resampling.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions man/as_task.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/as_task_classif.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/as_task_regr.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 20 additions & 0 deletions man/assert_empty_ellipsis.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 1 addition & 2 deletions tests/testthat/test_Learner.R
Original file line number Diff line number Diff line change
Expand Up @@ -549,8 +549,7 @@ test_that("learner state contains internal valid task information", {
test_that("validation task with 0 observations", {
learner = lrn("classif.debug", validate = "predefined")
task = tsk("iris")
task$internal_valid_task = integer(0)
expect_error({learner$train(task)}, "has 0 observations")
expect_warning({task$internal_valid_task = integer(0)})
})

test_that("column info is compared during predict", {
Expand Down
4 changes: 4 additions & 0 deletions tests/testthat/test_as_learner.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,7 @@ test_that("discard_state", {
as_learner(learner3, clone = FALSE, discard_state = TRUE)
expect_null(learner3$state)
})

test_that("error when arguments are misspelled", {
expect_error(as_learner(lrn("classif.rpart"), clone2 = TRUE), "Received the following")
})
4 changes: 4 additions & 0 deletions tests/testthat/test_as_measure.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,7 @@ test_that("as_measure conversion", {
default = as_measures(NULL, task_type = "classif")
expect_list(default, types = "Measure")
})

test_that("error when arguments are misspelled", {
expect_error(as_measure(msr("classif.acc"), clone2 = TRUE), "Received the following")
})
4 changes: 4 additions & 0 deletions tests/testthat/test_as_resampling.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,7 @@ test_that("as_resampling conversion", {
expect_list(as_resamplings(resampling), types = "Resampling")
expect_list(as_resamplings(list(resampling)), types = "Resampling")
})

test_that("error when arguments are misspelled", {
expect_error(as_resampling(rsmp("holdout"), clone2 = TRUE), "Received the following")
})
4 changes: 4 additions & 0 deletions tests/testthat/test_as_task.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,7 @@ test_that("as_task_xx error messages (#944)", {
"subset of"
)
})

test_that("error when arguments are misspelled", {
expect_error(as_task(tsk("iris"), clone2 = TRUE), "Received the following")
})
8 changes: 8 additions & 0 deletions tests/testthat/test_assertions.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
test_that("assert_empty_ellipsis works", {
expect_error(assert_empty_ellipsis(1), "Received 1 unnamed argument")
expect_error(assert_empty_ellipsis(1, 2), "Received 2 unnamed argument")
expect_error(assert_empty_ellipsis(a = 1), "that were unused: a")
expect_error(assert_empty_ellipsis(a = 1, b = 2), "that were unused: a, b")
expect_error(assert_empty_ellipsis(a = 1, b = 1, 2), "1 unnamed, as well as named arguments a, b")
expect_null(assert_empty_ellipsis())
})
Loading