diff --git a/crates/duckdb/src/appender/arrow.rs b/crates/duckdb/src/appender/arrow.rs index 942242cc..7506bfa3 100644 --- a/crates/duckdb/src/appender/arrow.rs +++ b/crates/duckdb/src/appender/arrow.rs @@ -31,7 +31,7 @@ impl Appender<'_> { let schema = record_batch.schema(); let mut logical_type: Vec = vec![]; for field in schema.fields() { - let logical_t = to_duckdb_logical_type(field.data_type()) + let logical_t = to_duckdb_logical_type(field.data_type(), field.metadata()) .map_err(|_op| Error::ArrowTypeToDuckdbType(field.to_string(), field.data_type().clone()))?; logical_type.push(logical_t); } diff --git a/crates/duckdb/src/core/data_chunk.rs b/crates/duckdb/src/core/data_chunk.rs index 3ef35992..6636e7a9 100644 --- a/crates/duckdb/src/core/data_chunk.rs +++ b/crates/duckdb/src/core/data_chunk.rs @@ -4,7 +4,7 @@ use super::{ }; use crate::ffi::{ duckdb_create_data_chunk, duckdb_data_chunk, duckdb_data_chunk_get_column_count, duckdb_data_chunk_get_size, - duckdb_data_chunk_get_vector, duckdb_data_chunk_set_size, duckdb_destroy_data_chunk, + duckdb_data_chunk_get_vector, duckdb_data_chunk_set_size, duckdb_destroy_data_chunk, duckdb_vector_get_column_type, }; /// Handle to the DataChunk in DuckDB. @@ -59,6 +59,15 @@ impl DataChunkHandle { StructVector::from(unsafe { duckdb_data_chunk_get_vector(self.ptr, idx as u64) }) } + /// Get the logical type of the vector at the column index: `idx`. + pub fn logical_type(&self, idx: usize) -> LogicalTypeHandle { + unsafe { + LogicalTypeHandle::new(duckdb_vector_get_column_type(duckdb_data_chunk_get_vector( + self.ptr, idx as u64, + ))) + } + } + /// Set the size of the data chunk pub fn set_len(&self, new_len: usize) { unsafe { duckdb_data_chunk_set_size(self.ptr, new_len as u64) }; diff --git a/crates/duckdb/src/vtab/arrow.rs b/crates/duckdb/src/vtab/arrow.rs index 219f6f71..665b2575 100644 --- a/crates/duckdb/src/vtab/arrow.rs +++ b/crates/duckdb/src/vtab/arrow.rs @@ -1,5 +1,5 @@ use super::{BindInfo, DataChunkHandle, Free, FunctionInfo, InitInfo, LogicalTypeHandle, LogicalTypeId, VTab}; -use std::ptr::null_mut; +use std::{collections::HashMap, ptr::null_mut}; use crate::core::{ArrayVector, FlatVector, Inserter, ListVector, StructVector, Vector}; use arrow::{ @@ -88,7 +88,7 @@ impl VTab for ArrowVTab { for f in rb.schema().fields() { let name = f.name(); let data_type = f.data_type(); - let logical_type = to_duckdb_logical_type(data_type)?; + let logical_type = to_duckdb_logical_type(data_type, f.metadata())?; bind.add_result_column(name, logical_type); } (*data).rb = Box::into_raw(Box::new(rb)); @@ -128,8 +128,15 @@ impl VTab for ArrowVTab { } } +const EXTENSION_NAME_KEY: &str = "ARROW:extension:name"; +const UUID_EXTENSION_NAME: &str = "arrow.uuid"; +const UUID_LENGTH: usize = 16; + /// Convert arrow DataType to duckdb type id -pub fn to_duckdb_type_id(data_type: &DataType) -> Result> { +pub fn to_duckdb_type_id( + data_type: &DataType, + metadata: &HashMap, +) -> Result> { use LogicalTypeId::*; let type_id = match data_type { @@ -157,7 +164,17 @@ pub fn to_duckdb_type_id(data_type: &DataType) -> Result Time, DataType::Duration(_) => Interval, DataType::Interval(_) => Interval, - DataType::Binary | DataType::LargeBinary | DataType::FixedSizeBinary(_) => Blob, + DataType::FixedSizeBinary(_) => { + if metadata + .get(EXTENSION_NAME_KEY) + .map_or(false, |name| name == UUID_EXTENSION_NAME) + { + Uuid + } else { + Blob + } + } + DataType::Binary | DataType::LargeBinary => Blob, DataType::Utf8 | DataType::LargeUtf8 => Varchar, DataType::List(_) | DataType::LargeList(_) | DataType::FixedSizeList(_, _) => List, DataType::Struct(_) => Struct, @@ -177,21 +194,28 @@ pub fn to_duckdb_type_id(data_type: &DataType) -> Result Result> { +pub fn to_duckdb_logical_type( + data_type: &DataType, + metadata: &HashMap, +) -> Result> { match data_type { - DataType::Dictionary(_, value_type) => to_duckdb_logical_type(value_type), + DataType::Dictionary(_, value_type) => to_duckdb_logical_type(value_type, &HashMap::new()), DataType::Struct(fields) => { let mut shape = vec![]; for field in fields.iter() { - shape.push((field.name().as_str(), to_duckdb_logical_type(field.data_type())?)); + shape.push(( + field.name().as_str(), + to_duckdb_logical_type(field.data_type(), field.metadata())?, + )); } Ok(LogicalTypeHandle::struct_type(shape.as_slice())) } - DataType::List(child) | DataType::LargeList(child) => { - Ok(LogicalTypeHandle::list(&to_duckdb_logical_type(child.data_type())?)) - } + DataType::List(child) | DataType::LargeList(child) => Ok(LogicalTypeHandle::list(&to_duckdb_logical_type( + child.data_type(), + child.metadata(), + )?)), DataType::FixedSizeList(child, array_size) => Ok(LogicalTypeHandle::array( - &to_duckdb_logical_type(child.data_type())?, + &to_duckdb_logical_type(child.data_type(), child.metadata())?, *array_size as u64, )), DataType::Decimal128(width, scale) if *scale > 0 => { @@ -203,8 +227,8 @@ pub fn to_duckdb_logical_type(data_type: &DataType) -> Result Ok(LogicalTypeHandle::from(to_duckdb_type_id(data_type)?)), - dtype if dtype.is_primitive() => Ok(LogicalTypeHandle::from(to_duckdb_type_id(data_type)?)), + | DataType::FixedSizeBinary(_) => Ok(LogicalTypeHandle::from(to_duckdb_type_id(data_type, metadata)?)), + dtype if dtype.is_primitive() => Ok(LogicalTypeHandle::from(to_duckdb_type_id(data_type, metadata)?)), _ => Err(format!( "Unsupported data type: {data_type}, please file an issue https://github.com/wangfenjin/duckdb-rs" ) @@ -245,8 +269,15 @@ pub fn record_batch_to_duckdb_data_chunk( DataType::Binary => { binary_array_to_vector(as_generic_binary_array(col.as_ref()), &mut chunk.flat_vector(i)); } - DataType::FixedSizeBinary(_) => { - fixed_size_binary_array_to_vector(col.as_ref().as_fixed_size_binary(), &mut chunk.flat_vector(i)); + DataType::FixedSizeBinary(length) => { + if chunk.logical_type(i).id() == LogicalTypeId::Uuid { + if *length != UUID_LENGTH as i32 { + return Err(format!("UUID FixedSizeBinaryArray must have value length of 16").into()); + } + uuid_array_to_vector(col.as_ref().as_fixed_size_binary(), &mut chunk.flat_vector(i)); + } else { + fixed_size_binary_array_to_vector(col.as_ref().as_fixed_size_binary(), &mut chunk.flat_vector(i)); + } } DataType::LargeBinary => { large_binary_array_to_vector( @@ -284,6 +315,20 @@ pub fn record_batch_to_duckdb_data_chunk( Ok(()) } +fn uuid_array_to_vector(array: &FixedSizeBinaryArray, out_vector: &mut FlatVector) { + let out_data: &mut [i128] = out_vector.as_mut_slice(); + for (i, value) in array.values().chunks_exact(UUID_LENGTH).enumerate() { + let value: [u8; UUID_LENGTH] = value.try_into().unwrap(); + let value = i128::from_be_bytes(value); + // For whatever reason, DuckDB internally uses a signed integer to represent UUIDs + // but we need to swap the sign bit in order to maintain fidelity. + const MASK: i128 = (1u128 << 127) as _; + let value = value ^ MASK; + out_data[i] = value; + } + set_nulls_in_flat_vector(array, out_vector); +} + fn primitive_array_to_flat_vector(array: &PrimitiveArray, out_vector: &mut FlatVector) { // assert!(array.len() <= out_vector.capacity()); out_vector.copy::(array.values()); @@ -698,24 +743,29 @@ fn set_nulls_in_list_vector(array: &dyn Array, out_vector: &mut ListVector) { #[cfg(test)] mod test { - use super::{arrow_recordbatch_to_query_params, ArrowVTab}; - use crate::{Connection, Result}; + use super::{arrow_recordbatch_to_query_params, ArrowVTab, UUID_LENGTH}; + use crate::{ + vtab::arrow::{EXTENSION_NAME_KEY, UUID_EXTENSION_NAME}, + Connection, Result, + }; use arrow::{ array::{ Array, ArrayRef, AsArray, BinaryArray, Date32Array, Date64Array, Decimal128Array, Decimal256Array, - DurationSecondArray, FixedSizeListArray, GenericByteArray, GenericListArray, Int32Array, - IntervalDayTimeArray, IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeStringArray, ListArray, - OffsetSizeTrait, PrimitiveArray, StringArray, StructArray, Time32SecondArray, Time64MicrosecondArray, - TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, + DurationSecondArray, FixedSizeBinaryArray, FixedSizeListArray, GenericByteArray, GenericListArray, + Int32Array, IntervalDayTimeArray, IntervalMonthDayNanoArray, IntervalYearMonthArray, + LargeStringArray, ListArray, OffsetSizeTrait, PrimitiveArray, StringArray, StructArray, Time32SecondArray, + Time64MicrosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, + TimestampSecondArray, }, - buffer::{OffsetBuffer, ScalarBuffer}, + buffer::{Buffer, NullBuffer, OffsetBuffer, ScalarBuffer}, datatypes::{ i256, ArrowPrimitiveType, ByteArrayType, DataType, DurationSecondType, Field, Fields, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalYearMonthType, Schema, }, - record_batch::RecordBatch, + record_batch::RecordBatch, util::pretty, }; - use std::{error::Error, sync::Arc}; + use std::{collections::HashMap, error::Error, sync::Arc}; + use uuid::Uuid; #[test] fn test_vtab_arrow() -> Result<(), Box> { @@ -1264,4 +1314,62 @@ mod test { assert_eq!(column.len(), 1); assert_eq!(column.value(0), b"test"); } + + #[test] + fn test_uuid_roundtrip() -> Result<(), Box> { + let db = Connection::open_in_memory()?; + db.register_table_function::("arrow")?; + let schema = Schema::new(vec![Field::new( + "a", + DataType::FixedSizeBinary(UUID_LENGTH as i32), + true, + ) + .with_metadata(HashMap::from([( + EXTENSION_NAME_KEY.to_string(), + UUID_EXTENSION_NAME.to_string(), + )]))]); + + let uuids = vec![ + Uuid::from_u128(0xa1a2a3a4b1b2c1c2d1d2d3d4d5d6d7d8u128), + Uuid::from_u128(0xb1a2a3a4b1b2c1c2d1d2d3d4d5d6d7d8u128), + Uuid::from_u128(0), + Uuid::from_u128(0x41), + Uuid::from_u128(0x42), + ]; + let buf = uuids + .iter() + .flat_map(|uuid| uuid.as_bytes()) + .copied() + .collect::>(); + let buf = Buffer::from_vec(buf); + let rb = RecordBatch::try_new( + Arc::new(schema), + vec![Arc::new(FixedSizeBinaryArray::new( + UUID_LENGTH as i32, + buf, + Some(NullBuffer::from(vec![true, true, false, true, true])), + ))], + )?; + let param = arrow_recordbatch_to_query_params(rb); + let mut stmt = db.prepare("SELECT a, typeof(a) FROM arrow(?, ?)")?; + let mut arr = stmt.query_arrow(param)?; + let rb = arr.next().expect("no record batch"); + let rb = [rb]; + let printed = pretty::pretty_format_batches(&rb).unwrap(); + assert_eq!( + "\ ++--------------------------------------+-----------+ +| a | typeof(a) | ++--------------------------------------+-----------+ +| a1a2a3a4-b1b2-c1c2-d1d2-d3d4d5d6d7d8 | UUID | +| b1a2a3a4-b1b2-c1c2-d1d2-d3d4d5d6d7d8 | UUID | +| | UUID | +| 00000000-0000-0000-0000-000000000041 | UUID | +| 00000000-0000-0000-0000-000000000042 | UUID | ++--------------------------------------+-----------+", + printed.to_string(), + "{printed}" + ); + Ok(()) + } }