From 4122d67331845954d563fd4aee1fb7e7d1c023fe Mon Sep 17 00:00:00 2001 From: n-kall Date: Mon, 11 Dec 2023 11:29:17 +0200 Subject: [PATCH 1/4] add exposed individual pareto diagnostics --- R/pareto_smooth.R | 78 +++++++++++++++++++++++++++++++++++--------- man/pareto_smooth.Rd | 6 ++-- 2 files changed, 66 insertions(+), 18 deletions(-) diff --git a/R/pareto_smooth.R b/R/pareto_smooth.R index 52ec6bba..d32fe384 100644 --- a/R/pareto_smooth.R +++ b/R/pareto_smooth.R @@ -213,7 +213,7 @@ pareto_smooth <- function(x, ...) UseMethod("pareto_smooth") #' @rdname pareto_smooth #' @export -pareto_smooth.rvar <- function(x, return_k = TRUE, extra_diags = FALSE, ...) { +pareto_smooth.rvar <- function(x, return_k = FALSE, extra_diags = FALSE, ...) { if (extra_diags) { return_k <- TRUE @@ -253,9 +253,9 @@ pareto_smooth.default <- function(x, tail = c("both", "right", "left"), r_eff = 1, ndraws_tail = NULL, - return_k = TRUE, + return_k = FALSE, extra_diags = FALSE, - verbose = FALSE, + verbose = TRUE, are_log_weights = FALSE, ...) { @@ -370,6 +370,61 @@ pareto_smooth.default <- function(x, return(out) } +#' Threshold for Pareto k-hat diagnostic based on sample size +#' +#' @param x +#' @param ... +#' @return +pareto_khat_threshold <- function(x, ...) { + UseMethod("pareto_khat_threshold") +} + + +pareto_khat_threshold.default <- function(x, ...) { + c(khat_threshold = ps_khat_threshold(length(x))) +} + +pareto_khat_threshold.rvar <- function(x, ...) { + c(khat_threshold = ps_khat_threshold(ndraws(x))) +} + +#' Minimum sample size for Pareto diagnostics +#' +#' @param ... +#' @return +pareto_min_ss <- function(x, ...) { + UseMethod("pareto_min_ss") +} + +pareto_min_ss.default <- function(x, ...) { + k <- pareto_khat(x)$k + c(min_ss = ps_min_ss(k)) +} + +pareto_min_ss.rvar <- function(x, ...) { + k <- pareto_khat(x)$k + c(min_ss = ps_min_ss(k)) +} + +#' Convergence rate based on Pareto diagnostics +#' +#' @param ... +#' @return +pareto_convergence_rate <- function(x, ...) { + UseMethod("pareto_convergence_rate") +} + +pareto_convergence_rate.default <- function(x, ...) { + k <- pareto_khat(x)$khat + c(convergence_rate = ps_convergence_rate(k, length(x))) +} + +pareto_convergence_rate.rvar <- function(x, ...) { + k <- pareto_khat(x) + c(convergence_rate = ps_convergence_rate(k, ndraws(x))) +} + + #' Pareto smooth tail #' internal function to pareto smooth the tail of a vector #' @noRd @@ -493,7 +548,6 @@ ps_min_ss <- function(k, ...) { out } - #' Pareto-smoothing k-hat threshold #' #' Given sample size S computes khat threshold for reliable Pareto @@ -561,26 +615,20 @@ pareto_k_diagmsg <- function(diags, are_weights = FALSE, ...) { if (!are_weights) { if (khat > 1) { - msg <- paste0(msg, "All estimates are unreliable. If the distribution of draws is bounded,\n", - "further draws may improve the estimates, but it is not possible to predict\n", - "whether any feasible sample size is sufficient.") + msg <- paste0(msg, " Mean does not exist, making empirical mean estimate of the draws not applicable.") } else { if (khat > khat_threshold) { - msg <- paste0(msg, "S is too small, and sample size larger than ", round(min_ss, 0), " is needed for reliable results.\n") - } else { - msg <- paste0(msg, "To halve the RMSE, approximately ", round(2^(2 / convergence_rate), 1), " times bigger S is needed.\n") + msg <- paste0(msg, "Sample size is too small, for given Pareto k-hat. Sample size larger than ", round(min_ss, 0), " is needed for reliable results.\n") } if (khat > 0.7) { - msg <- paste0(msg, "Bias dominates RMSE, and the variance based MCSE is underestimated.\n") + msg <- paste0(msg, " Bias dominates when k-hat > 0.7, making empirical mean estimate of the Pareto-smoothed draws unreliable.\n") } } - } else { - if (khat > khat_threshold || khat > 0.7) { - msg <- paste0(msg, "Pareto khat for weights is high (", round(khat, 1) ,"). This indicates a single or few weights dominate.\n", "Inference based on weighted draws will be unreliable.\n") + msg <- paste0(msg, " Pareto khat for weights is high (", round(khat, 1) ,"). This indicates a single or few weights dominate.\n", "Inference based on weighted draws will be unreliable.\n") } } - message(msg) + message("Pareto k-hat = ", round(khat, 2), ". ", msg) invisible(diags) } diff --git a/man/pareto_smooth.Rd b/man/pareto_smooth.Rd index c0e6f017..8d92f8c3 100644 --- a/man/pareto_smooth.Rd +++ b/man/pareto_smooth.Rd @@ -8,16 +8,16 @@ \usage{ pareto_smooth(x, ...) -\method{pareto_smooth}{rvar}(x, return_k = TRUE, extra_diags = FALSE, ...) +\method{pareto_smooth}{rvar}(x, return_k = FALSE, extra_diags = FALSE, ...) \method{pareto_smooth}{default}( x, tail = c("both", "right", "left"), r_eff = 1, ndraws_tail = NULL, - return_k = TRUE, + return_k = FALSE, extra_diags = FALSE, - verbose = FALSE, + verbose = TRUE, are_log_weights = FALSE, ... ) From db811072d2b1841dad3841d0582a3164f47d7517 Mon Sep 17 00:00:00 2001 From: n-kall Date: Mon, 11 Dec 2023 15:53:11 +0200 Subject: [PATCH 2/4] cleanup pareto messages and corresponding tests --- NAMESPACE | 9 +++++ R/pareto_smooth.R | 55 +++++++++++++++-------------- man/pareto_diags.Rd | 27 ++++++++++++++ man/pareto_smooth.Rd | 21 +++++------ tests/testthat/test-pareto_smooth.R | 34 +++++++++--------- 5 files changed, 90 insertions(+), 56 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index 2d81d65d..824c5036 100755 --- a/NAMESPACE +++ b/NAMESPACE @@ -223,10 +223,16 @@ S3method(order_draws,draws_list) S3method(order_draws,draws_matrix) S3method(order_draws,draws_rvars) S3method(order_draws,rvar) +S3method(pareto_convergence_rate,default) +S3method(pareto_convergence_rate,rvar) S3method(pareto_diags,default) S3method(pareto_diags,rvar) S3method(pareto_khat,default) S3method(pareto_khat,rvar) +S3method(pareto_khat_threshold,default) +S3method(pareto_khat_threshold,rvar) +S3method(pareto_min_ss,default) +S3method(pareto_min_ss,rvar) S3method(pareto_smooth,default) S3method(pareto_smooth,rvar) S3method(pillar_shaft,rvar) @@ -455,8 +461,11 @@ export(ndraws) export(niterations) export(nvariables) export(order_draws) +export(pareto_convergence_rate) export(pareto_diags) export(pareto_khat) +export(pareto_khat_threshold) +export(pareto_min_ss) export(pareto_smooth) export(quantile2) export(r_scale) diff --git a/R/pareto_smooth.R b/R/pareto_smooth.R index d32fe384..d064e916 100644 --- a/R/pareto_smooth.R +++ b/R/pareto_smooth.R @@ -181,8 +181,9 @@ pareto_diags.rvar <- function(x, ...) { #' #' @template args-pareto #' @param return_k (logical) Should the Pareto khat be included in -#' output? If `TRUE`, output will be a list containing of smoothed -#' draws and diagnostics. Default is `TRUE`. +#' output? If `TRUE`, output will be a list containing smoothed +#' draws and diagnostics, otherwise it will be a numeric of the +#' smoothed draws. Default is `FALSE`. #' @param extra_diags (logical) Should extra Pareto khat diagnostics #' be included in output? If `TRUE`, `min_ss`, `khat_threshold` and #' `convergence_rate` for the estimated k value will be @@ -190,15 +191,13 @@ pareto_diags.rvar <- function(x, ...) { #' @template args-methods-dots #' @template ref-vehtari-paretosmooth-2022 #' @return Either a vector `x` of smoothed values or a named list -#' containing the vector `x` and a named list `diagnostics` containing Pareto smoothing -#' diagnostics: -#' * `khat`: estimated Pareto k shape parameter, and -#' optionally -#' * `min_ss`: minimum sample size for reliable Pareto -#' smoothed estimate -#' * `khat_threshold`: khat-threshold for reliable +#' containing the vector `x` and a named list `diagnostics` +#' containing Pareto smoothing diagnostics: * `khat`: estimated +#' Pareto k shape parameter, and optionally * `min_ss`: minimum +#' sample size for reliable Pareto smoothed estimate * +#' `khat_threshold`: khat-threshold for reliable Pareto smoothed +#' estimates * `convergence_rate`: Relative convergence rate for #' Pareto smoothed estimates -#' * `convergence_rate`: Relative convergence rate for Pareto smoothed estimates #' #' @seealso [`pareto_khat`] for only calculating khat, and #' [`pareto_diags`] for additional diagnostics. @@ -370,55 +369,59 @@ pareto_smooth.default <- function(x, return(out) } -#' Threshold for Pareto k-hat diagnostic based on sample size -#' -#' @param x -#' @param ... -#' @return +#' @rdname pareto_diags +#' @export pareto_khat_threshold <- function(x, ...) { UseMethod("pareto_khat_threshold") } - +#' @rdname pareto_diags +#' @export pareto_khat_threshold.default <- function(x, ...) { c(khat_threshold = ps_khat_threshold(length(x))) } +#' @rdname pareto_diags +#' @export pareto_khat_threshold.rvar <- function(x, ...) { c(khat_threshold = ps_khat_threshold(ndraws(x))) } -#' Minimum sample size for Pareto diagnostics -#' -#' @param ... -#' @return +#' @rdname pareto_diags +#' @export pareto_min_ss <- function(x, ...) { UseMethod("pareto_min_ss") } +#' @rdname pareto_diags +#' @export pareto_min_ss.default <- function(x, ...) { k <- pareto_khat(x)$k c(min_ss = ps_min_ss(k)) } +#' @rdname pareto_diags +#' @export pareto_min_ss.rvar <- function(x, ...) { k <- pareto_khat(x)$k c(min_ss = ps_min_ss(k)) } -#' Convergence rate based on Pareto diagnostics -#' -#' @param ... -#' @return +#' @rdname pareto_diags +#' @export pareto_convergence_rate <- function(x, ...) { UseMethod("pareto_convergence_rate") } +#' @rdname pareto_diags +#' @export pareto_convergence_rate.default <- function(x, ...) { k <- pareto_khat(x)$khat c(convergence_rate = ps_convergence_rate(k, length(x))) } +#' @rdname pareto_diags +#' @export pareto_convergence_rate.rvar <- function(x, ...) { k <- pareto_khat(x) c(convergence_rate = ps_convergence_rate(k, ndraws(x))) @@ -618,7 +621,7 @@ pareto_k_diagmsg <- function(diags, are_weights = FALSE, ...) { msg <- paste0(msg, " Mean does not exist, making empirical mean estimate of the draws not applicable.") } else { if (khat > khat_threshold) { - msg <- paste0(msg, "Sample size is too small, for given Pareto k-hat. Sample size larger than ", round(min_ss, 0), " is needed for reliable results.\n") + msg <- paste0(msg, " Sample size is too small, for given Pareto k-hat. Sample size larger than ", round(min_ss, 0), " is needed for reliable results.\n") } if (khat > 0.7) { msg <- paste0(msg, " Bias dominates when k-hat > 0.7, making empirical mean estimate of the Pareto-smoothed draws unreliable.\n") @@ -629,6 +632,6 @@ pareto_k_diagmsg <- function(diags, are_weights = FALSE, ...) { msg <- paste0(msg, " Pareto khat for weights is high (", round(khat, 1) ,"). This indicates a single or few weights dominate.\n", "Inference based on weighted draws will be unreliable.\n") } } - message("Pareto k-hat = ", round(khat, 2), ". ", msg) + message("Pareto k-hat = ", round(khat, 2), ".", msg) invisible(diags) } diff --git a/man/pareto_diags.Rd b/man/pareto_diags.Rd index 9a1d5776..46370c49 100644 --- a/man/pareto_diags.Rd +++ b/man/pareto_diags.Rd @@ -4,6 +4,15 @@ \alias{pareto_diags} \alias{pareto_diags.default} \alias{pareto_diags.rvar} +\alias{pareto_khat_threshold} +\alias{pareto_khat_threshold.default} +\alias{pareto_khat_threshold.rvar} +\alias{pareto_min_ss} +\alias{pareto_min_ss.default} +\alias{pareto_min_ss.rvar} +\alias{pareto_convergence_rate} +\alias{pareto_convergence_rate.default} +\alias{pareto_convergence_rate.rvar} \title{Pareto smoothing diagnostics} \usage{ pareto_diags(x, ...) @@ -19,6 +28,24 @@ pareto_diags(x, ...) ) \method{pareto_diags}{rvar}(x, ...) + +pareto_khat_threshold(x, ...) + +\method{pareto_khat_threshold}{default}(x, ...) + +\method{pareto_khat_threshold}{rvar}(x, ...) + +pareto_min_ss(x, ...) + +\method{pareto_min_ss}{default}(x, ...) + +\method{pareto_min_ss}{rvar}(x, ...) + +pareto_convergence_rate(x, ...) + +\method{pareto_convergence_rate}{default}(x, ...) + +\method{pareto_convergence_rate}{rvar}(x, ...) } \arguments{ \item{x}{(multiple options) One of: diff --git a/man/pareto_smooth.Rd b/man/pareto_smooth.Rd index 8d92f8c3..24273139 100644 --- a/man/pareto_smooth.Rd +++ b/man/pareto_smooth.Rd @@ -33,8 +33,9 @@ pareto_smooth(x, ...) \item{...}{Arguments passed to individual methods (if applicable).} \item{return_k}{(logical) Should the Pareto khat be included in -output? If \code{TRUE}, output will be a list containing of smoothed -draws and diagnostics. Default is \code{TRUE}.} +output? If \code{TRUE}, output will be a list containing smoothed +draws and diagnostics, otherwise it will be a numeric of the +smoothed draws. Default is \code{FALSE}.} \item{extra_diags}{(logical) Should extra Pareto khat diagnostics be included in output? If \code{TRUE}, \code{min_ss}, \code{khat_threshold} and @@ -70,17 +71,13 @@ draws are log weights, and only right tail will be smoothed.} } \value{ Either a vector \code{x} of smoothed values or a named list -containing the vector \code{x} and a named list \code{diagnostics} containing Pareto smoothing -diagnostics: -\itemize{ -\item \code{khat}: estimated Pareto k shape parameter, and -optionally -\item \code{min_ss}: minimum sample size for reliable Pareto -smoothed estimate -\item \code{khat_threshold}: khat-threshold for reliable +containing the vector \code{x} and a named list \code{diagnostics} +containing Pareto smoothing diagnostics: * \code{khat}: estimated +Pareto k shape parameter, and optionally * \code{min_ss}: minimum +sample size for reliable Pareto smoothed estimate * +\code{khat_threshold}: khat-threshold for reliable Pareto smoothed +estimates * \code{convergence_rate}: Relative convergence rate for Pareto smoothed estimates -\item \code{convergence_rate}: Relative convergence rate for Pareto smoothed estimates -} } \description{ Smooth the tail draws of x by replacing tail draws by order diff --git a/tests/testthat/test-pareto_smooth.R b/tests/testthat/test-pareto_smooth.R index 6b67d2b0..9994d3df 100644 --- a/tests/testthat/test-pareto_smooth.R +++ b/tests/testthat/test-pareto_smooth.R @@ -56,26 +56,24 @@ test_that("pareto_khat diagnostics messages are as expected", { ) expect_message(pareto_k_diagmsg(diags), - paste0('To halve the RMSE, approximately 4.1 times bigger S is needed.')) + paste0("Pareto k-hat = 0.5.\n")) diags$khat <- 0.6 expect_message(pareto_k_diagmsg(diags), - paste0('S is too small, and sample size larger than 10 is needed for reliable results.\n')) + paste0("Pareto k-hat = 0.6. Sample size is too small, for given Pareto k-hat. Sample size larger than 10 is needed for reliable results.\n")) diags$khat <- 0.71 diags$khat_threshold <- 0.8 expect_message(pareto_k_diagmsg(diags), - paste0('To halve the RMSE, approximately 4.1 times bigger S is needed.\n', 'Bias dominates RMSE, and the variance based MCSE is underestimated.\n')) + paste0("Pareto k-hat = 0.71. Bias dominates when k-hat > 0.7, making empirical mean estimate of the Pareto-smoothed draws unreliable.\n")) diags$khat <- 1.1 expect_message(pareto_k_diagmsg(diags), - paste0('All estimates are unreliable. If the distribution of draws is bounded,\n', - 'further draws may improve the estimates, but it is not possible to predict\n', - 'whether any feasible sample size is sufficient.')) + paste0("Pareto k-hat = 1.1. Mean does not exist, making empirical mean estimate of the draws not applicable.\n")) }) @@ -131,8 +129,8 @@ test_that("pareto_khat functions work with matrix with chains", { expect_equal(pareto_khat(tau_chains, ndraws_tail = 20), pareto_khat(tau_nochains, ndraws_tail = 20)) - ps_chains <- pareto_smooth(tau_chains, ndraws_tail = 20) - ps_nochains <- pareto_smooth(tau_nochains, ndraws_tail = 20) + ps_chains <- pareto_smooth(tau_chains, ndraws_tail = 20, return_k = TRUE) + ps_nochains <- pareto_smooth(tau_nochains, ndraws_tail = 20, return_k = TRUE) expect_equal(as.numeric(ps_chains$x), as.numeric(ps_nochains$x)) @@ -159,22 +157,22 @@ test_that("pareto_khat functions work with rvars with and without chains", { expect_equal(pareto_diags(tau_rvar_chains, ndraws_tail = 20), pareto_diags(tau_rvar_nochains, ndraws_tail = 20)) - ps_chains <- pareto_smooth(tau_chains, ndraws_tail = 20) - ps_rvar_chains <- pareto_smooth(tau_rvar_chains, ndraws_tail = 20) + ps_chains <- pareto_smooth(tau_chains, ndraws_tail = 20, return_k = TRUE) + ps_rvar_chains <- pareto_smooth(tau_rvar_chains, ndraws_tail = 20, return_k = TRUE) - ps_nochains <- pareto_smooth(tau_nochains, ndraws_tail = 20) - ps_rvar_nochains <- pareto_smooth(tau_rvar_nochains, ndraws_tail = 20) + ps_nochains <- pareto_smooth(tau_nochains, ndraws_tail = 20, return_k = TRUE) + ps_rvar_nochains <- pareto_smooth(tau_rvar_nochains, ndraws_tail = 20, return_k = TRUE) expect_equal(ps_rvar_chains$x, rvar(ps_chains$x, with_chains = TRUE)) expect_equal(ps_rvar_nochains$x, rvar(ps_nochains$x)) - ps_chains <- pareto_smooth(tau_chains, ndraws_tail = 20, extra_diags = TRUE) - ps_rvar_chains <- pareto_smooth(tau_rvar_chains, ndraws_tail = 20, extra_diags = TRUE) + ps_chains <- pareto_smooth(tau_chains, ndraws_tail = 20, extra_diags = TRUE, return_k = TRUE) + ps_rvar_chains <- pareto_smooth(tau_rvar_chains, ndraws_tail = 20, extra_diags = TRUE, return_k = TRUE) - ps_nochains <- pareto_smooth(tau_nochains, ndraws_tail = 20, extra_diags = TRUE) - ps_rvar_nochains <- pareto_smooth(tau_rvar_nochains, ndraws_tail = 20, extra_diags = TRUE) + ps_nochains <- pareto_smooth(tau_nochains, ndraws_tail = 20, extra_diags = TRUE, return_k = TRUE) + ps_rvar_nochains <- pareto_smooth(tau_rvar_nochains, ndraws_tail = 20, extra_diags = TRUE, return_k = TRUE) expect_equal(ps_rvar_chains$x, rvar(ps_chains$x, with_chains = TRUE)) @@ -185,7 +183,7 @@ test_that("pareto_khat functions work with rvars with and without chains", { test_that("pareto_smooth returns x with smoothed tail", { tau <- extract_variable_matrix(example_draws(), "tau") - tau_smoothed <- pareto_smooth(tau, ndraws_tail = 10, tail = "right")$x + tau_smoothed <- pareto_smooth(tau, ndraws_tail = 10, tail = "right", return_k = TRUE)$x expect_equal(sort(tau)[1:390], sort(tau_smoothed)[1:390]) @@ -197,7 +195,7 @@ test_that("pareto_smooth works for log_weights", { w <- c(1:25, 1e3, 1e3, 1e3) lw <- log(w) - ps <- pareto_smooth(lw, are_log_weights = TRUE, verbose = FALSE, ndraws_tail = 10) + ps <- pareto_smooth(lw, are_log_weights = TRUE, verbose = FALSE, ndraws_tail = 10, return_k = TRUE) # only right tail is smoothed expect_equal(ps$x[1:15], lw[1:15]) From 27153ff02b116325e2700f972677b41377f345d8 Mon Sep 17 00:00:00 2001 From: n-kall Date: Thu, 14 Dec 2023 14:27:29 +0200 Subject: [PATCH 3/4] replace checkmate::expect_* with checkmate::assert_* outside tests --- R/pareto_smooth.R | 4 ++-- R/weight_draws.R | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/R/pareto_smooth.R b/R/pareto_smooth.R index 52ec6bba..4593ce98 100644 --- a/R/pareto_smooth.R +++ b/R/pareto_smooth.R @@ -259,8 +259,8 @@ pareto_smooth.default <- function(x, are_log_weights = FALSE, ...) { - checkmate::expect_numeric(ndraws_tail, null.ok = TRUE) - checkmate::expect_numeric(r_eff, null.ok = TRUE) + checkmate::assert_numeric(ndraws_tail, null.ok = TRUE) + checkmate::assert_numeric(r_eff, null.ok = TRUE) extra_diags <- as_one_logical(extra_diags) return_k <- as_one_logical(return_k) verbose <- as_one_logical(verbose) diff --git a/R/weight_draws.R b/R/weight_draws.R index 34494820..fa8bfd8b 100644 --- a/R/weight_draws.R +++ b/R/weight_draws.R @@ -179,8 +179,8 @@ weights.draws <- function(object, log = FALSE, normalize = TRUE, ...) { # validate weights and return log weights validate_weights <- function(weights, draws, log = FALSE) { - checkmate::expect_numeric(weights) - checkmate::expect_flag(log) + checkmate::assert_numeric(weights) + checkmate::assert_flag(log) if (length(weights) != ndraws(draws)) { stop_no_call("Number of weights must match the number of draws.") } From e0c29e46403469ea6ad554607bbe62c5108e767f Mon Sep 17 00:00:00 2001 From: n-kall Date: Tue, 9 Jan 2024 15:05:50 +0200 Subject: [PATCH 4/4] fix to log weight smoothing --- R/pareto_smooth.R | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/R/pareto_smooth.R b/R/pareto_smooth.R index 7da95a75..a0bb8486 100644 --- a/R/pareto_smooth.R +++ b/R/pareto_smooth.R @@ -138,7 +138,7 @@ pareto_diags.default <- function(x, extra_diags = TRUE, verbose = verbose, smooth_draws = FALSE, - are_log_weights = FALSE, + are_log_weights = are_log_weights, ...) return(smoothed$diagnostics) @@ -337,6 +337,7 @@ pareto_smooth.default <- function(x, x, ndraws_tail = ndraws_tail, tail = tail, + are_log_weights = are_log_weights, ... ) k <- smoothed$k