Skip to content

Commit

Permalink
add bernoulli and binomial observation families
Browse files Browse the repository at this point in the history
  • Loading branch information
Nicholas Clark committed Mar 19, 2024
1 parent dc09f31 commit d4d1c51
Show file tree
Hide file tree
Showing 34 changed files with 1,313 additions and 642 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ importFrom(bayesplot,color_scheme_set)
importFrom(bayesplot,log_posterior)
importFrom(bayesplot,neff_ratio)
importFrom(bayesplot,nuts_params)
importFrom(brms,bernoulli)
importFrom(brms,conditional_effects)
importFrom(brms,dstudent_t)
importFrom(brms,get_prior)
Expand Down
130 changes: 130 additions & 0 deletions R/add_binomial.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
#' @noRd
add_binomial = function(formula,
model_file,
model_data,
data_train,
data_test,
family_char){


# Add trial information if necessary
if(family_char %in% c('binomial', 'beta_binomial')){
# Identify which variable in data represents the number of trials
resp_terms <- as.character(terms(formula(formula))[[2]])
resp_terms <- resp_terms[-grepl('cbind', resp_terms)]
trial_name <- resp_terms[2]

# Pull the trials variable from the data and validate
train_trials <- data_train[[trial_name]]

if(any(is.na(train_trials))){
stop(paste0('variable ', trial_name, ' contains missing values'),
call. = FALSE)
}

if(any(is.infinite(train_trials))){
stop(paste0('variable ', trial_name, ' contains infinite values'),
call. = FALSE)
}

# Matrix of trials per series
all_trials <- data.frame(series = as.numeric(data_train$series),
time = data_train$time,
trials = data_train[[trial_name]]) %>%
dplyr::arrange(time, series)

# Same for data_test
if(!is.null(data_test)){
if(!(exists(trial_name, where = data_test))) {
stop('Number of trials must also be supplied in "newdata" for Binomial models',
call. = FALSE)
}

all_trials <- rbind(all_trials,
data.frame(series = as.numeric(data_test$series),
time = data_test$time,
trials = data_test[[trial_name]])) %>%
dplyr::arrange(time, series)

if(any(is.na(all_trials$trial)) | any(is.infinite(all_trials$trial))){
stop(paste0('Missing or infinite values found in ', trial_name, ' variable'),
call. = FALSE)
}
}

# Construct matrix of N-trials in the correct format so it can be
# flattened into one long vector
trials <- matrix(NA, nrow = length(unique(all_trials$time)),
ncol = length(unique(all_trials$series)))
for(i in 1:length(unique(all_trials$series))){
trials[,i] <- all_trials$trials[which(all_trials$series == i)]
}

# Add trial info to the model data
model_data$flat_trials <- as.vector(trials)
model_data$flat_trials_train <- as.vector(trials)[which(as.vector(model_data$y_observed) == 1)]

# Add trial vectors to model block
model_file[grep("int<lower=0> flat_ys[n_nonmissing]; // flattened nonmissing observations",
model_file, fixed = TRUE)] <-
paste0("array[n_nonmissing] int<lower=0> flat_ys; // flattened nonmissing observations\n",
"array[total_obs] int<lower=0> flat_trials; // flattened trial vector\n",
"array[n_nonmissing] int<lower=0> flat_trials_train; // flattened nonmissing trial vector\n")
model_file <- readLines(textConnection(model_file), n = -1)
} else {
trials <- NULL
}

# Update model block
if(family_char == 'binomial'){
if(any(grepl("flat_ys ~ poisson_log_glm(flat_xs,",
model_file, fixed = TRUE))){
model_file[grep("flat_ys ~ poisson_log_glm(flat_xs,",
model_file, fixed = TRUE)] <-
"flat_ys ~ binomial_logit_glm(flat_trials_train, flat_xs,"
}

if(any(grepl("flat_ys ~ poisson_log_glm(append_col(flat_xs, flat_trends),",
model_file, fixed = TRUE))){
model_file[grep("flat_ys ~ poisson_log_glm(append_col(flat_xs, flat_trends),",
model_file, fixed = TRUE)] <-
"flat_ys ~ binomial_logit_glm(flat_trials_train, append_col(flat_xs, flat_trends),"
}
}

if(family_char == 'bernoulli'){
if(any(grepl("flat_ys ~ poisson_log_glm(flat_xs,",
model_file, fixed = TRUE))){
model_file[grep("flat_ys ~ poisson_log_glm(flat_xs,",
model_file, fixed = TRUE)] <-
"flat_ys ~ bernoulli_logit_glm(flat_xs,"
}

if(any(grepl("flat_ys ~ poisson_log_glm(append_col(flat_xs, flat_trends),",
model_file, fixed = TRUE))){
model_file[grep("flat_ys ~ poisson_log_glm(append_col(flat_xs, flat_trends),",
model_file, fixed = TRUE)] <-
"flat_ys ~ bernoulli_logit_glm(append_col(flat_xs, flat_trends),"
}

}

# Update the generated quantities block
if(family_char == 'binomial'){
model_file[grep("ypred[1:n, s] = poisson_log_rng(mus[1:n, s]);" ,
model_file, fixed = TRUE)] <-
"ypred[1:n, s] = binomial_rng(flat_trials[ytimes[1:n, s]], inv_logit(mus[1:n, s]));"
}

if(family_char == 'bernoulli'){
model_file[grep("ypred[1:n, s] = poisson_log_rng(mus[1:n, s]);" ,
model_file, fixed = TRUE)] <-
"ypred[1:n, s] = bernoulli_logit_rng(mus[1:n, s]);"
}

#### Return ####
return(list(model_file = model_file,
model_data = model_data,
trials = trials))

}
146 changes: 0 additions & 146 deletions R/dynamic.R
Original file line number Diff line number Diff line change
Expand Up @@ -127,149 +127,3 @@ dynamic = function(variable, k, rho = 5, stationary = TRUE, scale = TRUE){
return(out)
}

#' Interpret the formula specified to mvgam and replace any dynamic terms
#' with the correct Gaussian Process smooth specification
#' @importFrom stats formula terms as.formula terms.formula
#' @noRd
interpret_mvgam = function(formula, N){

facs <- colnames(attr(terms.formula(formula), 'factors'))

# Check if formula has an intercept
keep_intercept <- attr(terms(formula), 'intercept') == 1

# Re-arrange so that random effects always come last
if(any(grepl('bs = \"re\"', facs, fixed = TRUE))){
newfacs <- facs[!grepl('bs = \"re\"', facs, fixed = TRUE)]
refacs <- facs[grepl('bs = \"re\"', facs, fixed = TRUE)]
int <- attr(terms.formula(formula), 'intercept')

# Preserve offset if included
if(!is.null(attr(terms(formula(formula)), 'offset'))){
newformula <- as.formula(paste(terms.formula(formula)[[2]], '~',
grep('offset', rownames(attr(terms.formula(formula), 'factors')),
value = TRUE),
'+',
paste(paste(newfacs, collapse = '+'),
'+',
paste(refacs, collapse = '+'),
collapse = '+'),
ifelse(int == 0, ' - 1', '')))

} else {
newformula <- as.formula(paste(terms.formula(formula)[[2]], '~',
paste(paste(newfacs, collapse = '+'),
'+',
paste(refacs, collapse = '+'),
collapse = '+'),
ifelse(int == 0, ' - 1', '')))
}

} else {
newformula <- formula
}

attr(newformula, '.Environment') <- attr(formula, '.Environment')

# Check if any terms use the gp wrapper, as mvgam cannot handle
# multivariate GPs yet
response <- terms.formula(newformula)[[2]]
tf <- terms.formula(newformula, specials = c("gp"))
which_gp <- attr(tf,"specials")$gp
if(length(which_gp) != 0L){
gp_details <- vector(length = length(which_gp),
mode = 'list')
for(i in seq_along(which_gp)){
gp_details[[i]] <- eval(parse(text = rownames(attr(tf,
"factors"))[which_gp[i]]))
}
if(any(unlist(lapply(purrr::map(gp_details, 'term'), length)) > 1)){
stop('mvgam cannot yet handle multidimensional gps',
call. = FALSE)
}
}

# Check if any terms use the dynamic wrapper
response <- terms.formula(newformula)[[2]]
tf <- attr(terms.formula(newformula, keep.order = TRUE),
'term.labels')
which_dynamics <- grep('dynamic(', tf, fixed = TRUE)

# Update the formula to the correct Gaussian Process implementation
if(length(which_dynamics) != 0L){
dyn_details <- vector(length = length(which_dynamics),
mode = 'list')
if(length(which_dynamics > 1)){
for(i in seq_along(which_dynamics)){
dyn_details[[i]] <- eval(parse(text = tf[which_dynamics[i]]))
}
}

# k is set based on the number of timepoints available; want to ensure
# it is large enough to capture the expected wiggliness of the latent GP
# (a smaller rho will require more basis functions for accurate approximation)
dyn_to_gpspline = function(term, N){

if(term$rho > N - 1){
stop('Argument "rho" in dynamic() cannot be larger than (max(time) - 1)',
call. = FALSE)
}

k <- term$k
if(is.null(k)){
if(N > 8){
k <- min(50, min(N, max(8, ceiling(N / (term$rho - (term$rho / 10))))))
} else {
k <- N
}
}

paste0("s(time,by=", term$term,
",bs='gp',m=c(",
ifelse(term$stationary, '-', ''),"2,",
term$rho, ",2),",
"k=", k, ")")
}

dyn_to_gphilbert = function(term, N){

k <- term$k
if(is.null(k)){
if(N > 8){
k <- min(40, min(N - 1, max(8, N - 1)))
} else {
k <- N - 1
}
}

paste0("gp(time,by=", term$term,
",c=5/4,",
"k=", k, ",scale=",
term$scale,
")")
}
# Replace dynamic terms with the correct specification
termlabs <- attr(terms(newformula, keep.order = TRUE), 'term.labels')
for(i in seq_along(which_dynamics)){
if(is.null(dyn_details[[i]]$rho)){
termlabs[which_dynamics[i]] <- dyn_to_gphilbert(dyn_details[[i]], N = N)
} else {
termlabs[which_dynamics[i]] <- dyn_to_gpspline(dyn_details[[i]], N = N)
}
}

# Return the updated formula for passing to mgcv
updated_formula <- reformulate(termlabs, rlang::f_lhs(newformula))
attr(updated_formula, '.Environment') <- attr(newformula, '.Environment')

} else {
updated_formula <- newformula
}

if(!keep_intercept){
updated_formula <- update(updated_formula, . ~ . - 1)
attr(updated_formula, '.Environment') <- attr(newformula, '.Environment')
}

return(updated_formula)
}
Loading

0 comments on commit d4d1c51

Please sign in to comment.