Skip to content

Commit

Permalink
some changes for cmdcheck
Browse files Browse the repository at this point in the history
  • Loading branch information
nicholasjclark committed Nov 6, 2023
1 parent 0e6be27 commit 798eaa4
Show file tree
Hide file tree
Showing 9 changed files with 85 additions and 20 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ importFrom(magrittr,"%>%")
importFrom(marginaleffects,get_coef)
importFrom(marginaleffects,get_predict)
importFrom(marginaleffects,get_vcov)
importFrom(marginaleffects,plot_predictions)
importFrom(marginaleffects,set_coef)
importFrom(methods,new)
importFrom(mgcv,bam)
Expand Down
6 changes: 4 additions & 2 deletions R/conditional_effects.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#' Display conditional effects of one or more numeric and/or categorical
#' predictors in `mvgam` models, including two-way interaction effects.
#' @importFrom brms conditional_effects
#' @importFrom marginaleffects plot_predictions
#' @importFrom graphics plot
#' @importFrom grDevices devAskNewPage
#' @inheritParams brms::conditional_effects.brmsfit
Expand All @@ -17,7 +18,7 @@
#' @param ... other arguments to pass to \code{\link[marginaleffects]{plot_predictions}}
#' @return `conditional_effects` returns an object of class
#' \code{mvgam_conditional_effects} which is a
#' named list with one slot per effect containing a \code{\link{ggplot}} object,
#' named list with one slot per effect containing a \code{\link[ggplot2]{ggplot}} object,
#' which can be further customized using the \pkg{ggplot2} package.
#' The corresponding `plot` method will draw these plots in the active graphic device
#'
Expand Down Expand Up @@ -152,7 +153,8 @@ conditional_effects.mvgam = function(x,
plot.mvgam_conditional_effects = function(x,
plot = TRUE,
ask = TRUE,
theme = NULL){
theme = NULL,
...){
out <- x
for(i in seq_along(out)){
if(plot){
Expand Down
66 changes: 60 additions & 6 deletions R/gp.R
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,60 @@ sim_hilbert_gp = function(alpha_gp,
as.vector((diag_SPD * b_gp) %*% t(eigenfunctions))
}

#' @noRd
seq_cols <- function(x) {
seq_len(NCOL(x))
}

#' compute the mth eigen function of an approximate GP
#' Credit to Paul Burkner from brms: https://github.com/paul-buerkner/brms/R/formula-gp.R#L289
#' @noRd
eigen_fun_cov_exp_quad <- function(x, m, L) {
x <- as.matrix(x)
D <- ncol(x)
stopifnot(length(m) == D, length(L) == D)
out <- vector("list", D)
for (i in seq_cols(x)) {
out[[i]] <- 1 / sqrt(L[i]) *
sin((m[i] * pi) / (2 * L[i]) * (x[, i] + L[i]))
}
Reduce("*", out)
}

#' compute squared differences
#' Credit to Paul Burkner from brms: https://github.com/paul-buerkner/brms/R/formula-gp.R#L241
#' @param x vector or matrix
#' @param x_new optional vector of matrix with the same ncol as x
#' @return an nrow(x) times nrow(x_new) matrix
#' @details if matrices are passed results are summed over the columns
#' @noRd
diff_quad <- function(x, x_new = NULL) {
x <- as.matrix(x)
if (is.null(x_new)) {
x_new <- x
} else {
x_new <- as.matrix(x_new)
}
.diff_quad <- function(x1, x2) (x1 - x2)^2
out <- 0
for (i in seq_cols(x)) {
out <- out + outer(x[, i], x_new[, i], .diff_quad)
}
out
}

#' extended range of input data for which predictions should be made
#' Credit to Paul Burkner from brms: https://github.com/paul-buerkner/brms/R/formula-gp.R#L301
#' @noRd
choose_L <- function(x, c) {
if (!length(x)) {
range <- 1
} else {
range <- max(1, max(x, na.rm = TRUE) - min(x, na.rm = TRUE))
}
c * range
}

#' Mean-center and scale the particular covariate of interest
#' so that the maximum Euclidean distance between any two points is 1
#' @noRd
Expand All @@ -330,7 +384,7 @@ scale_cov <- function(data, covariate, by, level,

# Compute max Euclidean distance if not supplied
if(is.na(max_dist)){
Xgp_max_dist <- sqrt(max(brms:::diff_quad(Xgp)))
Xgp_max_dist <- sqrt(max(diff_quad(Xgp)))
} else {
Xgp_max_dist <- max_dist
}
Expand Down Expand Up @@ -381,13 +435,13 @@ prep_eigenfunctions = function(data,
eigenfunctions <- matrix(NA, nrow = length(covariate_cent),
ncol = k)
if(missing(L)){
L <- brms:::choose_L(covariate_cent, boundary)
L <- choose_L(covariate_cent, boundary)
}

for(m in 1:k){
eigenfunctions[, m] <- brms:::eigen_fun_cov_exp_quad(x = matrix(covariate_cent),
m = m,
L = L)
eigenfunctions[, m] <- eigen_fun_cov_exp_quad(x = matrix(covariate_cent),
m = m,
L = L)
}

# Multiply eigenfunctions by the 'by' variable if one is supplied
Expand Down Expand Up @@ -473,7 +527,7 @@ prep_gp_covariate = function(data,
# same eigenvalues are always used in prediction, so we only need to
# create them when prepping the data. They will need to be included in
# the Stan data list
L <- brms:::choose_L(covariate_cent, boundary)
L <- choose_L(covariate_cent, boundary)
eigenvalues <- vector()
for(m in 1:k){
eigenvalues[m] <- sqrt(lambda(boundary = L,
Expand Down
10 changes: 9 additions & 1 deletion R/stan_utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -3224,12 +3224,20 @@ repair_stanfit <- function(x) {
read_csv_as_stanfit <- function(files, variables = NULL,
sampler_diagnostics = NULL) {

# Code borrowed from brms: https://github.com/paul-buerkner/brms/R/backends.R#L603
repair_names <- function(x) {
x <- sub("\\.", "[", x)
x <- gsub("\\.", ",", x)
x[grep("\\[", x)] <- paste0(x[grep("\\[", x)], "]")
x
}

if(!is.null(variables)){
# ensure that only relevant variables are read from CSV
metadata <- cmdstanr::read_cmdstan_csv(
files = files, variables = "", sampler_diagnostics = "")

all_vars <- brms:::repair_variable_names(metadata$metadata$variables)
all_vars <- repair_names(metadata$metadata$variables)
all_vars <- unique(sub("\\[.+", "", all_vars))
variables <- variables[variables %in% all_vars]
}
Expand Down
Binary file modified R/sysdata.rda
Binary file not shown.
4 changes: 2 additions & 2 deletions man/conditional_effects.mvgam.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Binary file modified src/mvgam.dll
Binary file not shown.
18 changes: 9 additions & 9 deletions tests/mvgam_examples.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,44 +3,44 @@
library(mvgam)
set.seed(1234)
mvgam_examp_dat <- sim_mvgam(family = gaussian(),
T = 50)
T = 40)

# Univariate process without trend_formula
mvgam_example1 <- mvgam(y ~ s(season, k = 7),
mvgam_example1 <- mvgam(y ~ s(season, k = 5),
trend_model = 'RW',
family = gaussian(),
data = mvgam_examp_dat$data_train,
burnin = 300,
samples = 40,
samples = 30,
chains = 1)

# Univariate process with trend_formula and correlated process errors
mvgam_example2 <- mvgam(y ~ 1,
trend_formula = ~ s(season, k = 7),
trend_formula = ~ s(season, k = 5),
trend_model = RW(cor = TRUE),
family = gaussian(),
data = mvgam_examp_dat$data_train,
burnin = 300,
samples = 40,
samples = 30,
chains = 1)

# Multivariate process without trend_formula
mvgam_example3 <- mvgam(y ~ s(season, k = 7),
mvgam_example3 <- mvgam(y ~ s(season, k = 5),
trend_model = 'VAR1cor',
family = gaussian(),
data = mvgam_examp_dat$data_train,
burnin = 300,
samples = 40,
samples = 30,
chains = 1)

# Multivariate process with trend_formula and moving averages
mvgam_example4 <- mvgam(y ~ 1,
trend_formula = ~ s(season, k = 7),
trend_formula = ~ s(season, k = 5),
trend_model = VAR(ma = TRUE, cor = TRUE),
family = gaussian(),
data = mvgam_examp_dat$data_train,
burnin = 300,
samples = 40,
samples = 30,
chains = 1)

# Save examples as internal data
Expand Down
Binary file modified tests/testthat/Rplots.pdf
Binary file not shown.

0 comments on commit 798eaa4

Please sign in to comment.