From 7a77ab4ecc9a697f34aa0f09920e758c75e12d15 Mon Sep 17 00:00:00 2001 From: "Joshua F. Wiley" Date: Sun, 24 Mar 2024 09:32:13 +1100 Subject: [PATCH] added functions to plot random effects from brmsfit object --- DESCRIPTION | 7 +- R/ranef.r | 315 +++++++++++++++++++++++++++++++++++++++------------- 2 files changed, 241 insertions(+), 81 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 41ec4f3..9061fb4 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,7 +1,7 @@ Package: multilevelTools Title: Multilevel and Mixed Effects Model Diagnostics and Effect Sizes -Version: 0.1.3 -Date: 2023-02-15 +Version: 0.2.0 +Date: 2024-03-24 Authors@R: person(given = "Joshua F.", family = "Wiley", @@ -27,7 +27,8 @@ Imports: cowplot, ggplot2, lavaan, - zoo + zoo, + brms Suggests: testthat (>= 3.1.0), covr, diff --git a/R/ranef.r b/R/ranef.r index c6ea6d3..edce2c3 100644 --- a/R/ranef.r +++ b/R/ranef.r @@ -1,3 +1,231 @@ +#' Convert ranef() output to long format +#' @rdname ranef2long +#' +#' @param x A \code{brmsfit} object +#' @param idvar A character string specifying the grouping variable name for the +#' random effects. +#' @importFrom nlme ranef +#' @return A data.table object with the random effects in long format. +#' @keywords internal +ranef2long <- function(x, idvar) { + x <- ranef(x, summary = FALSE)[[idvar]] + for (i in seq_along(dimnames(x)[[3]])) { + if (i == 1) { + out <- .re.data(x, i, idvar) + } else { + tmp <- .re.data(x, i, idvar) + out <- merge(out, tmp, by = c(idvar, "iteration")) + } + } + return(out) +} + +#' @rdname ranef2long +#' @param d A \code{ranef} object +#' @param i an integer, which random effect to pull out +.re.data <- function(d, i, idvar) { + xw <- as.data.table(t(d[, , i])) + xw[, (idvar) := dimnames(d)[[2]]] + xlong <- melt(xw, id.vars = idvar, + value.name = dimnames(d)[[3]][i], + variable.name = "iteration") + xlong[, iteration := as.integer(iteration)] + return(xlong) +} + +#' @name Convert re.data() output to a dataset for plotting +#' +#' @param data A data.table object, long format random effects +#' @param var A character string, the name of the random effect to plot +#' @return A data.table object with the random effect in a format suitable for plotting +#' @keywords internal +.make.replotdat <- function(data, var) { + out <- data[, .( + Estimate = mean(get(var)), + Q2.5 = quantile(get(var), .025), + Q97.5 = quantile(get(var), .975)), + by = ID][order(Estimate)] + out[, ID := factor(ID, levels = ID)] + return(out) +} + +#' Create data and plots for brms random effect models +#' +#' @param object a \code{brmsfit} objectx +#' @param usevars a character vector of random effects to plot +#' @param newdata a data.table object with the data used to generate the random effects, this is used as an anchor for the random intercepts so they have a meaningful 0 point +#' @param idvar a character string specifying the grouping variable name for the random effects +#' @return a list with the following components: +#' \itemize{ +#' \item{plot}{a list of ggplot objects} +#' \item{plotdat}{a list of data.table objects with the data used to generate the plots} +#' \item{relong}{a data.table object with the random effects in long format} +#' \item{yhat}{a list of data.table objects with the expected values for the random effects} +#' \item{usevars}{a character vector of the random effects to plot} +#' \item{idvar}{a character string specifying the grouping variable name for the random effects} +#' } +#' @importFrom nlme ranef fixef +#' @importFrom brms posterior_epred +#' @importFrom testthat expect_true +#' @importFrom extraoperators %ain% +#' @importFrom ggplot2 ggplot aes annotate geom_hex geom_pointrange geom_hline ggtitle +#' @importFrom ggplot2 stat_smooth scale_fill_continuous theme_classic theme axis.text.y coord_flip +#' @export +ranefdata <- function(object, usevars, newdata, idvar) { + intercept <- grepl("Intercept", usevars) + + dpars <- family(object)$dpars + ## only count some as "sigma" params if sigma is in dpars for the model + ## family, this is to prevent catching variables named sigma + if (isTRUE("sigma" %in% dpars)) { + sigma <- grepl("^sigma.*$", usevars) + } else { + sigma <- rep(FALSE, length(usevars)) + } + + fes <- fixef(object) + fes <- cbind(vars = rownames(fes), as.data.table(fes)) + setkey(fes, vars) + + expect_true(usevars[!intercept] %ain% fes[, vars]) + + res <- ranef(object, summary = FALSE)[[idvar]] + relong <- ranef2long(object, idvar) + + expect_true(usevars %ain% names(relong)) ## all usevars in the randome effects + + yhat <- vector("list", length(usevars)) + names(yhat) <- usevars + + for (i in seq_along(usevars)) { + if (isTRUE(intercept[i])) { + if (isTRUE(sigma[i])) { + yhat[[i]] <- as.data.table(posterior_summary(posterior_epred( + object, + newdata = newdata, re_formula = NA, + dpar = "sigma" + ))) + relong[, (usevars[i]) := exp( + get(usevars[i]) + + log(yhat[[i]][1, Estimate]) + )] + } else if (isFALSE(sigma[i])) { + yhat[[i]] <- as.data.table(posterior_summary(posterior_epred( + object, + newdata = newdata, re_formula = NA, + dpar = NULL + ))) + relong[, (usevars[i]) := ( + get(usevars[i]) + + (yhat[[i]][1, Estimate]))] + } + } else { + yhat[[i]] <- fes[vars == usevars[i]] + yhat[[i]][, vars := NULL] + relong[, (usevars[i]) := ( + get(usevars[i]) + + (yhat[[i]][1, Estimate]))] + } + } + + plot <- plotdat <- vector("list", length(usevars)) + names(plot) <- names(plotdat) <- usevars + + for (i in seq_along(usevars)) { + plotdat[[i]] <- .make.replotdat(relong, usevars[i]) + tmpplot <- ggplot(plotdat[[i]], aes(ID, Estimate, ymin = Q2.5, ymax = Q97.5)) + + annotate("rect", + xmin = -Inf, xmax = Inf, + ymin = yhat[[i]][1, Q2.5], ymax = yhat[[i]][1, Q97.5], + fill = "grey80" + ) + + geom_hline(yintercept = yhat[[i]][1, Estimate], linetype = "dashed") + + geom_pointrange() + + if (sigma[i] & intercept[i]) { + tmpplot <- tmpplot + + scale_y_continuous(trans = log_trans(), + breaks = breaks_log(n = 10, base = exp(1)), + labels = label_math(e^.x, format = log)) + } + tmpplot <- tmpplot + + xlab("ID") + + ylab(usevars[i]) + + theme_classic() + + theme(axis.text.y = element_blank()) + + coord_flip() + + plot[[i]] <- tmpplot + } + + + ## make scatter plots for the random effects + ## we only want these for all unique pairs of random effects + + ## make all pairs of usevars + tmp <- as.data.table(expand.grid( + A = usevars, B = usevars, + stringsAsFactors = FALSE)) + + ## remove any pairs that are duplicated or only reversed + tmp <- tmp[duplicated(apply(tmp, 1, function(x) { + paste(sort(x), collapse = "") + }))][A != B] + + tmp[, interceptA := grepl("Intercept", A)] + tmp[, interceptB := grepl("Intercept", B)] + + ## only count some as "sigma" params if sigma is in dpars for the model + ## family, this is to prevent catching variables named sigma + if (isTRUE("sigma" %in% dpars)) { + tmp[, sigmaA := grepl("^sigma.*$", A)] + tmp[, sigmaB := grepl("^sigma.*$", B)] + } else { + tmp[, sigmaA := rep(FALSE, .N)] + tmp[, sigmaB := rep(FALSE, .N)] + } + + scatterplot <- vector("list", nrow(tmp)) + + for (i in seq_len(nrow(tmp))) { + tmpdat <- copy(relong[, .(x = get(tmp[i, A]), y = get(tmp[i, B]))]) + + tmpplot <- ggplot(tmpdat, aes(x = x, y = y)) + + geom_hex(show.legend = FALSE) + + stat_smooth(method = "lm", formula = y ~ x, se = FALSE, + colour = "white", linewidth = 2) + + scale_fill_continuous(type = "viridis") + + theme_classic() + + xlab(tmp[i, A]) + ylab(tmp[i, B]) + + if (tmp[i, interceptA] & tmp[i, sigmaA]) { + tmpplot <- tmpplot + + scale_x_continuous(trans = log_trans(), + breaks = breaks_log(n = 10, base = exp(1)), + labels = label_math(e^.x, format = log)) + } + if (tmp[i, interceptB] & tmp[i, sigmaB]) { + tmpplot <- tmpplot + + scale_y_continuous(trans = log_trans(), + breaks = breaks_log(n = 10, base = exp(1)), + labels = label_math(e^.x, format = log)) + } + scatterplot[[i]] <- tmpplot + } + + out <- list( + replots = plot, + scatterplots = scatterplot, + replotdat = plotdat, + relong = relong, + yhat = yhat, + usevars = usevars, + idvar = idvar + ) + return(out) +} + + if (FALSE) { library(knitr) library(data.table) @@ -7,6 +235,7 @@ library(mice) library(miceadds) library(ggplot2) library(bayesplot) +library(testthat) dmixed <- withr::with_seed( seed = 12345, code = { @@ -34,6 +263,7 @@ dmixed <- withr::with_seed( copy(dmixed) }) + ls.me <- brm(bf( y ~ 1 + x + (1 + x | p | ID), sigma ~ 1 + x + (1 + x | p | ID)), @@ -42,84 +272,13 @@ ls.me <- brm(bf( silent = 2, refresh = 0, iter = 4000, warmup = 1000, thin = 3, chains = 4L, cores = 4L, backend = "cmdstanr") -.re.data <- function(x, i) { - xw <- as.data.table(t(x[, , i])) - xw[, ID := dimnames(x)[[2]]] - xlong <- melt(xw, id.vars = "ID", - value.name = dimnames(x)[[3]][i], - variable.name = ".imp") - xlong[, .imp := as.integer(.imp)] - return(xlong) -} -re.data <- function(x) { - for (i in seq_along(dimnames(x)[[3]])) { - if (i == 1) { - out <- .re.data(x, i) - } else { - tmp <- .re.data(x, i) - out <- merge(out, tmp, by = c("ID", ".imp")) - } - } - return(out) -} - -x <- ranef(ls.me, summary = FALSE)$ID - -xlong <- re.data(x) - - -ggplot(xlong, aes(x = Intercept, y = x)) + - geom_hex(show.legend = FALSE) + - stat_smooth(method = "lm", formula = y ~ x, se = FALSE, colour = "white", linewidth = 2) + - scale_fill_continuous(type = "viridis") + - theme_minimal() + - xlab("Random Intercept") + ylab("Random 'x' Slope") - -i <- 3 -names(x[1, , i]) -xw <- as.data.table(t(x[, , i])) -xw[, ID := dimnames(x)[[2]]] -xlong <- melt(xw, id.vars = "ID", value.name = dimnames(x)[[3]][i], variable.name = ".imp") - -dmixed2 <- dmixed[, .(MX = sample(x, 1)), by = ID] -dmixed2[, ID := as.character(ID)] - -xlong2 <- merge(xlong, dmixed2, by = "ID") -xlong2[, .imp := as.integer(.imp) - 1] - -xlong3 <- as.mids(xlong2, .imp = ".imp", .id = "ID") - -micombine.cor(xlong3, variables = 1:2) -xlong2[, .(r = cor(Intercept, MX)), by = .imp][, atanh(mean(tanh(r)))] - -dim(t(x[, , 1])) - - - -i <- 4 -yw <- as.data.table(t(x[, , i])) -yw[, ID := dimnames(x)[[2]]] -ylong <- melt(yw, id.vars = "ID", value.name = dimnames(x)[[3]][i], variable.name = ".imp") - -ylong[, .imp := as.integer(.imp) - 1] - -ylong2 <- merge(ylong, xlong2, by = c("ID", ".imp")) - -ggplot(ylong2, aes(x = sigma_x, y = sigma_Intercept)) + - geom_hex() + - stat_smooth(method = "lm", se = FALSE, colour = "white", linewidth = 2) + - theme_minimal() - -ylong2[, .(r = cor(sigma_x, sigma_Intercept))] - -ylong2[, .(r = cor(sigma_x, sigma_Intercept)), by = ID][, tanh(mean(atanh(r)))] - -ylong2[, .(r = cor(x, Intercept)), by = ID][, mean(r)] - +out <- ranefdata( + ls.me, + newdata = data.table(ID = dmixed$ID[1], x = 0), + usevars = c("Intercept", "x", "sigma_Intercept", "sigma_x"), + idvar = "ID") -summary(ls.me) +do.call(ggarrange, c(out$replots, ncol=2,nrow=2)) +do.call(ggarrange, c(out$scatterplots, ncol=2,nrow=3)) -mcmc_hex(ls.me, pars = c("sd_ID__Intercept", "sd_ID__x")) + - xlab("Random Intercept") + - ylab("Random Slope") } \ No newline at end of file