Skip to content

Commit

Permalink
add log_weights option to pareto functions
Browse files Browse the repository at this point in the history
  • Loading branch information
n-kall committed Nov 15, 2023
1 parent f95d847 commit 16843c1
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 18 deletions.
40 changes: 26 additions & 14 deletions R/pareto_smooth.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ pareto_khat.default <- function(x,
r_eff = NULL,
ndraws_tail = NULL,
verbose = FALSE,
log_weights = FALSE,
...) {
smoothed <- pareto_smooth.default(
x,
Expand All @@ -38,6 +39,7 @@ pareto_khat.default <- function(x,
verbose = verbose,
return_k = TRUE,
smooth_draws = FALSE,
log_weights = log_weights,
...)
return(smoothed$diagnostics)
}
Expand Down Expand Up @@ -120,11 +122,12 @@ pareto_diags <- function(x, ...) UseMethod("pareto_diags")
#' @rdname pareto_diags
#' @export
pareto_diags.default <- function(x,
tail = c("both", "right", "left"),
r_eff = NULL,
ndraws_tail = NULL,
verbose = FALSE,
...) {
tail = c("both", "right", "left"),
r_eff = NULL,
ndraws_tail = NULL,
verbose = FALSE,
log_weights = FALSE,
...) {

smoothed <- pareto_smooth.default(
x,
Expand All @@ -135,6 +138,7 @@ pareto_diags.default <- function(x,
extra_diags = TRUE,
verbose = verbose,
smooth_draws = FALSE,
log_weights = FALSE,
...)

return(smoothed$diagnostics)
Expand Down Expand Up @@ -234,8 +238,8 @@ pareto_smooth.rvar <- function(x, return_k = TRUE, extra_diags = FALSE, ...) {
)
}
out <- list(
x = rvar(apply(draws_diags, margins, function(x) x[[1]]$x), nchains = nchains(x)),
diagnostics = diags
x = rvar(apply(draws_diags, margins, function(x) x[[1]]$x), nchains = nchains(x)),
diagnostics = diags
)
} else {
out <- rvar(apply(draws_diags, margins, function(x) x[[1]]), nchains = nchains(x))
Expand All @@ -252,20 +256,26 @@ pareto_smooth.default <- function(x,
return_k = TRUE,
extra_diags = FALSE,
verbose = FALSE,
log_weights = FALSE,
...) {

checkmate::assert_number(ndraws_tail, null.ok = TRUE)
checkmate::assert_number(r_eff, null.ok = TRUE)
checkmate::assert_logical(extra_diags)
checkmate::assert_logical(return_k)
checkmate::assert_logical(verbose)
checkmate::assert_logical(log_weights)

# check for infinite or na values
if (should_return_NA(x)) {
warning_no_call("Input contains infinite or NA values, Pareto smoothing not performed.")
return(list(x = x, diagnostics = NA_real_))
}

if (log_weights) {
tail = "right"
}

tail <- match.arg(tail)
S <- length(x)

Expand Down Expand Up @@ -299,6 +309,7 @@ pareto_smooth.default <- function(x,
x,
ndraws_tail = ndraws_tail,
tail = "left",
log_weights = log_weights,
...
)
left_k <- smoothed$k
Expand All @@ -308,6 +319,7 @@ pareto_smooth.default <- function(x,
x = smoothed$x,
ndraws_tail = ndraws_tail,
tail = "right",
log_weights = log_weights,
...
)
right_k <- smoothed$k
Expand Down Expand Up @@ -358,11 +370,11 @@ pareto_smooth.default <- function(x,
ndraws_tail,
smooth_draws = TRUE,
tail = c("right", "left"),
log = FALSE,
log_weights = FALSE,
...
) {

if (log) {
if (log_weights) {
# shift log values for safe exponentiation
x <- x - max(x)
}
Expand Down Expand Up @@ -395,7 +407,7 @@ pareto_smooth.default <- function(x,
k <- NA
} else {
# save time not sorting since x already sorted
if (log) {
if (log_weights) {
draws_tail <- exp(draws_tail)
cutoff <- exp(cutoff)
}
Expand All @@ -405,7 +417,7 @@ pareto_smooth.default <- function(x,
if (is.finite(k) && smooth_draws) {
p <- (seq_len(ndraws_tail) - 0.5) / ndraws_tail
smoothed <- qgeneralized_pareto(p = p, mu = cutoff, k = k, sigma = sigma)
if (log) {
if (log_weights) {
smoothed <- log(smoothed)
}
} else {
Expand Down Expand Up @@ -467,11 +479,11 @@ pareto_smooth.default <- function(x,
#' @noRd
ps_min_ss <- function(k, ...) {
if (k < 1) {
out <- 10^(1 / (1 - max(0, k)))
out <- 10^(1 / (1 - max(0, k)))
} else {
out <- Inf
out <- Inf
}
out
out
}


Expand Down
6 changes: 5 additions & 1 deletion man-roxygen/args-pareto.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,14 @@
#' @param ndraws_tail (numeric) number of draws for the tail. If
#' `ndraws_tail` is not specified, it will be calculated as
#' ceiling(3 * sqrt(length(x) / r_eff)) if length(x) > 225 and
#' length(x) / 5 otherwise (see Appendix H in Vehtari et al. (2022)).
#' length(x) / 5 otherwise (see Appendix H in Vehtari et
#' al. (2022)).
#' @param r_eff (numeric) relative effective sample size estimate. If
#' `r_eff` is omitted, it will be calculated assuming the draws are
#' from MCMC.
#' @param verbose (logical) Should diagnostic messages be printed? If
#' `TRUE`, messages related to Pareto diagnostics will be
#' printed. Default is `FALSE`.
#' @param log_weights (logical) Are the draws log weights? Default is
#' `FALSE`. If `TRUE` computation will take into account that the
#' draws are log weights, and only right tail will be smoothed.
8 changes: 7 additions & 1 deletion man/pareto_diags.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 7 additions & 1 deletion man/pareto_khat.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 7 additions & 1 deletion man/pareto_smooth.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 16843c1

Please sign in to comment.