Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exclude draws #336

Merged
merged 11 commits into from
Jan 15, 2024
36 changes: 32 additions & 4 deletions R/draws-index.R
Original file line number Diff line number Diff line change
Expand Up @@ -478,15 +478,18 @@ ndraws.rvar <- function(x) {
# @param regex should 'variables' be treated as regular expressions?
# @param scalar_only should only scalar variables be matched?
check_existing_variables <- function(variables, x, regex = FALSE,
scalar_only = FALSE) {
scalar_only = FALSE, exclude = FALSE) {
check_draws_object(x)
if (is.null(variables)) {
return(NULL)
}

regex <- as_one_logical(regex)
scalar_only <- as_one_logical(scalar_only)
exclude <- as_one_logical(exclude)
variables <- unique(as.character(variables))
all_variables <- variables(x, reserved = TRUE)

if (regex) {
tmp <- named_list(variables)
for (i in seq_along(variables)) {
Expand Down Expand Up @@ -529,6 +532,12 @@ check_existing_variables <- function(variables, x, regex = FALSE,
stop_no_call("The following variables are missing in the draws object: ",
comma(missing_variables))
}

# handle excluding variables for subset_draws
if (exclude) {
variables <- setdiff(all_variables, variables)
}

invisible(variables)
}

Expand Down Expand Up @@ -564,12 +573,13 @@ check_reserved_variables <- function(variables) {

# check validity of iteration indices
# @param unique should the returned IDs be unique?
check_iteration_ids <- function(iteration_ids, x, unique = TRUE) {
check_iteration_ids <- function(iteration_ids, x, unique = TRUE, exclude = FALSE) {
check_draws_object(x)
if (is.null(iteration_ids)) {
return(NULL)
}
unique <- as_one_logical(unique)
exclude <- as_one_logical(exclude)
iteration_ids <- as.integer(iteration_ids)
if (unique) {
iteration_ids <- unique(iteration_ids)
Expand All @@ -584,17 +594,24 @@ check_iteration_ids <- function(iteration_ids, x, unique = TRUE) {
stop_no_call("Tried to subset iterations up to '", max_iteration, "' ",
"but the object only has '", niterations, "' iterations.")
}

# handle exclude iterations in subset_draws
if (exclude) {
iteration_ids <- setdiff(iteration_ids(x), iteration_ids)
}

invisible(iteration_ids)
}

# check validity of chain indices
# @param unique should the returned IDs be unique?
check_chain_ids <- function(chain_ids, x, unique = TRUE) {
check_chain_ids <- function(chain_ids, x, unique = TRUE, exclude = FALSE) {
check_draws_object(x)
if (is.null(chain_ids)) {
return(NULL)
}
unique <- as_one_logical(unique)
exclude <- as_one_logical(exclude)
chain_ids <- as.integer(chain_ids)
if (unique) {
chain_ids <- unique(chain_ids)
Expand All @@ -609,17 +626,23 @@ check_chain_ids <- function(chain_ids, x, unique = TRUE) {
stop_no_call("Tried to subset chains up to '", max_chain, "' ",
"but the object only has '", nchains, "' chains.")
}

if (exclude) {
chain_ids <- setdiff(chain_ids(x), chain_ids)
}

invisible(chain_ids)
}

# check validity of draw indices
# @param unique should the returned IDs be unique?
check_draw_ids <- function(draw_ids, x, unique = TRUE) {
check_draw_ids <- function(draw_ids, x, unique = TRUE, exclude = FALSE) {
check_draws_object(x)
if (is.null(draw_ids)) {
return(NULL)
}
unique <- as_one_logical(unique)
exclude <- as_one_logical(exclude)
draw_ids <- as.integer(draw_ids)
if (unique) {
draw_ids <- unique(draw_ids)
Expand All @@ -634,5 +657,10 @@ check_draw_ids <- function(draw_ids, x, unique = TRUE) {
stop_no_call("Tried to subset draws up to '", max_draw, "' ",
"but the object only has '", ndraws, "' draws.")
}

if (exclude) {
draw_ids <- setdiff(draw_ids(x), draw_ids)
}

invisible(draw_ids)
}
85 changes: 50 additions & 35 deletions R/subset_draws.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,25 @@
#' Subset [`draws`] objects by variables, iterations, chains, and draws indices.
#'
#' @template args-methods-x
#' @param variable (character vector) The variables to select. All elements of
#' non-scalar variables can be selected at once.
#' @param variable (character vector) The variables to select. All
#' elements of non-scalar variables can be selected at once.
#' @param iteration (integer vector) The iteration indices to select.
#' @param chain (integer vector) The chain indices to select.
#' @param draw (integer vector) The draw indices to be select. Subsetting draw
#' indices will lead to an automatic merging of chains via [`merge_chains`].
#' @param draw (integer vector) The draw indices to be
#' select. Subsetting draw indices will lead to an automatic merging
#' of chains via [`merge_chains`].
#' @param regex (logical) Should `variable` should be treated as a
#' (vector of) regular expressions? Any variable in `x` matching at least one
#' of the regular expressions will be selected. Defaults to `FALSE`.
#' @param unique (logical) Should duplicated selection of chains, iterations, or
#' draws be allowed? If `TRUE` (the default) only unique chains, iterations,
#' and draws are selected regardless of how often they appear in the
#' respective selecting arguments.
#' (vector of) regular expressions? Any variable in `x` matching at
#' least one of the regular expressions will be selected. Defaults
#' to `FALSE`.
#' @param unique (logical) Should duplicated selection of chains,
#' iterations, or draws be allowed? If `TRUE` (the default) only
#' unique chains, iterations, and draws are selected regardless of
#' how often they appear in the respective selecting arguments.
#' @param exclude (logical) Should the selected subset be excluded?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not just selection of draws but of variables etc. too.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

#' If `FALSE` (the default) only the selected subset will be
#' returned. If `TRUE` everything but the selected subset will be
#' returned.
#'
#' @template args-methods-dots
#' @template return-draws
Expand Down Expand Up @@ -46,15 +52,16 @@ subset_draws <- function(x, ...) {
#' @export
subset_draws.draws_matrix <- function(x, variable = NULL, iteration = NULL,
chain = NULL, draw = NULL, regex = FALSE,
unique = TRUE, ...) {
unique = TRUE, exclude = FALSE, ...) {
if (all_null(variable, iteration, chain, draw)) {
return(x)
}
x <- repair_draws(x)
variable <- check_existing_variables(variable, x, regex = regex)
iteration <- check_iteration_ids(iteration, x, unique = unique)
chain <- check_chain_ids(chain, x, unique = unique)
draw <- check_draw_ids(draw, x, unique = unique)
variable <- check_existing_variables(variable, x, regex = regex, exclude = exclude)
iteration <- check_iteration_ids(iteration, x, unique = unique, exclude = exclude)
chain <- check_chain_ids(chain, x, unique = unique, exclude = exclude)
draw <- check_draw_ids(draw, x, unique = unique, exclude = exclude)

x <- prepare_subsetting(x, iteration, chain, draw)
x <- .subset_draws(x, iteration, chain, draw, variable, reserved = TRUE)
if (!is.null(chain) || !is.null(iteration)) {
Expand All @@ -67,15 +74,17 @@ subset_draws.draws_matrix <- function(x, variable = NULL, iteration = NULL,
#' @export
subset_draws.draws_array <- function(x, variable = NULL, iteration = NULL,
chain = NULL, draw = NULL, regex = FALSE,
unique = TRUE, ...) {
unique = TRUE, exclude = FALSE, ...) {
if (all_null(variable, iteration, chain, draw)) {
return(x)
}

x <- repair_draws(x)
variable <- check_existing_variables(variable, x, regex = regex)
iteration <- check_iteration_ids(iteration, x, unique = unique)
chain <- check_chain_ids(chain, x, unique = unique)
draw <- check_draw_ids(draw, x, unique = unique)
variable <- check_existing_variables(variable, x, regex = regex, exclude = exclude)
iteration <- check_iteration_ids(iteration, x, unique = unique, exclude = exclude)
chain <- check_chain_ids(chain, x, unique = unique, exclude = exclude)
draw <- check_draw_ids(draw, x, unique = unique, exclude = exclude)

x <- prepare_subsetting(x, iteration, chain, draw)
if (!is.null(draw)) {
iteration <- draw
Expand All @@ -91,16 +100,18 @@ subset_draws.draws_array <- function(x, variable = NULL, iteration = NULL,
#' @export
subset_draws.draws_df <- function(x, variable = NULL, iteration = NULL,
chain = NULL, draw = NULL, regex = FALSE,
unique = TRUE, ...) {
unique = TRUE, exclude = FALSE, ...) {
if (all_null(variable, iteration, chain, draw)) {
return(x)
}

x <- repair_draws(x)
unique <- as_one_logical(unique)
variable <- check_existing_variables(variable, x, regex = regex)
iteration <- check_iteration_ids(iteration, x, unique = unique)
chain <- check_chain_ids(chain, x, unique = unique)
draw <- check_draw_ids(draw, x, unique = unique)
variable <- check_existing_variables(variable, x, regex = regex, exclude= exclude)
iteration <- check_iteration_ids(iteration, x, unique = unique, exclude= exclude)
chain <- check_chain_ids(chain, x, unique = unique, exclude= exclude)
draw <- check_draw_ids(draw, x, unique = unique, exclude= exclude)

x <- prepare_subsetting(x, iteration, chain, draw)
x <- .subset_draws(
x, iteration, chain, draw, variable,
Expand All @@ -113,15 +124,17 @@ subset_draws.draws_df <- function(x, variable = NULL, iteration = NULL,
#' @export
subset_draws.draws_list <- function(x, variable = NULL, iteration = NULL,
chain = NULL, draw = NULL, regex = FALSE,
unique = TRUE, ...) {
unique = TRUE, exclude = FALSE, ...) {
if (all_null(variable, iteration, chain, draw)) {
return(x)
}

x <- repair_draws(x)
variable <- check_existing_variables(variable, x, regex = regex)
iteration <- check_iteration_ids(iteration, x, unique = unique)
chain <- check_chain_ids(chain, x, unique = unique)
draw <- check_draw_ids(draw, x, unique = unique)
variable <- check_existing_variables(variable, x, regex = regex, exclude = exclude)
iteration <- check_iteration_ids(iteration, x, unique = unique, exclude = exclude)
chain <- check_chain_ids(chain, x, unique = unique, exclude = exclude)
draw <- check_draw_ids(draw, x, unique = unique, exclude = exclude)

x <- prepare_subsetting(x, iteration, chain, draw)
if (!is.null(draw)) {
iteration <- draw
Expand All @@ -137,15 +150,17 @@ subset_draws.draws_list <- function(x, variable = NULL, iteration = NULL,
#' @export
subset_draws.draws_rvars <- function(x, variable = NULL, iteration = NULL,
chain = NULL, draw = NULL, regex = FALSE,
unique = TRUE, ...) {
unique = TRUE, exclude = FALSE, ...) {
if (all_null(variable, iteration, chain, draw)) {
return(x)
}

x <- repair_draws(x)
variable <- check_existing_variables(variable, x, regex = regex)
iteration <- check_iteration_ids(iteration, x, unique = unique)
chain <- check_chain_ids(chain, x, unique = unique)
draw <- check_draw_ids(draw, x, unique = unique)
variable <- check_existing_variables(variable, x, regex = regex, exclude = exclude)
iteration <- check_iteration_ids(iteration, x, unique = unique, exclude= exclude)
chain <- check_chain_ids(chain, x, unique = unique, exclude= exclude)
draw <- check_draw_ids(draw, x, unique = unique, exclude= exclude)

x <- prepare_subsetting(x, iteration, chain, draw)
if (!is.null(draw)) {
iteration <- draw
Expand Down
33 changes: 22 additions & 11 deletions man/subset_draws.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading