Skip to content

Commit

Permalink
update to automatically generate nmix predictions
Browse files Browse the repository at this point in the history
  • Loading branch information
Nicholas Clark committed Feb 10, 2024
1 parent 8df3963 commit 9144ffe
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 11 deletions.
67 changes: 56 additions & 11 deletions R/add_nmixture.R
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ add_nmixture = function(model_file,
}

# Update transformed data block
if(nmix_trendmap){
model_file[grep("transformed data {", model_file, fixed = TRUE)] <-
paste0("transformed data {\n",
"matrix[total_obs, num_basis] X_ordered = X[ytimes_array, : ];\n",
Expand All @@ -113,6 +114,11 @@ add_nmixture = function(model_file,
"Y_max[k] = max(flat_ys[K_inds[k, K_starts[k] : K_stops[k]]]);\n",
"N_max[k] = max(cap[K_inds[k, K_starts[k] : K_stops[k]]]);\n",
"}")
} else {
model_file[grep("transformed data {", model_file, fixed = TRUE)] <-
paste0("transformed data {\n",
"matrix[total_obs, num_basis] X_ordered = X[ytimes_array, : ];")
}
model_file <- readLines(textConnection(model_file), n = -1)

# Update the transformed parameters block
Expand Down Expand Up @@ -471,7 +477,10 @@ add_nmix_functions = function(model_file,
#' Function to add generated quantities for nmixture models, which
#' saves huge computational time
#' @noRd
add_nmix_posterior = function(model_output, obs_data, mgcv_model,
add_nmix_posterior = function(model_output,
obs_data,
test_data,
mgcv_model,
n_lv, Z, K_inds){

# Function to add samples to the 'sim' slot of a stanfit object
Expand Down Expand Up @@ -504,24 +513,57 @@ add_nmix_posterior = function(model_output, obs_data, mgcv_model,
# Construct latent_ypred samples
Xp <- obs_Xp_matrix(newdata = obs_data,
mgcv_model = mgcv_model)
if(!is.null(test_data)){
offset_obs <- attr(Xp, 'model.offset')
Xp_test <- obs_Xp_matrix(newdata = test_data,
mgcv_model = mgcv_model)
offset_test <- attr(Xp_test, 'model.offset')
Xp <- rbind(Xp, Xp_test)
attr(Xp, 'model.offset') <- c(offset_obs,
offset_test)
}

betas <- mcmc_chains(model_output, 'b')
all_linpreds <- as.matrix(as.vector(t(apply(as.matrix(betas), 1,
function(row) Xp %*% row +
attr(Xp, 'model.offset')))))
attr(all_linpreds, 'model.offset') <- 0
cap <- data.frame(time = obs_data$time,
series = obs_data$series,
cap = obs_data$cap) %>%
dplyr::arrange(series, time) %>%
dplyr::pull(cap)
if(!is.null(test_data)){
cap <- rbind(data.frame(time = obs_data$time,
series = obs_data$series,
cap = obs_data$cap),
data.frame(time = test_data$time,
series = test_data$series,
cap = test_data$cap))%>%
dplyr::arrange(series, time) %>%
dplyr::pull(cap)
} else {
cap <- data.frame(time = obs_data$time,
series = obs_data$series,
cap = obs_data$cap) %>%
dplyr::arrange(series, time) %>%
dplyr::pull(cap)
}

cap <- as.vector(t(replicate(NROW(betas), cap)))

# Unconditional latent_N predictions
truth <- data.frame(time = obs_data$time,
series = obs_data$series,
y = obs_data$y) %>%
dplyr::arrange(series, time) %>%
dplyr::pull(y)
if(!is.null(test_data)){
truth <- rbind(data.frame(time = obs_data$time,
series = obs_data$series,
y = obs_data$y),
data.frame(time = test_data$time,
series = test_data$series,
y = test_data$y)) %>%
dplyr::arrange(series, time) %>%
dplyr::pull(y)
} else {
truth <- data.frame(time = obs_data$time,
series = obs_data$series,
y = obs_data$y) %>%
dplyr::arrange(series, time) %>%
dplyr::pull(y)
}

get_min_cap = function(truth, K_inds){
rowgroup = function(x){
Expand All @@ -537,6 +579,9 @@ add_nmix_posterior = function(model_output, obs_data, mgcv_model,
dplyr::mutate(min_cap = max(truth, na.rm = TRUE)) %>%
dplyr::pull(min_cap)
}
if(is.null(K_inds)){
K_inds <- matrix(1:length(truth), ncol = 1)
}
min_cap <- suppressWarnings(get_min_cap(truth, K_inds))
min_cap[!is.finite(min_cap)] <- 0
truth <- as.vector(t(replicate(NROW(betas), truth)))
Expand Down
1 change: 1 addition & 0 deletions R/mvgam.R
Original file line number Diff line number Diff line change
Expand Up @@ -2232,6 +2232,7 @@ mvgam = function(formula,
if(family_char == 'nmix'){
out_gam_mod <- add_nmix_posterior(model_output = out_gam_mod,
obs_data = data_train,
test_data = data_test,
mgcv_model = ss_gam,
Z = model_data$Z,
n_lv = n_lv,
Expand Down
Binary file modified src/mvgam.dll
Binary file not shown.
Binary file modified tests/testthat/Rplots.pdf
Binary file not shown.

0 comments on commit 9144ffe

Please sign in to comment.