-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add rosranger, smote, cv, rcv, gridsearch
- Loading branch information
Showing
21 changed files
with
804 additions
and
78 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,24 @@ | ||
# Generated by roxygen2: do not edit by hand | ||
|
||
export(cv) | ||
export(cv_rusranger) | ||
export(gridsearch) | ||
export(gs_rusranger) | ||
export(nested_gridsearch) | ||
export(nrcv_rusranger) | ||
export(rcv) | ||
export(rcv_rusranger) | ||
export(rosranger) | ||
export(rusranger) | ||
export(smote) | ||
import(future) | ||
import(ranger) | ||
importFrom(ROCR,performance) | ||
importFrom(ROCR,prediction) | ||
importFrom(future.apply,future_lapply) | ||
importFrom(stats,dist) | ||
importFrom(stats,median) | ||
importFrom(stats,predict) | ||
importFrom(stats,quantile) | ||
importFrom(stats,runif) | ||
importFrom(stats,setNames) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
#' Cross Validation | ||
#' | ||
#' Runs a cross validation for train and prediction function. | ||
#' | ||
#' @inheritParams rusranger | ||
#' @param FUN `function` function to optimize. | ||
#' @param nfolds `integer(1)` number of cross validation folds. | ||
#' @param \ldots further arguments passed to `FUN`. | ||
#' | ||
#' @return `double(1)` median AUC across all cross validation splits | ||
#' | ||
#' @note | ||
#' The function to optimize has to accept five arguments: xtrain, ytrain, xtest, | ||
#' ytest and \ldots. | ||
#' | ||
#' @import future | ||
#' @importFrom future.apply future_lapply | ||
#' @importFrom ROCR performance prediction | ||
#' @export | ||
#' @examples | ||
#' .rusranger <- function(xtrain, ytrain, xtest, ytest, ...) { | ||
#' rngr <- rusranger(x = xtrain, y = ytrain, ...) | ||
#' pred <- as.numeric(predict(rngr, xtest)$predictions[, 2L]) | ||
#' performance(prediction(pred, ytest), measure = "auc")@y.values[[1L]] | ||
#' } | ||
#' cv(iris[-5], as.numeric(iris$Species == "versicolor"), .rusranger, nfolds = 3) | ||
cv <- function(x, y, FUN, nfolds = 5, ...) { | ||
folds <- .bfolds(y, nfolds = nfolds) | ||
xl <- split(x, folds) | ||
yl <- split(y, folds) | ||
r <- unlist(future.apply::future_lapply( | ||
seq_len(nfolds), | ||
function(i) { | ||
xtrain <- do.call(rbind, xl[-i]) | ||
xtest <- xl[[i]] | ||
ytrain <- do.call(c, yl[-i]) | ||
ytest <- yl[[i]] | ||
do.call( | ||
FUN, | ||
list( | ||
xtrain = xtrain, ytrain = ytrain, | ||
xtest = xtest, ytest = ytest, | ||
... | ||
) | ||
) | ||
}, | ||
future.seed = TRUE | ||
)) | ||
median(r) | ||
} | ||
|
||
#' Repeated Cross Validation | ||
#' | ||
#' Runs a repeated cross validation for a function. | ||
#' See also [`cv()`]. | ||
#' | ||
#' @inheritParams cv | ||
#' @param nrepcv `integer(1)` number of repeats. | ||
#' @param \ldots further arguments passed to `FUN`. | ||
#' | ||
#' @return `double(5)`, minimal, 25 % quartiel, median, 75 % quartile and | ||
#' maximal results across the repeated cross validations. | ||
#' @importFrom stats median predict quantile setNames | ||
#' @export | ||
rcv <- function(x, y, nfolds = 5, nrepcv = 2, FUN, ...) { | ||
FUN <- match.fun(FUN) | ||
r <- unlist(future.apply::future_lapply( | ||
seq_len(nrepcv), function(i) | ||
cv(x = x, y = y, nfolds = nfolds, FUN = FUN, ...), | ||
future.seed = TRUE | ||
)) | ||
setNames( | ||
quantile(r, names = FALSE), c("Min", "Q1", "Median", "Q3", "Max") | ||
) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.