From 9d375b85abc01878756832300f04eb70c74f6f88 Mon Sep 17 00:00:00 2001 From: n-kall Date: Wed, 11 Oct 2023 15:07:05 +0300 Subject: [PATCH 01/13] add nested R-hat convergence diagnostic default method --- NAMESPACE | 2 + R/convergence.R | 1 + R/nested_rhat.R | 101 +++++++++++++++++++ man-roxygen/ref-margossian-nestedrhat-2023.R | 5 + man/diagnostics.Rd | 1 + man/ess_basic.Rd | 1 + man/ess_bulk.Rd | 1 + man/ess_quantile.Rd | 1 + man/ess_sd.Rd | 1 + man/ess_tail.Rd | 1 + man/mcse_mean.Rd | 1 + man/mcse_quantile.Rd | 1 + man/mcse_sd.Rd | 1 + man/rhat.Rd | 1 + man/rhat_basic.Rd | 1 + man/rstar.Rd | 1 + 16 files changed, 121 insertions(+) create mode 100644 R/nested_rhat.R create mode 100644 man-roxygen/ref-margossian-nestedrhat-2023.R diff --git a/NAMESPACE b/NAMESPACE index ace458d1..ea676bf3 100755 --- a/NAMESPACE +++ b/NAMESPACE @@ -273,6 +273,7 @@ S3method(rhat,default) S3method(rhat,rvar) S3method(rhat_basic,default) S3method(rhat_basic,rvar) +S3method(rhat_nested,default) S3method(sd,default) S3method(sd,rvar) S3method(split_chains,draws) @@ -466,6 +467,7 @@ export(reserved_variables) export(rfun) export(rhat) export(rhat_basic) +export(rhat_nested) export(rstar) export(rvar) export(rvar_all) diff --git a/R/convergence.R b/R/convergence.R index 01c7cd16..fa531193 100644 --- a/R/convergence.R +++ b/R/convergence.R @@ -23,6 +23,7 @@ #' | [mcse_sd()] | Monte Carlo standard error for standard deviations | #' | [rhat_basic()] | Basic version of Rhat | #' | [rhat()] | Improved, rank-based version of Rhat | +#' | [rhat_nested()] | Rhat for use with many short chains | #' | [rstar()] | R* diagnostic | #' #' @return diff --git a/R/nested_rhat.R b/R/nested_rhat.R new file mode 100644 index 00000000..c1470e59 --- /dev/null +++ b/R/nested_rhat.R @@ -0,0 +1,101 @@ +.add_superchain_ids <- function(draws, superchain_ids) { + + # determine size of dims + chains_per_superchain <- table(superchain_ids) + num_chains_per_superchain <- max(chains_per_superchain) + num_iterations <- dim(draws)[1] + num_superchains <- max(superchain_ids) + + # create new empty array with correct dims + new_draws <- array( + NA, + dim = c( + num_iterations, + num_chains_per_superchain, + num_superchains) + ) + + # add dim names + dimnames(new_draws) <- list( + iteration = 1:num_iterations, + chain = 1:num_chains_per_superchain, + superchain = 1:num_superchains + ) + + # assign chains to superchains + for (k in 1:num_superchains) { + chains_in_superchain <- which(superchain_ids == k) + new_draws[, , k] <- draws[, chains_in_superchain] + } + + return(new_draws) +} + +#' Nested Rhat convergence diagnostic +#' +#' Compute the Nested Rhat convergence diagnostic for a single variable +#' proposed in Margossian et al. (2023). +#' +#' @family diagnostics +#' @template args-conv +#' @param superchain_ids (numeric) Vector of length nchains specifying +#' which superchain each chain belongs to +#' @template args-methods-dots +#' @template return-conv +#' @template ref-margossian-nestedrhat-2023 +#' +#' @examples +#' mu <- extract_variable_matrix(example_draws(), "mu") +#' rhat_nested(mu, superchain_ids = c(1,1,2,2)) +#' +#' d <- as_draws_rvars(example_draws("multi_normal")) +#' rhat(d$Sigma, superchain_ids = c(1,1,2,2)) +#' +#' @export +rhat_nested <- function(x, superchain_ids, ...) UseMethod("rhat_nested") + +#' @rdname rhat_nested +#' @export +rhat_nested.default <- function(x, superchain_ids, ...) { + + x <- .add_superchain_ids(x, superchain_ids) + .rhat_nested(x) +} + +.rhat_nested <- function(x, ...) { + + array_dims <- dim(x) + ndraws <- array_dims[1] + nchains <- array_dims[2] + nsuperchains <- array_dims[3] + + superchain_mean <- apply(x, 3, mean) + chain_mean <- apply(x, c(2, 3), mean) + chain_var <- apply(x, c(2, 3), var) + + overall_mean <- mean(superchain_mean) + + if (nchains == 1) { + var_between_chain <- 0 + } else { + var_between_chain <- matrixStats::colVars( + chain_mean, + center = superchain_mean + ) + } + + if (ndraws == 1) { + var_within_chain <- 0 + } else { + var_within_chain <- colMeans(chain_var) + } + + var_between_superchain <- matrixStats::colVars( + as.matrix(superchain_mean), + center = overall_mean + ) + + var_within_superchain <- mean(var_within_chain + var_between_chain) + + sqrt(1 + var_between_superchain / var_within_superchain) +} diff --git a/man-roxygen/ref-margossian-nestedrhat-2023.R b/man-roxygen/ref-margossian-nestedrhat-2023.R new file mode 100644 index 00000000..9cdeed04 --- /dev/null +++ b/man-roxygen/ref-margossian-nestedrhat-2023.R @@ -0,0 +1,5 @@ +#' @references +#' Charles C. Margossian, Matthew D. Hoffman, Pavel Sountsov, Lionel +#' Riou-Durand, Aki Vehtari and Andrew Gelman (2023). Nested R-hat: +#' Assessing the convergence of Markov chain Monte Carlo when running +#' many short chains. arxiv:arXiv:2110.13017 diff --git a/man/diagnostics.Rd b/man/diagnostics.Rd index 2d5d7e31..43d54186 100644 --- a/man/diagnostics.Rd +++ b/man/diagnostics.Rd @@ -23,6 +23,7 @@ A list of available diagnostics and links to their individual help pages. \code{\link[=mcse_sd]{mcse_sd()}} \tab Monte Carlo standard error for standard deviations \cr \code{\link[=rhat_basic]{rhat_basic()}} \tab Basic version of Rhat \cr \code{\link[=rhat]{rhat()}} \tab Improved, rank-based version of Rhat \cr + \code{\link[=rhat_nested]{rhat_nested()}} \tab Rhat for use with many short chains \cr \code{\link[=rstar]{rstar()}} \tab R* diagnostic \cr } } diff --git a/man/ess_basic.Rd b/man/ess_basic.Rd index 548623b7..e300ad5e 100755 --- a/man/ess_basic.Rd +++ b/man/ess_basic.Rd @@ -80,6 +80,7 @@ Other diagnostics: \code{\link{mcse_quantile}()}, \code{\link{mcse_sd}()}, \code{\link{rhat_basic}()}, +\code{\link{rhat_nested}()}, \code{\link{rhat}()}, \code{\link{rstar}()} } diff --git a/man/ess_bulk.Rd b/man/ess_bulk.Rd index c518f97f..adf3faf8 100755 --- a/man/ess_bulk.Rd +++ b/man/ess_bulk.Rd @@ -73,6 +73,7 @@ Other diagnostics: \code{\link{mcse_quantile}()}, \code{\link{mcse_sd}()}, \code{\link{rhat_basic}()}, +\code{\link{rhat_nested}()}, \code{\link{rhat}()}, \code{\link{rstar}()} } diff --git a/man/ess_quantile.Rd b/man/ess_quantile.Rd index 704db402..6bfc3cdf 100755 --- a/man/ess_quantile.Rd +++ b/man/ess_quantile.Rd @@ -82,6 +82,7 @@ Other diagnostics: \code{\link{mcse_quantile}()}, \code{\link{mcse_sd}()}, \code{\link{rhat_basic}()}, +\code{\link{rhat_nested}()}, \code{\link{rhat}()}, \code{\link{rstar}()} } diff --git a/man/ess_sd.Rd b/man/ess_sd.Rd index fdff961c..2344211a 100755 --- a/man/ess_sd.Rd +++ b/man/ess_sd.Rd @@ -67,6 +67,7 @@ Other diagnostics: \code{\link{mcse_quantile}()}, \code{\link{mcse_sd}()}, \code{\link{rhat_basic}()}, +\code{\link{rhat_nested}()}, \code{\link{rhat}()}, \code{\link{rstar}()} } diff --git a/man/ess_tail.Rd b/man/ess_tail.Rd index 691bc7e7..f211f7aa 100755 --- a/man/ess_tail.Rd +++ b/man/ess_tail.Rd @@ -73,6 +73,7 @@ Other diagnostics: \code{\link{mcse_quantile}()}, \code{\link{mcse_sd}()}, \code{\link{rhat_basic}()}, +\code{\link{rhat_nested}()}, \code{\link{rhat}()}, \code{\link{rstar}()} } diff --git a/man/mcse_mean.Rd b/man/mcse_mean.Rd index da52f9d9..9afaa7b7 100755 --- a/man/mcse_mean.Rd +++ b/man/mcse_mean.Rd @@ -64,6 +64,7 @@ Other diagnostics: \code{\link{mcse_quantile}()}, \code{\link{mcse_sd}()}, \code{\link{rhat_basic}()}, +\code{\link{rhat_nested}()}, \code{\link{rhat}()}, \code{\link{rstar}()} } diff --git a/man/mcse_quantile.Rd b/man/mcse_quantile.Rd index ad50e233..cc4f9685 100755 --- a/man/mcse_quantile.Rd +++ b/man/mcse_quantile.Rd @@ -79,6 +79,7 @@ Other diagnostics: \code{\link{mcse_mean}()}, \code{\link{mcse_sd}()}, \code{\link{rhat_basic}()}, +\code{\link{rhat_nested}()}, \code{\link{rhat}()}, \code{\link{rstar}()} } diff --git a/man/mcse_sd.Rd b/man/mcse_sd.Rd index 571b50ab..7e322864 100755 --- a/man/mcse_sd.Rd +++ b/man/mcse_sd.Rd @@ -69,6 +69,7 @@ Other diagnostics: \code{\link{mcse_mean}()}, \code{\link{mcse_quantile}()}, \code{\link{rhat_basic}()}, +\code{\link{rhat_nested}()}, \code{\link{rhat}()}, \code{\link{rstar}()} } diff --git a/man/rhat.Rd b/man/rhat.Rd index 39b9286d..fed2c14a 100755 --- a/man/rhat.Rd +++ b/man/rhat.Rd @@ -68,6 +68,7 @@ Other diagnostics: \code{\link{mcse_quantile}()}, \code{\link{mcse_sd}()}, \code{\link{rhat_basic}()}, +\code{\link{rhat_nested}()}, \code{\link{rstar}()} } \concept{diagnostics} diff --git a/man/rhat_basic.Rd b/man/rhat_basic.Rd index 68dac7a3..8a94efb3 100755 --- a/man/rhat_basic.Rd +++ b/man/rhat_basic.Rd @@ -75,6 +75,7 @@ Other diagnostics: \code{\link{mcse_mean}()}, \code{\link{mcse_quantile}()}, \code{\link{mcse_sd}()}, +\code{\link{rhat_nested}()}, \code{\link{rhat}()}, \code{\link{rstar}()} } diff --git a/man/rstar.Rd b/man/rstar.Rd index 97ec5aac..87e8e372 100644 --- a/man/rstar.Rd +++ b/man/rstar.Rd @@ -116,6 +116,7 @@ Other diagnostics: \code{\link{mcse_quantile}()}, \code{\link{mcse_sd}()}, \code{\link{rhat_basic}()}, +\code{\link{rhat_nested}()}, \code{\link{rhat}()} } \concept{diagnostics} From df28355f364215472b9f4ce55f99cb7404c1e85e Mon Sep 17 00:00:00 2001 From: n-kall Date: Wed, 11 Oct 2023 15:38:30 +0300 Subject: [PATCH 02/13] add nested rhat documentation --- man/rhat_nested.Rd | 72 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 man/rhat_nested.Rd diff --git a/man/rhat_nested.Rd b/man/rhat_nested.Rd new file mode 100644 index 00000000..875f19fa --- /dev/null +++ b/man/rhat_nested.Rd @@ -0,0 +1,72 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/nested_rhat.R +\name{rhat_nested} +\alias{rhat_nested} +\alias{rhat_nested.default} +\title{Nested Rhat convergence diagnostic} +\usage{ +rhat_nested(x, superchain_ids, ...) + +\method{rhat_nested}{default}(x, superchain_ids, ...) +} +\arguments{ +\item{x}{(multiple options) One of: +\itemize{ +\item A matrix of draws for a single variable (iterations x chains). See +\code{\link[=extract_variable_matrix]{extract_variable_matrix()}}. +\item An \code{\link{rvar}}. +}} + +\item{superchain_ids}{(numeric) Vector of length nchains specifying +which superchain each chain belongs to} + +\item{...}{Arguments passed to individual methods (if applicable).} +} +\value{ +If the input is an array, returns a single numeric value. If any of the draws +is non-finite, that is, \code{NA}, \code{NaN}, \code{Inf}, or \code{-Inf}, the returned output +will be (numeric) \code{NA}. Also, if all draws within any of the chains of a +variable are the same (constant), the returned output will be (numeric) \code{NA} +as well. The reason for the latter is that, for constant draws, we cannot +distinguish between variables that are supposed to be constant (e.g., a +diagonal element of a correlation matrix is always 1) or variables that just +happened to be constant because of a failure of convergence or other problems +in the sampling process. + +If the input is an \code{\link{rvar}}, returns an array of the same dimensions as the +\code{\link{rvar}}, where each element is equal to the value that would be returned by +passing the draws array for that element of the \code{\link{rvar}} to this function. +} +\description{ +Compute the Nested Rhat convergence diagnostic for a single variable +proposed in Margossian et al. (2023). +} +\examples{ +mu <- extract_variable_matrix(example_draws(), "mu") +rhat_nested(mu, superchain_ids = c(1,1,2,2)) + +d <- as_draws_rvars(example_draws("multi_normal")) +rhat(d$Sigma, superchain_ids = c(1,1,2,2)) + +} +\references{ +Charles C. Margossian, Matthew D. Hoffman, Pavel Sountsov, Lionel +Riou-Durand, Aki Vehtari and Andrew Gelman (2023). Nested R-hat: +Assessing the convergence of Markov chain Monte Carlo when running +many short chains. arxiv:arXiv:2110.13017 +} +\seealso{ +Other diagnostics: +\code{\link{ess_basic}()}, +\code{\link{ess_bulk}()}, +\code{\link{ess_quantile}()}, +\code{\link{ess_sd}()}, +\code{\link{ess_tail}()}, +\code{\link{mcse_mean}()}, +\code{\link{mcse_quantile}()}, +\code{\link{mcse_sd}()}, +\code{\link{rhat_basic}()}, +\code{\link{rhat}()}, +\code{\link{rstar}()} +} +\concept{diagnostics} From e02321db4df7da4a0dff066e84613ad1cc8e4e39 Mon Sep 17 00:00:00 2001 From: n-kall Date: Fri, 13 Oct 2023 12:56:59 +0300 Subject: [PATCH 03/13] improve memory efficiency of rhat_nested by not creating a new array --- R/nested_rhat.R | 69 ++++++++++++++----------------------------------- 1 file changed, 19 insertions(+), 50 deletions(-) diff --git a/R/nested_rhat.R b/R/nested_rhat.R index c1470e59..f2eebff5 100644 --- a/R/nested_rhat.R +++ b/R/nested_rhat.R @@ -1,36 +1,3 @@ -.add_superchain_ids <- function(draws, superchain_ids) { - - # determine size of dims - chains_per_superchain <- table(superchain_ids) - num_chains_per_superchain <- max(chains_per_superchain) - num_iterations <- dim(draws)[1] - num_superchains <- max(superchain_ids) - - # create new empty array with correct dims - new_draws <- array( - NA, - dim = c( - num_iterations, - num_chains_per_superchain, - num_superchains) - ) - - # add dim names - dimnames(new_draws) <- list( - iteration = 1:num_iterations, - chain = 1:num_chains_per_superchain, - superchain = 1:num_superchains - ) - - # assign chains to superchains - for (k in 1:num_superchains) { - chains_in_superchain <- which(superchain_ids == k) - new_draws[, , k] <- draws[, chains_in_superchain] - } - - return(new_draws) -} - #' Nested Rhat convergence diagnostic #' #' Compute the Nested Rhat convergence diagnostic for a single variable @@ -57,37 +24,39 @@ rhat_nested <- function(x, superchain_ids, ...) UseMethod("rhat_nested") #' @rdname rhat_nested #' @export rhat_nested.default <- function(x, superchain_ids, ...) { + .rhat_nested(x, superchain_ids = superchain_ids) +} - x <- .add_superchain_ids(x, superchain_ids) - .rhat_nested(x) +#' @rdname rhat_nested +#' @export +rhat_nested.rvar <- function(x, superchain_ids, ...) { + summarise_rvar_by_element_with_chains(x, rhat_nested, superchain_ids = superchain_ids, ...) } -.rhat_nested <- function(x, ...) { + +.rhat_nested <- function(x, superchain_ids, ...) { array_dims <- dim(x) ndraws <- array_dims[1] - nchains <- array_dims[2] - nsuperchains <- array_dims[3] - - superchain_mean <- apply(x, 3, mean) - chain_mean <- apply(x, c(2, 3), mean) - chain_var <- apply(x, c(2, 3), var) + nchains_per_superchain <- max(table(superchain_ids)) + nsuperchains <- length(unique(superchain_ids)) + + superchain_mean <- sapply(unique(superchain_ids), function(k) mean(x[, which(superchain_ids == k)])) + + chain_mean <- matrix(matrixStats::colMeans2(x), nrow = 1) + chain_var <- matrixStats::colVars(x, center=chain_mean) overall_mean <- mean(superchain_mean) - if (nchains == 1) { + if (nchains_per_superchain == 1) { var_between_chain <- 0 } else { - var_between_chain <- matrixStats::colVars( - chain_mean, - center = superchain_mean - ) + var_between_chain <- sapply(unique(superchain_ids), function(k) var(chain_mean[, which(superchain_ids == k)])) } - if (ndraws == 1) { var_within_chain <- 0 } else { - var_within_chain <- colMeans(chain_var) + var_within_chain <- sapply(unique(superchain_ids), function(k) mean(chain_var[which(superchain_ids == k)])) } var_between_superchain <- matrixStats::colVars( @@ -97,5 +66,5 @@ rhat_nested.default <- function(x, superchain_ids, ...) { var_within_superchain <- mean(var_within_chain + var_between_chain) - sqrt(1 + var_between_superchain / var_within_superchain) + sqrt(1 + var_between_superchain / var_within_superchain) } From 3632a2f5ecee69efee7e57735833cc4253b3f08f Mon Sep 17 00:00:00 2001 From: n-kall Date: Fri, 13 Oct 2023 12:57:59 +0300 Subject: [PATCH 04/13] rhat_nested documentation --- NAMESPACE | 1 + man/rhat_nested.Rd | 3 +++ 2 files changed, 4 insertions(+) diff --git a/NAMESPACE b/NAMESPACE index ea676bf3..2d81d65d 100755 --- a/NAMESPACE +++ b/NAMESPACE @@ -274,6 +274,7 @@ S3method(rhat,rvar) S3method(rhat_basic,default) S3method(rhat_basic,rvar) S3method(rhat_nested,default) +S3method(rhat_nested,rvar) S3method(sd,default) S3method(sd,rvar) S3method(split_chains,draws) diff --git a/man/rhat_nested.Rd b/man/rhat_nested.Rd index 875f19fa..d741cd7c 100644 --- a/man/rhat_nested.Rd +++ b/man/rhat_nested.Rd @@ -3,11 +3,14 @@ \name{rhat_nested} \alias{rhat_nested} \alias{rhat_nested.default} +\alias{rhat_nested.rvar} \title{Nested Rhat convergence diagnostic} \usage{ rhat_nested(x, superchain_ids, ...) \method{rhat_nested}{default}(x, superchain_ids, ...) + +\method{rhat_nested}{rvar}(x, superchain_ids, ...) } \arguments{ \item{x}{(multiple options) One of: From 9e5a44f5de728c42bf58d690f033ce0170d95e86 Mon Sep 17 00:00:00 2001 From: n-kall Date: Fri, 13 Oct 2023 14:25:12 +0300 Subject: [PATCH 05/13] cleanup nested rhat and add test --- R/nested_rhat.R | 48 +++++++++++++++++++++---------- man/rhat_nested.Rd | 2 +- tests/testthat/test-rhat_nested.R | 11 +++++++ 3 files changed, 45 insertions(+), 16 deletions(-) create mode 100644 tests/testthat/test-rhat_nested.R diff --git a/R/nested_rhat.R b/R/nested_rhat.R index f2eebff5..bfca55c6 100644 --- a/R/nested_rhat.R +++ b/R/nested_rhat.R @@ -16,7 +16,7 @@ #' rhat_nested(mu, superchain_ids = c(1,1,2,2)) #' #' d <- as_draws_rvars(example_draws("multi_normal")) -#' rhat(d$Sigma, superchain_ids = c(1,1,2,2)) +#' rhat_nested(d$Sigma, superchain_ids = c(1,1,2,2)) #' #' @export rhat_nested <- function(x, superchain_ids, ...) UseMethod("rhat_nested") @@ -33,38 +33,56 @@ rhat_nested.rvar <- function(x, superchain_ids, ...) { summarise_rvar_by_element_with_chains(x, rhat_nested, superchain_ids = superchain_ids, ...) } - .rhat_nested <- function(x, superchain_ids, ...) { - array_dims <- dim(x) - ndraws <- array_dims[1] + x <- as.matrix(x) + niterations <- NROW(x) nchains_per_superchain <- max(table(superchain_ids)) - nsuperchains <- length(unique(superchain_ids)) + superchains <- unique(superchain_ids) + + # mean and variance of chains calculated as in rhat + chain_mean <- matrixStats::colMeans2(x) + chain_var <- matrixStats::colVars(x, center = chain_mean) - superchain_mean <- sapply(unique(superchain_ids), function(k) mean(x[, which(superchain_ids == k)])) + # mean of superchains calculated by only including specified chains + # (equation 15 in Margossian et al. 2023) + superchain_mean <- sapply( + superchains, + function(k) mean(x[, which(superchain_ids == k)]) + ) - chain_mean <- matrix(matrixStats::colMeans2(x), nrow = 1) - chain_var <- matrixStats::colVars(x, center=chain_mean) - + # overall mean (as defined in equation 16 in Margossian et al. 2023) overall_mean <- mean(superchain_mean) + # between-chain variance estimate (B_k in equation 18 in Margossian et al. 2023) if (nchains_per_superchain == 1) { var_between_chain <- 0 } else { - var_between_chain <- sapply(unique(superchain_ids), function(k) var(chain_mean[, which(superchain_ids == k)])) + var_between_chain <- sapply( + superchains, + function(k) var(chain_mean[which(superchain_ids == k)]) + ) } - if (ndraws == 1) { + + # within-chain variance estimate (W_k in equation 18 in Margossian et al. 2023) + if (niterations == 1) { var_within_chain <- 0 } else { - var_within_chain <- sapply(unique(superchain_ids), function(k) mean(chain_var[which(superchain_ids == k)])) + var_within_chain <- sapply( + superchains, + function(k) mean(chain_var[which(superchain_ids == k)]) + ) } - + + # between-superchain variance (nB in equation 17 in Margossian et al. 2023) var_between_superchain <- matrixStats::colVars( as.matrix(superchain_mean), center = overall_mean ) - + + # within-superchain variance (nW in equation 18 in Margossian et al. 2023) var_within_superchain <- mean(var_within_chain + var_between_chain) - sqrt(1 + var_between_superchain / var_within_superchain) + # nested Rhat (nRhat in equation 19 in Margossian et al. 2023) + sqrt(1 + var_between_superchain / var_within_superchain) } diff --git a/man/rhat_nested.Rd b/man/rhat_nested.Rd index d741cd7c..5988b91a 100644 --- a/man/rhat_nested.Rd +++ b/man/rhat_nested.Rd @@ -49,7 +49,7 @@ mu <- extract_variable_matrix(example_draws(), "mu") rhat_nested(mu, superchain_ids = c(1,1,2,2)) d <- as_draws_rvars(example_draws("multi_normal")) -rhat(d$Sigma, superchain_ids = c(1,1,2,2)) +rhat_nested(d$Sigma, superchain_ids = c(1,1,2,2)) } \references{ diff --git a/tests/testthat/test-rhat_nested.R b/tests/testthat/test-rhat_nested.R new file mode 100644 index 00000000..1d71ea9e --- /dev/null +++ b/tests/testthat/test-rhat_nested.R @@ -0,0 +1,11 @@ +test_that("rhat_nested returns reasonable values", { + tau <- extract_variable_matrix(example_draws(), "tau") + + rhat <- rhat_nested(tau, superchain_ids = c(1, 2, 1, 2)) + expect_true(rhat > 0.99 & rhat < 1.05) + + rhat <- rhat_nested(tau, superchain_ids = c(1, 2, 1, 2)) + expect_true(rhat > 0.99 & rhat < 1.05) +}) + + From 32cd38b5e5f151cbd52a513671d093ffcc4c8c2c Mon Sep 17 00:00:00 2001 From: n-kall Date: Wed, 25 Oct 2023 13:04:07 +0300 Subject: [PATCH 06/13] add input checks and tests for rhat_nested --- R/nested_rhat.R | 24 +++++++++++++++++++++++- tests/testthat/test-rhat_nested.R | 20 ++++++++++++++++++++ 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/R/nested_rhat.R b/R/nested_rhat.R index bfca55c6..e62533ae 100644 --- a/R/nested_rhat.R +++ b/R/nested_rhat.R @@ -34,12 +34,34 @@ rhat_nested.rvar <- function(x, superchain_ids, ...) { } .rhat_nested <- function(x, superchain_ids, ...) { + if (should_return_NA(x)) { + return(NA_real_) + } x <- as.matrix(x) niterations <- NROW(x) - nchains_per_superchain <- max(table(superchain_ids)) + nchains <- NCOL(x) + + + # check that all chains are assigned a superchain + if (length(superchain_ids) != nchains) { + warning_no_call("Length of superchain_ids not equal to number of chains, returning NA.") + return(NA_real_) + } + + + # check that superchains are equal length + superchain_id_table <- table(superchain_ids) + nchains_per_superchain <- max(superchain_id_table) + + if (nchains_per_superchain != min(superchain_id_table)) { + warning_no_call("Number of chains per superchain is not the same for each superchain, returning NA.") + return(NA_real_) + } + superchains <- unique(superchain_ids) + # mean and variance of chains calculated as in rhat chain_mean <- matrixStats::colMeans2(x) chain_var <- matrixStats::colVars(x, center = chain_mean) diff --git a/tests/testthat/test-rhat_nested.R b/tests/testthat/test-rhat_nested.R index 1d71ea9e..15b3a929 100644 --- a/tests/testthat/test-rhat_nested.R +++ b/tests/testthat/test-rhat_nested.R @@ -9,3 +9,23 @@ test_that("rhat_nested returns reasonable values", { }) +test_that("rhat_nested handles special cases correctly", { + set.seed(1234) + x <- c(rnorm(10), NA) + expect_true(is.na(rhat_nested(x, superchain_ids = c(1)))) + + x <- c(rnorm(10), Inf) + expect_true(is.na(rhat_nested(x, superchain_ids = c(1,2,1,2)))) + + tau <- extract_variable_matrix(example_draws(), "tau") + expect_warning( + rhat_nested(tau, superchain_ids = c(1,1,1,3)), + "Number of chains per superchain is not the same for each superchain, returning NA." + ) + + tau <- extract_variable_matrix(example_draws(), "tau") + expect_warning( + rhat_nested(tau, superchain_ids = c(1,2)), + "Length of superchain_ids not equal to number of chains, returning NA." + ) +}) From 04f30ab47afdefcfdcc18a67fabfab88d1a8e837 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Paul-Christian=20B=C3=BCrkner?= Date: Wed, 25 Oct 2023 12:20:45 +0200 Subject: [PATCH 07/13] minor cleaning --- R/nested_rhat.R | 22 +++++++++++----------- man/rhat_nested.Rd | 10 +++++----- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/R/nested_rhat.R b/R/nested_rhat.R index e62533ae..72d8280d 100644 --- a/R/nested_rhat.R +++ b/R/nested_rhat.R @@ -13,13 +13,13 @@ #' #' @examples #' mu <- extract_variable_matrix(example_draws(), "mu") -#' rhat_nested(mu, superchain_ids = c(1,1,2,2)) +#' rhat_nested(mu, superchain_ids = c(1, 1, 2, 2)) #' #' d <- as_draws_rvars(example_draws("multi_normal")) -#' rhat_nested(d$Sigma, superchain_ids = c(1,1,2,2)) +#' rhat_nested(d$Sigma, superchain_ids = c(1, 1, 2, 2)) #' #' @export -rhat_nested <- function(x, superchain_ids, ...) UseMethod("rhat_nested") +rhat_nested <- function(x, ...) UseMethod("rhat_nested") #' @rdname rhat_nested #' @export @@ -30,7 +30,9 @@ rhat_nested.default <- function(x, superchain_ids, ...) { #' @rdname rhat_nested #' @export rhat_nested.rvar <- function(x, superchain_ids, ...) { - summarise_rvar_by_element_with_chains(x, rhat_nested, superchain_ids = superchain_ids, ...) + summarise_rvar_by_element_with_chains( + x, rhat_nested, superchain_ids = superchain_ids, ... + ) } .rhat_nested <- function(x, superchain_ids, ...) { @@ -42,26 +44,25 @@ rhat_nested.rvar <- function(x, superchain_ids, ...) { niterations <- NROW(x) nchains <- NCOL(x) - # check that all chains are assigned a superchain if (length(superchain_ids) != nchains) { - warning_no_call("Length of superchain_ids not equal to number of chains, returning NA.") + warning_no_call("Length of superchain_ids not equal to number of chains, ", + "returning NA.") return(NA_real_) } - # check that superchains are equal length superchain_id_table <- table(superchain_ids) nchains_per_superchain <- max(superchain_id_table) if (nchains_per_superchain != min(superchain_id_table)) { - warning_no_call("Number of chains per superchain is not the same for each superchain, returning NA.") + warning_no_call("Number of chains per superchain is not the same for ", + "each superchain, returning NA.") return(NA_real_) } superchains <- unique(superchain_ids) - # mean and variance of chains calculated as in rhat chain_mean <- matrixStats::colMeans2(x) chain_var <- matrixStats::colVars(x, center = chain_mean) @@ -69,8 +70,7 @@ rhat_nested.rvar <- function(x, superchain_ids, ...) { # mean of superchains calculated by only including specified chains # (equation 15 in Margossian et al. 2023) superchain_mean <- sapply( - superchains, - function(k) mean(x[, which(superchain_ids == k)]) + superchains, function(k) mean(x[, which(superchain_ids == k)]) ) # overall mean (as defined in equation 16 in Margossian et al. 2023) diff --git a/man/rhat_nested.Rd b/man/rhat_nested.Rd index 5988b91a..c81a3cc8 100644 --- a/man/rhat_nested.Rd +++ b/man/rhat_nested.Rd @@ -6,7 +6,7 @@ \alias{rhat_nested.rvar} \title{Nested Rhat convergence diagnostic} \usage{ -rhat_nested(x, superchain_ids, ...) +rhat_nested(x, ...) \method{rhat_nested}{default}(x, superchain_ids, ...) @@ -20,10 +20,10 @@ rhat_nested(x, superchain_ids, ...) \item An \code{\link{rvar}}. }} +\item{...}{Arguments passed to individual methods (if applicable).} + \item{superchain_ids}{(numeric) Vector of length nchains specifying which superchain each chain belongs to} - -\item{...}{Arguments passed to individual methods (if applicable).} } \value{ If the input is an array, returns a single numeric value. If any of the draws @@ -46,10 +46,10 @@ proposed in Margossian et al. (2023). } \examples{ mu <- extract_variable_matrix(example_draws(), "mu") -rhat_nested(mu, superchain_ids = c(1,1,2,2)) +rhat_nested(mu, superchain_ids = c(1, 1, 2, 2)) d <- as_draws_rvars(example_draws("multi_normal")) -rhat_nested(d$Sigma, superchain_ids = c(1,1,2,2)) +rhat_nested(d$Sigma, superchain_ids = c(1, 1, 2, 2)) } \references{ From 67b9e0bf6a2ccd20e898333bd44ffa27ecdd7450 Mon Sep 17 00:00:00 2001 From: n-kall Date: Wed, 25 Oct 2023 15:21:37 +0300 Subject: [PATCH 08/13] simplify between superchain variance calculation in nested rhat --- R/nested_rhat.R | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/R/nested_rhat.R b/R/nested_rhat.R index 72d8280d..c97a2a54 100644 --- a/R/nested_rhat.R +++ b/R/nested_rhat.R @@ -97,10 +97,7 @@ rhat_nested.rvar <- function(x, superchain_ids, ...) { } # between-superchain variance (nB in equation 17 in Margossian et al. 2023) - var_between_superchain <- matrixStats::colVars( - as.matrix(superchain_mean), - center = overall_mean - ) + var_between_superchain <- var(superchain_mean) # within-superchain variance (nW in equation 18 in Margossian et al. 2023) var_within_superchain <- mean(var_within_chain + var_between_chain) From 86b64c8db5531848837474bc761c89dd83686e66 Mon Sep 17 00:00:00 2001 From: n-kall Date: Wed, 25 Oct 2023 15:31:43 +0300 Subject: [PATCH 09/13] remove unnecessary variable in rhat_nested --- R/nested_rhat.R | 3 --- 1 file changed, 3 deletions(-) diff --git a/R/nested_rhat.R b/R/nested_rhat.R index c97a2a54..7c88d5e2 100644 --- a/R/nested_rhat.R +++ b/R/nested_rhat.R @@ -73,9 +73,6 @@ rhat_nested.rvar <- function(x, superchain_ids, ...) { superchains, function(k) mean(x[, which(superchain_ids == k)]) ) - # overall mean (as defined in equation 16 in Margossian et al. 2023) - overall_mean <- mean(superchain_mean) - # between-chain variance estimate (B_k in equation 18 in Margossian et al. 2023) if (nchains_per_superchain == 1) { var_between_chain <- 0 From 17bf7575e43d499518d50224b6907e45dc4c775e Mon Sep 17 00:00:00 2001 From: n-kall Date: Wed, 25 Oct 2023 15:32:00 +0300 Subject: [PATCH 10/13] cleanup rhat nested tests --- tests/testthat/test-rhat_nested.R | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/testthat/test-rhat_nested.R b/tests/testthat/test-rhat_nested.R index 15b3a929..a0d49ac1 100644 --- a/tests/testthat/test-rhat_nested.R +++ b/tests/testthat/test-rhat_nested.R @@ -1,11 +1,11 @@ test_that("rhat_nested returns reasonable values", { tau <- extract_variable_matrix(example_draws(), "tau") - rhat <- rhat_nested(tau, superchain_ids = c(1, 2, 1, 2)) - expect_true(rhat > 0.99 & rhat < 1.05) + nested_rhat <- rhat_nested(tau, superchain_ids = c(1, 2, 1, 2)) + expect_true(nested_rhat > 1 & nested_rhat < 1.05) - rhat <- rhat_nested(tau, superchain_ids = c(1, 2, 1, 2)) - expect_true(rhat > 0.99 & rhat < 1.05) + nested_rhat <- rhat_nested(tau, superchain_ids = c(1, 2, 1, 2)) + expect_true(nested_rhat > 1 & nested_rhat < 1.05) }) @@ -15,17 +15,17 @@ test_that("rhat_nested handles special cases correctly", { expect_true(is.na(rhat_nested(x, superchain_ids = c(1)))) x <- c(rnorm(10), Inf) - expect_true(is.na(rhat_nested(x, superchain_ids = c(1,2,1,2)))) + expect_true(is.na(rhat_nested(x, superchain_ids = c(1, 2, 1, 2)))) tau <- extract_variable_matrix(example_draws(), "tau") expect_warning( - rhat_nested(tau, superchain_ids = c(1,1,1,3)), + rhat_nested(tau, superchain_ids = c(1, 1, 1, 3)), "Number of chains per superchain is not the same for each superchain, returning NA." ) tau <- extract_variable_matrix(example_draws(), "tau") expect_warning( - rhat_nested(tau, superchain_ids = c(1,2)), + rhat_nested(tau, superchain_ids = c(1, 2)), "Length of superchain_ids not equal to number of chains, returning NA." ) }) From d521daea064f8dc90a5196e9d6286af3cd2b4477 Mon Sep 17 00:00:00 2001 From: n-kall Date: Sat, 28 Oct 2023 12:15:27 +0300 Subject: [PATCH 11/13] improve nested rhat documentation add details section with explanation of superchains and discrepancy between Rhat and nested Rhat calculation, specify version of preprint --- R/nested_rhat.R | 21 +++++++++++++++--- man-roxygen/ref-margossian-nestedrhat-2023.R | 2 +- man/rhat_nested.Rd | 23 ++++++++++++++++---- 3 files changed, 38 insertions(+), 8 deletions(-) diff --git a/R/nested_rhat.R b/R/nested_rhat.R index 7c88d5e2..e80af4aa 100644 --- a/R/nested_rhat.R +++ b/R/nested_rhat.R @@ -1,13 +1,28 @@ #' Nested Rhat convergence diagnostic #' -#' Compute the Nested Rhat convergence diagnostic for a single variable -#' proposed in Margossian et al. (2023). +#' Compute the nested Rhat convergence diagnostic for a single +#' variable as proposed in Margossian et al. (2023). #' #' @family diagnostics #' @template args-conv #' @param superchain_ids (numeric) Vector of length nchains specifying -#' which superchain each chain belongs to +#' which superchain each chain belongs to. There should be equal +#' numbers of chains in each superchain. All chains within the same +#' superchain are assumed to have been initialized at the same +#' point. #' @template args-methods-dots +#' +#' @details Nested Rhat is a convergence diagnostic useful when +#' running many short chains. It calculated on superchains, which +#' are groups of chains that have been initialized at the same +#' point. +#' +#' Note that there is a slight difference in the calculation of Rhat +#' and nested Rhat, as nested Rhat is lower bounded by 1. This means +#' that nested Rhat with one chain per superchain will not be +#' exactly equal to basic Rhat (see Footnote 1 in Margossian et +#' al. (2023)). +#' #' @template return-conv #' @template ref-margossian-nestedrhat-2023 #' diff --git a/man-roxygen/ref-margossian-nestedrhat-2023.R b/man-roxygen/ref-margossian-nestedrhat-2023.R index 9cdeed04..474881e2 100644 --- a/man-roxygen/ref-margossian-nestedrhat-2023.R +++ b/man-roxygen/ref-margossian-nestedrhat-2023.R @@ -2,4 +2,4 @@ #' Charles C. Margossian, Matthew D. Hoffman, Pavel Sountsov, Lionel #' Riou-Durand, Aki Vehtari and Andrew Gelman (2023). Nested R-hat: #' Assessing the convergence of Markov chain Monte Carlo when running -#' many short chains. arxiv:arXiv:2110.13017 +#' many short chains. arxiv:arXiv:2110.13017 (version 4) diff --git a/man/rhat_nested.Rd b/man/rhat_nested.Rd index c81a3cc8..4ad8f3dd 100644 --- a/man/rhat_nested.Rd +++ b/man/rhat_nested.Rd @@ -23,7 +23,10 @@ rhat_nested(x, ...) \item{...}{Arguments passed to individual methods (if applicable).} \item{superchain_ids}{(numeric) Vector of length nchains specifying -which superchain each chain belongs to} +which superchain each chain belongs to. There should be equal +numbers of chains in each superchain. All chains within the same +superchain are assumed to have been initialized at the same +point.} } \value{ If the input is an array, returns a single numeric value. If any of the draws @@ -41,8 +44,20 @@ If the input is an \code{\link{rvar}}, returns an array of the same dimensions a passing the draws array for that element of the \code{\link{rvar}} to this function. } \description{ -Compute the Nested Rhat convergence diagnostic for a single variable -proposed in Margossian et al. (2023). +Compute the nested Rhat convergence diagnostic for a single +variable as proposed in Margossian et al. (2023). +} +\details{ +Nested Rhat is a convergence diagnostic useful when +running many short chains. It calculated on superchains, which +are groups of chains that have been initialized at the same +point. + +Note that there is a slight difference in the calculation of Rhat +and nested Rhat, as nested Rhat is lower bounded by 1. This means +that nested Rhat with one chain per superchain will not be +exactly equal to basic Rhat (see Footnote 1 in Margossian et +al. (2023)). } \examples{ mu <- extract_variable_matrix(example_draws(), "mu") @@ -56,7 +71,7 @@ rhat_nested(d$Sigma, superchain_ids = c(1, 1, 2, 2)) Charles C. Margossian, Matthew D. Hoffman, Pavel Sountsov, Lionel Riou-Durand, Aki Vehtari and Andrew Gelman (2023). Nested R-hat: Assessing the convergence of Markov chain Monte Carlo when running -many short chains. arxiv:arXiv:2110.13017 +many short chains. arxiv:arXiv:2110.13017 (version 4) } \seealso{ Other diagnostics: From 2fb6270f76cc6be266879acea7bb884f95ef0f5d Mon Sep 17 00:00:00 2001 From: n-kall Date: Sat, 28 Oct 2023 12:19:04 +0300 Subject: [PATCH 12/13] typofix in nested rhat docs --- R/nested_rhat.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/nested_rhat.R b/R/nested_rhat.R index e80af4aa..f1b81901 100644 --- a/R/nested_rhat.R +++ b/R/nested_rhat.R @@ -13,7 +13,7 @@ #' @template args-methods-dots #' #' @details Nested Rhat is a convergence diagnostic useful when -#' running many short chains. It calculated on superchains, which +#' running many short chains. It is calculated on superchains, which #' are groups of chains that have been initialized at the same #' point. #' From 32f97c450a6a28a3f238b033e3a6b75766ecea73 Mon Sep 17 00:00:00 2001 From: n-kall Date: Sat, 28 Oct 2023 12:19:40 +0300 Subject: [PATCH 13/13] render docs with typofix --- man/rhat_nested.Rd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/man/rhat_nested.Rd b/man/rhat_nested.Rd index 4ad8f3dd..2e23242d 100644 --- a/man/rhat_nested.Rd +++ b/man/rhat_nested.Rd @@ -49,7 +49,7 @@ variable as proposed in Margossian et al. (2023). } \details{ Nested Rhat is a convergence diagnostic useful when -running many short chains. It calculated on superchains, which +running many short chains. It is calculated on superchains, which are groups of chains that have been initialized at the same point.