Skip to content

Commit

Permalink
array updates for newest Cmdstan; more piecewise tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Nicholas Clark committed Jan 19, 2024
1 parent d4e89f2 commit 678d003
Show file tree
Hide file tree
Showing 144 changed files with 931 additions and 555 deletions.
31 changes: 27 additions & 4 deletions R/add_nmixture.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,36 @@ add_nmixture = function(model_file,
stop('Max abundances must be supplied as a variable named "cap" for N-mixture models',
call. = FALSE)
}
cap <- data_train$cap

if(inherits(data_train, 'data.frame')){
cap = data_train %>%
dplyr::arrange(series, time) %>%
dplyr::pull(cap)
} else {
cap = data.frame(series = data_train$series,
cap = data_train$cap,
time = data_train$time)%>%
dplyr::arrange(series, time) %>%
dplyr::pull(cap)
}

if(!is.null(data_test)){
if(!(exists('cap', where = data_test))) {
stop('Max abundances must be supplied in test data as a variable named "cap" for N-mixture models',
call. = FALSE)
}
cap <- c(cap, data_test$cap)
if(inherits(data_test, 'data.frame')){
captest = data_test %>%
dplyr::arrange(series, time) %>%
dplyr::pull(cap)
} else {
captest = data.frame(series = data_test$series,
cap = data_test$cap,
time = data_test$time)%>%
dplyr::arrange(series, time) %>%
dplyr::pull(cap)
}
cap <- c(cap, captest)
}

validate_pos_integers(cap)
Expand All @@ -33,7 +56,7 @@ add_nmixture = function(model_file,

model_data$cap <- as.vector(cap)

if(any(model_data$cap < model_data$y)){
if(any(model_data$cap[model_data$obs_ind] < model_data$flat_ys)){
stop(paste0('Some "cap" terms are < the observed counts. This is not allowed'),
call. = FALSE)
}
Expand Down Expand Up @@ -243,7 +266,7 @@ add_nmixture = function(model_file,
'array[n, n_series] int latent_ypred;\n',
'array[total_obs] int latent_truncpred;\n',
'vector[n_nonmissing] flat_ps;\n',
'int flat_caps[n_nonmissing];',
'int flat_caps[n_nonmissing];\n',
'vector[total_obs] flat_trends;\n',
'vector[n_nonmissing] flat_trends_nonmis;\n',
'vector[total_obs] detprob;\n',
Expand Down
2 changes: 1 addition & 1 deletion R/get_mvgam_priors.R
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ get_mvgam_priors = function(formula,
}

# Remove sigma prior if this is an N-mixture with no dynamics
if(add_nmix){
if(add_nmix & trend_model == 'None'){
out <- out[-grep('vector<lower=0>[n_lv] sigma;',
out$param_name,
fixed = TRUE),]
Expand Down
13 changes: 9 additions & 4 deletions R/gp.R
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,15 @@ make_gp_additions = function(gp_details, data,

# Add coefficient indices to attribute table and to Stan data
for(covariate in seq_along(gp_att_table)){
# coef_indices <- grep(gp_att_table[[covariate]]$name,
# names(coef(mgcv_model)), fixed = TRUE)
coef_indices <- which(grepl(gp_att_table[[covariate]]$name,
names(coef(mgcv_model)), fixed = TRUE) &
# coef_indices <- which(grepl(gp_att_table[[covariate]]$name,
# names(coef(mgcv_model)), fixed = TRUE) &
# !grepl(paste0(gp_att_table[[covariate]]$name,':'),
# names(coef(mgcv_model)), fixed = TRUE) == TRUE)

coef_indices <- which(grepl(paste0(gsub("([()])","\\\\\\1",
gp_att_table[[covariate]]$name),
'\\.+[0-9]'),
names(coef(mgcv_model)), fixed = FALSE) &
!grepl(paste0(gp_att_table[[covariate]]$name,':'),
names(coef(mgcv_model)), fixed = TRUE) == TRUE)

Expand Down
5 changes: 2 additions & 3 deletions R/mvgam.R
Original file line number Diff line number Diff line change
Expand Up @@ -1718,8 +1718,7 @@ mvgam = function(formula,
# Auto-format the model file
if(autoformat){
if(requireNamespace('cmdstanr') & cmdstanr::cmdstan_version() >= "2.29.0") {
tmp_file <- cmdstanr::write_stan_file(vectorised$model_file)
vectorised$model_file <- .autoformat(tmp_file,
vectorised$model_file <- .autoformat(vectorised$model_file,
overwrite_file = FALSE)
}
vectorised$model_file <- readLines(textConnection(vectorised$model_file),
Expand Down Expand Up @@ -1908,7 +1907,7 @@ mvgam = function(formula,
cpp_options = list(stan_threads = TRUE))
} else {
cmd_mod <- cmdstanr::cmdstan_model(cmdstanr::write_stan_file(vectorised$model_file),
stanc_options = list('O1'))
stanc_options = list('O1'),)
}

} else {
Expand Down
Loading

0 comments on commit 678d003

Please sign in to comment.