Skip to content

Commit

Permalink
feat: add rosranger, smote, cv, rcv, gridsearch
Browse files Browse the repository at this point in the history
  • Loading branch information
sgibb committed Apr 17, 2022
1 parent 6e1ee26 commit 975a667
Show file tree
Hide file tree
Showing 21 changed files with 804 additions and 78 deletions.
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
Package: rusranger
Title:
Modified ranger implementation to support random-under-sampling
Version: 0.0.2
Date: 2022-03-24
Version: 0.0.3
Date: 2022-04-17
Description:
The random forest implementation of the ranger package is modified to
support random-under-sampling. Additional helper functions for
Expand Down
8 changes: 8 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
# Generated by roxygen2: do not edit by hand

export(cv)
export(cv_rusranger)
export(gridsearch)
export(gs_rusranger)
export(nested_gridsearch)
export(nrcv_rusranger)
export(rcv)
export(rcv_rusranger)
export(rosranger)
export(rusranger)
export(smote)
import(future)
import(ranger)
importFrom(ROCR,performance)
importFrom(ROCR,prediction)
importFrom(future.apply,future_lapply)
importFrom(stats,dist)
importFrom(stats,median)
importFrom(stats,predict)
importFrom(stats,quantile)
importFrom(stats,runif)
importFrom(stats,setNames)
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# rusranger 0.0

## Changes in 0.0.3

- Add `rosranger`, `smote`, `cv`, `rcv`, `gridsearch`.

## Changes in 0.0.2

- Fix `.caseweights` for 0/1 binary class coding (instead of 1/2).
Expand Down
75 changes: 75 additions & 0 deletions R/cv.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#' Cross Validation
#'
#' Runs a cross validation for train and prediction function.
#'
#' @inheritParams rusranger
#' @param FUN `function` function to optimize.
#' @param nfolds `integer(1)` number of cross validation folds.
#' @param \ldots further arguments passed to `FUN`.
#'
#' @return `double(1)` median AUC across all cross validation splits
#'
#' @note
#' The function to optimize has to accept five arguments: xtrain, ytrain, xtest,
#' ytest and \ldots.
#'
#' @import future
#' @importFrom future.apply future_lapply
#' @importFrom ROCR performance prediction
#' @export
#' @examples
#' .rusranger <- function(xtrain, ytrain, xtest, ytest, ...) {
#' rngr <- rusranger(x = xtrain, y = ytrain, ...)
#' pred <- as.numeric(predict(rngr, xtest)$predictions[, 2L])
#' performance(prediction(pred, ytest), measure = "auc")@y.values[[1L]]
#' }
#' cv(iris[-5], as.numeric(iris$Species == "versicolor"), .rusranger, nfolds = 3)
cv <- function(x, y, FUN, nfolds = 5, ...) {
folds <- .bfolds(y, nfolds = nfolds)
xl <- split(x, folds)
yl <- split(y, folds)
r <- unlist(future.apply::future_lapply(
seq_len(nfolds),
function(i) {
xtrain <- do.call(rbind, xl[-i])
xtest <- xl[[i]]
ytrain <- do.call(c, yl[-i])
ytest <- yl[[i]]
do.call(
FUN,
list(
xtrain = xtrain, ytrain = ytrain,
xtest = xtest, ytest = ytest,
...
)
)
},
future.seed = TRUE
))
median(r)
}

#' Repeated Cross Validation
#'
#' Runs a repeated cross validation for a function.
#' See also [`cv()`].
#'
#' @inheritParams cv
#' @param nrepcv `integer(1)` number of repeats.
#' @param \ldots further arguments passed to `FUN`.
#'
#' @return `double(5)`, minimal, 25 % quartiel, median, 75 % quartile and
#' maximal results across the repeated cross validations.
#' @importFrom stats median predict quantile setNames
#' @export
rcv <- function(x, y, nfolds = 5, nrepcv = 2, FUN, ...) {
FUN <- match.fun(FUN)
r <- unlist(future.apply::future_lapply(
seq_len(nrepcv), function(i)
cv(x = x, y = y, nfolds = nfolds, FUN = FUN, ...),
future.seed = TRUE
))
setNames(
quantile(r, names = FALSE), c("Min", "Q1", "Median", "Q3", "Max")
)
}
171 changes: 148 additions & 23 deletions R/cv_rusranger.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,7 @@
#' @importFrom ROCR performance prediction
#' @export
cv_rusranger <- function(x, y, nfolds = 5, ...) {
folds <- .bfolds(y, nfolds = nfolds)
xl <- split(x, folds)
yl <- split(y, folds)
auc <- unlist(future.apply::future_lapply(
seq_len(nfolds),
function(i) {
xtrain <- do.call(rbind, xl[-i])
xtest <- xl[[i]]
ytrain <- do.call(c, yl[-i])
ytest <- yl[[i]]
rngr <- rusranger(x = xtrain, y = ytrain, ...)
pred <- as.numeric(predict(rngr, xtest)$predictions[, 2L])
performance(prediction(pred, ytest), measure = "auc")@y.values[[1L]]
},
future.seed = TRUE
))
median(auc)
cv(x = x, y = y, FUN = .rusranger, nfolds = nfolds, ...)
}

#' Repeated Cross Validation for rusranger
Expand All @@ -53,12 +37,153 @@ cv_rusranger <- function(x, y, nfolds = 5, ...) {
#' nfolds = 3, nrepcv = 3
#' )
rcv_rusranger <- function(x, y, nfolds = 5, nrepcv = 2, ...) {
auc <- unlist(future.apply::future_lapply(
seq_len(nrepcv),
function(i)cv_rusranger(x = x, y = y, nfolds = nfolds, ...),
rcv(x = x, y = y, nfolds = nfolds, nrepcv = nrepcv, FUN = .rusranger, ...)
}

#' Grid Search
#'
#' Grid search to optimise hyperparameters for `rusranger()`
#'
#' @inheritParams rusranger
#' @inheritParams rcv_rusranger
#' @param searchspace `data.frame`, hyperparameters to tune. Column names have
#' to match the argument names of [`ranger()`]/[`rusranger()`].
#' @param \ldots further arguments passed to [`rcv_rusranger()`].
#' @return `data.frame` with tested hyperparameters and AUCs
#' @export
#' @examples
#' iris <- subset(iris, Species != "setosa")
#' searchspace <- expand.grid(
#' mtry = c(2, 3),
#' num.trees = c(500, 1000)
#' )
#' ## nfolds and nrepcv are too low for real world applications, and are just
#' ## used for demonstration and to keep the run time of the examples low
#' gs_rusranger(
#' iris[-5], as.numeric(iris$Species == "versicolor"),
#' searchspace = searchspace, nfolds = 3, nrepcv = 1
#' )
gs_rusranger <- function(x, y, searchspace, nfolds = 5, nrepcv = 2, ...) {
gridsearch(
x = x, y = y, searchspace = searchspace,
FUN = .rusranger, nfolds = nfolds, nrepcv = nrepcv, ...
)
}

#' helper function to provide cv_rusranger/rcv_rusranger for backwards
#' compatibility
#'
#' @noRd
.rusranger <- function(xtrain, ytrain, xtest, ytest, ...) {
rngr <- rusranger(x = xtrain, y = ytrain, ...)
pred <- as.numeric(predict(rngr, xtest)$predictions[, 2L])
performance(prediction(pred, ytest), measure = "auc")@y.values[[1L]]
}

#' Nested Cross Validation for Hyperparameter Search
#'
#' Run a grid search in a nested cross validation.
#'
#' @note
#' The reported performance could slightly differ from the median performance
#' in the reported gridsearch. After the gridsearch `rusranger` is trained again
#' with the best hyperparameters which results in a new subsampling.
#'
#' @inheritParams gs_rusranger
#' @param nouterfolds `integer(1)`, number of outer cross validation folds.
#' @param ninnerfolds `integer(1)`, number of inner cross validation folds.
#' @param nrepcv `integer(1)`, number repeats of inner cross validations.
#' @param \ldots further arguments passed to [`gs_rusranger()`].
#' @return `list`, with an element per `nouterfolds` containing the following
#' subelements:
#' * model selected `ranger` model.
#' * indextrain index of the used training items.
#' * indextest index of the used test items.
#' * prediction predictions results.
#' * truth original labels/classes.
#' * performance resulting performance (AUC).
#' * selectedparams select hyperparameters.
#' * gridsearch `data.frame`, results of the grid search.
#' * nouterfolds `integer(1)`.
#' * ninnerfolds `integer(1)`.
#' * nrepcv `integer(1)`.
#' @export
#' @examples
#' set.seed(20220324)
#' iris <- subset(iris, Species != "setosa")
#' searchspace <- expand.grid(
#' mtry = c(2, 3),
#' num.trees = c(500, 1000)
#' )
#' ## n(outer|inner) folds and nrepcv are too low for real world applications,
#' ## and are just used for demonstration and to keep the run time of the examples
#' ## low
#' nrcv_rusranger(
#' iris[-5], as.numeric(iris$Species == "versicolor"),
#' searchspace = searchspace, nouterfolds = 3, ninnerfolds = 3, nrepcv = 1
#' )
nrcv_rusranger <- function(x, y, searchspace,
nouterfolds = 5, ninnerfolds = 5, nrepcv = 2,
...) {

folds <- .bfolds(y, nfolds = nouterfolds)
xl <- split(x, folds)
yl <- split(y, folds)
indices <- split(seq_along(y), folds)

nrcv <- future.apply::future_lapply(
seq_len(nouterfolds),
function(i) {
xtrain <- do.call(rbind, xl[-i])
xtest <- xl[[i]]
ytrain <- do.call(c, yl[-i])
ytest <- yl[[i]]

gs <- gs_rusranger(
xtrain, ytrain, searchspace,
nfolds = ninnerfolds, nrepcv = nrepcv, ...
)

top <- which.max(gs$Median)
selparms <-
gs[top,
!colnames(gs) %in% c("Min", "Q1", "Median", "Q3", "Max"),
drop = FALSE
]

## additional call of an already calculated tree, ...
## could be avoided if we would store the results of the trees
## but this would take alot of memory
## this could slightly change the results because of new
## resampling
rngr <- do.call(
rusranger,
c(
list(
x = xtrain,
y = ytrain
), list(...), selparms
)
)
pred <- as.numeric(predict(rngr, xtest)$predictions[, 2L])

list(
model = rngr,
indextrain = unlist(indices[-i]),
indextest = unlist(indices[i]),
prediction = pred,
truth = ytest,
performance = performance(
prediction(pred, ytest), measure = "auc"
)@y.values[[1L]],
selectedparams = selparms,
gridsearch = gs,
nouterfolds = nouterfolds,
ninnerfolds = ninnerfolds,
nrepcv = nrepcv
)
},
future.seed = TRUE
))
setNames(
quantile(auc, names = FALSE), c("Min", "Q1", "Median", "Q3", "Max")
)
nrcv
}
Loading

0 comments on commit 975a667

Please sign in to comment.