diff --git a/NAMESPACE b/NAMESPACE
index cf0c0208..648f5651 100644
--- a/NAMESPACE
+++ b/NAMESPACE
@@ -44,6 +44,7 @@ S3method(plot,mvgam_lfo)
S3method(posterior_epred,mvgam)
S3method(posterior_linpred,mvgam)
S3method(posterior_predict,mvgam)
+S3method(pp_check,mvgam)
S3method(ppc,mvgam)
S3method(predict,mvgam)
S3method(print,mvgam)
@@ -95,6 +96,7 @@ export(plot_mvgam_series)
export(plot_mvgam_smooth)
export(plot_mvgam_trend)
export(plot_mvgam_uncertainty)
+export(pp_check)
export(ppc)
export(roll_eval_mvgam)
export(score)
@@ -109,6 +111,7 @@ importFrom(bayesplot,color_scheme_set)
importFrom(bayesplot,log_posterior)
importFrom(bayesplot,neff_ratio)
importFrom(bayesplot,nuts_params)
+importFrom(bayesplot,pp_check)
importFrom(brms,bernoulli)
importFrom(brms,conditional_effects)
importFrom(brms,dbeta_binomial)
diff --git a/R/RW.R b/R/RW.R
index c6a54ce0..406a4f09 100644
--- a/R/RW.R
+++ b/R/RW.R
@@ -16,6 +16,66 @@
#' @return An object of class \code{mvgam_trend}, which contains a list of
#' arguments to be interpreted by the parsing functions in \code{mvgam}
#' @rdname RW
+#'@examples
+#'\dontrun{
+#'# A short example to illustrate CAR(1) models
+#'# Function to simulate CAR1 data with seasonality
+#'sim_corcar1 = function(n = 120,
+#' phi = 0.5,
+#' sigma = 1,
+#' sigma_obs = 0.75){
+#'# Sample irregularly spaced time intervals
+#'time_dis <- c(0, runif(n - 1, -0.1, 1))
+#'time_dis[time_dis < 0] <- 0; time_dis <- time_dis * 5
+#'
+#'# Set up the latent dynamic process
+#'x <- vector(length = n); x[1] <- -0.3
+#'for(i in 2:n){
+#' # zero-distances will cause problems in sampling, so mvgam uses a
+#' # minimum threshold; this simulation function emulates that process
+#' if(time_dis[i] == 0){
+#' x[i] <- rnorm(1, mean = (phi ^ 1e-12) * x[i - 1], sd = sigma)
+#' } else {
+#' x[i] <- rnorm(1, mean = (phi ^ time_dis[i]) * x[i - 1], sd = sigma)
+#' }
+#' }
+#'
+#'# Add 12-month seasonality
+#' cov1 <- sin(2 * pi * (1 : n) / 12); cov2 <- cos(2 * pi * (1 : n) / 12)
+#' beta1 <- runif(1, 0.3, 0.7); beta2 <- runif(1, 0.2, 0.5)
+#' seasonality <- beta1 * cov1 + beta2 * cov2
+#'
+#'# Take Gaussian observations with error and return
+#' data.frame(y = rnorm(n, mean = x + seasonality, sd = sigma),
+#' season = rep(1:12, 20)[1:n],
+#' time = cumsum(time_dis))
+#'}
+#'
+#'# Sample two time series
+#'dat <- rbind(dplyr::bind_cols(sim_corcar1(phi = 0.65,
+#' sigma_obs = 0.55),
+#' data.frame(series = 'series1')),
+#' dplyr::bind_cols(sim_corcar1(phi = 0.8,
+#' sigma_obs = 0.35),
+#' data.frame(series = 'series2'))) %>%
+#' dplyr::mutate(series = as.factor(series))
+#'
+#'# mvgam with CAR(1) trends and series-level seasonal smooths
+#'mod <- mvgam(formula = y ~ s(season, bs = 'cc',
+#' k = 5, by = series),
+#' trend_model = CAR(),
+#' data = dat,
+#' family = gaussian(),
+#' run_model = TRUE)
+#'
+#'# View usual summaries and plots
+#'summary(mod)
+#'conditional_effects(mod, type = 'expected')
+#'plot(mod, type = 'trend', series = 1)
+#'plot(mod, type = 'trend', series = 2)
+#'plot(mod, type = 'residuals', series = 1)
+#'plot(mod, type = 'residuals', series = 2)
+#'}
#' @export
RW = function(ma = FALSE, cor = FALSE){
out <- structure(list(trend_model = 'RW',
diff --git a/R/compute_edf.R b/R/compute_edf.R
index ea46d0c1..939a19f3 100644
--- a/R/compute_edf.R
+++ b/R/compute_edf.R
@@ -80,7 +80,8 @@ compute_edf = function(mgcv_model, object, rho_names, sigma_raw_names,
# Calculate variance using family's mean-variance relationship
mu_variance <- predict(object,
process_error = FALSE,
- type = 'variance')[best_draw, ]
+ type = 'variance',
+ summary = FALSE)[best_draw, ]
if(length(mu_variance) > 1){
mu_variance <- mu_variance[1:length(eta)]
}
diff --git a/R/families.R b/R/families.R
index 81d4e196..e72b7caf 100644
--- a/R/families.R
+++ b/R/families.R
@@ -1631,7 +1631,7 @@ dsresids_vec = function(object){
# Need to know which series each observation belongs to so we can
# pull out appropriate family-level parameters (overdispersions, shapes, etc...)
all_dat <- data.frame(series = object$obs_data$series,
- time = object$obs_data$time,
+ time = object$obs_data$index..time..index,
y = object$obs_data$y) %>%
dplyr::arrange(time, series)
diff --git a/R/forecast.mvgam.R b/R/forecast.mvgam.R
index ec5d95f4..7ac54315 100644
--- a/R/forecast.mvgam.R
+++ b/R/forecast.mvgam.R
@@ -56,7 +56,9 @@ forecast <- function(object, ...){
#'
#' }
#'@export
-forecast.mvgam = function(object, newdata, data_test,
+forecast.mvgam = function(object,
+ newdata,
+ data_test,
n_cores = 1,
type = 'response',
...){
@@ -112,8 +114,34 @@ forecast.mvgam = function(object, newdata, data_test,
}
if(is.null(object$test_data)){
- data_test %>%
- dplyr::filter(time > max(object$obs_data$time)) -> data_test
+ data_test <- validate_series_time(data_test, name = 'newdata',
+ trend_model = attr(object$model_data, 'trend_model'))
+ data.frame(series = object$obs_data$series,
+ time = object$obs_data$time) %>%
+ dplyr::group_by(series) %>%
+ dplyr::summarise(maxt = max(time)) -> series_max_ts
+
+ data.frame(series = data_test$series,
+ time = data_test$time) %>%
+ dplyr::mutate(orig_rows = dplyr::row_number()) %>%
+ dplyr::left_join(series_max_ts, by = 'series') %>%
+ dplyr::filter(time > maxt) %>%
+ dplyr::pull(orig_rows) -> idx
+
+ if(inherits(data_test, 'list')){
+ data_arranged <- data_test
+ data_arranged <- lapply(data_test, function(x){
+ if(is.matrix(x)){
+ matrix(x[idx,], ncol = NCOL(x))
+ } else {
+ x[idx]
+ }
+ })
+ names(data_arranged) <- names(data_test)
+ data_test <- data_arranged
+ } else {
+ data_test <- data_test[idx, ]
+ }
}
# Only compute forecasts if they don't already exist!
@@ -269,7 +297,7 @@ forecast.mvgam = function(object, newdata, data_test,
series_obs <- lapply(seq_len(n_series), function(series){
s_name <- levels(object$obs_data$series)[series]
data.frame(series = object$obs_data$series,
- time = object$obs_data$time,
+ time = object$obs_data$index..time..index,
y = object$obs_data$y) %>%
dplyr::filter(series == s_name) %>%
dplyr::arrange(time) %>%
@@ -280,7 +308,7 @@ forecast.mvgam = function(object, newdata, data_test,
series_test <- lapply(seq_len(n_series), function(series){
s_name <- levels(object$obs_data$series)[series]
data.frame(series = data_test$series,
- time = data_test$time,
+ time = data_test$index..time..index,
y = data_test$y) %>%
dplyr::filter(series == s_name) %>%
dplyr::arrange(time) %>%
@@ -372,7 +400,7 @@ forecast.mvgam = function(object, newdata, data_test,
names(series_hcs) <- s_name
series_obs <- list(data.frame(series = object$obs_data$series,
- time = object$obs_data$time,
+ time = object$obs_data$index..time..index,
y = object$obs_data$y) %>%
dplyr::filter(series == s_name) %>%
dplyr::arrange(time) %>%
@@ -380,7 +408,7 @@ forecast.mvgam = function(object, newdata, data_test,
names(series_obs) <- s_name
series_test <- list(data.frame(series = data_test$series,
- time = data_test$time,
+ time = data_test$index..time..index,
y = data_test$y) %>%
dplyr::filter(series == s_name) %>%
dplyr::arrange(time) %>%
@@ -391,8 +419,8 @@ forecast.mvgam = function(object, newdata, data_test,
} else {
# If forecasts already exist, simply extract them
data_test <- object$test_data
- last_train <- max(object$obs_data$time) -
- (min(object$obs_data$time) - 1)
+ last_train <- max(object$obs_data$index..time..index) -
+ (min(object$obs_data$index..time..index) - 1)
if(series == 'all'){
data_train <- object$obs_data
@@ -595,7 +623,7 @@ forecast.mvgam = function(object, newdata, data_test,
series_obs <- lapply(seq_len(n_series), function(series){
s_name <- levels(object$obs_data$series)[series]
data.frame(series = object$obs_data$series,
- time = object$obs_data$time,
+ time = object$obs_data$index..time..index,
y = object$obs_data$y) %>%
dplyr::filter(series == s_name) %>%
dplyr::arrange(time) %>%
@@ -606,7 +634,7 @@ forecast.mvgam = function(object, newdata, data_test,
series_test <- lapply(seq_len(n_series), function(series){
s_name <- levels(object$obs_data$series)[series]
data.frame(series = object$test_data$series,
- time = object$test_data$time,
+ time = object$test_data$index..time..index,
y = object$test_data$y) %>%
dplyr::filter(series == s_name) %>%
dplyr::arrange(time) %>%
@@ -709,7 +737,7 @@ forecast.mvgam = function(object, newdata, data_test,
# Training observations
series_obs <- list(data.frame(series = object$obs_data$series,
- time = object$obs_data$time,
+ time = object$obs_data$index..time..index,
y = object$obs_data$y) %>%
dplyr::filter(series == s_name) %>%
dplyr::arrange(time) %>%
@@ -718,7 +746,7 @@ forecast.mvgam = function(object, newdata, data_test,
# Testing observations
series_test <- list(data.frame(series = object$test_data$series,
- time = object$test_data$time,
+ time = object$test_data$index..time..index,
y = object$test_data$y) %>%
dplyr::filter(series == s_name) %>%
dplyr::arrange(time) %>%
@@ -743,9 +771,9 @@ forecast.mvgam = function(object, newdata, data_test,
series_names = factor(unique(data_train$series),
levels = levels(data_train$series)),
train_observations = series_obs,
- train_times = unique(data_train$time),
+ train_times = unique(data_train$index..time..index),
test_observations = series_test,
- test_times = unique(data_test$time),
+ test_times = unique(data_test$index..time..index),
hindcasts = series_hcs,
forecasts = series_fcs),
class = 'mvgam_forecast')
@@ -767,7 +795,8 @@ forecast_draws = function(object,
# Check arguments
validate_pos_integer(n_cores)
- data_test <- validate_series_time(data_test, name = 'newdata')
+ data_test <- validate_series_time(data_test, name = 'newdata',
+ trend_model = attr(object$model_data, 'trend_model'))
n_series <- NCOL(object$ytimes)
use_lv <- object$use_lv
@@ -784,14 +813,14 @@ forecast_draws = function(object,
if(series != 'all'){
obs_keep <- data.frame(y = data_test$y,
series = data_test$series,
- time = data_test$time,
+ time = data_test$index..time..index,
rowid = 1:length(data_test$y)) %>%
dplyr::filter(series == s_name) %>%
dplyr::arrange(time) %>%
dplyr::pull(rowid)
series_test <- data.frame(y = data_test$y,
series = data_test$series,
- time = data_test$time,
+ time = data_test$index..time..index,
rowid = 1:length(data_test$y)) %>%
dplyr::filter(series == s_name) %>%
dplyr::arrange(time)
@@ -803,7 +832,7 @@ forecast_draws = function(object,
if(series != 'all'){
series_test <- data_test %>%
dplyr::filter(series == s_name) %>%
- dplyr::arrange(time)
+ dplyr::arrange(index..time..index)
Xp <- obs_Xp_matrix(newdata = series_test,
mgcv_model = object$mgcv_model)
} else {
@@ -834,9 +863,9 @@ forecast_draws = function(object,
# Ensure the last three values are used, in case the obs_data
# was not supplied in order
- data.frame(time = object$obs_data$time,
+ data.frame(time = object$obs_data$index..time..index,
series = object$obs_data$series,
- row_id = 1:length(object$obs_data$time)) %>%
+ row_id = 1:length(object$obs_data$index..time..index)) %>%
dplyr::arrange(time, series) %>%
dplyr::pull(row_id) -> sorted_inds
@@ -897,7 +926,7 @@ forecast_draws = function(object,
if(series != 'all'){
fc_horizon <- NROW(series_test)
} else {
- fc_horizon <- length(unique(data_test$time))
+ fc_horizon <- length(unique(data_test$index..time..index))
}
# Beta coefficients for GAM observation component
@@ -935,7 +964,7 @@ forecast_draws = function(object,
resp_terms <- resp_terms[-grepl('cbind', resp_terms)]
trial_name <- resp_terms[2]
trial_df <- data.frame(series = data_test$series,
- time = data_test$time,
+ time = data_test$index..time..index,
trial = data_test[[trial_name]])
trials <- matrix(NA, nrow = fc_horizon, ncol = n_series)
for(i in 1:n_series){
@@ -951,6 +980,18 @@ forecast_draws = function(object,
# Trend model
trend_model <- attr(object$model_data, 'trend_model')
+ # Calculate time_dis if this is a CAR1 model
+ if(trend_model == 'CAR1'){
+ data_test$index..time..index <- data_test$index..time..index +
+ max(object$obs_data$index..time..index)
+ time_dis <- add_corcar(model_data = list(),
+ data_train = object$obs_data,
+ data_test = data_test)[[1]]
+ time_dis <- time_dis[-c(1:max(object$obs_data$index..time..index)),]
+ } else {
+ time_dis <- NULL
+ }
+
# Trend-specific parameters
if(missing(ending_time)){
trend_pars <- extract_trend_pars(object = object,
@@ -964,13 +1005,19 @@ forecast_draws = function(object,
# Any model in which an autoregressive process was included should be
# considered as VAR1 for forecasting purposes as this will make use of the
# faster c++ functions
- if('Sigma' %in% names(trend_pars) |
- 'sigma' %in% names(trend_pars) |
- 'tau' %in% names(trend_pars)){
- trend_model <- 'VAR1'
+ if(trend_model == 'CAR1'){
if(!'last_lvs' %in% names(trend_pars)){
trend_pars$last_lvs <- trend_pars$last_trends
}
+ } else {
+ if('Sigma' %in% names(trend_pars) |
+ 'sigma' %in% names(trend_pars) |
+ 'tau' %in% names(trend_pars)){
+ trend_model <- 'VAR1'
+ if(!'last_lvs' %in% names(trend_pars)){
+ trend_pars$last_lvs <- trend_pars$last_trends
+ }
+ }
}
# Set up parallel environment for looping across posterior draws
@@ -995,7 +1042,8 @@ forecast_draws = function(object,
'fc_horizon',
'b_uncertainty',
'trend_uncertainty',
- 'obs_uncertainty'),
+ 'obs_uncertainty',
+ 'time_dis'),
envir = environment())
parallel::clusterExport(cl = cl,
unclass(lsf.str(envir = asNamespace("mvgam"),
@@ -1035,7 +1083,7 @@ forecast_draws = function(object,
samp_index = 1)
}
- if(use_lv || trend_model %in% c('VAR1', 'PWlinear', 'PWlogistic')){
+ if(use_lv || trend_model %in% c('VAR1', 'PWlinear', 'PWlogistic', 'CAR1')){
if(trend_model == 'PWlogistic'){
if(!(exists('cap', where = data_test))) {
stop('Capacities must also be supplied in "newdata" for logistic growth predictions',
@@ -1046,7 +1094,7 @@ forecast_draws = function(object,
family_links <- Gamma(link = 'log')
}
cap <- data.frame(series = data_test$series,
- time = data_test$time,
+ time = data_test$index..time..index,
cap = suppressWarnings(linkfun(data_test$cap,
link = family_links$link)))
@@ -1068,8 +1116,10 @@ forecast_draws = function(object,
h = fc_horizon,
betas_trend = betas_trend,
Xp_trend = Xp_trend,
- time = unique(data_test$time - min(object$obs_data$time) + 1),
- cap = cap)
+ time = unique(data_test$index..time..index -
+ min(object$obs_data$index..time..index) + 1),
+ cap = cap,
+ time_dis = time_dis)
}
# Loop across series and produce the next trend estimate
@@ -1081,11 +1131,11 @@ forecast_draws = function(object,
trend_pars = trend_pars,
use_lv = use_lv)
- if(use_lv || trend_model %in% c('VAR1', 'PWlinear', 'PWlogistic')){
+ if(use_lv || trend_model %in% c('VAR1', 'PWlinear', 'PWlogistic', 'CAR1')){
if(use_lv){
# Multiply lv states with loadings to generate the series' forecast trend state
out <- as.numeric(trends %*% trend_extracts$lv_coefs)
- } else if(trend_model %in% c('VAR1', 'PWlinear', 'PWlogistic')){
+ } else if(trend_model %in% c('VAR1', 'PWlinear', 'PWlogistic', 'CAR1')){
out <- trends[,series]
}
@@ -1097,7 +1147,8 @@ forecast_draws = function(object,
h = fc_horizon,
betas_trend = betas_trend,
Xp_trend = Xp_trend,
- time = sort(unique(data_test$time)))
+ time = sort(unique(data_test$index..time..index)),
+ time_dis = NULL)
}
out
})
@@ -1176,7 +1227,7 @@ forecast_draws = function(object,
h = fc_horizon,
betas_trend = betas_trend,
Xp_trend = Xp_trend,
- time = sort(unique(series_test$time)))
+ time = sort(unique(series_test$index..time..index)))
if(use_lv){
# Multiply lv states with loadings to generate the series' forecast trend state
diff --git a/R/get_linear_predictors.R b/R/get_linear_predictors.R
index ba7e664a..06dd9895 100644
--- a/R/get_linear_predictors.R
+++ b/R/get_linear_predictors.R
@@ -81,7 +81,7 @@ trend_Xp_matrix = function(newdata, trend_map, series = 'all',
mgcv_model){
trend_test <- newdata
- trend_indicators <- vector(length = length(trend_test$time))
+ trend_indicators <- vector(length = length(trend_test[[1]]))
for(i in 1:length(trend_test$series)){
trend_indicators[i] <- trend_map$trend[which(as.character(trend_map$series) ==
as.character(trend_test$series[i]))]
diff --git a/R/get_mvgam_priors.R b/R/get_mvgam_priors.R
index b05b1315..aaec611d 100644
--- a/R/get_mvgam_priors.R
+++ b/R/get_mvgam_priors.R
@@ -256,6 +256,7 @@ get_mvgam_priors = function(formula,
}
trend_train <- data_train
+ trend_train$time <- trend_train$index..time..index
trend_train$trend_y <- rnorm(length(trend_train$time))
# Add indicators of trend names as factor levels using the trend_map
diff --git a/R/marginaleffects.mvgam.R b/R/marginaleffects.mvgam.R
index 20587d40..6c0481b9 100644
--- a/R/marginaleffects.mvgam.R
+++ b/R/marginaleffects.mvgam.R
@@ -80,6 +80,7 @@ get_predict.mvgam <- function(model, newdata,
newdata = newdata,
type = type,
process_error = process_error,
+ summary = FALSE,
...)
out <- data.frame(
rowid = seq_len(NCOL(preds)),
@@ -113,7 +114,7 @@ get_data.mvgam = function (x, source = "environment", verbose = TRUE, ...) {
# Original series, time and outcomes
orig_dat <- data.frame(series = x$obs_data$series,
- time = x$obs_data$time,
+ time = x$obs_data$index..time..index,
y = x$obs_data$y)
# Add indicators of trend names as factor levels using the trend_map
@@ -129,7 +130,7 @@ get_data.mvgam = function (x, source = "environment", verbose = TRUE, ...) {
series = orig_dat$series,
time = orig_dat$time,
y = orig_dat$y,
- row_num = 1:length(x$obs_data$time))
+ row_num = 1:length(x$obs_data$index..time..index))
# # We only kept one time observation per trend
trend_level_data %>%
@@ -208,7 +209,7 @@ get_data.mvgam_prefit = function (x, source = "environment", verbose = TRUE, ...
# Original series, time and outcomes
orig_dat <- data.frame(series = x$obs_data$series,
- time = x$obs_data$time,
+ time = x$obs_data$index..time..index,
y = x$obs_data$y)
# Add indicators of trend names as factor levels using the trend_map
@@ -224,7 +225,7 @@ get_data.mvgam_prefit = function (x, source = "environment", verbose = TRUE, ...
series = orig_dat$series,
time = orig_dat$time,
y = orig_dat$y,
- row_num = 1:length(x$obs_data$time))
+ row_num = 1:length(x$obs_data$index..time..index))
# # We only kept one time observation per trend
trend_level_data %>%
diff --git a/R/model.frame.mvgam.R b/R/model.frame.mvgam.R
index ffa8fa23..e030e668 100644
--- a/R/model.frame.mvgam.R
+++ b/R/model.frame.mvgam.R
@@ -36,7 +36,9 @@ model.frame.mvgam = function(formula, trend_effects = FALSE, ...){
out[,resp] <- formula$obs_data$y
# Ensure 'cap' is included if this is an N-mixture model
- out$cap <- formula$obs_data$cap
+ if(attr(formula$model_data, 'trend_model') == 'nmix'){
+ out$cap <- formula$obs_data$cap
+ }
}
return(out)
}
@@ -72,7 +74,9 @@ model.frame.mvgam_prefit = function(formula, trend_effects = FALSE, ...){
}
# Ensure 'cap' is included if this is an N-mixture model
- out$cap <- formula$obs_data$cap
+ if(attr(formula$model_data, 'trend_model') == 'nmix'){
+ out$cap <- formula$obs_data$cap
+ }
return(out)
}
diff --git a/R/mvgam.R b/R/mvgam.R
index 61b5b710..5889bf14 100644
--- a/R/mvgam.R
+++ b/R/mvgam.R
@@ -42,7 +42,9 @@
#' \item`time` (\code{numeric} or \code{integer} index of the time point for each observation).
#' For most dynamic trend types available in `mvgam` (see argument `trend_model`), time should be
#' measured in discrete, regularly spaced intervals (i.e. `c(1, 2, 3, ...)`). However you can
-#' use irregularly spaced intervals if using `trend_model = CAR(1)`
+#' use irregularly spaced intervals if using `trend_model = CAR(1)`, though note that any
+#'temporal intervals that are exactly `0` will be adjusted to a very small number
+#'(`1e-12`) to prevent sampling errors. See an example of `CAR()` trends in \code{\link{CAR}}
#' }
#'Should also include any other variables to be included in the linear predictor of \code{formula}
#'@param data_train Deprecated. Still works in place of \code{data} but users are recommended to use
@@ -334,7 +336,7 @@
#' model_data <- mod1$model_data
#' library(rstan)
#' fit <- stan(model_code = mod1$model_file,
-#' data = model_data)
+#' data = model_data)
#'
#' # Now using cmdstanr
#' library(cmdstanr)
@@ -384,9 +386,14 @@
#' plot(mod1, type = 'smooths', realisations = TRUE)
#'
#' # Plot conditional response predictions using marginaleffects
-#' plot(conditional_effects(mod1), ask = FALSE)
+#' conditional_effects(mod1)
#' plot_predictions(mod1, condition = 'season', points = 0.5)
#'
+#' # Generate posterior predictive checks through bayesplot
+#' pp_check(mod1)
+#' pp_check(mod, type = "bars_grouped",
+#' group = "series", ndraws = 50)
+#'
#' # Extract observation model beta coefficient draws as a data.frame
#' beta_draws_df <- as.data.frame(mod1, variable = 'betas')
#' head(beta_draws_df)
@@ -1447,7 +1454,7 @@ mvgam = function(formula,
vectorised$model_file <- trend_map_setup$model_file
vectorised$model_data <- trend_map_setup$model_data
- if(trend_model %in% c('RW', 'AR1', 'AR2', 'AR3')){
+ if(trend_model %in% c('RW', 'AR1', 'AR2', 'AR3', 'CAR1')){
param <- c(param, 'sigma')
}
diff --git a/R/plot_mvgam_resids.R b/R/plot_mvgam_resids.R
index e5213ea4..4cf95eca 100644
--- a/R/plot_mvgam_resids.R
+++ b/R/plot_mvgam_resids.R
@@ -67,7 +67,7 @@ series_residuals <- object$resids[[series]]
# Get indices of training horizon
if(class(data_train)[1] == 'list'){
- data_train_df <- data.frame(time = data_train$time,
+ data_train_df <- data.frame(time = data_train$index..time..index,
y = data_train$y,
series = data_train$series)
obs_length <- length(data_train_df %>%
@@ -79,9 +79,9 @@ if(class(data_train)[1] == 'list'){
} else {
obs_length <- length(data_train %>%
dplyr::filter(series == !!(levels(data_train$series)[series])) %>%
- dplyr::select(time, y) %>%
+ dplyr::select(index..time..index, y) %>%
dplyr::distinct() %>%
- dplyr::arrange(time) %>%
+ dplyr::arrange(index..time..index) %>%
dplyr::pull(y))
}
@@ -143,7 +143,7 @@ if(missing(data_test)){
if(class(data_train)[1] == 'list'){
all_obs <- c(data.frame(y = data_train$y,
series = data_train$series,
- time = data_train$time) %>%
+ time = data_train$index..time..index) %>%
dplyr::filter(series == s_name) %>%
dplyr::select(time, y) %>%
dplyr::distinct() %>%
@@ -160,9 +160,9 @@ if(missing(data_test)){
} else {
all_obs <- c(data_train %>%
dplyr::filter(series == s_name) %>%
- dplyr::select(time, y) %>%
+ dplyr::select(index..time..index, y) %>%
dplyr::distinct() %>%
- dplyr::arrange(time) %>%
+ dplyr::arrange(index..time..index) %>%
dplyr::pull(y),
data_test %>%
dplyr::filter(series == s_name) %>%
diff --git a/R/posterior_epred.mvgam.R b/R/posterior_epred.mvgam.R
index 8101258f..047cbbb6 100644
--- a/R/posterior_epred.mvgam.R
+++ b/R/posterior_epred.mvgam.R
@@ -12,6 +12,8 @@
#' averaged across draws should be very similar.
#' @importFrom rstantools posterior_epred
#' @inheritParams predict.mvgam
+#' @param ndraws Positive `integer` indicating how many posterior draws should be used.
+#' If `NULL` (the default) all draws are used.
#' @param process_error Logical. If \code{TRUE} and \code{newdata} is supplied,
#' expected uncertainty in the process model is accounted for by using draws
#' from any latent trend SD parameters. If \code{FALSE}, uncertainty in the latent
@@ -49,19 +51,32 @@
#'str(expectations)
#'}
#' @export
-posterior_epred.mvgam = function(object, newdata,
+posterior_epred.mvgam = function(object,
+ newdata,
data_test,
+ ndraws = NULL,
process_error = TRUE, ...){
if(missing(newdata) & missing(data_test)){
- .mvgam_fitted(object, type = 'expected')
+ out <- .mvgam_fitted(object, type = 'expected')
} else {
- predict(object,
- newdata = newdata,
- data_test = data_test,
- process_error = process_error,
- type = 'expected')
+ out <- predict(object,
+ newdata = newdata,
+ data_test = data_test,
+ process_error = process_error,
+ type = 'expected',
+ summary = FALSE)
}
+
+ if(!is.null(ndraws)){
+ validate_pos_integer(ndraws)
+ if(ndraws > NROW(out)){
+ } else {
+ idx <- sample(1:NROW(out), ndraws, replace = FALSE)
+ out <- out[idx, ]
+ }
+ }
+ return(out)
}
#' Posterior Draws of the Linear Predictor
@@ -110,6 +125,7 @@ posterior_epred.mvgam = function(object, newdata,
posterior_linpred.mvgam = function(object,
transform = FALSE,
newdata,
+ ndraws = NULL,
data_test,
process_error = TRUE,
...){
@@ -120,11 +136,22 @@ posterior_linpred.mvgam = function(object,
type <- 'link'
}
- predict(object,
- newdata = newdata,
- data_test = data_test,
- process_error = process_error,
- type = type)
+ out <- predict(object,
+ newdata = newdata,
+ data_test = data_test,
+ process_error = process_error,
+ type = type,
+ summary = FALSE)
+
+ if(!is.null(ndraws)){
+ validate_pos_integer(ndraws)
+ if(ndraws > NROW(out)){
+ } else {
+ idx <- sample(1:NROW(out), ndraws, replace = FALSE)
+ out <- out[idx, ]
+ }
+ }
+ return(out)
}
#' Draws from the Posterior Predictive Distribution
@@ -138,6 +165,7 @@ posterior_linpred.mvgam = function(object,
#' both methods averaged across draws should be very similar.
#' @importFrom rstantools posterior_predict
#' @inheritParams predict.mvgam
+#' @inheritParams posterior_epred.mvgam
#' @param process_error Logical. If \code{TRUE} and \code{newdata} is supplied,
#' expected uncertainty in the process model is accounted for by using draws
#' from any latent trend SD parameters. If \code{FALSE}, uncertainty in the latent
@@ -174,15 +202,28 @@ posterior_linpred.mvgam = function(object,
#'str(predictions)
#'}
#' @export
-posterior_predict.mvgam = function(object, newdata,
+posterior_predict.mvgam = function(object,
+ newdata,
data_test,
+ ndraws = NULL,
process_error = TRUE, ...){
- predict(object,
- newdata = newdata,
- data_test = data_test,
- process_error = process_error,
- type = 'response')
+ out <- predict(object,
+ newdata = newdata,
+ data_test = data_test,
+ process_error = process_error,
+ type = 'response',
+ summary = FALSE)
+
+ if(!is.null(ndraws)){
+ validate_pos_integer(ndraws)
+ if(ndraws > NROW(out)){
+ } else {
+ idx <- sample(1:NROW(out), ndraws, replace = FALSE)
+ out <- out[idx, ]
+ }
+ }
+ return(out)
}
#' Expected Values of the Posterior Predictive Distribution
diff --git a/R/ppc.mvgam.R b/R/ppc.mvgam.R
index 6543dbac..b65982f2 100644
--- a/R/ppc.mvgam.R
+++ b/R/ppc.mvgam.R
@@ -743,3 +743,223 @@ ppc.mvgam = function(object, newdata, data_test, series = 1, type = 'hist',
box(bty = 'L', lwd = 2)
}
}
+
+#' Posterior Predictive Checks for \code{mvgam} Objects
+#'
+#' Perform posterior predictive checks with the help
+#' of the \pkg{bayesplot} package.
+#'
+#' @aliases pp_check
+#' @inheritParams brms::pp_check
+#' @param object An object of class \code{mvgam}.
+#' @param newdata Optional \code{dataframe} or \code{list} of test data containing the
+#' variables included in the linear predictor of \code{formula}. If not supplied,
+#' predictions are generated for the original observations used for the model fit.
+#' @param ... Further arguments passed to \code{\link{predict.mvgam}}
+#' as well as to the PPC function specified in \code{type}.
+#' @inheritParams prepare_predictions.brmsfit
+#'
+#' @return A ggplot object that can be further
+#' customized using the \pkg{ggplot2} package.
+#'
+#' @details For a detailed explanation of each of the ppc functions,
+#' see the \code{\link[bayesplot:PPC-overview]{PPC}}
+#' documentation of the \pkg{\link[bayesplot:bayesplot-package]{bayesplot}}
+#' package.
+#'
+#' @examples
+#' \dontrun{
+#'simdat <- sim_mvgam(seasonality = 'hierarchical')
+#'mod <- mvgam(y ~ series +
+#' s(season, bs = 'cc', k = 6) +
+#' s(season, series, bs = 'sz', k = 4),
+#' data = simdat$data_train)
+#'
+#'# Get a list of available plot types
+#'pp_check(mod, type = "xyz")
+#'
+#'# Default is a density overlay for all observations
+#'pp_check(mod)
+#'
+#'# Rootograms particularly useful for count data
+#'pp_check(mod, type = "rootogram")
+#'
+#'# Grouping plots by series is useful
+#'pp_check(mod, type = "bars_grouped",
+#' group = "series", ndraws = 50)
+#'pp_check(mod, type = "ecdf_overlay_grouped",
+#' group = "series", ndraws = 50)
+#'pp_check(mod, type = "stat_freqpoly_grouped",
+#' group = "series", ndraws = 50)
+#'
+#'# Custom functions accepted
+#'prop_zero <- function(x) mean(x == 0)
+#'pp_check(mod, type = "stat", stat = "prop_zero")
+#'pp_check(mod, type = "stat_grouped",
+#' stat = "prop_zero",
+#' group = "series")
+#'
+#'# Some functions accept covariates to set the x-axes
+#'pp_check(mod, x = "season",
+#' type = "ribbon_grouped",
+#' prob = 0.5,
+#' prob_outer = 0.8,
+#' group = "series")
+#'
+#'# Many plots can be made without the observed data
+#'pp_check(mod, prefix = "ppd")
+#' }
+#'
+#' @importFrom bayesplot pp_check
+#' @export pp_check
+#' @export
+pp_check.mvgam <- function(object, type, ndraws = NULL, prefix = c("ppc", "ppd"),
+ group = NULL, x = NULL, newdata = NULL, ...) {
+
+ # Set red colour scheme
+ col_scheme <- attr(color_scheme_get(),
+ 'scheme_name')
+ color_scheme_set('red')
+
+ dots <- list(...)
+ if (missing(type)) {
+ type <- "dens_overlay"
+ }
+
+ prefix <- match.arg(prefix)
+ ndraws_given <- "ndraws" %in% names(match.call())
+
+ if(is.null(newdata)){
+ newdata <- object$obs_data
+ }
+
+ if (prefix == "ppc") {
+ # No type checking for prefix 'ppd' yet
+ valid_types <- as.character(bayesplot::available_ppc(""))
+ valid_types <- sub("^ppc_", "", valid_types)
+ if (!type %in% valid_types) {
+ stop("Type '", type, "' is not a valid ppc type. ",
+ "Valid types are:\n",
+ paste0("'", valid_types, "'", collapse = ", "),
+ call. = FALSE)
+ }
+ }
+ ppc_fun <- get(paste0(prefix, "_", type), asNamespace("bayesplot"))
+
+ family <- object$family
+ if (family == 'nmix') {
+ stop("'pp_check' is not implemented for this family.",
+ call. = FALSE)
+ }
+ valid_vars <- names(get_predictors(object))
+ if ("group" %in% names(formals(ppc_fun))) {
+ if (is.null(group)) {
+ stop("Argument 'group' is required for ppc type '", type, "'.",
+ call. = FALSE)
+ }
+ if (!group %in% valid_vars) {
+ stop("Variable '", group, "' could not be found in the data.",
+ call. = FALSE)
+ }
+ }
+ if ("x" %in% names(formals(ppc_fun))) {
+ if (!is.null(x) && !x %in% valid_vars) {
+ stop("Variable '", x, "' could not be found in the data.",
+ call. = FALSE)
+ }
+ }
+ if (type == "error_binned") {
+ method <- "posterior_epred"
+ } else {
+ method <- "posterior_predict"
+ }
+ if (!ndraws_given) {
+ aps_types <- c(
+ "error_scatter_avg", "error_scatter_avg_vs_x",
+ "intervals", "intervals_grouped",
+ "loo_intervals", "loo_pit", "loo_pit_overlay",
+ "loo_pit_qq", "loo_ribbon",
+ 'pit_ecdf', 'pit_ecdf_grouped',
+ "ribbon", "ribbon_grouped",
+ "rootogram", "scatter_avg", "scatter_avg_grouped",
+ "stat", "stat_2d", "stat_freqpoly_grouped", "stat_grouped",
+ "violin_grouped"
+ )
+ if (type %in% aps_types) {
+ ndraws <- NULL
+ message("Using all posterior draws for ppc type '",
+ type, "' by default.")
+ } else {
+ ndraws <- 10
+ message("Using 10 posterior draws for ppc type '",
+ type, "' by default.")
+ }
+ }
+
+ y <- NULL
+ if (prefix == "ppc") {
+ # y is ignored in prefix 'ppd' plots
+ y <- newdata[[terms(formula(object$call))[[2]]]]
+ }
+
+ pred_args <- list(
+ object, newdata = newdata, ndraws = ndraws, ...)
+ yrep <- do_call(method, pred_args)
+
+ if (anyNA(y)) {
+ warning("NA responses are not shown in 'pp_check'.")
+ take <- !is.na(y)
+ y <- y[take]
+ yrep <- yrep[, take, drop = FALSE]
+ }
+
+ # Prepare plotting arguments
+ ppc_args <- list()
+ if (prefix == "ppc") {
+ ppc_args$y <- y
+ ppc_args$yrep <- yrep
+ } else if (prefix == "ppd") {
+ ppc_args$ypred <- yrep
+ }
+ if (!is.null(group)) {
+ if(!exists(group, newdata)) stop(paste0('Variable ', group, ' not in newdata'),
+ call. = FALSE)
+ ppc_args$group <- newdata[[group]]
+ }
+
+ is_like_factor = function(x){
+ is.factor(x) || is.character(x) || is.logical(x)
+ }
+
+ if (!is.null(x)) {
+ ppc_args$x <- newdata[[x]]
+ if (!is_like_factor(ppc_args$x)) {
+ ppc_args$x <- as.numeric(ppc_args$x)
+ }
+ }
+ if ("psis_object" %in% setdiff(names(formals(ppc_fun)), names(ppc_args))) {
+ # ppc_args$psis_object <- do_call(
+ # compute_loo, c(pred_args, criterion = "psis")
+ # )
+ # compute_loo() not available yet for mvgam
+ ppc_args$psis_object <- NULL
+ }
+ if ("lw" %in% setdiff(names(formals(ppc_fun)), names(ppc_args))) {
+ # ppc_args$lw <- weights(
+ # do_call(compute_loo, c(pred_args, criterion = "psis"))
+ # )
+ # compute_loo() not available yet for mvgam
+ ppc_args$lw <- NULL
+ }
+
+ # Most ... arguments are meant for the prediction function
+ for_pred <- names(dots) %in% names(formals(posterior_predict.mvgam))
+ ppc_args <- c(ppc_args, dots[!for_pred])
+
+ # Generate plot and reset colour scheme
+ out_plot <- do_call(ppc_fun, ppc_args)
+ color_scheme_set(col_scheme)
+
+ # Return the plot
+ return(out_plot)
+}
diff --git a/R/predict.mvgam.R b/R/predict.mvgam.R
index d49017a4..c84716b6 100644
--- a/R/predict.mvgam.R
+++ b/R/predict.mvgam.R
@@ -1,5 +1,6 @@
#'Predict from the GAM component of an mvgam model
#'@importFrom stats predict
+#'@inheritParams brms::fitted.brmsfit
#'@param object \code{list} object returned from \code{mvgam}. See [mvgam()]
#'@param newdata Optional \code{dataframe} or \code{list} of test data containing the
#'variables included in the linear predictor of \code{formula}. If not supplied,
@@ -31,13 +32,28 @@
#'of a \code{mvgam} model, while the forecasting functions
#'\code{\link{plot_mvgam_fc}} and \code{\link{forecast.mvgam}} are better suited to generate h-step ahead forecasts
#'that respect the temporal dynamics of estimated latent trends.
-#'@return A \code{matrix} of dimension \code{n_samples x new_obs}, where \code{n_samples} is the number of
-#'posterior samples from the fitted object and \code{n_obs} is the number of test observations in \code{newdata}
+#' @return Predicted values on the appropriate scale.
+#' If \code{summary = FALSE} the output is a matrix of dimension `n_draw x n_observations`
+#' containing predicted values for each posterior draw in `object`.
+#'
+#' If \code{summary = TRUE} the output is an \code{n_observations} x \code{E}
+#' matrix. The number of summary statistics \code{E} is equal to \code{2 +
+#' length(probs)}: The \code{Estimate} column contains point estimates (either
+#' mean or median depending on argument \code{robust}), while the
+#' \code{Est.Error} column contains uncertainty estimates (either standard
+#' deviation or median absolute deviation depending on argument
+#' \code{robust}). The remaining columns starting with \code{Q} contain
+#' quantile estimates as specified via argument \code{probs}.
#'@export
-predict.mvgam = function(object, newdata,
+predict.mvgam = function(object,
+ newdata,
data_test,
type = 'link',
- process_error = TRUE, ...){
+ process_error = TRUE,
+ summary = TRUE,
+ robust = FALSE,
+ probs = c(0.025, 0.975),
+ ...){
# Argument checks
if(!missing("data_test")){
@@ -84,7 +100,7 @@ predict.mvgam = function(object, newdata,
mgcv_model = object$trend_mgcv_model)
# Extract process error estimates
- if(attr(object$model_data, 'trend_model') %in% c('RW','AR1','AR2','AR3')){
+ if(attr(object$model_data, 'trend_model') %in% c('RW','AR1','AR2','AR3','CAR1')){
if(object$family == 'nmix'){
family_pars <- list(sigma_obs = .Machine$double.eps)
} else {
@@ -189,7 +205,7 @@ predict.mvgam = function(object, newdata,
if(!object$use_lv){
if(attr(object$model_data, 'trend_model') %in%
- c('RW','AR1','AR2','AR3','VAR1')){
+ c('RW','AR1','AR2','AR3','VAR1','CAR1')){
family_pars <- list(sigma_obs = mcmc_chains(object$model_output,
'sigma'))
}
@@ -335,7 +351,26 @@ predict.mvgam = function(object, newdata,
family_pars = family_extracts)
# Convert back to matrix
- predictions <- matrix(predictions_vec, nrow = NROW(betas))
- return(predictions)
+ preds <- matrix(predictions_vec, nrow = NROW(betas))
+
+ if(summary){
+ Qupper <- apply(preds, 2, quantile, probs = max(probs), na.rm = TRUE)
+ Qlower <- apply(preds, 2, quantile, probs = min(probs), na.rm = TRUE)
+
+ if(robust){
+ estimates <- apply(preds, 2, median, na.rm = TRUE)
+ errors <- apply(abs(preds - estimates), 2, median, na.rm = TRUE)
+ } else {
+ estimates <- apply(preds, 2, mean, na.rm = TRUE)
+ errors <- apply(preds, 2, sd, na.rm = TRUE)
+ }
+
+ out <- cbind(estimates, errors, Qlower, Qupper)
+ colnames(out) <- c('Estimate', 'Est.Error', paste0('Q', 100*min(probs)),
+ paste0('Q', 100*max(probs)))
+ } else {
+ out <- preds
+ }
+ return(out)
}
diff --git a/R/stan_utils.R b/R/stan_utils.R
index 46035ac4..812711b1 100644
--- a/R/stan_utils.R
+++ b/R/stan_utils.R
@@ -2701,7 +2701,7 @@ if(trend_model != 'VAR1'){
model_file <- readLines(textConnection(model_file), n = -1)
# We can estimate the variance parameters if a trend map is supplied
- if(trend_model %in% c('RW', 'AR1', 'AR2', 'AR3')){
+ if(trend_model %in% c('RW', 'AR1', 'AR2', 'AR3', 'CAR1')){
model_file <- model_file[-grep('vector[num_basis] b_raw;',
model_file, fixed = TRUE)]
model_file[grep("// raw basis coefficients",
@@ -2902,6 +2902,7 @@ add_trend_predictors = function(trend_formula,
}
trend_train <- data_train
+ trend_train$time <- trend_train$index..time..index
trend_train$trend_y <- rnorm(length(trend_train$time))
# Add indicators of trend names as factor levels using the trend_map
@@ -2940,6 +2941,7 @@ add_trend_predictors = function(trend_formula,
# If newdata supplied, also create a fake design matrix
# for the test data
trend_test <- data_test
+ trend_test$time <- trend_test$index..time..index
trend_test$trend_y <- rnorm(length(trend_test$time))
trend_indicators <- vector(length = length(trend_test$time))
for(i in 1:length(trend_test$time)){
diff --git a/R/sysdata.rda b/R/sysdata.rda
index 47d8da84..6dc11577 100644
Binary files a/R/sysdata.rda and b/R/sysdata.rda differ
diff --git a/R/trends.R b/R/trends.R
index d16ef2ad..a1024d28 100644
--- a/R/trends.R
+++ b/R/trends.R
@@ -15,7 +15,9 @@
#'
#'For most dynamic trend types available in `mvgam` (see argument `trend_model`), time should be
#'measured in discrete, regularly spaced intervals (i.e. `c(1, 2, 3, ...)`). However you can
-#'use irregularly spaced intervals if using `trend_model = CAR(1)`. For all trend types
+#'use irregularly spaced intervals if using `trend_model = CAR(1)`, though note that any
+#'temporal intervals that are exactly `0` will be adjusted to a very small number
+#'(`1e-12`) to prevent sampling errors. For all trend types
#'apart from `GP()`, `PW()`, and `CAR()`, moving average and/or correlated
#'process error terms can also be estimated (for example, `RW(cor = TRUE)` will set up a
#'multivariate Random Walk if `data` contains `>1` series). Character strings can also be supplied
@@ -411,14 +413,99 @@ sim_varma = function(drift,
# Stochastic realisations
varma_recursC(A = A,
- A2 = A2,
- A3 = A3,
- theta = theta,
- drift = drift,
- linpreds = linpreds,
- errors = errors,
- last_trends = last_trends,
- h = h)
+ A2 = A2,
+ A3 = A3,
+ theta = theta,
+ drift = drift,
+ linpreds = linpreds,
+ errors = errors,
+ last_trends = last_trends,
+ h = h)
+}
+
+#' Continuous time AR1 simulation function
+#' @noRd
+sim_corcar1 = function(drift,
+ A,
+ A2,
+ A3,
+ theta,
+ Sigma,
+ last_trends,
+ last_errors,
+ Xp_trend = NULL,
+ betas_trend = NULL,
+ h,
+ time_dis){
+
+ # Validate dimensions
+ validate_equaldims(A, Sigma)
+ validate_equaldims(A2, Sigma)
+ validate_equaldims(A3, Sigma)
+ validate_equaldims(theta, Sigma)
+
+ if(NROW(last_trends) != 3){
+ stop('Last 3 state estimates are required, in matrix form',
+ call. = FALSE)
+ }
+
+ if(NROW(last_errors) != 3){
+ stop('Last 3 error estimates are required, in matrix form',
+ call. = FALSE)
+ }
+
+ if(missing(drift)){
+ drift <- rep(0, NROW(A))
+ }
+
+ if(length(drift) != NROW(A)){
+ stop('Number of drift parameters must match number of rows in VAR coefficient matrix "A"',
+ call. = FALSE)
+ }
+
+ # Linear predictor, if supplied
+ if(!is.null(Xp_trend)){
+ linpreds <- as.vector(((matrix(Xp_trend,
+ ncol = NCOL(Xp_trend)) %*%
+ betas_trend)) +
+ attr(Xp_trend, 'model.offset'))
+ linpreds <- matrix(linpreds, ncol = NROW(A),
+ byrow = TRUE)
+ if(NROW(linpreds) != h + 3){
+ stop('trend linear predictor matrix should be h + 3 rows in dimension',
+ call. = FALSE)
+ }
+ } else {
+ linpreds <- matrix(0, nrow = h + 3, ncol = NROW(A))
+ }
+
+ # Draw forecast errors
+ errors <- rbind(last_errors,
+ mvnfast::rmvn(h,
+ mu = rep(0, NROW(A)),
+ sigma = Sigma))
+
+ # Stochastic realisations (will move to c++ eventually)
+ states <- matrix(NA, nrow = h + 3, ncol = NCOL(A))
+ states[1, ] <- last_trends[1, ]
+ states[2, ] <- last_trends[2, ]
+ states[3, ] <- last_trends[3, ]
+ for(t in 4:NROW(states)){
+ states[t, ] <-
+
+ # autoregressive means
+ (states[t - 1, ] - linpreds[t - 1, ]) %*% (A ^ time_dis[t - 3, ]) +
+
+ # linear predictor contributions
+ linpreds[t, ] +
+
+ # drift terms
+ drift +
+
+ # stochastic errors
+ errors[t, ]
+ }
+ states[4:NROW(states), ]
}
#' Generic function to take outputs from different trend models
@@ -721,13 +808,6 @@ extract_trend_pars = function(object, keep_all_estimates = TRUE,
out <- list()
}
- # time distance calculations for CAR1 trends
- if(attr(object$model_data, 'trend_model') == 'CAR1'){
- out$time_dist <- lapply(seq_along(levels(object$obs_data$series)), function(series){
- t(replicate(NROW(out$ar1[[1]]), object$trend_model$time_dist[,series]))
- })
- }
-
# delta params for piecewise trends
if(attr(object$model_data, 'trend_model') %in%
c('PWlinear', 'PWlogistic')){
@@ -969,10 +1049,10 @@ extract_general_trend_pars = function(samp_index, trend_pars){
if(names(trend_pars)[x] %in% c('last_lvs', 'lv_coefs', 'last_trends',
'A', 'Sigma', 'theta', 'b_gp', 'error',
- 'delta_trend', 'cap', 'time_dist')){
+ 'delta_trend', 'cap')){
if(names(trend_pars)[x] %in% c('last_lvs', 'lv_coefs', 'last_trends',
- 'b_gp', 'delta_trend', 'cap', 'time_dist')){
+ 'b_gp', 'delta_trend', 'cap')){
out <- unname(lapply(trend_pars[[x]], `[`, samp_index, ))
}
@@ -1053,69 +1133,45 @@ extract_series_trend_pars = function(series, samp_index, trend_pars,
forecast_trend = function(trend_model, use_lv, trend_pars,
Xp_trend = NULL, betas_trend = NULL,
h = 1, time = NULL, cap = NULL,
- time_dist = NULL){
+ time_dis = NULL){
# Propagate dynamic factors forward
if(use_lv){
n_lv <- length(trend_pars$last_lvs)
- if(trend_model %in% c('RW', 'AR1', 'AR2', 'AR3', 'CAR1')){
- next_lvs <- do.call(cbind, lapply(seq_len(n_lv), function(lv){
-
- ar1 <- ifelse('ar1' %in% names(trend_pars),
- trend_pars$ar1[lv],
- 0)
- if(trend_model == 'RW'){
- ar1 <- 1
- }
-
- if(!is.null(Xp_trend)){
- inds_keep <- seq(lv, NROW(Xp_trend), by = n_lv)
- Xp_trend_sub = Xp_trend[inds_keep, ]
- attr(Xp_trend_sub, 'model.offset') <- attr(Xp_trend, 'model.offset')[inds_keep]
- attr(Xp_trend_sub, 'model.offset')[is.na(attr(Xp_trend_sub, 'model.offset'))] <- 0
- } else {
- Xp_trend_sub <- NULL
- }
+ if(trend_model == 'CAR1'){
+ ar1 <- trend_pars$ar1
+ Sigma <- rlang::missing_arg()
+ if('drift' %in% names(trend_pars)){
+ drift <- trend_pars$drift
+ } else {
+ drift <- rep(0, length(ar1))
+ }
+ varma_params <- prep_varma_params(drift = drift,
+ ar1 = ar1,
+ ar2 = 0,
+ ar3 = 0,
+ theta = 0,
+ Sigma = Sigma,
+ tau = trend_pars$tau,
+ Xp_trend = Xp_trend,
+ betas_trend = betas_trend,
+ last_trends = do.call(cbind,(lapply(trend_pars$last_lvs,
+ function(x) tail(x, 3)))),
+ h = h)
- # Prep VARMA parameters
- if('Sigma' %in% names(trend_pars)){
- Sigma <- trend_pars$Sigma
- } else {
- Sigma <- rlang::missing_arg()
- }
- varma_params <- prep_varma_params(drift = ifelse('drift' %in% names(trend_pars),
- trend_pars$drift[lv],
- 0),
- ar1 = ar1,
- ar2 = ifelse('ar2' %in% names(trend_pars),
- trend_pars$ar2[lv],
- 0),
- ar3 = ifelse('ar3' %in% names(trend_pars),
- trend_pars$ar3[lv],
- 0),
- theta = ifelse('theta' %in% names(trend_pars),
- trend_pars$theta[lv],
- 0),
- tau = trend_pars$tau[lv],
- Sigma = Sigma,
- Xp_trend = Xp_trend_sub,
- betas_trend = betas_trend,
- last_trends = tail(trend_pars$last_lvs[[lv]], 3),
- h = h)
-
- # Propagate forward
- sim_varma(A = varma_params$A,
- A2 = varma_params$A2,
- A3 = varma_params$A3,
- drift = varma_params$drift,
- theta = varma_params$theta,
- Sigma = varma_params$Sigma,
- last_trends = varma_params$last_trends,
- last_errors = varma_params$last_errors,
- Xp_trend = varma_params$Xp_trend,
- betas_trend = varma_params$betas_trend,
- h = varma_params$h)
- }))
+ # Propagate forward
+ next_lvs <- sim_corcar1(A = varma_params$A,
+ A2 = varma_params$A2,
+ A3 = varma_params$A3,
+ drift = varma_params$drift,
+ theta = varma_params$theta,
+ Sigma = varma_params$Sigma,
+ last_trends = varma_params$last_trends,
+ last_errors = varma_params$last_errors,
+ Xp_trend = varma_params$Xp_trend,
+ betas_trend = varma_params$betas_trend,
+ h = varma_params$h,
+ time_dis = time_dis)
}
if(trend_model == 'GP'){
@@ -1216,49 +1272,34 @@ forecast_trend = function(trend_model, use_lv, trend_pars,
betas_trend = varma_params$betas_trend,
h = varma_params$h)
}
-
trend_fc <- next_lvs
}
# Simpler if not using dynamic factors
if(!use_lv){
-
- if(trend_model %in% c('RW', 'AR1', 'AR2', 'AR3', 'CAR1')){
-
- ar1 <- ifelse('ar1' %in% names(trend_pars),
- trend_pars$ar1, 0)
- if(trend_model == 'RW'){
- ar1 <- 1
- }
-
- # Construct VARMA parameters
- if('Sigma' %in% names(trend_pars)){
- Sigma <- trend_pars$Sigma
+ if(trend_model == 'CAR1'){
+ ar1 <- trend_pars$ar1
+ Sigma <- rlang::missing_arg()
+ if('drift' %in% names(trend_pars)){
+ drift <- trend_pars$drift
} else {
- Sigma <- rlang::missing_arg()
+ drift <- rep(0, length(ar1))
}
- varma_params <- prep_varma_params(drift = ifelse('drift' %in% names(trend_pars),
- trend_pars$drift,
- 0),
+ varma_params <- prep_varma_params(drift = drift,
ar1 = ar1,
- ar2 = ifelse('ar2' %in% names(trend_pars),
- trend_pars$ar2,
- 0),
- ar3 = ifelse('ar3' %in% names(trend_pars),
- trend_pars$ar3,
- 0),
- theta = ifelse('theta' %in% names(trend_pars),
- trend_pars$theta,
- 0),
+ ar2 = 0,
+ ar3 = 0,
+ theta = 0,
Sigma = Sigma,
tau = trend_pars$tau,
Xp_trend = Xp_trend,
betas_trend = betas_trend,
- last_trends = tail(trend_pars$last_trends, 3),
+ last_trends = do.call(cbind,(lapply(trend_pars$last_lvs,
+ function(x) tail(x, 3)))),
h = h)
# Propagate forward
- trend_fc <- sim_varma(A = varma_params$A,
+ trend_fc <- sim_corcar1(A = varma_params$A,
A2 = varma_params$A2,
A3 = varma_params$A3,
drift = varma_params$drift,
@@ -1268,7 +1309,8 @@ forecast_trend = function(trend_model, use_lv, trend_pars,
last_errors = varma_params$last_errors,
Xp_trend = varma_params$Xp_trend,
betas_trend = varma_params$betas_trend,
- h = varma_params$h)
+ h = varma_params$h,
+ time_dis = time_dis)
}
if(trend_model == 'GP'){
diff --git a/R/validations.R b/R/validations.R
index c62534b1..383f7c23 100644
--- a/R/validations.R
+++ b/R/validations.R
@@ -119,12 +119,12 @@ sort_data = function(data, series_time = FALSE){
if(inherits(data, 'list')){
data_arranged <- data
if(series_time){
- temp_dat = data.frame(time = data$time,
+ temp_dat = data.frame(time = data$index..time..index,
series = data$series) %>%
dplyr::mutate(index = dplyr::row_number()) %>%
dplyr::arrange(series, time)
} else {
- temp_dat = data.frame(time = data$time,
+ temp_dat = data.frame(time = data$index..time..index,
series = data$series) %>%
dplyr::mutate(index = dplyr::row_number()) %>%
dplyr::arrange(time, series)
@@ -141,10 +141,10 @@ sort_data = function(data, series_time = FALSE){
} else {
if(series_time){
data_arranged <- data %>%
- dplyr::arrange(series, time)
+ dplyr::arrange(series, index..time..index)
} else {
data_arranged <- data %>%
- dplyr::arrange(time, series)
+ dplyr::arrange(index..time..index, series)
}
}
@@ -555,8 +555,8 @@ validate_trend_restrictions = function(trend_model,
trend = 1:length(unique(data_train$series)))
}
- if(!trend_model %in% c('RW', 'AR1', 'AR2', 'AR3', 'VAR1')){
- stop('only RW, AR1, AR2, AR3 and VAR trends currently supported for trend predictor models',
+ if(!trend_model %in% c('RW', 'AR1', 'AR2', 'AR3', 'VAR1', 'CAR1')){
+ stop('only RW, AR1, AR2, AR3, CAR1 and VAR trends currently supported for trend predictor models',
call. = FALSE)
}
}
diff --git a/README.Rmd b/README.Rmd
index 9f386cf7..f6662952 100644
--- a/README.Rmd
+++ b/README.Rmd
@@ -148,19 +148,19 @@ and for the latent trend parameters
mcmc_plot(lynx_mvgam, variable = 'trend_params', regex = TRUE, type = 'trace')
```
-Use posterior predictive checks to see if the model can simulate data that looks realistic and unbiased. First, examine histograms for posterior retrodictions (`yhat`) and compare to the histogram of the observations (`y`)
+Use posterior predictive checks, which capitalize on the extensive functionality of the `bayesplot` package, to see if the model can simulate data that looks realistic and unbiased. First, examine histograms for posterior retrodictions (`yhat`) and compare to the histogram of the observations (`y`)
```{r, fig.alt = "Posterior predictive checks for discrete time series in R"}
-ppc(lynx_mvgam, series = 1, type = 'hist')
+pp_check(lynx_mvgam, type = "hist", ndraws = 5)
```
Next examine simulated empirical Cumulative Distribution Functions (CDF) for posterior predictions
```{r, fig.alt = "Posterior predictive checks for discrete time series in R"}
-ppc(lynx_mvgam, series = 1, type = 'cdf')
+pp_check(lynx_mvgam, type = "ecdf_overlay", ndraws = 25)
```
-Rootograms are [popular graphical tools for checking a discrete model's ability to capture dispersion properties of the response variable](https://arxiv.org/pdf/1605.01311.pdf){target="_blank"}. Posterior predictive hanging rootograms can be displayed using the `ppc()` function. In the plot below, we bin the unique observed values into `25` bins to prevent overplotting and help with interpretation. This plot compares the frequencies of observed vs predicted values for each bin. For example, if the gray bars (representing observed frequencies) tend to stretch below zero, this suggests the model's simulations predict the values in that particular bin less frequently than they are observed in the data. A well-fitting model that can generate realistic simulated data will provide a rootogram in which the lower boundaries of the grey bars are generally near zero
+Rootograms are [popular graphical tools for checking a discrete model's ability to capture dispersion properties of the response variable](https://arxiv.org/pdf/1605.01311.pdf){target="_blank"}. Posterior predictive hanging rootograms can be displayed using the `ppc()` function. In the plot below, we bin the unique observed values into `25` bins to prevent overplotting and help with interpretation. This plot compares the frequencies of observed vs predicted values for each bin. For example, if the gray bars (representing observed frequencies) tend to stretch below zero, this suggests the model's simulations predict the values in that particular bin less frequently than they are observed in the data. A well-fitting model that can generate realistic simulated data will provide a rootogram in which the lower boundaries of the grey bars are generally near zero. For this plot we use the `S3` function `ppc.mvgam()`, which is not as versatile as `pp_check()` but allows us to bin rootograms to avoid overplotting
```{r, fig.alt = "Posterior predictive rootograms for discrete time series in R"}
-ppc(lynx_mvgam, series = 1, type = 'rootogram', n_bins = 25)
+ppc(lynx_mvgam, type = "rootogram", n_bins = 25)
```
All plots indicate the model is well calibrated against the training data. Inspect the estimated cyclic smooth, which is shown as a ribbon plot of posterior empirical quantiles. We can also overlay posterior quantiles of partial residuals (shown in red), which represent the leftover variation that the model expects would remain if this smooth term was dropped but all other parameters remained unchanged. A strong pattern in the partial residuals suggests there would be strong patterns left unexplained in the model *if* we were to drop this term, giving us further confidence that this function is important in the model
diff --git a/README.md b/README.md
index a9273108..b0c5ad27 100644
--- a/README.md
+++ b/README.md
@@ -327,29 +327,29 @@ summary(lynx_mvgam)
#>
#>
#> GAM coefficient (beta) estimates:
-#> 2.5% 50% 97.5% Rhat n_eff
-#> (Intercept) 6.1000 6.60000 7.000 1.01 758
-#> s(season).1 -0.6400 0.02600 0.690 1.00 810
-#> s(season).2 -0.2400 0.81000 1.800 1.02 227
-#> s(season).3 -0.0084 1.20000 2.400 1.02 219
-#> s(season).4 -0.5000 0.44000 1.300 1.00 680
-#> s(season).5 -1.2000 -0.10000 0.950 1.02 444
-#> s(season).6 -1.1000 0.00088 1.100 1.01 564
-#> s(season).7 -0.7300 0.36000 1.500 1.00 673
-#> s(season).8 -0.9800 0.24000 1.800 1.02 337
-#> s(season).9 -1.2000 -0.29000 0.680 1.02 450
-#> s(season).10 -1.4000 -0.66000 -0.025 1.01 451
+#> 2.5% 50% 97.5% Rhat n_eff
+#> (Intercept) 6.100 6.600 7.000 1.01 657
+#> s(season).1 -0.610 0.029 0.690 1.01 859
+#> s(season).2 -0.250 0.830 1.800 1.00 609
+#> s(season).3 -0.088 1.200 2.500 1.01 436
+#> s(season).4 -0.560 0.440 1.400 1.01 753
+#> s(season).5 -1.100 -0.180 0.850 1.01 573
+#> s(season).6 -1.100 -0.012 1.100 1.00 682
+#> s(season).7 -0.670 0.390 1.500 1.00 677
+#> s(season).8 -0.920 0.330 1.800 1.01 352
+#> s(season).9 -1.100 -0.260 0.700 1.01 476
+#> s(season).10 -1.400 -0.690 0.011 1.00 797
#>
#> Approximate significance of GAM smooths:
#> edf Ref.df Chi.sq p-value
-#> s(season) 9.97 10 19379 0.24
+#> s(season) 9.91 10 19445 0.23
#>
#> Latent trend AR parameter estimates:
#> 2.5% 50% 97.5% Rhat n_eff
-#> ar1[1] 0.75 1.10 1.40 1.01 703
-#> ar2[1] -0.85 -0.40 0.04 1.00 1839
-#> ar3[1] -0.47 -0.13 0.31 1.01 393
-#> sigma[1] 0.41 0.51 0.64 1.00 1027
+#> ar1[1] 0.73 1.10 1.400 1.00 637
+#> ar2[1] -0.82 -0.39 0.036 1.00 1604
+#> ar3[1] -0.47 -0.11 0.290 1.01 619
+#> sigma[1] 0.40 0.50 0.630 1.00 1115
#>
#> Stan MCMC diagnostics:
#> n_eff / iter looks reasonable for all parameters
@@ -358,7 +358,7 @@ summary(lynx_mvgam)
#> 0 of 2000 iterations saturated the maximum tree depth of 12 (0%)
#> E-FMI indicated no pathological behavior
#>
-#> Samples were drawn using NUTS(diag_e) at Thu Mar 28 8:44:38 PM 2024.
+#> Samples were drawn using NUTS(diag_e) at Fri Apr 12 8:58:47 PM 2024.
#> For each parameter, n_eff is a crude measure of effective sample size,
#> and Rhat is the potential scale reduction factor on split MCMC chains
#> (at convergence, Rhat = 1)
@@ -384,13 +384,15 @@ mcmc_plot(lynx_mvgam, variable = 'trend_params', regex = TRUE, type = 'trace')
-Use posterior predictive checks to see if the model can simulate data
-that looks realistic and unbiased. First, examine histograms for
-posterior retrodictions (`yhat`) and compare to the histogram of the
-observations (`y`)
+Use posterior predictive checks, which capitalize on the extensive
+functionality of the `bayesplot` package, to see if the model can
+simulate data that looks realistic and unbiased. First, examine
+histograms for posterior retrodictions (`yhat`) and compare to the
+histogram of the observations (`y`)
``` r
-ppc(lynx_mvgam, series = 1, type = 'hist')
+pp_check(lynx_mvgam, type = "hist", ndraws = 5)
+#> `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
```
@@ -399,7 +401,7 @@ Next examine simulated empirical Cumulative Distribution Functions (CDF)
for posterior predictions
``` r
-ppc(lynx_mvgam, series = 1, type = 'cdf')
+pp_check(lynx_mvgam, type = "ecdf_overlay", ndraws = 25)
```
@@ -417,10 +419,12 @@ below zero, this suggests the model’s simulations predict the values in
that particular bin less frequently than they are observed in the data.
A well-fitting model that can generate realistic simulated data will
provide a rootogram in which the lower boundaries of the grey bars are
-generally near zero
+generally near zero. For this plot we use the `S3` function
+`ppc.mvgam()`, which is not as versatile as `pp_check()` but allows us
+to bin rootograms to avoid overplotting
``` r
-ppc(lynx_mvgam, series = 1, type = 'rootogram', n_bins = 25)
+ppc(lynx_mvgam, type = "rootogram", n_bins = 25)
```
@@ -480,7 +484,7 @@ plot(lynx_mvgam, type = 'forecast', newdata = lynx_test)
#> Out of sample CRPS:
- #> [1] 2942.292
+ #> [1] 2776.972
And the estimated latent trend component, again using the more flexible
`plot_mvgam_...()` option to show first derivatives of the estimated
@@ -641,7 +645,7 @@ summary(mod, include_betas = FALSE)
#> 0 of 2000 iterations saturated the maximum tree depth of 12 (0%)
#> E-FMI indicated no pathological behavior
#>
-#> Samples were drawn using NUTS(diag_e) at Thu Mar 28 8:46:19 PM 2024.
+#> Samples were drawn using NUTS(diag_e) at Fri Apr 12 9:00:27 PM 2024.
#> For each parameter, n_eff is a crude measure of effective sample size,
#> and Rhat is the potential scale reduction factor on split MCMC chains
#> (at convergence, Rhat = 1)
diff --git a/man/RW.Rd b/man/RW.Rd
index 6dd797d4..56402ca4 100644
--- a/man/RW.Rd
+++ b/man/RW.Rd
@@ -38,3 +38,64 @@ in \code{mvgam}. These functions do not evaluate their arguments –
they exist purely to help set up a model with particular autoregressive
trend models.
}
+\examples{
+\dontrun{
+# A short example to illustrate CAR(1) models
+# Function to simulate CAR1 data with seasonality
+sim_corcar1 = function(n = 120,
+ phi = 0.5,
+ sigma = 1,
+ sigma_obs = 0.75){
+# Sample irregularly spaced time intervals
+time_dis <- c(0, runif(n - 1, -0.1, 1))
+time_dis[time_dis < 0] <- 0; time_dis <- time_dis * 5
+
+# Set up the latent dynamic process
+x <- vector(length = n); x[1] <- -0.3
+for(i in 2:n){
+ # zero-distances will cause problems in sampling, so mvgam uses a
+ # minimum threshold; this simulation function emulates that process
+ if(time_dis[i] == 0){
+ x[i] <- rnorm(1, mean = (phi ^ 1e-12) * x[i - 1], sd = sigma)
+ } else {
+ x[i] <- rnorm(1, mean = (phi ^ time_dis[i]) * x[i - 1], sd = sigma)
+ }
+}
+
+# Add 12-month seasonality
+cov1 <- sin(2 * pi * (1 : n) / 12); cov2 <- cos(2 * pi * (1 : n) / 12)
+beta1 <- runif(1, 0.3, 0.7); beta2 <- runif(1, 0.2, 0.5)
+seasonality <- beta1 * cov1 + beta2 * cov2
+
+# Take Gaussian observations with error and return
+data.frame(y = rnorm(n, mean = x + seasonality, sd = sigma),
+ season = rep(1:12, 20)[1:n],
+ time = cumsum(time_dis))
+}
+
+# Sample two time series
+dat <- rbind(dplyr::bind_cols(sim_corcar1(phi = 0.65,
+ sigma_obs = 0.55),
+ data.frame(series = 'series1')),
+ dplyr::bind_cols(sim_corcar1(phi = 0.8,
+ sigma_obs = 0.35),
+ data.frame(series = 'series2'))) \%>\%
+ dplyr::mutate(series = as.factor(series))
+
+# mvgam with CAR(1) trends and series-level seasonal smooths
+mod <- mvgam(formula = y ~ s(season, bs = 'cc',
+ k = 5, by = series),
+ trend_model = CAR(),
+ data = dat,
+ family = gaussian(),
+ run_model = TRUE)
+
+# View usual summaries and plots
+summary(mod)
+conditional_effects(mod, type = 'expected')
+plot(mod, type = 'trend', series = 1)
+plot(mod, type = 'trend', series = 2)
+plot(mod, type = 'residuals', series = 1)
+plot(mod, type = 'residuals', series = 2)
+}
+}
diff --git a/man/figures/README-unnamed-chunk-13-1.png b/man/figures/README-unnamed-chunk-13-1.png
index 09f75cdb..96de31e1 100644
Binary files a/man/figures/README-unnamed-chunk-13-1.png and b/man/figures/README-unnamed-chunk-13-1.png differ
diff --git a/man/figures/README-unnamed-chunk-14-1.png b/man/figures/README-unnamed-chunk-14-1.png
index 36bc2232..e2583aee 100644
Binary files a/man/figures/README-unnamed-chunk-14-1.png and b/man/figures/README-unnamed-chunk-14-1.png differ
diff --git a/man/figures/README-unnamed-chunk-15-1.png b/man/figures/README-unnamed-chunk-15-1.png
index 0aef7e7b..1aeab5e2 100644
Binary files a/man/figures/README-unnamed-chunk-15-1.png and b/man/figures/README-unnamed-chunk-15-1.png differ
diff --git a/man/figures/README-unnamed-chunk-16-1.png b/man/figures/README-unnamed-chunk-16-1.png
index 78dc3d14..b2cd1a3a 100644
Binary files a/man/figures/README-unnamed-chunk-16-1.png and b/man/figures/README-unnamed-chunk-16-1.png differ
diff --git a/man/figures/README-unnamed-chunk-17-1.png b/man/figures/README-unnamed-chunk-17-1.png
index 6afa071d..1b049f45 100644
Binary files a/man/figures/README-unnamed-chunk-17-1.png and b/man/figures/README-unnamed-chunk-17-1.png differ
diff --git a/man/figures/README-unnamed-chunk-18-1.png b/man/figures/README-unnamed-chunk-18-1.png
index c799fe95..373182b5 100644
Binary files a/man/figures/README-unnamed-chunk-18-1.png and b/man/figures/README-unnamed-chunk-18-1.png differ
diff --git a/man/figures/README-unnamed-chunk-19-1.png b/man/figures/README-unnamed-chunk-19-1.png
index a66e6d6d..7704bee0 100644
Binary files a/man/figures/README-unnamed-chunk-19-1.png and b/man/figures/README-unnamed-chunk-19-1.png differ
diff --git a/man/figures/README-unnamed-chunk-20-1.png b/man/figures/README-unnamed-chunk-20-1.png
index 5b25c223..94716e7d 100644
Binary files a/man/figures/README-unnamed-chunk-20-1.png and b/man/figures/README-unnamed-chunk-20-1.png differ
diff --git a/man/figures/README-unnamed-chunk-21-1.png b/man/figures/README-unnamed-chunk-21-1.png
index 40dd2644..02c29214 100644
Binary files a/man/figures/README-unnamed-chunk-21-1.png and b/man/figures/README-unnamed-chunk-21-1.png differ
diff --git a/man/figures/README-unnamed-chunk-22-1.png b/man/figures/README-unnamed-chunk-22-1.png
index 59250c51..84397385 100644
Binary files a/man/figures/README-unnamed-chunk-22-1.png and b/man/figures/README-unnamed-chunk-22-1.png differ
diff --git a/man/figures/README-unnamed-chunk-23-1.png b/man/figures/README-unnamed-chunk-23-1.png
index 24340f18..c54bc834 100644
Binary files a/man/figures/README-unnamed-chunk-23-1.png and b/man/figures/README-unnamed-chunk-23-1.png differ
diff --git a/man/figures/README-unnamed-chunk-24-1.png b/man/figures/README-unnamed-chunk-24-1.png
index 9ba85217..6622ea9c 100644
Binary files a/man/figures/README-unnamed-chunk-24-1.png and b/man/figures/README-unnamed-chunk-24-1.png differ
diff --git a/man/get_mvgam_priors.Rd b/man/get_mvgam_priors.Rd
index 48275de3..d5b1b697 100644
--- a/man/get_mvgam_priors.Rd
+++ b/man/get_mvgam_priors.Rd
@@ -41,10 +41,17 @@ latent abundance}
\item{data}{A \code{dataframe} or \code{list} containing the model response variable and covariates
required by the GAM \code{formula} and optional \code{trend_formula}. Should include columns:
-\code{series} (a \code{factor} index of the series IDs;the number of levels should be identical
+#'\itemize{
+\item\code{series} (a \code{factor} index of the series IDs; the number of levels should be identical
to the number of unique series labels (i.e. \code{n_series = length(levels(data$series))}))
-\code{time} (\code{numeric} or \code{integer} index of the time point for each observation).
-Any other variables to be included in the linear predictor of \code{formula} must also be present}
+\item\code{time} (\code{numeric} or \code{integer} index of the time point for each observation).
+For most dynamic trend types available in \code{mvgam} (see argument \code{trend_model}), time should be
+measured in discrete, regularly spaced intervals (i.e. \code{c(1, 2, 3, ...)}). However you can
+use irregularly spaced intervals if using \code{trend_model = CAR(1)}, though note that any
+temporal intervals that are exactly \code{0} will be adjusted to a very small number
+(\code{1e-12}) to prevent sampling errors. See an example of \code{CAR()} trends in \code{\link{CAR}}
+}
+Should also include any other variables to be included in the linear predictor of \code{formula}}
\item{data_train}{Deprecated. Still works in place of \code{data} but users are recommended to use
\code{data} instead for more seamless integration into \code{R} workflows}
@@ -98,12 +105,13 @@ and the observation process is the only source of error; similarly to what is es
\item \code{'AR1'} or \code{AR(p = 1)}
\item \code{'AR2'} or \code{AR(p = 2)}
\item \code{'AR3'} or \code{AR(p = 3)}
+\item \code{'CAR1'} or \code{CAR(p = 1)}
\item \code{'VAR1'} or \code{VAR()}(only available in \code{Stan})
\item \verb{'PWlogistic}, \code{'PWlinear'} or \code{PW()} (only available in \code{Stan})
\item \code{'GP'} or \code{GP()} (Gaussian Process with squared exponential kernel;
only available in \code{Stan})}
-For all trend types apart from \code{GP()} and \code{PW()}, moving average and/or correlated
+For all trend types apart from \code{GP()}, \code{CAR()} and \code{PW()}, moving average and/or correlated
process error terms can also be estimated (for example, \code{RW(cor = TRUE)} will set up a
multivariate Random Walk if \code{n_series > 1}). See \link{mvgam_trends} for more details}
diff --git a/man/mvgam.Rd b/man/mvgam.Rd
index 9d9d7024..41bfff35 100644
--- a/man/mvgam.Rd
+++ b/man/mvgam.Rd
@@ -74,10 +74,17 @@ functions within the \code{trend_formula}}
\item{data}{A \code{dataframe} or \code{list} containing the model response variable and covariates
required by the GAM \code{formula} and optional \code{trend_formula}. Should include columns:
-\code{series} (a \code{factor} index of the series IDs;the number of levels should be identical
+#'\itemize{
+\item\code{series} (a \code{factor} index of the series IDs; the number of levels should be identical
to the number of unique series labels (i.e. \code{n_series = length(levels(data$series))}))
-\code{time} (\code{numeric} or \code{integer} index of the time point for each observation).
-Any other variables to be included in the linear predictor of \code{formula} must also be present}
+\item\code{time} (\code{numeric} or \code{integer} index of the time point for each observation).
+For most dynamic trend types available in \code{mvgam} (see argument \code{trend_model}), time should be
+measured in discrete, regularly spaced intervals (i.e. \code{c(1, 2, 3, ...)}). However you can
+use irregularly spaced intervals if using \code{trend_model = CAR(1)}, though note that any
+temporal intervals that are exactly \code{0} will be adjusted to a very small number
+(\code{1e-12}) to prevent sampling errors. See an example of \code{CAR()} trends in \code{\link{CAR}}
+}
+Should also include any other variables to be included in the linear predictor of \code{formula}}
\item{data_train}{Deprecated. Still works in place of \code{data} but users are recommended to use
\code{data} instead for more seamless integration into \code{R} workflows}
@@ -99,7 +106,8 @@ simulations from prior distributions are returned}
\item{return_model_data}{\code{logical}. If \code{TRUE}, the list of data that is needed to fit the
model is returned, along with the initial values for smooth and AR parameters, once the model is fitted.
This will be helpful if users wish to modify the model file to add
-other stochastic elements that are not currently avaiable in \code{mvgam}. Default is \code{FALSE} to reduce
+other stochastic elements that are not currently available in \code{mvgam}.
+Default is \code{FALSE} to reduce
the size of the returned object, unless \code{run_model == FALSE}}
\item{family}{\code{family} specifying the exponential observation family for the series. Currently supported
@@ -157,12 +165,13 @@ and the observation process is the only source of error; similarly to what is es
\item \code{'AR1'} or \code{AR(p = 1)}
\item \code{'AR2'} or \code{AR(p = 2)}
\item \code{'AR3'} or \code{AR(p = 3)}
+\item \code{'CAR1'} or \code{CAR(p = 1)}
\item \code{'VAR1'} or \code{VAR()}(only available in \code{Stan})
\item \verb{'PWlogistic}, \code{'PWlinear'} or \code{PW()} (only available in \code{Stan})
\item \code{'GP'} or \code{GP()} (Gaussian Process with squared exponential kernel;
only available in \code{Stan})}
-For all trend types apart from \code{GP()} and \code{PW()}, moving average and/or correlated
+For all trend types apart from \code{GP()}, \code{CAR()} and \code{PW()}, moving average and/or correlated
process error terms can also be estimated (for example, \code{RW(cor = TRUE)} will set up a
multivariate Random Walk if \code{n_series > 1}). See \link{mvgam_trends} for more details}
@@ -309,7 +318,7 @@ are non-identifiable (as in piecewise trends, see examples below)
\code{\link{mvgam_families}}.
\cr
\cr
-\emph{Trend models}: Details of latent trend models supported by \pkg{mvgam} can be found in
+\emph{Trend models}: Details of latent trend dynamic models supported by \pkg{mvgam} can be found in
\code{\link{mvgam_trends}}.
\cr
\cr
@@ -405,7 +414,7 @@ str(mod1$model_data)
model_data <- mod1$model_data
library(rstan)
fit <- stan(model_code = mod1$model_file,
- data = model_data)
+ data = model_data)
# Now using cmdstanr
library(cmdstanr)
@@ -455,9 +464,14 @@ plot(mod1, type = 'smooths', residuals = TRUE)
plot(mod1, type = 'smooths', realisations = TRUE)
# Plot conditional response predictions using marginaleffects
-plot(conditional_effects(mod1), ask = FALSE)
+conditional_effects(mod1)
plot_predictions(mod1, condition = 'season', points = 0.5)
+# Generate posterior predictive checks through bayesplot
+pp_check(mod1)
+pp_check(mod, type = "bars_grouped",
+ group = "series", ndraws = 50)
+
# Extract observation model beta coefficient draws as a data.frame
beta_draws_df <- as.data.frame(mod1, variable = 'betas')
head(beta_draws_df)
diff --git a/man/mvgam_trends.Rd b/man/mvgam_trends.Rd
index db252188..cdf12366 100644
--- a/man/mvgam_trends.Rd
+++ b/man/mvgam_trends.Rd
@@ -19,7 +19,12 @@ and the observation process is the only source of error; similarly to what is es
\item \code{GP()} (Gaussian Process with squared exponential kernel;
only available in \code{Stan})}
-For all types apart from \code{GP()}, \code{PW()}, and \code{CAR()}, moving average and/or correlated
+For most dynamic trend types available in \code{mvgam} (see argument \code{trend_model}), time should be
+measured in discrete, regularly spaced intervals (i.e. \code{c(1, 2, 3, ...)}). However you can
+use irregularly spaced intervals if using \code{trend_model = CAR(1)}, though note that any
+temporal intervals that are exactly \code{0} will be adjusted to a very small number
+(\code{1e-12}) to prevent sampling errors. For all trend types
+apart from \code{GP()}, \code{PW()}, and \code{CAR()}, moving average and/or correlated
process error terms can also be estimated (for example, \code{RW(cor = TRUE)} will set up a
multivariate Random Walk if \code{data} contains \verb{>1} series). Character strings can also be supplied
instead of the various trend functions. The full list of possible models that are
diff --git a/man/posterior_epred.mvgam.Rd b/man/posterior_epred.mvgam.Rd
index 1911f035..a1cb4358 100644
--- a/man/posterior_epred.mvgam.Rd
+++ b/man/posterior_epred.mvgam.Rd
@@ -5,7 +5,14 @@
\alias{posterior_epred}
\title{Draws from the Expected Value of the Posterior Predictive Distribution}
\usage{
-\method{posterior_epred}{mvgam}(object, newdata, data_test, process_error = TRUE, ...)
+\method{posterior_epred}{mvgam}(
+ object,
+ newdata,
+ data_test,
+ ndraws = NULL,
+ process_error = TRUE,
+ ...
+)
}
\arguments{
\item{object}{\code{list} object returned from \code{mvgam}. See \code{\link[=mvgam]{mvgam()}}}
@@ -17,6 +24,9 @@ predictions are generated for the original observations used for the model fit.}
\item{data_test}{Deprecated. Still works in place of \code{newdata} but users are recommended to use
\code{newdata} instead for more seamless integration into \code{R} workflows}
+\item{ndraws}{Positive \code{integer} indicating how many posterior draws should be used.
+If \code{NULL} (the default) all draws are used.}
+
\item{process_error}{Logical. If \code{TRUE} and \code{newdata} is supplied,
expected uncertainty in the process model is accounted for by using draws
from any latent trend SD parameters. If \code{FALSE}, uncertainty in the latent
diff --git a/man/posterior_linpred.mvgam.Rd b/man/posterior_linpred.mvgam.Rd
index dd125dcb..02fa72b5 100644
--- a/man/posterior_linpred.mvgam.Rd
+++ b/man/posterior_linpred.mvgam.Rd
@@ -8,6 +8,7 @@
object,
transform = FALSE,
newdata,
+ ndraws = NULL,
data_test,
process_error = TRUE,
...
@@ -25,6 +26,9 @@ i.e. the conditional expectation, are returned.}
variables included in the linear predictor of \code{formula}. If not supplied,
predictions are generated for the original observations used for the model fit.}
+\item{ndraws}{Positive \code{integer} indicating how many posterior draws should be used.
+If \code{NULL} (the default) all draws are used.}
+
\item{data_test}{Deprecated. Still works in place of \code{newdata} but users are recommended to use
\code{newdata} instead for more seamless integration into \code{R} workflows}
diff --git a/man/posterior_predict.mvgam.Rd b/man/posterior_predict.mvgam.Rd
index 0c437971..0b346ad3 100644
--- a/man/posterior_predict.mvgam.Rd
+++ b/man/posterior_predict.mvgam.Rd
@@ -4,7 +4,14 @@
\alias{posterior_predict.mvgam}
\title{Draws from the Posterior Predictive Distribution}
\usage{
-\method{posterior_predict}{mvgam}(object, newdata, data_test, process_error = TRUE, ...)
+\method{posterior_predict}{mvgam}(
+ object,
+ newdata,
+ data_test,
+ ndraws = NULL,
+ process_error = TRUE,
+ ...
+)
}
\arguments{
\item{object}{\code{list} object returned from \code{mvgam}. See \code{\link[=mvgam]{mvgam()}}}
@@ -16,6 +23,9 @@ predictions are generated for the original observations used for the model fit.}
\item{data_test}{Deprecated. Still works in place of \code{newdata} but users are recommended to use
\code{newdata} instead for more seamless integration into \code{R} workflows}
+\item{ndraws}{Positive \code{integer} indicating how many posterior draws should be used.
+If \code{NULL} (the default) all draws are used.}
+
\item{process_error}{Logical. If \code{TRUE} and \code{newdata} is supplied,
expected uncertainty in the process model is accounted for by using draws
from any latent trend SD parameters. If \code{FALSE}, uncertainty in the latent
diff --git a/man/pp_check.mvgam.Rd b/man/pp_check.mvgam.Rd
new file mode 100644
index 00000000..623f2490
--- /dev/null
+++ b/man/pp_check.mvgam.Rd
@@ -0,0 +1,111 @@
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/ppc.mvgam.R
+\name{pp_check.mvgam}
+\alias{pp_check.mvgam}
+\alias{pp_check}
+\title{Posterior Predictive Checks for \code{mvgam} Objects}
+\usage{
+\method{pp_check}{mvgam}(
+ object,
+ type,
+ ndraws = NULL,
+ prefix = c("ppc", "ppd"),
+ group = NULL,
+ x = NULL,
+ newdata = NULL,
+ ...
+)
+}
+\arguments{
+\item{object}{An object of class \code{mvgam}.}
+
+\item{type}{Type of the ppc plot as given by a character string.
+See \code{\link[bayesplot:PPC-overview]{PPC}} for an overview
+of currently supported types. You may also use an invalid
+type (e.g. \code{type = "xyz"}) to get a list of supported
+types in the resulting error message.}
+
+\item{ndraws}{Positive integer indicating how many
+posterior draws should be used.
+If \code{NULL} all draws are used. If not specified,
+the number of posterior draws is chosen automatically.
+Ignored if \code{draw_ids} is not \code{NULL}.}
+
+\item{prefix}{The prefix of the \pkg{bayesplot} function to be applied.
+Either `"ppc"` (posterior predictive check; the default)
+or `"ppd"` (posterior predictive distribution), the latter being the same
+as the former except that the observed data is not shown for `"ppd"`.}
+
+\item{group}{Optional name of a factor variable in the model
+by which to stratify the ppc plot. This argument is required for
+ppc \code{*_grouped} types and ignored otherwise.}
+
+\item{x}{Optional name of a variable in the model.
+Only used for ppc types having an \code{x} argument
+and ignored otherwise.}
+
+\item{newdata}{Optional \code{dataframe} or \code{list} of test data containing the
+variables included in the linear predictor of \code{formula}. If not supplied,
+predictions are generated for the original observations used for the model fit.}
+
+\item{...}{Further arguments passed to \code{\link{predict.mvgam}}
+as well as to the PPC function specified in \code{type}.}
+}
+\value{
+A ggplot object that can be further
+customized using the \pkg{ggplot2} package.
+}
+\description{
+Perform posterior predictive checks with the help
+of the \pkg{bayesplot} package.
+}
+\details{
+For a detailed explanation of each of the ppc functions,
+see the \code{\link[bayesplot:PPC-overview]{PPC}}
+documentation of the \pkg{\link[bayesplot:bayesplot-package]{bayesplot}}
+package.
+}
+\examples{
+\dontrun{
+simdat <- sim_mvgam(seasonality = 'hierarchical')
+mod <- mvgam(y ~ series +
+ s(season, bs = 'cc', k = 6) +
+ s(season, series, bs = 'sz', k = 4),
+ data = simdat$data_train)
+
+# Get a list of available plot types
+pp_check(mod, type = "xyz")
+
+# Default is a density overlay for all observations
+pp_check(mod)
+
+# Rootograms particularly useful for count data
+pp_check(mod, type = "rootogram")
+
+# Grouping plots by series is useful
+pp_check(mod, type = "bars_grouped",
+ group = "series", ndraws = 50)
+pp_check(mod, type = "ecdf_overlay_grouped",
+ group = "series", ndraws = 50)
+pp_check(mod, type = "stat_freqpoly_grouped",
+ group = "series", ndraws = 50)
+
+# Custom functions accepted
+prop_zero <- function(x) mean(x == 0)
+pp_check(mod, type = "stat", stat = "prop_zero")
+pp_check(mod, type = "stat_grouped",
+ stat = "prop_zero",
+ group = "series")
+
+# Some functions accept covariates to set the x-axes
+pp_check(mod, x = "season",
+ type = "ribbon_grouped",
+ prob = 0.5,
+ prob_outer = 0.8,
+ group = "series")
+
+# Many plots can be made without the observed data
+pp_check(mod, prefix = "ppd")
+}
+
+}
diff --git a/man/predict.mvgam.Rd b/man/predict.mvgam.Rd
index 2c9d8832..66b8d1f7 100644
--- a/man/predict.mvgam.Rd
+++ b/man/predict.mvgam.Rd
@@ -4,7 +4,17 @@
\alias{predict.mvgam}
\title{Predict from the GAM component of an mvgam model}
\usage{
-\method{predict}{mvgam}(object, newdata, data_test, type = "link", process_error = TRUE, ...)
+\method{predict}{mvgam}(
+ object,
+ newdata,
+ data_test,
+ type = "link",
+ process_error = TRUE,
+ summary = TRUE,
+ robust = FALSE,
+ probs = c(0.025, 0.975),
+ ...
+)
}
\arguments{
\item{object}{\code{list} object returned from \code{mvgam}. See \code{\link[=mvgam]{mvgam()}}}
@@ -30,11 +40,33 @@ expected uncertainty in the process model is accounted for by using draws
from the latent trend SD parameters. If \code{FALSE}, uncertainty in the latent trend
component is ignored when calculating predictions}
+\item{summary}{Should summary statistics be returned
+instead of the raw values? Default is \code{TRUE}..}
+
+\item{robust}{If \code{FALSE} (the default) the mean is used as
+the measure of central tendency and the standard deviation as
+the measure of variability. If \code{TRUE}, the median and the
+median absolute deviation (MAD) are applied instead.
+Only used if \code{summary} is \code{TRUE}.}
+
+\item{probs}{The percentiles to be computed by the \code{quantile}
+function. Only used if \code{summary} is \code{TRUE}.}
+
\item{...}{Ignored}
}
\value{
-A \code{matrix} of dimension \code{n_samples x new_obs}, where \code{n_samples} is the number of
-posterior samples from the fitted object and \code{n_obs} is the number of test observations in \code{newdata}
+Predicted values on the appropriate scale.
+If \code{summary = FALSE} the output is a matrix of dimension \verb{n_draw x n_observations}
+containing predicted values for each posterior draw in \code{object}.
+
+If \code{summary = TRUE} the output is an \code{n_observations} x \code{E}
+matrix. The number of summary statistics \code{E} is equal to \code{2 +
+ length(probs)}: The \code{Estimate} column contains point estimates (either
+mean or median depending on argument \code{robust}), while the
+\code{Est.Error} column contains uncertainty estimates (either standard
+deviation or median absolute deviation depending on argument
+\code{robust}). The remaining columns starting with \code{Q} contain
+quantile estimates as specified via argument \code{probs}.
}
\description{
Predict from the GAM component of an mvgam model
diff --git a/man/update.mvgam.Rd b/man/update.mvgam.Rd
index 7e85fe90..d52bcae2 100644
--- a/man/update.mvgam.Rd
+++ b/man/update.mvgam.Rd
@@ -45,10 +45,17 @@ latent abundance}
\item{data}{A \code{dataframe} or \code{list} containing the model response variable and covariates
required by the GAM \code{formula} and optional \code{trend_formula}. Should include columns:
-\code{series} (a \code{factor} index of the series IDs;the number of levels should be identical
+#'\itemize{
+\item\code{series} (a \code{factor} index of the series IDs; the number of levels should be identical
to the number of unique series labels (i.e. \code{n_series = length(levels(data$series))}))
-\code{time} (\code{numeric} or \code{integer} index of the time point for each observation).
-Any other variables to be included in the linear predictor of \code{formula} must also be present}
+\item\code{time} (\code{numeric} or \code{integer} index of the time point for each observation).
+For most dynamic trend types available in \code{mvgam} (see argument \code{trend_model}), time should be
+measured in discrete, regularly spaced intervals (i.e. \code{c(1, 2, 3, ...)}). However you can
+use irregularly spaced intervals if using \code{trend_model = CAR(1)}, though note that any
+temporal intervals that are exactly \code{0} will be adjusted to a very small number
+(\code{1e-12}) to prevent sampling errors. See an example of \code{CAR()} trends in \code{\link{CAR}}
+}
+Should also include any other variables to be included in the linear predictor of \code{formula}}
\item{newdata}{Optional \code{dataframe} or \code{list} of test data containing at least \code{series} and \code{time}
in addition to any other variables included in the linear predictor of \code{formula}. If included, the
@@ -63,12 +70,13 @@ and the observation process is the only source of error; similarly to what is es
\item \code{'AR1'} or \code{AR(p = 1)}
\item \code{'AR2'} or \code{AR(p = 2)}
\item \code{'AR3'} or \code{AR(p = 3)}
+\item \code{'CAR1'} or \code{CAR(p = 1)}
\item \code{'VAR1'} or \code{VAR()}(only available in \code{Stan})
\item \verb{'PWlogistic}, \code{'PWlinear'} or \code{PW()} (only available in \code{Stan})
\item \code{'GP'} or \code{GP()} (Gaussian Process with squared exponential kernel;
only available in \code{Stan})}
-For all trend types apart from \code{GP()} and \code{PW()}, moving average and/or correlated
+For all trend types apart from \code{GP()}, \code{CAR()} and \code{PW()}, moving average and/or correlated
process error terms can also be estimated (for example, \code{RW(cor = TRUE)} will set up a
multivariate Random Walk if \code{n_series > 1}). See \link{mvgam_trends} for more details}
diff --git a/src/mvgam.dll b/src/mvgam.dll
index a2379794..6af28a38 100644
Binary files a/src/mvgam.dll and b/src/mvgam.dll differ
diff --git a/tests/testthat/Rplots.pdf b/tests/testthat/Rplots.pdf
index f14a73f0..3c9b6206 100644
Binary files a/tests/testthat/Rplots.pdf and b/tests/testthat/Rplots.pdf differ
diff --git a/tests/testthat/test-mvgam-methods.R b/tests/testthat/test-mvgam-methods.R
index 33302eca..30ff2a58 100644
--- a/tests/testthat/test-mvgam-methods.R
+++ b/tests/testthat/test-mvgam-methods.R
@@ -4,6 +4,9 @@ expect_ggplot <- function(object, ...) {
testthat::expect_true(is(object, "ggplot"), ...)
}
+SM <- suppressMessages
+SW <- suppressWarnings
+
test_that("conditional_effects works properly", {
effects <- conditional_effects(mvgam:::mvgam_example1)
lapply(effects, expect_ggplot)
@@ -12,6 +15,36 @@ test_that("conditional_effects works properly", {
lapply(effects, expect_ggplot)
})
+test_that("mcmc_plot works properly", {
+ expect_ggplot(mcmc_plot(mvgam:::mvgam_example1, type = "dens"))
+ expect_ggplot(mcmc_plot(mvgam:::mvgam_example1,
+ type = "scatter",
+ variable = variables(mvgam:::mvgam_example1)$observation_betas[2:3, 1]))
+ expect_error(mcmc_plot(mvgam:::mvgam_example1, type = "density"), "Invalid plot type")
+ expect_ggplot(SW(mcmc_plot(mvgam:::mvgam_example2, type = "neff")))
+ expect_ggplot(mcmc_plot(mvgam:::mvgam_example3, type = "acf"))
+ expect_silent(p <- mcmc_plot(mvgam:::mvgam_example3, type = "areas"))
+ expect_error(mcmc_plot(mvgam:::mvgam_example3, type = "hex"),
+ "Exactly 2 parameters must be selected")
+ expect_ggplot(mcmc_plot(mvgam:::mvgam_example4))
+})
+
+test_that("pp_check works properly", {
+ expect_ggplot(SM(pp_check(mvgam:::mvgam_example1)))
+ expect_ggplot(SM(pp_check(mvgam:::mvgam_example1,
+ newdata = mvgam:::mvgam_example1$obs_data[1:10, ])))
+ expect_ggplot(pp_check(mvgam:::mvgam_example2, "stat", ndraws = 5))
+ expect_ggplot(SM(pp_check(mvgam:::mvgam_example3, "error_binned")))
+ pp <- pp_check(mvgam:::mvgam_example4,
+ type = "ribbon",
+ x = "season")
+ expect_ggplot(pp)
+ pp <- pp_check(mvgam:::mvgam_example2, type = "violin_grouped",
+ group = "season",
+ newdata = mvgam:::mvgam_example2$obs_data[1:10, ])
+ expect_ggplot(pp)
+ expect_ggplot(pp_check(mvgam:::mvgam_example4, prefix = "ppd"))
+})
test_that("model.frame gives expected output structure", {
mod_data <- model.frame(gaus_ar1)
@@ -66,11 +99,11 @@ test_that("logLik has reasonable ouputs", {
})
test_that("predict has reasonable outputs", {
- gaus_preds <- predict(gaus_ar1, type = 'link')
+ gaus_preds <- predict(gaus_ar1, type = 'link', summary = FALSE)
expect_equal(dim(gaus_preds),
c(1200, NROW(gaus_data$data_train)))
- beta_preds <- predict(beta_gp, type = 'response')
+ beta_preds <- predict(beta_gp, type = 'response', summary = FALSE)
expect_equal(dim(beta_preds),
c(300, NROW(beta_data$data_train)))
expect_lt(max(beta_preds), 1.00000001)
@@ -82,7 +115,8 @@ test_that("predict has reasonable outputs", {
})
test_that("get_predict has reasonable outputs", {
-gaus_preds <- predict(gaus_ar1, type = 'link', process_error = FALSE)
+gaus_preds <- predict(gaus_ar1, type = 'link', process_error = FALSE,
+ summary = FALSE)
meffects_preds <- get_predict(gaus_ar1, type = 'link')
expect_true(NROW(meffects_preds) == NCOL(gaus_preds))
expect_true(identical(meffects_preds$estimate,