Skip to content

Commit

Permalink
feat: <DataFrame>$partition_by()
Browse files Browse the repository at this point in the history
  • Loading branch information
eitsupi committed Mar 9, 2024
1 parent e9d96ac commit c658187
Show file tree
Hide file tree
Showing 12 changed files with 336 additions and 34 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -117,5 +117,5 @@ Collate:
'zzz.R'
Config/rextendr/version: 0.3.1
VignetteBuilder: knitr
Config/polars/LibVersion: 0.38.0
Config/polars/LibVersion: 0.38.1
Config/polars/RustToolchainVersion: nightly-2024-02-23
109 changes: 109 additions & 0 deletions R/dataframe__frame.R
Original file line number Diff line number Diff line change
Expand Up @@ -884,6 +884,8 @@ DataFrame_filter = function(...) {
#' @details Within each group, the order of the rows is always preserved,
#' regardless of the `maintain_order` argument.
#' @return [GroupBy][GroupBy_class] (a DataFrame with special groupby methods like `$agg()`)
#' @seealso
#' - [`<DataFrame>$partition_by()`][DataFrame_partition_by]
#' @examples
#' df = pl$DataFrame(
#' a = c("a", "b", "a", "b", "c"),
Expand Down Expand Up @@ -2093,3 +2095,110 @@ DataFrame_group_by_dynamic = function(
by, start_by, check_sorted
)
}


# TODO: support selectors
#' Split a DataFrame into multiple DataFrames
#'
#' Similar to [`$group_by()`][DataFrame_group_by].
#' Group by the given columns and return the groups as separate [DataFrames][DataFrame_class].
#' It is useful to use this in combination with functions like [lapply()] or `purrr::map()`.
#' @param ... Characters of column names to group by. Passed to [`pl$col()`][pl_col].
#' @param maintain_order If `TRUE`, ensure that the order of the groups is consistent with the input data.
#' This is slower than a default partition by operation.
#' @param include_key If `TRUE`, include the columns used to partition the DataFrame in the output.
#' @param as_nested_list This affects the format of the output.
#' If `FALSE` (default), the output is a flat [list] of [DataFrames][DataFrame_class].
#' IF `TRUE` and one of the `maintain_order` or `include_key` argument is `TRUE`,
#' and the each elements of the output has two children: `key` and `data`.
#' See the examples for more details.
#' @return A list of [DataFrames][DataFrame_class]. See the examples for details.
#' @seealso
#' - [`<DataFrame>$group_by()`][DataFrame_group_by]
#' @examples
#' df = pl$DataFrame(
#' a = c("a", "b", "a", "b", "c"),
#' b = c(1, 2, 1, 3, 3),
#' c = c(5, 4, 3, 2, 1)
#' )
#' df
#'
#' # Pass a single column name to partition by that column.
#' df$partition_by("a")
#'
#' # Partition by multiple columns.
#' df$partition_by("a", "b")
#'
#' # Partition by column data type
#' df$partition_by(pl$String)
#'
#' # If `as_nested_list = TRUE`, the output is a list whose elements have a `key` and a `data` field.
#' # The `key` is a named list of the key values, and the `data` is the DataFrame.
#' df$partition_by("a", "b", as_nested_list = TRUE)
#'
#' # `as_nested_list = TRUE` should be used with `maintain_order = TRUE` or `include_key = TRUE`.
#' tryCatch(
#' df$partition_by("a", "b", maintain_order = FALSE, include_key = FALSE, as_nested_list = TRUE),
#' warning = function(w) w
#' )
#'
#' # Example of using with lapply(), and printing the key and the data sammary
#' df$partition_by("a", "b", maintain_order = FALSE, as_nested_list = TRUE) |>
#' lapply(\(x) {
#' sprintf("The key value of `a` is %s and the key value of `b` is %s", x$key$a, x$key$b) |>
#' cat()
#' cat("\n")
#' x$data$drop(names(x$key))$describe() |>
#' print()
#' invisible(NULL)
#' }) |>
#' invisible()
DataFrame_partition_by = function(
...,
maintain_order = TRUE,
include_key = TRUE,
as_nested_list = FALSE) {
uw = \(res) unwrap(res, "in $partition_by():")

by = result(dots_to_colnames(self, ...)) |>
uw()

if (!length(by)) {
Err_plain("There is no column to partition by.") |>
uw()
}

partitions = .pr$DataFrame$partition_by(self, by, maintain_order, include_key) |>
uw()

if (isTRUE(as_nested_list)) {
if (include_key) {
out = lapply(seq_along(partitions), \(index) {
data = partitions[[index]]
key = data$select(by)$head(1)$to_list()

list(key = key, data = data)
})

return(out)
} else if (maintain_order) {
key_df = self$select(by)$unique(maintain_order = TRUE)
out = lapply(seq_along(partitions), \(index) {
data = partitions[[index]]
key = key_df$slice(index - 1, 1)$to_list()

list(key = key, data = data)
})

return(out)
} else {
warning(
"can not use `$partition_by` with ",
"`maintain_order = FALSE, include_key = FALSE, as_nested_list = TRUE`. ",
"Fall back to a flat list."
)
}
}

partitions
}
16 changes: 13 additions & 3 deletions R/dotdotdot.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ unpack_list = function(..., .context = NULL, .call = sys.call(1L), skip_classes
l = list2(..., .context = .context, .call = .call)
if (
length(l) == 1L &&
is.list(l[[1L]]) &&
!(!is.null(skip_classes) && inherits(l[[1L]], skip_classes)) &&
is.null(names(l))
is.list(l[[1L]]) &&
!(!is.null(skip_classes) && inherits(l[[1L]], skip_classes)) &&
is.null(names(l))
) {
l[[1L]]
} else {
Expand Down Expand Up @@ -79,3 +79,13 @@ unpack_bool_expr_result = function(...) {
}
})
}


#' Convert dots to a character vector of column names
#' @param .df [RPolarsDataFrame]
#' @param ... Arguments to pass to [`pl$col()`][pl_col]
#' @noRd
dots_to_colnames = function(.df, ..., .call = sys.call(1L)) {
result(pl$DataFrame(schema = .df$schema)$select(pl$col(...))$columns) |>
unwrap(call = .call)
}
2 changes: 2 additions & 0 deletions R/extendr-wrappers.R
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ RPolarsDataFrame$to_struct <- function(name) .Call(wrap__RPolarsDataFrame__to_st

RPolarsDataFrame$unnest <- function(names) .Call(wrap__RPolarsDataFrame__unnest, self, names)

RPolarsDataFrame$partition_by <- function(by, maintain_order, include_keys) .Call(wrap__RPolarsDataFrame__partition_by, self, by, maintain_order, include_keys)

RPolarsDataFrame$export_stream <- function(stream_ptr) invisible(.Call(wrap__RPolarsDataFrame__export_stream, self, stream_ptr))

RPolarsDataFrame$from_arrow_record_batches <- function(rbr) .Call(wrap__RPolarsDataFrame__from_arrow_record_batches, rbr)
Expand Down
5 changes: 5 additions & 0 deletions man/DataFrame_group_by.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

79 changes: 79 additions & 0 deletions man/DataFrame_partition_by.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion src/rust/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion src/rust/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "r-polars"
version = "0.38.0"
version = "0.38.1"
edition = "2021"
rust-version = "1.74.1"
publish = false
Expand Down
20 changes: 20 additions & 0 deletions src/rust/src/rdataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,26 @@ impl RPolarsDataFrame {
self.lazy().unnest(names)?.collect()
}

pub fn partition_by(
&self,
by: Robj,
maintain_order: Robj,
include_keys: Robj,
) -> RResult<List> {
let by = robj_to!(Vec, String, by)?;
let maintain_order = robj_to!(bool, maintain_order)?;
let include_keys = robj_to!(bool, include_keys)?;
let out = if maintain_order {
self.0.clone().partition_by_stable(by, include_keys)
} else {
self.0.partition_by(by, include_keys)
}
.map_err(polars_to_rpolars_err)?;

let vec = unsafe { std::mem::transmute::<Vec<pl::DataFrame>, Vec<RPolarsDataFrame>>(out) };
Ok(List::from_values(vec))
}

pub fn export_stream(&self, stream_ptr: &str) {
let schema = self.0.schema().to_arrow(false);
let data_type = ArrowDataType::Struct(schema.fields);
Expand Down
45 changes: 23 additions & 22 deletions tests/testthat/_snaps/after-wrappers.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,16 +79,16 @@
[21] "group_by_dynamic" "head" "height" "join"
[25] "join_asof" "last" "lazy" "limit"
[29] "max" "mean" "median" "melt"
[33] "min" "n_chunks" "null_count" "pivot"
[37] "print" "quantile" "rechunk" "rename"
[41] "reverse" "rolling" "sample" "schema"
[45] "select" "shape" "shift" "shift_and_fill"
[49] "slice" "sort" "std" "sum"
[53] "tail" "to_data_frame" "to_list" "to_series"
[57] "to_struct" "transpose" "unique" "unnest"
[61] "var" "width" "with_columns" "with_row_count"
[65] "with_row_index" "write_csv" "write_json" "write_ndjson"
[69] "write_parquet"
[33] "min" "n_chunks" "null_count" "partition_by"
[37] "pivot" "print" "quantile" "rechunk"
[41] "rename" "reverse" "rolling" "sample"
[45] "schema" "select" "shape" "shift"
[49] "shift_and_fill" "slice" "sort" "std"
[53] "sum" "tail" "to_data_frame" "to_list"
[57] "to_series" "to_struct" "transpose" "unique"
[61] "unnest" "var" "width" "with_columns"
[65] "with_row_count" "with_row_index" "write_csv" "write_json"
[69] "write_ndjson" "write_parquet"

---

Expand All @@ -104,18 +104,19 @@
[13] "get_columns" "lazy"
[15] "melt" "n_chunks"
[17] "new_with_capacity" "null_count"
[19] "pivot_expr" "print"
[21] "rechunk" "sample_frac"
[23] "sample_n" "schema"
[25] "select" "select_at_idx"
[27] "set_column_from_robj" "set_column_from_series"
[29] "set_column_names_mut" "shape"
[31] "to_list" "to_list_tag_structs"
[33] "to_list_unwind" "to_struct"
[35] "transpose" "unnest"
[37] "with_columns" "with_row_index"
[39] "write_csv" "write_json"
[41] "write_ndjson" "write_parquet"
[19] "partition_by" "pivot_expr"
[21] "print" "rechunk"
[23] "sample_frac" "sample_n"
[25] "schema" "select"
[27] "select_at_idx" "set_column_from_robj"
[29] "set_column_from_series" "set_column_names_mut"
[31] "shape" "to_list"
[33] "to_list_tag_structs" "to_list_unwind"
[35] "to_struct" "transpose"
[37] "unnest" "with_columns"
[39] "with_row_index" "write_csv"
[41] "write_json" "write_ndjson"
[43] "write_parquet"

# public and private methods of each class GroupBy

Expand Down
Loading

0 comments on commit c658187

Please sign in to comment.