Skip to content

Commit

Permalink
Merge branch 'stan-dev:master' into autothin
Browse files Browse the repository at this point in the history
  • Loading branch information
n-kall authored Jan 10, 2024
2 parents 8e34f21 + 1210f4a commit bb5996a
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 63 deletions.
9 changes: 9 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -223,10 +223,16 @@ S3method(order_draws,draws_list)
S3method(order_draws,draws_matrix)
S3method(order_draws,draws_rvars)
S3method(order_draws,rvar)
S3method(pareto_convergence_rate,default)
S3method(pareto_convergence_rate,rvar)
S3method(pareto_diags,default)
S3method(pareto_diags,rvar)
S3method(pareto_khat,default)
S3method(pareto_khat,rvar)
S3method(pareto_khat_threshold,default)
S3method(pareto_khat_threshold,rvar)
S3method(pareto_min_ss,default)
S3method(pareto_min_ss,rvar)
S3method(pareto_smooth,default)
S3method(pareto_smooth,rvar)
S3method(pillar_shaft,rvar)
Expand Down Expand Up @@ -455,8 +461,11 @@ export(ndraws)
export(niterations)
export(nvariables)
export(order_draws)
export(pareto_convergence_rate)
export(pareto_diags)
export(pareto_khat)
export(pareto_khat_threshold)
export(pareto_min_ss)
export(pareto_smooth)
export(quantile2)
export(r_scale)
Expand Down
108 changes: 80 additions & 28 deletions R/pareto_smooth.R
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ pareto_diags.default <- function(x,
extra_diags = TRUE,
verbose = verbose,
smooth_draws = FALSE,
are_log_weights = FALSE,
are_log_weights = are_log_weights,
...)

return(smoothed$diagnostics)
Expand Down Expand Up @@ -181,24 +181,23 @@ pareto_diags.rvar <- function(x, ...) {
#'
#' @template args-pareto
#' @param return_k (logical) Should the Pareto khat be included in
#' output? If `TRUE`, output will be a list containing of smoothed
#' draws and diagnostics. Default is `TRUE`.
#' output? If `TRUE`, output will be a list containing smoothed
#' draws and diagnostics, otherwise it will be a numeric of the
#' smoothed draws. Default is `FALSE`.
#' @param extra_diags (logical) Should extra Pareto khat diagnostics
#' be included in output? If `TRUE`, `min_ss`, `khat_threshold` and
#' `convergence_rate` for the estimated k value will be
#' returned. Default is `FALSE`.
#' @template args-methods-dots
#' @template ref-vehtari-paretosmooth-2022
#' @return Either a vector `x` of smoothed values or a named list
#' containing the vector `x` and a named list `diagnostics` containing Pareto smoothing
#' diagnostics:
#' * `khat`: estimated Pareto k shape parameter, and
#' optionally
#' * `min_ss`: minimum sample size for reliable Pareto
#' smoothed estimate
#' * `khat_threshold`: khat-threshold for reliable
#' containing the vector `x` and a named list `diagnostics`
#' containing Pareto smoothing diagnostics: * `khat`: estimated
#' Pareto k shape parameter, and optionally * `min_ss`: minimum
#' sample size for reliable Pareto smoothed estimate *
#' `khat_threshold`: khat-threshold for reliable Pareto smoothed
#' estimates * `convergence_rate`: Relative convergence rate for
#' Pareto smoothed estimates
#' * `convergence_rate`: Relative convergence rate for Pareto smoothed estimates
#'
#' @seealso [`pareto_khat`] for only calculating khat, and
#' [`pareto_diags`] for additional diagnostics.
Expand All @@ -213,7 +212,7 @@ pareto_smooth <- function(x, ...) UseMethod("pareto_smooth")

#' @rdname pareto_smooth
#' @export
pareto_smooth.rvar <- function(x, return_k = TRUE, extra_diags = FALSE, ...) {
pareto_smooth.rvar <- function(x, return_k = FALSE, extra_diags = FALSE, ...) {

if (extra_diags) {
return_k <- TRUE
Expand Down Expand Up @@ -253,14 +252,14 @@ pareto_smooth.default <- function(x,
tail = c("both", "right", "left"),
r_eff = 1,
ndraws_tail = NULL,
return_k = TRUE,
return_k = FALSE,
extra_diags = FALSE,
verbose = FALSE,
verbose = TRUE,
are_log_weights = FALSE,
...) {

checkmate::expect_numeric(ndraws_tail, null.ok = TRUE)
checkmate::expect_numeric(r_eff, null.ok = TRUE)
checkmate::assert_numeric(ndraws_tail, null.ok = TRUE)
checkmate::assert_numeric(r_eff, null.ok = TRUE)
extra_diags <- as_one_logical(extra_diags)
return_k <- as_one_logical(return_k)
verbose <- as_one_logical(verbose)
Expand Down Expand Up @@ -338,6 +337,7 @@ pareto_smooth.default <- function(x,
x,
ndraws_tail = ndraws_tail,
tail = tail,
are_log_weights = are_log_weights,
...
)
k <- smoothed$k
Expand Down Expand Up @@ -370,6 +370,65 @@ pareto_smooth.default <- function(x,
return(out)
}

#' @rdname pareto_diags
#' @export
pareto_khat_threshold <- function(x, ...) {
UseMethod("pareto_khat_threshold")
}

#' @rdname pareto_diags
#' @export
pareto_khat_threshold.default <- function(x, ...) {
c(khat_threshold = ps_khat_threshold(length(x)))
}

#' @rdname pareto_diags
#' @export
pareto_khat_threshold.rvar <- function(x, ...) {
c(khat_threshold = ps_khat_threshold(ndraws(x)))
}

#' @rdname pareto_diags
#' @export
pareto_min_ss <- function(x, ...) {
UseMethod("pareto_min_ss")
}

#' @rdname pareto_diags
#' @export
pareto_min_ss.default <- function(x, ...) {
k <- pareto_khat(x)$k
c(min_ss = ps_min_ss(k))
}

#' @rdname pareto_diags
#' @export
pareto_min_ss.rvar <- function(x, ...) {
k <- pareto_khat(x)$k
c(min_ss = ps_min_ss(k))
}

#' @rdname pareto_diags
#' @export
pareto_convergence_rate <- function(x, ...) {
UseMethod("pareto_convergence_rate")
}

#' @rdname pareto_diags
#' @export
pareto_convergence_rate.default <- function(x, ...) {
k <- pareto_khat(x)$khat
c(convergence_rate = ps_convergence_rate(k, length(x)))
}

#' @rdname pareto_diags
#' @export
pareto_convergence_rate.rvar <- function(x, ...) {
k <- pareto_khat(x)
c(convergence_rate = ps_convergence_rate(k, ndraws(x)))
}


#' Pareto smooth tail
#' internal function to pareto smooth the tail of a vector
#' @noRd
Expand Down Expand Up @@ -493,7 +552,6 @@ ps_min_ss <- function(k, ...) {
out
}


#' Pareto-smoothing k-hat threshold
#'
#' Given sample size S computes khat threshold for reliable Pareto
Expand Down Expand Up @@ -561,26 +619,20 @@ pareto_k_diagmsg <- function(diags, are_weights = FALSE, ...) {
if (!are_weights) {

if (khat > 1) {
msg <- paste0(msg, "All estimates are unreliable. If the distribution of draws is bounded,\n",
"further draws may improve the estimates, but it is not possible to predict\n",
"whether any feasible sample size is sufficient.")
msg <- paste0(msg, " Mean does not exist, making empirical mean estimate of the draws not applicable.")
} else {
if (khat > khat_threshold) {
msg <- paste0(msg, "S is too small, and sample size larger than ", round(min_ss, 0), " is needed for reliable results.\n")
} else {
msg <- paste0(msg, "To halve the RMSE, approximately ", round(2^(2 / convergence_rate), 1), " times bigger S is needed.\n")
msg <- paste0(msg, " Sample size is too small, for given Pareto k-hat. Sample size larger than ", round(min_ss, 0), " is needed for reliable results.\n")
}
if (khat > 0.7) {
msg <- paste0(msg, "Bias dominates RMSE, and the variance based MCSE is underestimated.\n")
msg <- paste0(msg, " Bias dominates when k-hat > 0.7, making empirical mean estimate of the Pareto-smoothed draws unreliable.\n")
}
}

} else {

if (khat > khat_threshold || khat > 0.7) {
msg <- paste0(msg, "Pareto khat for weights is high (", round(khat, 1) ,"). This indicates a single or few weights dominate.\n", "Inference based on weighted draws will be unreliable.\n")
msg <- paste0(msg, " Pareto khat for weights is high (", round(khat, 1) ,"). This indicates a single or few weights dominate.\n", "Inference based on weighted draws will be unreliable.\n")
}
}
message(msg)
message("Pareto k-hat = ", round(khat, 2), ".", msg)
invisible(diags)
}
4 changes: 2 additions & 2 deletions R/weight_draws.R
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,8 @@ weights.draws <- function(object, log = FALSE, normalize = TRUE, ...) {

# validate weights and return log weights
validate_weights <- function(weights, draws, log = FALSE) {
checkmate::expect_numeric(weights)
checkmate::expect_flag(log)
checkmate::assert_numeric(weights)
checkmate::assert_flag(log)
if (length(weights) != ndraws(draws)) {
stop_no_call("Number of weights must match the number of draws.")
}
Expand Down
27 changes: 27 additions & 0 deletions man/pareto_diags.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

27 changes: 12 additions & 15 deletions man/pareto_smooth.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit bb5996a

Please sign in to comment.