Skip to content

Commit

Permalink
tidy.mvgam() WIP
Browse files Browse the repository at this point in the history
draft method
  • Loading branch information
swpease committed Jan 10, 2025
1 parent bd20133 commit 6fde2fd
Showing 1 changed file with 203 additions and 0 deletions.
203 changes: 203 additions & 0 deletions R/tidier_methods.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,210 @@
#' @importFrom generics tidy
#' @export
generics::tidy

#' @importFrom generics augment
#' @export
generics::augment


#' List of observation families and any extra parameters.
#' Wrapped in a function for testing purposes.
#' @noRd
# extra_params <- function() {
# list(
# 'negative binomial' = c('phi'),
# 'beta_binomial' = c('phi'),
# 'beta' = c('phi'),
# 'tweedie' = c('phi'),
# 'gaussian' = c('sigma_obs'),
# 'student' = c('sigma_obs', 'nu'),
# 'lognormal' = c('sigma_obs'),
# 'Gamma' = c('shape'),
# 'poisson' = c(),
# 'binomial' = c(),
# 'bernoulli' = c(),
# 'nmix' = c()
# )
# }

# TODO: name cols based on `get_mvgam_priors()$param_info`?
#' @export
tidy.mvgam <- function(x, probs = c(0.025, 0.5, 0.975), ...) {
object <- x # For consistency with `summary.mvgam()`
obj_vars <- variables(object)
digits <- 2 # TODO: Let user change?
partialized_mcmc_summary <- purrr::partial(mcmc_summary,
object$model_output,
... =,
ISB = FALSE, # Matches `x[i]`'s rather than `x`.
probs = probs,
digits = digits,
Rhat = FALSE,
n.eff = FALSE)
out <- tibble::tibble()

# Observation family extra parameters

# extra_params <- extra_params()
# for (xp in extra_params[[object$family]]) {
# extra_params_out <- mcmc_summary(object$model_output,
# params = xp,
# digits = digits,
# variational = variational)
# out <- dplyr::bind_rows(out, extra_params_out)
# }

# Alt implementation of Observation family extra parameters
xp_names_all <- obj_vars$observation_pars$orig_name
xp_names <- grep("vec", xp_names_all, value = TRUE, invert = TRUE)
if (!is.null(xp_names)) {
extra_params_out <- partialized_mcmc_summary(params = xp_names)
extra_params_out <- tibble::add_column(extra_params_out,
param_type = "obs_fam_extra_param",
.before = 1)
out <- dplyr::bind_rows(out, extra_params_out)
}
# END Alt implementation
# END Observation family extra parameters


# obs non-smoother betas
if (object$mgcv_model$nsdf > 0) {
obs_beta_name_map <- dplyr::slice_head(obj_vars$observation_betas, n = object$mgcv_model$nsdf) # df("orig_name", "alias")
obs_betas_out <- partialized_mcmc_summary(params = obs_beta_name_map$orig_name)
row.names(obs_betas_out) <- obs_beta_name_map$alias
obs_betas_out <- tibble::add_column(obs_betas_out,
param_type = "observation_beta",
.before = 1)
out <- dplyr::bind_rows(out, obs_betas_out)
}
# END obs non-smoother betas


# random effects
# TODO: include specific s(re).[n] intercepts?
# TODO: random slopes' names? obj$mgcv_model$smooth$label?
re_param_name_map <- obj_vars$observation_re_params
if (!is.null(re_param_name_map)) {
re_params_out <- partialized_mcmc_summary(params = re_param_name_map$orig_name)
row.names(re_params_out) <- re_param_name_map$alias
re_params_out <- tibble::add_column(re_params_out,
param_type = "random effect (group-level)",
.before = 1)
out <- dplyr::bind_rows(out, re_params_out)
}
# END random effects -----------

# GPs
if (!is.null(obj_vars$trend_pars)) {
tm_param_names_all <- obj_vars$trend_pars$orig_name
gp_param_names <- grep("^alpha_gp|^rho_gp", tm_param_names_all, value = TRUE)
if (length(gp_param_names) > 0) {
gp_params_out <- partialized_mcmc_summary(params = gp_param_names)
# where is GP? can be in formula, trend_formula, or trend_model
if (grepl("^(alpha|rho)_gp_trend", gp_param_names[[1]])) {
param_type = "trend_formula_param"
} else if (grepl("^(alpha|rho)_gp_", gp_param_names[[1]])) { # hmph.
param_type = "observation_param"
} else {
param_type = "trend_model_param"
}
gp_params_out <- tibble::add_column(gp_params_out,
param_type = param_type,
.before = 1)
out <- dplyr::bind_rows(out, gp_params_out)
}
}
# END GPs --------------

# RW, AR, CAR, VAR
# TODO: split out Sigma for heircor?
# str vs called obj as arg to mvgam
# TODO: move trend_model_name up?
trend_model_name <- ifelse(inherits(object$trend_model, "mvgam_trend"),
object$trend_model$trend_model,
object$trend_model)
if (grepl("^VAR|^CAR|^AR|^RW|^ZMVN", trend_model_name)) {
# theta = MA terms
# alpha_cor = heirarchical corr term
# A = VAR auto-regressive matrix
# Sigma = correlated errors matrix
# sigma = errors

# setting up the params to extract
if (trend_model_name == "VAR") {
trend_model_params <- c("^A\\[", "^alpha_cor", "^theta", "^Sigma")
} else if (grepl("^CAR|^AR|^RW", trend_model_name)) {
cor = inherits(object$trend_model, "mvgam_trend") && object$trend_model$cor
sigma_name <- ifelse(cor, "^Sigma", "^sigma")
trend_model_params <- c("^ar", "^alpha_cor", "^theta", sigma_name)
} else if (grepl("^ZMVN", trend_model_name)) {
trend_model_params <- c("^alpha_cor", "^Sigma")
}

# extracting the params
trend_model_params <- paste(trend_model_params, collapse = "|")
tm_param_names_all <- obj_vars$trend_pars$orig_name
tm_param_names <- grep(trend_model_params, tm_param_names_all, value = TRUE)
tm_params_out <- partialized_mcmc_summary(params = tm_param_names)
tm_params_out <- tibble::add_column(tm_params_out,
param_type = "trend_model_param",
.before = 1)
out <- dplyr::bind_rows(out, tm_params_out)
}
# END RW, AR, CAR, VAR-----------

# Piecewise
# TODO: potentially lump into AR section, above; how to handle change points?
# to lump in, just add an
# `else if (grepl("^PW", trend_model_name)`, then
# `trend_model_params <- c("^k_trend", "^m_trend", "^delta_trend")`
# and change initial grep(ar car var) call
if (grepl("^PW", trend_model_name)) {
trend_model_params <- "^k_trend|^m_trend|^delta_trend"
tm_param_names_all <- obj_vars$trend_pars$orig_name
tm_param_names <- grep(trend_model_params, tm_param_names_all, value = TRUE)
tm_params_out <- partialized_mcmc_summary(params = tm_param_names)
tm_params_out <- tibble::add_column(tm_params_out,
param_type = "trend_model_param",
.before = 1)
out <- dplyr::bind_rows(out, tm_params_out)
}
# END Piecewise ------------

# Trend formula betas
if (!is.null(object$trend_call) && object$trend_mgcv_model$nsdf > 0) {
trend_beta_name_map <- dplyr::slice_head(obj_vars$trend_betas,
n = object$trend_mgcv_model$nsdf) # df("orig_name", "alias")
trend_betas_out <- partialized_mcmc_summary(params = trend_beta_name_map$orig_name)
row.names(trend_betas_out) <- trend_beta_name_map$alias
trend_betas_out <- tibble::add_column(trend_betas_out,
param_type = "trend_beta",
.before = 1)
out <- dplyr::bind_rows(out, trend_betas_out)
}
# END Trend formula betas ----------

# trend random effects
# TODO: include specific s(re).[n] intercepts?
trend_re_param_name_map <- obj_vars$trend_re_params
if (!is.null(trend_re_param_name_map)) {
trend_re_params_out <- partialized_mcmc_summary(params = trend_re_param_name_map$orig_name)
row.names(trend_re_params_out) <- trend_re_param_name_map$alias
trend_re_params_out <- tibble::add_column(trend_re_params_out,
param_type = "trend random effect (group-level)",
.before = 1)
out <- dplyr::bind_rows(out, trend_re_params_out)
}
# END tremd random effects -----------

# OUTPUT
# TODO: might need to put this prior to every bind_rows to avoid rowname dups.
out <- tibble::rownames_to_column(out, "parameter")
out
}


#' Augment an mvgam object's data
#'
#' Add fits and residuals to the data, implementing the generic `augment` from
Expand Down

0 comments on commit 6fde2fd

Please sign in to comment.