diff --git a/R/update.mvgam.R b/R/update.mvgam.R index 69a9992c..76335bcf 100644 --- a/R/update.mvgam.R +++ b/R/update.mvgam.R @@ -40,9 +40,29 @@ update.mvgam = function(object, family, share_obs_params, priors, + chains, + burnin, + samples, + algorithm, lfo = FALSE, ...){ + if(missing(chains)){ + chains <- object$model_output@sim$chains + } + + if(missing(burnin)){ + burnin <- object$model_output@sim$warmup + } + + if(missing(samples)){ + samples <- object$model_output@sim$iter - burnin + } + + if(missing(algorithm)){ + algorithm <- object$algorithm + } + if(missing(formula)){ formula <- object$call @@ -147,6 +167,10 @@ update.mvgam = function(object, use_stan = ifelse(object$fit_engine == 'stan', TRUE, FALSE), priors = priors, + chains = chains, + burnin = burnin, + samples = samples, + algorithm = algorithm, ...) } else { updated_mod <- mvgam(formula = formula, @@ -163,6 +187,10 @@ update.mvgam = function(object, use_stan = ifelse(object$fit_engine == 'stan', TRUE, FALSE), priors = priors, + chains = chains, + burnin = burnin, + samples = samples, + algorithm = algorithm, ...) } diff --git a/man/update.mvgam.Rd b/man/update.mvgam.Rd index 2ac09e0d..66af77a6 100644 --- a/man/update.mvgam.Rd +++ b/man/update.mvgam.Rd @@ -17,6 +17,10 @@ family, share_obs_params, priors, + chains, + burnin, + samples, + algorithm, lfo = FALSE, ... ) @@ -120,6 +124,30 @@ definitions (in JAGS or Stan syntax). if using Stan, this can also be an object class \code{brmsprior} (see. \code{\link[brms]{prior}} for details). See \link{get_mvgam_priors} and 'Details' for more information on changing default prior distributions} +\item{chains}{\code{integer} specifying the number of parallel chains for the model. Ignored +if \code{algorithm \%in\% c('meanfield', 'fullrank', 'pathfinder', 'laplace')}} + +\item{burnin}{\code{integer} specifying the number of warmup iterations of the Markov chain to run +to tune sampling algorithms. Ignored +if \code{algorithm \%in\% c('meanfield', 'fullrank', 'pathfinder', 'laplace')}} + +\item{samples}{\code{integer} specifying the number of post-warmup iterations of the Markov chain to run for +sampling the posterior distribution} + +\item{algorithm}{Character string naming the estimation approach to use. +Options are \code{"sampling"} for MCMC (the default), \code{"meanfield"} for +variational inference with factorized normal distributions, +\code{"fullrank"} for variational inference with a multivariate normal +distribution, \code{"laplace"} for a Laplace approximation (only available +when using cmdstanr as the backend) or \code{"pathfinder"} for the pathfinder +algorithm (only currently available when using cmdstanr as the backend). +Can be set globally for the current \R session via the +\code{"brms.algorithm"} option (see \code{\link{options}}). Limited testing +suggests that \code{"meanfield"} performs best out of the non-MCMC approximations for +dynamic GAMs, possibly because of the difficulties estimating covariances among the +many spline parameters and latent trend parameters. But rigorous testing has not +been carried out} + \item{lfo}{Logical indicating whether this is part of a call to \link{lfo_cv.mvgam}. Returns a lighter version of the model with no residuals and fewer monitored parameters to speed up post-processing. But other downstream functions will not work properly, so users should always diff --git a/src/mvgam.dll b/src/mvgam.dll index 3dd0a97f..c6f9b13d 100644 Binary files a/src/mvgam.dll and b/src/mvgam.dll differ diff --git a/tests/local/setup_tests_local.R b/tests/local/setup_tests_local.R new file mode 100644 index 00000000..ae3a1f71 --- /dev/null +++ b/tests/local/setup_tests_local.R @@ -0,0 +1,11 @@ +# Setup models for tests locally +library("testthat") +library("mvgam") +set.seed(100) + +expect_match2 <- function(object, regexp) { + any(grepl(regexp, object, fixed = TRUE)) +} + +context("local tests") + diff --git a/tests/local/tests-models1.R b/tests/local/tests-models1.R new file mode 100644 index 00000000..9b082ac7 --- /dev/null +++ b/tests/local/tests-models1.R @@ -0,0 +1,52 @@ +source("setup_tests_local.R") + +test_that("lfo_cv working properly", { + gaus_data <- sim_mvgam(family = gaussian(), + T = 60, + trend_model = 'AR1', + seasonality = 'shared', + mu = c(-1, 0, 1), + prop_trend = 0.5, + prop_missing = 0.2) + gaus_ar1fc <- mvgam(y ~ s(series, bs = 're') + + s(season, bs = 'cc', k = 5) - 1, + trend_model = AR(), + data = gaus_data$data_train, + newdata = gaus_data$data_test, + family = gaussian(), + samples = 300) + + lfcv <- lfo_cv(gaus_ar1fc, min_t = 42) + expect_true(inherits(lfcv, 'mvgam_lfo')) + expect_true(all.equal(lfcv$eval_timepoints, c(43,44))) +}) + +# Beta model with trend_formula (use meanfield to ensure that works) +beta_data <- sim_mvgam(family = betar(), + trend_model = AR(), + prop_trend = 0.5, + T = 60) + +test_that("variational methods working properly", { + beta_gpfc <- mvgam(y ~ series, + trend_formula = ~ s(season, bs = 'cc', k = 5), + trend_model = AR(cor = TRUE), + data = beta_data$data_train, + newdata = beta_data$data_test, + family = betar(), + algorithm = 'meanfield') + expect_true(inherits(beta_gpfc, 'mvgam')) + + beta_gpfc <- mvgam(y ~ series, + trend_formula = ~ s(season, bs = 'cc', k = 5), + trend_model = AR(cor = TRUE), + data = beta_data$data_train, + newdata = beta_data$data_test, + family = betar(), + algorithm = 'fullrank') + expect_true(inherits(beta_gpfc, 'mvgam')) + + loomod <- loo(beta_gpfc) + expect_true(inherits(loomod, 'psis_loo')) +}) + diff --git a/tests/testthat/Rplots.pdf b/tests/testthat/Rplots.pdf index 6c643dcf..0925395a 100644 Binary files a/tests/testthat/Rplots.pdf and b/tests/testthat/Rplots.pdf differ diff --git a/tests/testthat/setup.R b/tests/testthat/setup.R index 9ce4d93d..ff2dc0af 100644 --- a/tests/testthat/setup.R +++ b/tests/testthat/setup.R @@ -36,7 +36,7 @@ gaus_ar1fc <- mvgam(y ~ s(series, bs = 're') + samples = 300, parallel = FALSE) -# Simple Beta models, using variational bayes to ensure this works as well +# Simple Beta models set.seed(100) beta_data <- sim_mvgam(family = betar(), trend_model = 'GP', @@ -47,8 +47,7 @@ beta_gp <- mvgam(y ~ s(season, bs = 'cc'), data = beta_data$data_train, family = betar(), samples = 300, - backend = 'cmdstanr', - algorithm = 'fullrank') + chains = 1) beta_gpfc <- mvgam(y ~ s(season, bs = 'cc'), trend_model = 'GP', data = beta_data$data_train, diff --git a/tests/testthat/test-gp.R b/tests/testthat/test-gp.R index df5947a3..43b84d8f 100644 --- a/tests/testthat/test-gp.R +++ b/tests/testthat/test-gp.R @@ -2,25 +2,25 @@ context("gp") test_that("gp_to_s is working properly", { # All true gp() terms should be changed to s() with k = k+1 -formula <- y ~ s(series) + gp(banana) + - infect:you + gp(hardcourt) - -expect_equal(attr(terms(mvgam:::gp_to_s(formula), keep.order = TRUE), - 'term.labels'), - attr(terms(formula(y ~ s(series) + - s(banana, k = 11) + - infect:you + - s(hardcourt, k = 11)), - keep.order = TRUE), - 'term.labels')) - -# Characters that match to 'gp' should not be changed -formula <- y ~ gp(starwars) + s(gp) -expect_equal(attr(terms(mvgam:::gp_to_s(formula), keep.order = TRUE), - 'term.labels'), - attr(terms(formula(y ~ s(starwars, k = 11) + s(gp)), - keep.order = TRUE), - 'term.labels')) + formula <- y ~ s(series) + gp(banana) + + infect:you + gp(hardcourt) + + expect_equal(attr(terms(mvgam:::gp_to_s(formula), keep.order = TRUE), + 'term.labels'), + attr(terms(formula(y ~ s(series) + + s(banana, k = 11) + + infect:you + + s(hardcourt, k = 11)), + keep.order = TRUE), + 'term.labels')) + + # Characters that match to 'gp' should not be changed + formula <- y ~ gp(starwars) + s(gp) + expect_equal(attr(terms(mvgam:::gp_to_s(formula), keep.order = TRUE), + 'term.labels'), + attr(terms(formula(y ~ s(starwars, k = 11) + s(gp)), + keep.order = TRUE), + 'term.labels')) }) test_that("gp for observation models working properly", {