Skip to content

Commit

Permalink
fix: fix compile errors [skip ci]
Browse files Browse the repository at this point in the history
  • Loading branch information
eitsupi committed Nov 17, 2024
1 parent 2a9bd71 commit f0a6b6f
Show file tree
Hide file tree
Showing 8 changed files with 85 additions and 46 deletions.
14 changes: 9 additions & 5 deletions src/rust/src/arrow_interop/to_rust.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,28 +80,32 @@ pub unsafe fn to_rust_df(rb: Robj) -> Result<pl::DataFrame, String> {
// for instance utf8 -> large-utf8
// dict encoded to categorical

let series_vec = if run_parallel {
let columns = if run_parallel {
POOL.install(|| {
arrays_vec
.into_par_iter()
.zip(names.par_iter())
.map(|(arr, name)| {
let s =
Series::try_from((name.clone(), arr)).map_err(|err| err.to_string())?;
let s = Series::try_from((name.clone(), arr))
.map_err(|err| err.to_string())?
.into_column();
Ok(s)
})
.collect::<Result<Vec<_>, String>>()
})
} else {
let iter = arrays_vec.into_iter().zip(names.iter()).map(|(arr, name)| {
let s = Series::try_from((name.clone(), arr)).map_err(|err| err.to_string())?;
let s = Series::try_from((name.clone(), arr))
.map_err(|err| err.to_string())?
.into_column();
Ok(s)
});
crate::utils::collect_hinted_result(n_columns, iter)
}?;

// no need to check as a record batch has the same guarantees
let df_res: Result<_, String> = Ok(DataFrame::new_no_checks(series_vec));
let df_res: Result<_, String> =
Ok(unsafe { DataFrame::new_no_checks_height_from_first(columns) });
df_res
});
let dfs = crate::utils::collect_hinted_result(rb_len, dfs_iter)?;
Expand Down
34 changes: 18 additions & 16 deletions src/rust/src/lazy/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ use pl::{Duration, IntoColumn, RollingGroupOptions, SetOperation, TemporalMethod
use polars::lazy::dsl;
use polars::prelude as pl;
use polars::prelude::{ExprEvalExtension, SortOptions};
use std::any::Any;
use std::ops::{Add, Div, Mul, Rem, Sub};
use std::result::Result;
pub type NameGenerator = pl::Arc<dyn Fn(usize) -> String + Send + Sync>;
Expand Down Expand Up @@ -633,7 +632,7 @@ impl RPolarsExpr {
weights,
min_periods,
center,
Some(Arc::new(pl::RollingVarParams { ddof }) as Arc<dyn Any + Send + Sync>),
Some(pl::RollingFnParams::Var(pl::RollingVarParams { ddof })),
)?)
.into())
}
Expand All @@ -657,7 +656,7 @@ impl RPolarsExpr {
window_size,
min_periods,
closed,
Some(Arc::new(pl::RollingVarParams { ddof }) as Arc<dyn Any + Send + Sync>),
Some(pl::RollingFnParams::Var(pl::RollingVarParams { ddof })),
)?,
)
.into())
Expand All @@ -682,7 +681,7 @@ impl RPolarsExpr {
weights,
min_periods,
center,
Some(Arc::new(pl::RollingVarParams { ddof }) as Arc<dyn Any + Send + Sync>),
Some(pl::RollingFnParams::Var(pl::RollingVarParams { ddof })),
)?)
.into())
}
Expand All @@ -706,7 +705,7 @@ impl RPolarsExpr {
window_size,
min_periods,
closed,
Some(Arc::new(pl::RollingVarParams { ddof }) as Arc<dyn Any + Send + Sync>),
Some(pl::RollingFnParams::Var(pl::RollingVarParams { ddof })),
)?,
)
.into())
Expand Down Expand Up @@ -1973,12 +1972,15 @@ impl RPolarsExpr {
pub fn map_batches(&self, lambda: Robj, output_type: Robj, agg_list: Robj) -> RResult<Self> {
// define closure how to request R code evaluated in main thread from a some polars sub thread
let par_fn = ParRObj(lambda);
let f = move |s: pl::Series| {
let f = move |col: pl::Column| {
let thread_com = ThreadCom::try_from_global(&CONFIG)
.expect("polars was thread could not initiate ThreadCommunication to R");
thread_com.send(RFnSignature::FnSeriesToSeries(par_fn.clone(), s));
thread_com.send(RFnSignature::FnSeriesToSeries(
par_fn.clone(),
col.as_materialized_series().clone(),
));
let s = thread_com.recv().unwrap_series();
Ok(Some(s))
Ok(Some(s.into_column()))
};

// set expected type of output from R function
Expand Down Expand Up @@ -2007,12 +2009,12 @@ impl RPolarsExpr {
) -> RResult<Self> {
let raw_func = crate::rbackground::serialize_robj(lambda).unwrap();

let rbgfunc = move |s| {
let rbgfunc = move |col: pl::Column| {
crate::RBGPOOL
.rmap_series(raw_func.clone(), s)
.rmap_series(raw_func.clone(), col.as_materialized_series().clone())
.map_err(rpolars_to_polars_err)?()
.map_err(rpolars_to_polars_err)
.map(Some)
.map(|s| Some(s.into_column()))
};

let ot = robj_to!(Option, PLPolarsDataType, output_type)?;
Expand Down Expand Up @@ -2040,12 +2042,12 @@ impl RPolarsExpr {
) -> Self {
let raw_func = crate::rbackground::serialize_robj(lambda).unwrap();

let rbgfunc = move |s| {
let rbgfunc = move |column: pl::Column| {
crate::RBGPOOL
.rmap_series(raw_func.clone(), s)
.rmap_series(raw_func.clone(), column.as_materialized_series().clone())
.map_err(rpolars_to_polars_err)?()
.map_err(rpolars_to_polars_err)
.map(Some)
.map(|s| Some(s.into_column()))
};

let ot = null_to_opt(output_type).map(|rdt| rdt.0.clone());
Expand Down Expand Up @@ -2821,7 +2823,7 @@ pub fn make_rolling_options_fixed_window(
weights: Robj,
min_periods: Robj,
center: Robj,
fn_params: Option<Arc<dyn Any + Send + Sync>>,
fn_params: Option<pl::RollingFnParams>,
) -> RResult<pl::RollingOptionsFixedWindow> {
Ok(pl::RollingOptionsFixedWindow {
window_size: robj_to!(usize, window_size)?,
Expand All @@ -2836,7 +2838,7 @@ pub fn make_rolling_options_dynamic_window(
window_size: &str,
min_periods: Robj,
closed_window: Robj,
fn_params: Option<Arc<dyn Any + Send + Sync>>,
fn_params: Option<pl::RollingFnParams>,
) -> RResult<pl::RollingOptionsDynamicWindow> {
Ok(pl::RollingOptionsDynamicWindow {
window_size: Duration::parse(window_size),
Expand Down
4 changes: 4 additions & 0 deletions src/rust/src/rbackground.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,10 @@ pub fn deserialize_series(bits: &[u8]) -> RResult<pl::Series> {
let tn = std::any::type_name::<pl::Series>();
deserialize_dataframe(bits, None, None)?
.get_columns()
.to_vec()
.into_iter()
.map(|c| c.take_materialized_series())
.collect::<Vec<_>>()
.split_first()
.ok_or(RPolarsErr::new())
.mistyped(tn)
Expand Down
46 changes: 30 additions & 16 deletions src/rust/src/rdataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,23 @@ use polars_core::utils::arrow;
use polars::frame::explode::UnpivotArgsIR;
use polars::prelude::pivot::{pivot, pivot_stable};

pub struct OwnedDataFrameIterator {
pub struct DataFrameStreamIterator {
columns: Vec<polars::series::Series>,
data_type: arrow::datatypes::ArrowDataType,
idx: usize,
n_chunks: usize,
compat_level: CompatLevel,
}

impl OwnedDataFrameIterator {
impl DataFrameStreamIterator {
pub fn new(df: polars::frame::DataFrame, compat_level: CompatLevel) -> Self {
let schema = df.schema().to_arrow(compat_level);
// TODO: changed when bumping to 0.43.1, might need refactor
let data_type = ArrowDataType::Struct(schema.iter_values().map(|x| x.clone()).collect());
let vs = df.get_columns().to_vec();
let data_type = ArrowDataType::Struct(schema.into_iter_values().collect());
let vs = df
.get_columns()
.iter()
.map(|v| v.as_materialized_series().clone())
.collect();
Self {
columns: vs,
data_type,
Expand All @@ -52,7 +55,7 @@ impl OwnedDataFrameIterator {
}
}

impl Iterator for OwnedDataFrameIterator {
impl Iterator for DataFrameStreamIterator {
type Item = Result<Box<dyn arrow::array::Array>, PolarsError>;

fn next(&mut self) -> Option<Self::Item> {
Expand All @@ -64,14 +67,14 @@ impl Iterator for OwnedDataFrameIterator {
.columns
.iter()
.map(|s| s.to_arrow(self.idx, self.compat_level))
.collect();
.collect::<Vec<_>>();
self.idx += 1;

let chunk = arrow::record_batch::RecordBatch::new(batch_cols);
let array = arrow::array::StructArray::new(
self.data_type.clone(),
chunk.into_arrays(),
std::option::Option::None,
batch_cols[0].len(),
batch_cols,
None,
);
Some(std::result::Result::Ok(Box::new(array)))
}
Expand Down Expand Up @@ -139,8 +142,8 @@ impl RPolarsDataFrame {

//internal use
pub fn new_with_capacity(capacity: i32) -> Self {
let empty_series: Vec<pl::Series> = Vec::with_capacity(capacity as usize);
RPolarsDataFrame(pl::DataFrame::new(empty_series).unwrap())
let empty_cols: Vec<pl::Column> = Vec::with_capacity(capacity as usize);
RPolarsDataFrame(pl::DataFrame::new(empty_cols).unwrap())
}

//internal use
Expand Down Expand Up @@ -201,7 +204,13 @@ impl RPolarsDataFrame {
}

pub fn get_columns(&self) -> List {
let cols = self.0.get_columns().to_vec();
let cols = self
.0
.get_columns()
.to_vec()
.into_iter()
.map(|c| c.take_materialized_series())
.collect();
let vec = unsafe { std::mem::transmute::<Vec<pl::Series>, Vec<RPolarsSeries>>(cols) };
List::from_values(vec)
}
Expand Down Expand Up @@ -296,14 +305,19 @@ impl RPolarsDataFrame {
let expr_result = {
self.0
.select_at_idx(idx as usize)
.map(|s| RPolarsSeries(s.clone()))
.map(|s| RPolarsSeries(s.as_materialized_series().clone()))
.ok_or_else(|| format!("select_at_idx: no series found at idx {:?}", idx))
};
r_result_list(expr_result)
}

pub fn drop_in_place(&mut self, names: &str) -> RPolarsSeries {
RPolarsSeries(self.0.drop_in_place(names).unwrap())
RPolarsSeries(
self.0
.drop_in_place(names)
.unwrap()
.take_materialized_series(),
)
}

pub fn select(&self, exprs: Robj) -> RResult<Self> {
Expand Down Expand Up @@ -355,7 +369,7 @@ impl RPolarsDataFrame {
let data_type = ArrowDataType::Struct(schema.iter_values().map(|x| x.clone()).collect());
let field = ArrowField::new("".into(), data_type, false);

let iter_boxed = Box::new(OwnedDataFrameIterator::new(self.0.clone(), compat_level));
let iter_boxed = Box::new(DataFrameStreamIterator::new(self.0.clone(), compat_level));
let mut stream = arrow::ffi::export_iterator(iter_boxed, field);
let stream_out_ptr_addr: usize = stream_ptr.parse().unwrap();
let stream_out_ptr = stream_out_ptr_addr as *mut arrow::ffi::ArrowArrayStream;
Expand Down
6 changes: 6 additions & 0 deletions src/rust/src/rdataframe/read_parquet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@ pub fn new_from_parquet(
use_statistics: Robj,
low_memory: Robj,
hive_partitioning: Robj,
schema: Robj,
hive_schema: Robj,
try_parse_hive_dates: Robj,
glob: Robj,
include_file_paths: Robj,
allow_missing_columns: Robj,
//retries: Robj // not supported yet, with CloudOptions
) -> RResult<RPolarsLazyFrame> {
let path = robj_to!(String, path)?;
Expand All @@ -41,6 +43,8 @@ pub fn new_from_parquet(
schema: robj_to!(Option, WrapSchema, hive_schema)?.map(|x| Arc::new(x.0)),
try_parse_dates: robj_to!(bool, try_parse_hive_dates)?,
};
let schema = robj_to!(Option, WrapSchema, schema)?;
let allow_missing_columns = robj_to!(bool, allow_missing_columns)?;
let args = pl::ScanArgsParquet {
n_rows: robj_to!(Option, usize, n_rows)?,
cache: robj_to!(bool, cache)?,
Expand All @@ -50,9 +54,11 @@ pub fn new_from_parquet(
low_memory: robj_to!(bool, low_memory)?,
cloud_options,
use_statistics: robj_to!(bool, use_statistics)?,
schema: schema.map(|x| Arc::new(x.0)),
hive_options,
glob: robj_to!(bool, glob)?,
include_file_paths: robj_to!(Option, String, include_file_paths)?.map(|x| x.into()),
allow_missing_columns,
};

pl::LazyFrame::scan_parquet(path, args)
Expand Down
6 changes: 3 additions & 3 deletions src/rust/src/rdatatype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ pub fn literal_to_any_value(litval: pl::LiteralValue) -> RResult<pl::AnyValue<'s

pub fn expr_to_any_value(e: pl::Expr) -> std::result::Result<pl::AnyValue<'static>, String> {
use pl::*;
pl::DataFrame::default()
let av = pl::DataFrame::default()
.lazy()
.select(&[e])
.collect()
Expand All @@ -491,8 +491,8 @@ pub fn expr_to_any_value(e: pl::Expr) -> std::result::Result<pl::AnyValue<'stati
.iter()
.next()
.ok_or_else(|| String::from("series had no first value"))?
.into_static()
.map_err(|err| err.to_string())
.into_static();
Ok(av)
}

pub fn robj_to_width_strategy(robj: Robj) -> RResult<pl::ListToStructWidthStrategy> {
Expand Down
13 changes: 9 additions & 4 deletions src/rust/src/rlib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use extendr_api::prelude::*;
use polars::chunked_array::ops::SortMultipleOptions;
use polars::lazy::dsl;
use polars::prelude as pl;
use polars::prelude::IntoColumn;
use std::result::Result;

#[extendr]
Expand Down Expand Up @@ -236,12 +237,16 @@ fn test_robj_to_rchoice(robj: Robj) -> RResult<String> {
#[extendr]
fn fold(acc: Robj, lambda: Robj, exprs: Robj) -> RResult<RPolarsExpr> {
let par_fn = ParRObj(lambda);
let f = move |acc: pl::Series, x: pl::Series| {
let f = move |acc: pl::Column, x: pl::Column| {
let thread_com = ThreadCom::try_from_global(&CONFIG)
.map_err(|err| pl::polars_err!(ComputeError: err))?;
thread_com.send(RFnSignature::FnTwoSeriesToSeries(par_fn.clone(), acc, x));
thread_com.send(RFnSignature::FnTwoSeriesToSeries(
par_fn.clone(),
acc.as_materialized_series().clone(),
x.as_materialized_series().clone(),
));
let s = thread_com.recv().unwrap_series();
Ok(Some(s))
Ok(Some(s.into_column()))
};
Ok(pl::fold_exprs(robj_to!(PLExpr, acc)?, f, robj_to!(Vec, PLExpr, exprs)?).into())
}
Expand All @@ -258,7 +263,7 @@ fn reduce(lambda: Robj, exprs: Robj) -> RResult<RPolarsExpr> {
x.take_materialized_series(),
));
let s = thread_com.recv().unwrap_series();
Ok(Some(s))
Ok(Some(s.into_column()))
};
Ok(pl::reduce_exprs(f, robj_to!(Vec, PLExpr, exprs)?).into())
}
Expand Down
8 changes: 6 additions & 2 deletions src/rust/src/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use extendr_api::{extendr, prelude::*, rprintln};
use pl::SeriesMethods;
use polars::datatypes::*;
use polars::prelude as pl;
use polars::prelude::{ArgAgg, IntoSeries};
use polars::prelude::{ArgAgg, IntoColumn, IntoSeries};
use polars_core::series::IsSorted;
pub const R_INT_NA_ENC: i32 = -2147483648;
use crate::rpolarserr::polars_to_rpolars_err;
Expand Down Expand Up @@ -81,7 +81,11 @@ impl From<&RPolarsExpr> for pl::PolarsResult<RPolarsSeries> {
.map(|df| {
df.select_at_idx(0)
.cloned()
.unwrap_or_else(|| pl::Series::new_empty("".into(), &pl::DataType::Null))
.unwrap_or_else(|| {
pl::Series::new_empty("".into(), &pl::DataType::Null).into_column()
})
.as_materialized_series()
.clone()
.into()
})
}
Expand Down

0 comments on commit f0a6b6f

Please sign in to comment.