From 6d8ca79038c212965accb92679d36ab49df14887 Mon Sep 17 00:00:00 2001 From: Bidek56 Date: Mon, 20 May 2024 16:02:06 -0400 Subject: [PATCH] Initial commit for Decimal type --- __tests__/dataframe.test.ts | 36 ++++++++++++++++++++---------------- polars/dataframe.ts | 1 + polars/datatypes/datatype.ts | 8 +++++++- polars/datatypes/index.ts | 4 ++++ polars/index.ts | 1 + src/conversion.rs | 25 +++++++++++++++++++++++++ src/datatypes.rs | 9 +++++++++ src/series.rs | 6 ++++++ 8 files changed, 73 insertions(+), 17 deletions(-) diff --git a/__tests__/dataframe.test.ts b/__tests__/dataframe.test.ts index c2e62f318..d493f96ec 100644 --- a/__tests__/dataframe.test.ts +++ b/__tests__/dataframe.test.ts @@ -1621,6 +1621,19 @@ describe("io", () => { fs.rmSync("./test.csv"); done(); }); + test("writeParquet", (done) => { + const df = pl.DataFrame([ + pl.Series("doo", [1, 2, 3], pl.Decimal), + pl.Series("foo", [1, 2, 3], pl.UInt32), + pl.Series("bar", ["a", "b", "c"]), + ]); + const pqFile = "./test.parquet"; + df.writeParquet(pqFile); + const newDF = pl.readParquet(pqFile); + expect(newDF).toFrameEqual(df); + fs.rmSync(pqFile); + done(); + }); test("JSON.stringify", () => { const df = pl.DataFrame({ foo: [1], @@ -1951,30 +1964,21 @@ describe("create", () => { { a: [1, 2, 3], b: ["1", "2", "3"], + d: [1, 2, 3], }, { schema: { x: pl.Int32, y: pl.String, + z: pl.Decimal, }, }, ); - expect(df.schema).toStrictEqual({ x: pl.Int32, y: pl.String }); - }); - test("with schema", () => { - const df = pl.DataFrame( - { - a: [1, 2, 3], - b: ["1", "2", "3"], - }, - { - schema: { - x: pl.Int32, - y: pl.String, - }, - }, - ); - expect(df.schema).toStrictEqual({ x: pl.Int32, y: pl.String }); + expect(df.schema).toStrictEqual({ + x: pl.Int32, + y: pl.String, + z: pl.Decimal, + }); }); test("with schema overrides", () => { const df = pl.DataFrame( diff --git a/polars/dataframe.ts b/polars/dataframe.ts index 81ac27c83..2522438da 100644 --- a/polars/dataframe.ts +++ b/polars/dataframe.ts @@ -1853,6 +1853,7 @@ function mapPolarsTypeToJSONSchema(colType: DataType): string { UInt64: "integer", Float32: "number", Float64: "number", + Decimal: "number", Date: "string", Datetime: "string", Utf8: "string", diff --git a/polars/datatypes/datatype.ts b/polars/datatypes/datatype.ts index 1ac4f72ee..25f50b7f1 100644 --- a/polars/datatypes/datatype.ts +++ b/polars/datatypes/datatype.ts @@ -56,7 +56,6 @@ export abstract class DataType { public static get UInt64(): DataType { return new _UInt64(); } - /** A `f32` */ public static get Float32(): DataType { return new _Float32(); @@ -65,6 +64,10 @@ export abstract class DataType { public static get Float64(): DataType { return new _Float64(); } + /** A `decimal` */ + public static get Decimal(): DataType { + return new _Decimal(); + } public static get Date(): DataType { return new _Date(); } @@ -163,6 +166,7 @@ class _UInt32 extends DataType {} class _UInt64 extends DataType {} class _Float32 extends DataType {} class _Float64 extends DataType {} +class _Decimal extends DataType {} class _Date extends DataType {} class _Time extends DataType {} class _Object extends DataType {} @@ -299,6 +303,8 @@ export namespace DataType { export type Float32 = _Float32; /** Float64 */ export type Float64 = _Float64; + /** Decimal */ + export type Decimal = _Decimal; /** Date dtype */ export type Date = _Date; /** Datetime */ diff --git a/polars/datatypes/index.ts b/polars/datatypes/index.ts index 5f86f2885..114dc28a8 100644 --- a/polars/datatypes/index.ts +++ b/polars/datatypes/index.ts @@ -42,6 +42,7 @@ export const DTYPE_TO_FFINAME = { UInt64: "U64", Float32: "F32", Float64: "F64", + Decimal: "Decimal", Bool: "Bool", Utf8: "Str", String: "Str", @@ -61,6 +62,9 @@ const POLARS_TYPE_TO_CONSTRUCTOR: Record = { Float64(name, values, strict?) { return pli.JsSeries.newOptF64(name, values, strict); }, + Decimal(name, values, strict?) { + return pli.JsSeries.newOptDecimal(name, values, strict); + }, Int8(name, values, strict?) { return pli.JsSeries.newOptI32(name, values, strict); }, diff --git a/polars/index.ts b/polars/index.ts index 4abd9caf1..8a73a9ead 100644 --- a/polars/index.ts +++ b/polars/index.ts @@ -98,6 +98,7 @@ export namespace pl { export import UInt64 = DataType.UInt64; export import Float32 = DataType.Float32; export import Float64 = DataType.Float64; + export import Decimal = DataType.Decimal; export import Bool = DataType.Bool; export import Utf8 = DataType.Utf8; // biome-ignore lint/suspicious/noShadowRestrictedNames: pl.String diff --git a/src/conversion.rs b/src/conversion.rs index 525a94f73..b2bd7ebc4 100644 --- a/src/conversion.rs +++ b/src/conversion.rs @@ -261,6 +261,29 @@ impl FromNapiValue for Wrap> { } } +impl FromNapiValue for Wrap> { + unsafe fn from_napi_value(env: sys::napi_env, napi_val: sys::napi_value) -> JsResult { + let arr = Array::from_napi_value(env, napi_val)?; + let len = arr.len() as usize; + let mut builder = PrimitiveChunkedBuilder::::new("", len); + for i in 0..len { + match arr.get::(i as u32) { + Ok(val) => match val { + Some(v) => { + let (v, _b) = v.get_i128(); + builder.append_value(v) + } + None => builder.append_null(), + }, + Err(_) => { + builder.append_null() + } + } + } + Ok(Wrap(builder.finish())) + } +} + impl FromNapiValue for Wrap { unsafe fn from_napi_value(env: sys::napi_env, napi_val: sys::napi_value) -> JsResult { let obj = Object::from_napi_value(env, napi_val)?; @@ -655,6 +678,7 @@ impl FromNapiValue for Wrap { "UInt64" => DataType::UInt64, "Float32" => DataType::Float32, "Float64" => DataType::Float64, + "Decimal" => DataType::Decimal(None, None), "Bool" => DataType::Boolean, "Utf8" => DataType::String, "String" => DataType::String, @@ -927,6 +951,7 @@ impl ToNapiValue for Wrap { DataType::UInt64 => String::to_napi_value(env, "UInt64".to_owned()), DataType::Float32 => String::to_napi_value(env, "Float32".to_owned()), DataType::Float64 => String::to_napi_value(env, "Float64".to_owned()), + DataType::Decimal(_,_) => String::to_napi_value(env, "Decimal".to_owned()), DataType::Boolean => String::to_napi_value(env, "Bool".to_owned()), DataType::String => String::to_napi_value(env, "String".to_owned()), DataType::List(inner) => { diff --git a/src/datatypes.rs b/src/datatypes.rs index 4fb27f676..5f6e58f66 100644 --- a/src/datatypes.rs +++ b/src/datatypes.rs @@ -12,6 +12,7 @@ pub enum JsDataType { UInt64, Float32, Float64, + Decimal, Bool, Utf8, String, @@ -36,6 +37,7 @@ impl JsDataType { "UInt64" => JsDataType::UInt64, "Float32" => JsDataType::Float32, "Float64" => JsDataType::Float64, + "Decimal" => JsDataType::Decimal, "Bool" => JsDataType::Bool, "Utf8" => JsDataType::Utf8, "String" => JsDataType::String, @@ -65,6 +67,7 @@ impl From<&DataType> for JsDataType { DataType::UInt64 => UInt64, DataType::Float32 => Float32, DataType::Float64 => Float64, + DataType::Decimal(..) => Decimal, DataType::Boolean => Bool, DataType::String => Utf8, DataType::List(_) => List, @@ -115,6 +118,7 @@ pub enum JsAnyValue { Int64(i64), Float32(f32), Float64(f64), + Decimal(i128, usize), Date(i32), Datetime(i64, TimeUnit, Option), Duration(i64, TimeUnit), @@ -227,6 +231,7 @@ impl ToNapiValue for JsAnyValue { JsAnyValue::UInt64(n) => u64::to_napi_value(env, n), JsAnyValue::Float32(n) => f64::to_napi_value(env, n as f64), JsAnyValue::Float64(n) => f64::to_napi_value(env, n), + JsAnyValue::Decimal(n, _u) => i128::to_napi_value(env, n), JsAnyValue::Utf8(s) => String::to_napi_value(env, s), JsAnyValue::String(s) => String::to_napi_value(env, s), JsAnyValue::Date(v) => { @@ -276,6 +281,7 @@ impl<'a> From for AnyValue<'a> { JsAnyValue::Int64(v) => AnyValue::Int64(v), JsAnyValue::Float32(v) => AnyValue::Float32(v), JsAnyValue::Float64(v) => AnyValue::Float64(v), + JsAnyValue::Decimal(v, u) => AnyValue::Decimal(v, u), JsAnyValue::Date(v) => AnyValue::Date(v), JsAnyValue::Datetime(v, w, _) => AnyValue::Datetime(v, w, &None), JsAnyValue::Duration(v, _) => AnyValue::Duration(v, TimeUnit::Milliseconds), @@ -302,6 +308,7 @@ impl From> for JsAnyValue { AnyValue::Int64(v) => JsAnyValue::Int64(v), AnyValue::Float32(v) => JsAnyValue::Float32(v), AnyValue::Float64(v) => JsAnyValue::Float64(v), + AnyValue::Decimal(v, u) => JsAnyValue::Decimal(v, u), AnyValue::Date(v) => JsAnyValue::Date(v), AnyValue::Datetime(v, w, _) => JsAnyValue::Datetime(v, w, None), AnyValue::Duration(v, _) => JsAnyValue::Duration(v, TimeUnit::Milliseconds), @@ -328,6 +335,7 @@ impl From<&JsAnyValue> for DataType { JsAnyValue::Int64(_) => DataType::Int64, JsAnyValue::Float32(_) => DataType::Float32, JsAnyValue::Float64(_) => DataType::Float64, + JsAnyValue::Decimal(_,_) => DataType::Decimal(None, None), JsAnyValue::Date(_) => DataType::Date, JsAnyValue::Datetime(_, _, _) => DataType::Datetime(TimeUnit::Milliseconds, None), JsAnyValue::Duration(_, _) => DataType::Duration(TimeUnit::Milliseconds), @@ -392,6 +400,7 @@ impl Into for JsDataType { JsDataType::UInt64 => UInt64, JsDataType::Float32 => Float32, JsDataType::Float64 => Float64, + JsDataType::Decimal => Decimal(None, None), JsDataType::Bool => Boolean, JsDataType::Utf8 => String, JsDataType::String => String, diff --git a/src/series.rs b/src/series.rs index 0180a8525..80508515b 100644 --- a/src/series.rs +++ b/src/series.rs @@ -156,6 +156,12 @@ impl JsSeries { s.rename(&name); JsSeries::new(s) } + #[napi(factory, catch_unwind)] + pub fn new_opt_decimal(name: String, val: Wrap, _strict: bool) -> JsSeries { + let mut s = val.0.into_series(); + s.rename(&name); + JsSeries::new(s) + } #[napi(factory, catch_unwind)] pub fn new_opt_date(