Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding pl.Struct support for pl.Dataframe #306

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 95 additions & 1 deletion __tests__/dataframe.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ describe("dataframe", () => {
pl.Series("foo", [1, 2, 9], pl.Int16),
pl.Series("bar", [6, 2, 8], pl.Int16),
]);

test("dtypes", () => {
const expected = [pl.Float64, pl.String];
const actual = pl.DataFrame({ a: [1, 2, 3], b: ["a", "b", "c"] }).dtypes;
Expand Down Expand Up @@ -1314,6 +1313,101 @@ describe("dataframe", () => {
]);
expect(actual).toFrameEqual(expected);
});
test("df from JSON with multiple struct", () => {
const rows = [
{
id: 1,
name: "one",
attributes: {
b: false,
bb: true,
s: "one",
x: 1,
att2: { s: "two", y: 2, att3: { s: "three", y: 3 } },
},
},
];

const actual = pl.DataFrame(rows);
const expected = `shape: (1,)
Series: 'attributes' [struct[5]]
[
{false,true,"one",1.0,{"two",2.0,{"three",3.0}}}
]`;
expect(actual.select("attributes").toSeries().toString()).toEqual(expected);
});
test("df from JSON with struct", () => {
const rows = [
{
id: 1,
name: "one",
attributes: { b: false, bb: true, s: "one", x: 1 },
},
{
id: 2,
name: "two",
attributes: { b: false, bb: true, s: "two", x: 2 },
},
{
id: 3,
name: "three",
attributes: { b: false, bb: true, s: "three", x: 3 },
},
];

let actual = pl.DataFrame(rows);
expect(actual.schema).toStrictEqual({
id: pl.Float64,
name: pl.String,
attributes: pl.Struct([
new pl.Field("b", pl.Bool),
new pl.Field("bb", pl.Bool),
new pl.Field("s", pl.String),
new pl.Field("x", pl.Float64),
]),
});

let expected = `shape: (3, 3)
┌─────┬───────┬──────────────────────────┐
│ id ┆ name ┆ attributes │
│ --- ┆ --- ┆ --- │
│ f64 ┆ str ┆ struct[4] │
╞═════╪═══════╪══════════════════════════╡
│ 1.0 ┆ one ┆ {false,true,"one",1.0} │
│ 2.0 ┆ two ┆ {false,true,"two",2.0} │
│ 3.0 ┆ three ┆ {false,true,"three",3.0} │
└─────┴───────┴──────────────────────────┘`;
expect(actual.toString()).toStrictEqual(expected);

const schema = {
id: pl.Int32,
name: pl.String,
attributes: pl.Struct([
new pl.Field("b", pl.Bool),
new pl.Field("bb", pl.Bool),
new pl.Field("s", pl.String),
new pl.Field("x", pl.Int16),
]),
};
actual = pl.DataFrame(rows, { schema: schema });
expected = `shape: (3, 3)
┌─────┬───────┬────────────────────────┐
│ id ┆ name ┆ attributes │
│ --- ┆ --- ┆ --- │
│ i32 ┆ str ┆ struct[4] │
╞═════╪═══════╪════════════════════════╡
│ 1 ┆ one ┆ {false,true,"one",1} │
│ 2 ┆ two ┆ {false,true,"two",2} │
│ 3 ┆ three ┆ {false,true,"three",3} │
└─────┴───────┴────────────────────────┘`;
expect(actual.toString()).toStrictEqual(expected);
expect(actual.getColumn("name").toArray()).toEqual(
rows.map((e) => e["name"]),
);
expect(actual.getColumn("attributes").toArray()).toMatchObject(
rows.map((e) => e["attributes"]),
);
});
test("pivot", () => {
{
const df = pl.DataFrame({
Expand Down
4 changes: 2 additions & 2 deletions src/conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -830,8 +830,8 @@ impl FromNapiValue for Wrap<CsvWriterOptions> {
let obj = Object::from_napi_value(env, napi_val)?;
let include_bom = obj.get::<_, bool>("includeBom")?.unwrap_or(false);
let include_header = obj.get::<_, bool>("includeHeader")?.unwrap_or(true);
let batch_size =
NonZero::new(obj.get::<_, i64>("batchSize")?.unwrap_or(1024) as usize).ok_or_else(|| napi::Error::from_reason("Invalid batch size"))?;
let batch_size = NonZero::new(obj.get::<_, i64>("batchSize")?.unwrap_or(1024) as usize)
.ok_or_else(|| napi::Error::from_reason("Invalid batch size"))?;
let maintain_order = obj.get::<_, bool>("maintainOrder")?.unwrap_or(true);
let date_format = obj.get::<_, String>("dateFormat")?;
let time_format = obj.get::<_, String>("timeFormat")?;
Expand Down
179 changes: 119 additions & 60 deletions src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -442,27 +442,26 @@ pub fn from_rows(
infer_schema(pairs, infer_schema_length)
}
};
let len = rows.len();
let it: Vec<Row> = (0..len)
let it: Vec<Row> = (0..rows.len())
.into_iter()
.map(|idx| {
let obj = rows
.get::<Object>(idx as u32)
.unwrap_or(None)
.unwrap_or_else(|| env.create_object().unwrap());

Row(schema
.iter_fields()
.map(|fld| {
let dtype = fld.dtype().clone();
let key = fld.name();
if let Ok(unknown) = obj.get(key) {
let av = match unknown {
Some(unknown) => unsafe {
coerce_js_anyvalue(unknown, dtype).unwrap_or(AnyValue::Null)
},
None => AnyValue::Null,
};
av
let dtype: &DataType = fld.dtype();
let key: &PlSmallStr = fld.name();
if let Ok(unknown) = obj.get::<&polars::prelude::PlSmallStr, JsUnknown>(key) {
match unknown {
Some(unknown) => {
coerce_js_anyvalue(unknown, dtype.clone()).unwrap_or(AnyValue::Null)
}
_ => AnyValue::Null,
}
} else {
AnyValue::Null
}
Expand Down Expand Up @@ -1620,61 +1619,79 @@ fn obj_to_pairs(rows: &Array, len: usize) -> impl '_ + Iterator<Item = Vec<(Stri
let len = std::cmp::min(len, rows.len() as usize);
(0..len).map(move |idx| {
let obj = rows.get::<Object>(idx as u32).unwrap().unwrap();

let keys = Object::keys(&obj).unwrap();
keys.iter()
.map(|key| {
let value = obj.get::<_, napi::JsUnknown>(&key).unwrap_or(None);
let dtype = match value {
Some(val) => {
let ty = val.get_type().unwrap();
match ty {
ValueType::Boolean => DataType::Boolean,
ValueType::Number => DataType::Float64,
ValueType::String => DataType::String,
ValueType::Object => {
if val.is_array().unwrap() {
let arr: napi::JsObject = unsafe { val.cast() };
let len = arr.get_array_length().unwrap();

if len == 0 {
DataType::List(DataType::Null.into())
} else {
// dont compare too many items, as it could be expensive
let max_take = std::cmp::min(len as usize, 10);
let mut dtypes: Vec<DataType> =
Vec::with_capacity(len as usize);

for idx in 0..max_take {
let item: napi::JsUnknown =
arr.get_element(idx as u32).unwrap();
let ty = item.get_type().unwrap();
let dt: Wrap<DataType> = ty.into();
dtypes.push(dt.0)
}
let dtype = coerce_data_type(&dtypes);

DataType::List(dtype.into())
}
} else if val.is_date().unwrap() {
DataType::Datetime(TimeUnit::Milliseconds, None)
} else {
DataType::Struct(vec![])
}
}
ValueType::BigInt => DataType::UInt64,
_ => DataType::Null,
}
}
None => DataType::Null,
};
(key.to_owned(), dtype)
(key.to_owned(), obj_to_type(value))
})
.collect()
})
}

unsafe fn coerce_js_anyvalue<'a>(val: JsUnknown, dtype: DataType) -> JsResult<AnyValue<'a>> {
fn obj_to_type(value: Option<JsUnknown>) -> DataType {
match value {
Some(val) => {
let ty = val.get_type().unwrap();
match ty {
ValueType::Boolean => DataType::Boolean,
ValueType::Number => DataType::Float64,
ValueType::BigInt => DataType::UInt64,
ValueType::String => DataType::String,
ValueType::Object => {
if val.is_array().unwrap() {
let arr: napi::JsObject = unsafe { val.cast() };
let len = arr.get_array_length().unwrap();
if len == 0 {
DataType::List(DataType::Null.into())
} else {
// dont compare too many items, as it could be expensive
let max_take = std::cmp::min(len as usize, 10);
let mut dtypes: Vec<DataType> = Vec::with_capacity(len as usize);

for idx in 0..max_take {
let item: napi::JsUnknown = arr.get_element(idx as u32).unwrap();
let ty = item.get_type().unwrap();
let dt: Wrap<DataType> = ty.into();
dtypes.push(dt.0)
}
let dtype = coerce_data_type(&dtypes);

DataType::List(dtype.into())
}
} else if val.is_date().unwrap() {
DataType::Datetime(TimeUnit::Milliseconds, None)
} else {
let inner_val: napi::JsObject = unsafe { val.cast() };
let inner_keys = Object::keys(&inner_val).unwrap();
let mut fldvec: Vec<Field> = Vec::with_capacity(inner_keys.len() as usize);

inner_keys.iter().for_each(|key| {
let inner_val = inner_val.get::<_, napi::JsUnknown>(&key).unwrap();
let dtype = match inner_val.as_ref().unwrap().get_type().unwrap() {
ValueType::Boolean => DataType::Boolean,
ValueType::Number => DataType::Float64,
ValueType::BigInt => DataType::UInt64,
ValueType::String => DataType::String,
// determine struct type using a recursive func
ValueType::Object => obj_to_type(inner_val),
_ => DataType::Null,
};

let fld = Field::new(key.into(), dtype);
fldvec.push(fld);
});
DataType::Struct(fldvec)
}
}
_ => DataType::Null,
}
}
None => DataType::Null,
}
}

fn coerce_js_anyvalue<'a>(val: JsUnknown, dtype: DataType) -> JsResult<AnyValue<'a>> {
use DataType::*;
let vtype = val.get_type().unwrap();
match (vtype, dtype) {
Expand Down Expand Up @@ -1749,17 +1766,59 @@ unsafe fn coerce_js_anyvalue<'a>(val: JsUnknown, dtype: DataType) -> JsResult<An
}
(ValueType::Object, DataType::Datetime(_, _)) => {
if val.is_date()? {
let d: napi::JsDate = val.cast();
let d: napi::JsDate = unsafe { val.cast() };
let d = d.value_of()?;
Ok(AnyValue::Datetime(d as i64, TimeUnit::Milliseconds, None))
} else {
Ok(AnyValue::Null)
}
}
(ValueType::Object, DataType::List(_)) => {
let s = val.to_series();
let s = unsafe { val.to_series() };
Ok(AnyValue::List(s))
}
(ValueType::Object, DataType::Struct(fields)) => {
let number_of_fields: i8 = fields.len().try_into().map_err(|e| {
napi::Error::from_reason(format!(
"the number of `fields` cannot be larger than i8::MAX {e:?}"
))
})?;

let inner_val: napi::JsObject = unsafe { val.cast() };
let mut val_vec: Vec<polars::prelude::AnyValue<'_>> =
Vec::with_capacity(number_of_fields as usize);
fields.iter().for_each(|fld| {
let single_val = inner_val
.get::<_, napi::JsUnknown>(&fld.name)
.unwrap()
.unwrap();
let vv = match &fld.dtype {
DataType::Boolean => {
AnyValue::Boolean(single_val.coerce_to_bool().unwrap().get_value().unwrap())
}
DataType::String => AnyValue::from_js(single_val).expect("Expecting string"),
DataType::Int16 => AnyValue::Int16(
single_val.coerce_to_number().unwrap().get_int32().unwrap() as i16,
),
DataType::Int32 => {
AnyValue::Int32(single_val.coerce_to_number().unwrap().get_int32().unwrap())
}
DataType::Int64 => {
AnyValue::Int64(single_val.coerce_to_number().unwrap().get_int64().unwrap())
}
DataType::Float64 => AnyValue::Float64(
single_val.coerce_to_number().unwrap().get_double().unwrap(),
),
DataType::Struct(_) => {
coerce_js_anyvalue(single_val, fld.dtype.clone()).unwrap()
}
_ => AnyValue::Null,
};
val_vec.push(vv);
});

Ok(AnyValue::StructOwned(Box::new((val_vec, fields))))
}
_ => Ok(AnyValue::Null),
}
}
Loading