Skip to content

Commit

Permalink
more efficient N-mixture models
Browse files Browse the repository at this point in the history
  • Loading branch information
Nicholas Clark committed Feb 9, 2024
1 parent d422646 commit 8df3963
Show file tree
Hide file tree
Showing 15 changed files with 177 additions and 194 deletions.
124 changes: 48 additions & 76 deletions R/add_nmixture.R
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,10 @@ add_nmixture = function(model_file,
"matrix[n, n_lv] LV;\n")
}

# Add functions for N-mixtures
# Update functions block
model_file <- add_nmix_functions(model_file,
trend_map,
nmix_trendmap)

# Update the data block
model_file[grep('int<lower=0> n_nonmissing; // number of nonmissing observations', model_file, fixed = TRUE)] <-
paste0("int<lower=0> n_nonmissing; // number of nonmissing observations\n",
Expand Down Expand Up @@ -107,7 +106,13 @@ add_nmixture = function(model_file,
# Update transformed data block
model_file[grep("transformed data {", model_file, fixed = TRUE)] <-
paste0("transformed data {\n",
"matrix[total_obs, num_basis] X_ordered = X[ytimes_array, : ];")
"matrix[total_obs, num_basis] X_ordered = X[ytimes_array, : ];\n",
"array[K_groups] int<lower=0> Y_max;\n",
"array[K_groups] int<lower=0> N_max;\n",
"for ( k in 1 : K_groups ) {\n",
"Y_max[k] = max(flat_ys[K_inds[k, K_starts[k] : K_stops[k]]]);\n",
"N_max[k] = max(cap[K_inds[k, K_starts[k] : K_stops[k]]]);\n",
"}")
model_file <- readLines(textConnection(model_file), n = -1)

# Update the transformed parameters block
Expand Down Expand Up @@ -251,17 +256,33 @@ add_nmix_data = function(model_data,
data.frame(shift_nas(group_mat))
}))

# A second version of K_inds is needed for later generation
# of properly-constrained latent N predictions; for this version,
# all observations must be included (no NAs)
K_inds_all <- dplyr::bind_rows(lapply(seq_len(NCOL(Z)), function(i){
factor_inds <- which(which_factor == i)
group_mat <- matrix(NA, nrow = model_data$n,
ncol = n_replicates[i])
for(j in 1:model_data$n){
group_mat[j, ] <- seq(factor_inds[j],
max(factor_inds),
by = model_data$n)
}
data.frame(group_mat)
}))

# Add starting and ending indices for each group to model_data
model_data$K_starts <- rep(1, NROW(K_inds))
model_data$K_stops <- length_reps(K_inds)

# Change any remaining NAs to 1 so they are integers
K_inds[is.na(K_inds)] <- 1

# Add reamining group information to the model_data
# Add remaining group information to the model_data
model_data$K_reps <- NCOL(K_inds)
model_data$K_groups <- NROW(K_inds)
model_data$K_inds <- as.matrix(K_inds)
model_data$K_inds_all <- as.matrix(K_inds_all)
model_data$ytimes_pred <- matrix(1:model_data$total_obs,
nrow = model_data$n,
byrow = FALSE)
Expand Down Expand Up @@ -326,12 +347,28 @@ add_nmix_model = function(model_file,
model_file <- readLines(textConnection(model_file), n = -1)

model_file[grep('flat_ys ~ poisson_log_glm(append_col(flat_xs, flat_trends),',
model_file, fixed = TRUE)] <- paste0('for (k in 1:K_groups){\n',
'target += pb_lpmf(flat_ys[K_inds[k, K_starts[k]:K_stops[k]]] |\n',
'cap[K_inds[k, K_starts[k]:K_stops[k]]],\n',
'flat_trends[K_inds[k, K_starts[k]:K_stops[k]]],\n',
'flat_ps[K_inds[k, K_starts[k]:K_stops[k]]]);\n',
'}')
model_file, fixed = TRUE)] <-
paste0('// loop over replicate sampling window (each site*time*species combination)\n',
'for ( k in 1 : K_groups ) {\n',
'// all log_lambdas are identical because they represent site*time\n',
'// covariates; so just use the first measurement\n',
'real log_lambda = flat_trends[K_inds[k, 1]];\n',
'vector[N_max[k] - Y_max[k] + 1] terms;\n',
'int l = 0;\n',
'// marginalize over latent abundance\n',
'for ( Ni in Y_max[k] : N_max[k] ) {\n',
'l = l + 1;\n',
'// factor for poisson prob of latent Ni; compute\n',
'// only once per sampling window\n',
'terms[l] = poisson_log_lpmf( Ni | log_lambda ) +\n',
'// for each replicate observation, binomial prob observed is\n',
'// computed in a vectorized statement\n',
'binomial_logit_lpmf( flat_ys[K_inds[k, K_starts[k] : K_stops[k]]] |\n',
'Ni,\n',
'flat_ps[K_inds[k, K_starts[k] : K_stops[k]]] );\n',
'}\n',
'target += log_sum_exp( terms );\n',
'}')
model_file <- model_file[-grep('0.0,append_row(b, 1.0));',
model_file, fixed = TRUE)]
model_file <- readLines(textConnection(model_file), n = -1)
Expand Down Expand Up @@ -365,72 +402,7 @@ add_nmix_functions = function(model_file,
trend_map,
nmix_trendmap){
if(nmix_trendmap){
# If trend_map supplied, we need array versions of nmixture functions
if(any(grepl('functions {', model_file, fixed = TRUE))){
model_file[grep('functions {', model_file, fixed = TRUE)] <-
paste0('functions {\n',
'/* Functions to return the log probability of a Poisson Binomial Mixture */\n',
'/* see Bollen et al 2023 for details (https://doi.org/10.1002/ece3.10595)*/\n',
'real poisbin_lpmf(array[] int count, int k, array[] real lambda, array[] real p) {\n',
'if (max(count) > k) {\n',
'return negative_infinity();\n',
'}\n',
'return poisson_log_lpmf(k | lambda) + binomial_logit_lupmf(count | k, p);\n',
'}\n',
'vector pb_logp(array[] int count, int max_k,\n',
'array[] real lambda, array[] real p) {\n',
'int c_max = max(count);\n',
'if (max_k < c_max)\n',
'reject("cap variable max_k must be >= observed counts");\n',
'vector[max_k + 1] lp;\n',
'for (k in 0:(c_max - 1))\n',
'lp[k + 1] = negative_infinity();\n',
'for (k in c_max:max_k)\n',
'lp[k + 1] = poisbin_lpmf(count | k, lambda, p);\n',
'return lp;\n',
'}\n',
'real pb_lpmf(array[] int count, array[] int max_k,\n',
'array[] real lambda, array[] real p) {\n',
'// Take maximum of all supplied caps, in case they vary for some reason\n',
'int max_k_max = max(max_k);\n',
'vector[max_k_max + 1] lp;\n',
'lp = pb_logp(count, max_k_max, lambda, p);\n',
'return log_sum_exp(lp);\n',
'}\n')
} else {
model_file[grep('Stan model code', model_file)] <-
paste0('// Stan model code generated by package mvgam\n',
'functions {\n',
'/* Functions to return the log probability of a Poisson Binomial Mixture */\n',
'/* see Bollen et al 2023 for details (https://doi.org/10.1002/ece3.10595)*/\n',
'real poisbin_lpmf(array[] int count, int k, array[] real lambda, array[] real p) {\n',
'if (max(count) > k) {\n',
'return negative_infinity();\n',
'}\n',
'return poisson_log_lpmf(k | lambda) + binomial_logit_lupmf(count | k, p);\n',
'}\n',
'vector pb_logp(array[] int count, int max_k,\n',
'array[] real lambda, array[] real p) {\n',
'int c_max = max(count);\n',
'if (max_k < c_max)\n',
'reject("cap variable max_k must be >= observed counts");\n',
'vector[max_k + 1] lp;\n',
'for (k in 0:(c_max - 1))\n',
'lp[k + 1] = negative_infinity();\n',
'for (k in c_max:max_k)\n',
'lp[k + 1] = poisbin_lpmf(count | k, lambda, p);\n',
'return lp;\n',
'}\n',
'real pb_lpmf(array[] int count, array[] int max_k,\n',
'array[] real lambda, array[] real p) {\n',
'// Take maximum of all supplied caps, in case they vary for some reason\n',
'int max_k_max = max(max_k);\n',
'vector[max_k_max + 1] lp;\n',
'lp = pb_logp(count, max_k_max, lambda, p);\n',
'return log_sum_exp(lp);\n',
'}\n',
'}\n')
}
# If trend_map supplied, no modifications needed
} else {
if(any(grepl('functions {', model_file, fixed = TRUE))){
model_file[grep('functions {', model_file, fixed = TRUE)] <-
Expand Down
27 changes: 26 additions & 1 deletion R/families.R
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,31 @@ beta_shapes = function(mu, phi) {
shape2 = (1 - mu) * phi))
}

# Calculate all possible Poisson log-densities for N-mixture simulation
#' @noRd
pois_dens = function(min_cap, max_cap, lambdas){
# Identify which indices share the exact same lambda AND
# k value so that we only need to run dpois once for each group
data.frame(lambdas,
min_cap,
max_cap) %>%
dplyr::group_by(lambdas) %>%
dplyr::summarise(min_cap = min(min_cap, na.rm = TRUE),
max_cap = max(max_cap, na.rm = TRUE)) -> group_inds

l <- mapply(`:`, group_inds$min_cap, group_inds$max_cap)

data.frame(k = unlist(l),
lambda = group_inds$lambdas[rep(1:nrow(group_inds),
lengths(l))]) %>%
dplyr::mutate(pois_dens = dpois(k,
lambda,
log = TRUE)) -> all_ks

return(all_ks)
}


#' Generic prediction function
#' @importFrom stats predict
#' @param Xp A `mgcv` linear predictor matrix
Expand Down Expand Up @@ -322,7 +347,7 @@ mvgam_predict = function(Xp,
lik <- exp(dbinom(truth[[i]], size = ks,
prob = p[[i]], log = TRUE) +
dpois(x = ks,
lambda = lambdas[[i]], log = TRUE))
lambda = lambdas[i], log = TRUE))
probs <- lik / sum(lik)
probs[!is.finite(probs)] <- 0
output <- ks[wrswoR::sample_int_ccrank(length(ks),
Expand Down
17 changes: 6 additions & 11 deletions R/mvgam.R
Original file line number Diff line number Diff line change
Expand Up @@ -1720,8 +1720,8 @@ mvgam = function(formula,
# to produce these in R after sampling has finished
param <- c(param, 'p')
param <- param[!param %in% c('ypred', 'mus', 'theta',
'Sigma', 'detprob', 'latent_ypred',
'lv_coefs', 'sigma', 'error')]
'detprob', 'latent_ypred',
'lv_coefs', 'error')]
}

# Tidy the representation
Expand Down Expand Up @@ -2114,11 +2114,6 @@ mvgam = function(formula,

stan_control <- list(max_treedepth = max_treedepth,
adapt_delta = adapt_delta)
if(save_all_pars){
pars <- NA
} else {
pars <- param
}

if(algorithm == 'sampling'){
if(parallel){
Expand All @@ -2132,7 +2127,7 @@ mvgam = function(formula,
verbose = FALSE,
thin = thin,
control = stan_control,
pars = pars,
pars = NA,
refresh = 100,
...)
} else {
Expand All @@ -2146,7 +2141,7 @@ mvgam = function(formula,
verbose = FALSE,
thin = thin,
control = stan_control,
pars = pars,
pars = NA,
refresh = 100,
...)
}
Expand All @@ -2158,7 +2153,7 @@ mvgam = function(formula,
output_samples = samples,
data = model_data,
algorithm = algorithm,
pars = pars,
pars = NA,
...)
}

Expand Down Expand Up @@ -2240,7 +2235,7 @@ mvgam = function(formula,
mgcv_model = ss_gam,
Z = model_data$Z,
n_lv = n_lv,
K_inds = model_data$K_inds)
K_inds = model_data$K_inds_all)
}

# Get Dunn-Smyth Residual distributions for each series if this
Expand Down
Loading

0 comments on commit 8df3963

Please sign in to comment.