Skip to content

Commit

Permalink
Feat: Develop query polars (#169)
Browse files Browse the repository at this point in the history
* feat(polars): query polars

* feat(polars): add doc

* feat(polars): remove wrapper, fix docstring, re-export

* remove as_string

* disable polars in windows runner

* jobs to 1

* try disable debuginfo

* feat: ffi arrow2 + polars

---------

Co-authored-by: wangfenjin <wangfenj@gmail.com>
  • Loading branch information
therealhieu and wangfenjin authored Jun 18, 2023
1 parent dd8803b commit eebd09f
Show file tree
Hide file tree
Showing 6 changed files with 221 additions and 21 deletions.
9 changes: 5 additions & 4 deletions .github/workflows/rust.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,10 @@ jobs:
targets: x86_64-pc-windows-msvc
- run: cargo install cargo-examples
# make the feature set the same so build faster
- run: cargo test --features "modern-full extensions-full vtab-loadable"
- run: cargo examples --skip hello-ext --features "modern-full extensions-full vtab-loadable"
- run: cargo build --example hello-ext --features "modern-full extensions-full vtab-loadable"
# don't test modern-full as polars requires too much memory
- run: cargo test --features "extensions-full vtab-loadable"
- run: cargo examples --skip hello-ext --features "extensions-full vtab-loadable"
- run: cargo build --example hello-ext --features "extensions-full vtab-loadable"

Sanitizer:
name: Address Sanitizer
Expand All @@ -140,7 +141,7 @@ jobs:
components: rust-src
- name: Tests with asan
env:
RUSTFLAGS: -Zsanitizer=address
RUSTFLAGS: -Zsanitizer=address -C debuginfo=0
RUSTDOCFLAGS: -Zsanitizer=address
ASAN_OPTIONS: "detect_stack_use_after_return=1:detect_leaks=1"
# Work around https://github.com/rust-lang/rust/issues/59125 by
Expand Down
13 changes: 5 additions & 8 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,8 @@ vtab-arrow = ["vtab", "num"]
vtab-full = ["vtab-excel", "vtab-arrow"]
extensions-full = ["httpfs", "json", "parquet", "vtab-full"]
buildtime_bindgen = ["libduckdb-sys/buildtime_bindgen"]
modern-full = [
"chrono",
"serde_json",
"url",
"r2d2",
"uuid",
]
modern-full = ["chrono", "serde_json", "url", "r2d2", "uuid", "polars"]
polars = ["dep:polars"]

[dependencies]
# time = { version = "0.3.2", features = ["formatting", "parsing"], optional = true }
Expand All @@ -63,6 +58,7 @@ r2d2 = { version = "0.8.9", optional = true }
calamine = { version = "0.21.0", optional = true }
num = { version = "0.4", optional = true, default-features = false, features = ["std"] }
duckdb-loadable-macros = { version = "0.1.0", path="./duckdb-loadable-macros", optional = true }
polars = { version = "0.30.0", features = ["dtype-full"], optional = true}

[dev-dependencies]
doc-comment = "0.3"
Expand All @@ -73,6 +69,7 @@ uuid = { version = "1.0", features = ["v4"] }
unicase = "2.6.0"
rand = "0.8.3"
tempdir = "0.3.7"
polars-core = "0.30.0"
# criterion = "0.3"

# [[bench]]
Expand All @@ -96,4 +93,4 @@ all-features = false
[[example]]
name = "hello-ext"
crate-type = ["cdylib"]
required-features = ["vtab-loadable"]
required-features = ["vtab-loadable"]
8 changes: 7 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,13 @@ pub use crate::{
transaction::{DropBehavior, Savepoint, Transaction, TransactionBehavior},
types::ToSql,
};
#[cfg(feature = "polars")]
pub use polars_dataframe::Polars;

// re-export dependencies from arrow-rs to minimise version maintenance for crate users
// re-export dependencies to minimise version maintenance for crate users
pub use arrow;
#[cfg(feature = "polars")]
pub use polars::{self, export::arrow as arrow2};

#[macro_use]
mod error;
Expand All @@ -100,6 +104,8 @@ mod column;
mod config;
mod inner_connection;
mod params;
#[cfg(feature = "polars")]
mod polars_dataframe;
mod pragma;
#[cfg(feature = "r2d2")]
mod r2d2;
Expand Down
86 changes: 86 additions & 0 deletions src/polars_dataframe.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
use polars::prelude::DataFrame;

use super::{arrow::datatypes::SchemaRef, Statement};

/// An handle for the resulting Polars DataFrame of a query.
#[must_use = "Polars is lazy and will do nothing unless consumed"]
pub struct Polars<'stmt> {
pub(crate) stmt: Option<&'stmt Statement<'stmt>>,
}

impl<'stmt> Polars<'stmt> {
#[inline]
pub(crate) fn new(stmt: &'stmt Statement<'stmt>) -> Polars<'stmt> {
Polars { stmt: Some(stmt) }
}

/// return arrow schema
#[inline]
pub fn get_schema(&self) -> SchemaRef {
self.stmt.unwrap().stmt.schema()
}
}

impl<'stmt> Iterator for Polars<'stmt> {
type Item = DataFrame;

fn next(&mut self) -> Option<Self::Item> {
let struct_array = self.stmt?.step2()?;
let df = DataFrame::try_from(struct_array).expect("Failed to construct DataFrame from StructArray");

Some(df)
}
}

#[cfg(test)]
mod tests {
use polars::prelude::*;
use polars_core::utils::accumulate_dataframes_vertical_unchecked;

use crate::{test::checked_memory_handle, Result};

#[test]
fn test_query_polars_small() -> Result<()> {
let db = checked_memory_handle();
let sql = "BEGIN TRANSACTION;
CREATE TABLE test(t INTEGER);
INSERT INTO test VALUES (1); INSERT INTO test VALUES (2); INSERT INTO test VALUES (3); INSERT INTO test VALUES (4); INSERT INTO test VALUES (5);
END TRANSACTION;";
db.execute_batch(sql)?;
let mut stmt = db.prepare("select t from test order by t desc")?;
let mut polars = stmt.query_polars([])?;

let df = polars.next().expect("Failed to get DataFrame");
assert_eq!(
df,
df! (
"t" => [5i32, 4, 3, 2, 1],
)
.expect("Failed to construct DataFrame")
);
assert!(polars.next().is_none());

Ok(())
}

#[test]
fn test_query_polars_large() -> Result<()> {
let db = checked_memory_handle();
db.execute_batch("BEGIN TRANSACTION")?;
db.execute_batch("CREATE TABLE test(t INTEGER);")?;

for _ in 0..600 {
db.execute_batch("INSERT INTO test VALUES (1); INSERT INTO test VALUES (2); INSERT INTO test VALUES (3); INSERT INTO test VALUES (4); INSERT INTO test VALUES (5);")?;
}

db.execute_batch("END TRANSACTION")?;
let mut stmt = db.prepare("select t from test order by t")?;
let pl = stmt.query_polars([])?;

let df = accumulate_dataframes_vertical_unchecked(pl);
assert_eq!(df.height(), 3000);
assert_eq!(df.column("t").unwrap().i32().unwrap().sum().unwrap(), 9000);

Ok(())
}
}
65 changes: 61 additions & 4 deletions src/raw_statement.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
use std::{convert::TryFrom, ffi::CStr, ptr, sync::Arc};

use super::{ffi, Result};
use crate::error::result_from_duckdb_arrow;

use arrow::{
array::{ArrayData, StructArray},
datatypes::{DataType, Schema, SchemaRef},
ffi::{ArrowArray, FFI_ArrowArray, FFI_ArrowSchema},
};

use super::{ffi, Result};
#[cfg(feature = "polars")]
use crate::arrow2;
use crate::error::result_from_duckdb_arrow;

// Private newtype for raw sqlite3_stmts that finalize themselves when dropped.
// TODO: destroy statement and result
#[derive(Debug)]
Expand Down Expand Up @@ -83,10 +85,12 @@ impl RawStatement {
if ffi::duckdb_query_arrow_array(
self.result_unwrap(),
&mut std::ptr::addr_of_mut!(arrays) as *mut _ as *mut ffi::duckdb_arrow_array,
) != ffi::DuckDBSuccess
)
.ne(&ffi::DuckDBSuccess)
{
return None;
}

if arrays.is_empty() {
return None;
}
Expand All @@ -107,6 +111,59 @@ impl RawStatement {
}
}

#[cfg(feature = "polars")]
#[inline]
pub fn step2(&self) -> Option<arrow2::array::StructArray> {
self.result?;

unsafe {
let mut ffi_arrow2_array = arrow2::ffi::ArrowArray::empty();

if ffi::duckdb_query_arrow_array(
self.result_unwrap(),
&mut std::ptr::addr_of_mut!(ffi_arrow2_array) as *mut _ as *mut ffi::duckdb_arrow_array,
)
.ne(&ffi::DuckDBSuccess)
{
return None;
}

let mut ffi_arrow2_schema = arrow2::ffi::ArrowSchema::empty();

if ffi::duckdb_query_arrow_schema(
self.result_unwrap(),
&mut std::ptr::addr_of_mut!(ffi_arrow2_schema) as *mut _ as *mut ffi::duckdb_arrow_schema,
)
.ne(&ffi::DuckDBSuccess)
{
return None;
}

let arrow2_field =
arrow2::ffi::import_field_from_c(&ffi_arrow2_schema).expect("Failed to import arrow2 Field from C");
let import_arrow2_array = arrow2::ffi::import_array_from_c(ffi_arrow2_array, arrow2_field.data_type);

if let Err(err) = import_arrow2_array {
// When array is empty, import_array_from_c returns error with message
// "OutOfSpec("An ArrowArray of type X must have non-null children")
// Therefore, we return None when encountering this error.
match err {
arrow2::error::Error::OutOfSpec(_) => return None,
_ => panic!("Failed to import arrow2 Array from C: {}", err),
}
}

let arrow2_array = import_arrow2_array.unwrap();
let arrow2_struct_array = arrow2_array
.as_any()
.downcast_ref::<arrow2::array::StructArray>()
.expect("Failed to downcast arrow2 Array to arrow2 StructArray")
.to_owned();

Some(arrow2_struct_array)
}
}

#[inline]
pub fn column_count(&self) -> usize {
unsafe { ffi::duckdb_arrow_column_count(self.result_unwrap()) as usize }
Expand Down
61 changes: 57 additions & 4 deletions src/statement.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
use std::{convert, ffi::c_void, fmt, iter::IntoIterator, mem, os::raw::c_char, ptr, str};

use arrow::{array::StructArray, datatypes::DataType};

use super::{ffi, AndThenRows, Connection, Error, MappedRows, Params, RawStatement, Result, Row, Rows, ValueRef};
#[cfg(feature = "polars")]
use crate::{arrow2, polars_dataframe::Polars};
use crate::{
arrow_batch::Arrow,
error::result_from_duckdb_prepare,
types::{TimeUnit, ToSql, ToSqlOutput},
};

use arrow::{array::StructArray, datatypes::DataType};

/// A prepared statement.
pub struct Statement<'conn> {
conn: &'conn Connection,
Expand Down Expand Up @@ -107,6 +109,50 @@ impl Statement<'_> {
Ok(Arrow::new(self))
}

/// Execute the prepared statement, returning a handle to the resulting
/// vector of polars DataFrame.
///
/// ## Example
///
/// ```rust,no_run
/// # use duckdb::{Result, Connection};
/// # use polars::prelude::DataFrame;
///
/// fn get_polars_dfs(conn: &Connection) -> Result<Vec<DataFrame>> {
/// let dfs: Vec<DataFrame> = conn
/// .prepare("SELECT * FROM test")?
/// .query_polars([])?
/// .collect();
///
/// Ok(dfs)
/// }
/// ```
///
/// To derive a DataFrame from Vec\<DataFrame>, we can use function
/// [polars_core::utils::accumulate_dataframes_vertical_unchecked](https://docs.rs/polars-core/latest/polars_core/utils/fn.accumulate_dataframes_vertical_unchecked.html).
///
/// ```rust,no_run
/// # use duckdb::{Result, Connection};
/// # use polars::prelude::DataFrame;
/// # use polars_core::utils::accumulate_dataframes_vertical_unchecked;
///
/// fn get_polars_df(conn: &Connection) -> Result<DataFrame> {
/// let mut stmt = conn.prepare("SELECT * FROM test")?;
/// let pl = stmt.query_polars([])?;
/// let df = accumulate_dataframes_vertical_unchecked(pl);
///
/// Ok(df)
/// }
/// ```
///
///
#[cfg(feature = "polars")]
#[inline]
pub fn query_polars<P: Params>(&mut self, params: P) -> Result<Polars<'_>> {
self.execute(params)?;
Ok(Polars::new(self))
}

/// Execute the prepared statement, returning a handle to the resulting
/// rows.
///
Expand Down Expand Up @@ -219,7 +265,7 @@ impl Statement<'_> {
///
/// ### Use with positional params
///
/// ```rust,no_run
/// ```no_run
/// # use duckdb::{Connection, Result};
/// fn get_names(conn: &Connection) -> Result<Vec<String>> {
/// let mut stmt = conn.prepare("SELECT name FROM people WHERE id = ?")?;
Expand Down Expand Up @@ -285,12 +331,19 @@ impl Statement<'_> {
self.stmt.row_count()
}

/// Get next batch records
/// Get next batch records in arrow-rs
#[inline]
pub fn step(&self) -> Option<StructArray> {
self.stmt.step()
}

#[cfg(feature = "polars")]
/// Get next batch records in arrow2
#[inline]
pub fn step2(&self) -> Option<arrow2::array::StructArray> {
self.stmt.step2()
}

#[inline]
pub(crate) fn bind_parameters<P>(&mut self, params: P) -> Result<()>
where
Expand Down

0 comments on commit eebd09f

Please sign in to comment.