Skip to content

Commit

Permalink
Implement pl$arg_where() (#922)
Browse files Browse the repository at this point in the history
  • Loading branch information
etiennebacher authored Mar 15, 2024
1 parent 54c8c18 commit afbb94c
Show file tree
Hide file tree
Showing 8 changed files with 97 additions and 34 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

- New functions `pl$datetime()`, `pl$date()`, and `pl$time()` to easily create
Expr of class datetime, date, and time via columns and literals (#918).
- New function `pl$arg_where()` to get the indices that match a condition (#922).

## Polars R Package 0.15.1

Expand Down
2 changes: 2 additions & 0 deletions R/extendr-wrappers.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ all_horizontal <- function(dotdotdot) .Call(wrap__all_horizontal, dotdotdot)

any_horizontal <- function(dotdotdot) .Call(wrap__any_horizontal, dotdotdot)

arg_where <- function(condition) .Call(wrap__arg_where, condition)

coalesce_exprs <- function(exprs) .Call(wrap__coalesce_exprs, exprs)

datetime <- function(year, month, day, hour, minute, second, microsecond, time_unit, time_zone, ambiguous) .Call(wrap__datetime, year, month, day, hour, minute, second, microsecond, time_unit, time_zone, ambiguous)
Expand Down
16 changes: 16 additions & 0 deletions R/functions__lazy.R
Original file line number Diff line number Diff line change
Expand Up @@ -1177,3 +1177,19 @@ pl_time = function(hour = NULL, minute = NULL, second = NULL, microsecond = NULL
result() |>
unwrap("in pl$time():")
}

#' Return indices that match a condition
#'
#' @param condition An Expr that gives a boolean.
#'
#' @return Expr
#'
#' @examples
#' df = pl$DataFrame(a = c(1, 2, 3, 4, 5))
#' df$select(
#' pl$arg_where(pl$col("a") %% 2 == 0)
#' )
pl_arg_where = function(condition) {
arg_where(condition) |>
unwrap("in $arg_where():")
}
23 changes: 23 additions & 0 deletions man/pl_arg_where.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/pl_pl.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions src/rust/src/rlib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -320,11 +320,17 @@ pub fn datetime(
Ok(polars::lazy::dsl::datetime(args).into())
}

#[extendr]
fn arg_where(condition: Robj) -> RResult<RPolarsExpr> {
Ok(pl::arg_where(robj_to!(PLExpr, condition)?).into())
}

extendr_module! {
mod rlib;

fn all_horizontal;
fn any_horizontal;
fn arg_where;
fn coalesce_exprs;
fn datetime;
fn duration;
Expand Down
67 changes: 34 additions & 33 deletions tests/testthat/_snaps/after-wrappers.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,39 +19,40 @@
[27] "UInt8" "Unknown"
[29] "Utf8" "all"
[31] "all_horizontal" "any_horizontal"
[33] "approx_n_unique" "class_names"
[35] "coalesce" "col"
[37] "concat" "concat_list"
[39] "concat_str" "corr"
[41] "count" "cov"
[43] "date" "date_range"
[45] "datetime" "disable_string_cache"
[47] "dtypes" "duration"
[49] "element" "enable_string_cache"
[51] "expr_to_r" "first"
[53] "fold" "from_epoch"
[55] "get_global_rpool_cap" "head"
[57] "implode" "is_schema"
[59] "last" "len"
[61] "lit" "max"
[63] "max_horizontal" "mean"
[65] "median" "mem_address"
[67] "min" "min_horizontal"
[69] "n_unique" "numeric_dtypes"
[71] "raw_list" "read_csv"
[73] "read_ndjson" "read_parquet"
[75] "reduce" "rolling_corr"
[77] "rolling_cov" "same_outer_dt"
[79] "scan_csv" "scan_ipc"
[81] "scan_ndjson" "scan_parquet"
[83] "select" "set_global_rpool_cap"
[85] "show_all_public_functions" "show_all_public_methods"
[87] "std" "struct"
[89] "sum" "sum_horizontal"
[91] "tail" "thread_pool_size"
[93] "threadpool_size" "time"
[95] "using_string_cache" "var"
[97] "when" "with_string_cache"
[33] "approx_n_unique" "arg_where"
[35] "class_names" "coalesce"
[37] "col" "concat"
[39] "concat_list" "concat_str"
[41] "corr" "count"
[43] "cov" "date"
[45] "date_range" "datetime"
[47] "disable_string_cache" "dtypes"
[49] "duration" "element"
[51] "enable_string_cache" "expr_to_r"
[53] "first" "fold"
[55] "from_epoch" "get_global_rpool_cap"
[57] "head" "implode"
[59] "is_schema" "last"
[61] "len" "lit"
[63] "max" "max_horizontal"
[65] "mean" "median"
[67] "mem_address" "min"
[69] "min_horizontal" "n_unique"
[71] "numeric_dtypes" "raw_list"
[73] "read_csv" "read_ndjson"
[75] "read_parquet" "reduce"
[77] "rolling_corr" "rolling_cov"
[79] "same_outer_dt" "scan_csv"
[81] "scan_ipc" "scan_ndjson"
[83] "scan_parquet" "select"
[85] "set_global_rpool_cap" "show_all_public_functions"
[87] "show_all_public_methods" "std"
[89] "struct" "sum"
[91] "sum_horizontal" "tail"
[93] "thread_pool_size" "threadpool_size"
[95] "time" "using_string_cache"
[97] "var" "when"
[99] "with_string_cache"

---

Expand Down
14 changes: 14 additions & 0 deletions tests/testthat/test-lazy_functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -403,3 +403,17 @@ test_that("pl$date() works", {
# time_floats = pl$time(pl$lit("abc"), -2, 1)
# )
# })

test_that("pl$arg_where() works", {
df = pl$DataFrame(a = c(1, 2, 3, 4, 5))
expect_identical(
df$select(pl$arg_where(pl$col("a") %% 2 == 0))$to_list(),
list(a = c(1, 3))
)

# no matches
expect_identical(
df$select(pl$arg_where(pl$col("a") %% 10 == 0))$to_list(),
list(a = numeric(0))
)
})

0 comments on commit afbb94c

Please sign in to comment.