Skip to content

Commit

Permalink
feat: add ArrowJSONtype to extend pyarrow for JSONDtype
Browse files Browse the repository at this point in the history
  • Loading branch information
chelsea-lin committed Jan 9, 2025
1 parent e132ed6 commit 6239bc3
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 4 deletions.
7 changes: 4 additions & 3 deletions db_dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@

from db_dtypes import core
from db_dtypes.version import __version__
from . import _versions_helpers

from . import _versions_helpers

date_dtype_name = "dbdate"
time_dtype_name = "dbtime"
Expand All @@ -50,7 +50,7 @@
# To use JSONArray and JSONDtype, you'll need Pandas 1.5.0 or later. With the removal
# of Python 3.7 compatibility, the minimum Pandas version will be updated to 1.5.0.
if packaging.version.Version(pandas.__version__) >= packaging.version.Version("1.5.0"):
from db_dtypes.json import JSONArray, JSONDtype
from db_dtypes.json import ArrowJSONType, JSONArray, JSONDtype
else:
JSONArray = None
JSONDtype = None
Expand Down Expand Up @@ -359,7 +359,7 @@ def __sub__(self, other):
)


if not JSONArray or not JSONDtype:
if not JSONArray or not JSONDtype or not ArrowJSONType:
__all__ = [
"__version__",
"DateArray",
Expand All @@ -370,6 +370,7 @@ def __sub__(self, other):
else:
__all__ = [
"__version__",
"ArrowJSONType",
"DateArray",
"DateDtype",
"JSONDtype",
Expand Down
51 changes: 50 additions & 1 deletion db_dtypes/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ def construct_array_type(cls):
"""Return the array type associated with this dtype."""
return JSONArray

def __from_arrow__(self, array: pa.Array | pa.ChunkedArray) -> JSONArray:
"""Convert the pyarrow array to the extension array."""
return JSONArray(array)


class JSONArray(arrays.ArrowExtensionArray):
"""Extension array that handles BigQuery JSON data, leveraging a string-based
Expand Down Expand Up @@ -92,6 +96,10 @@ def __init__(self, values) -> None:
else:
raise NotImplementedError(f"Unsupported pandas version: {pd.__version__}")

def __arrow_array__(self):
"""Convert to an arrow array. This is required for pyarrow extension."""
return self.pa_data

@classmethod
def _box_pa(
cls, value, pa_type: pa.DataType | None = None
Expand Down Expand Up @@ -151,7 +159,12 @@ def _serialize_json(value):
def _deserialize_json(value):
"""A static method that converts a JSON string back into its original value."""
if not pd.isna(value):
return json.loads(value)
# Attempt to interpret the value as a JSON object.
# If it's not valid JSON, treat it as a regular string.
try:
return json.loads(value)
except json.JSONDecodeError:
return value
else:
return value

Expand Down Expand Up @@ -244,3 +257,39 @@ def __array__(self, dtype=None, copy: bool | None = None) -> np.ndarray:
result[mask] = self._dtype.na_value
result[~mask] = data[~mask].pa_data.to_numpy()
return result


class ArrowJSONType(pa.ExtensionType):
"""Arrow extension type for the `dbjson` Pandas extension type."""

def __init__(self) -> None:
super().__init__(pa.string(), "dbjson")

def __arrow_ext_serialize__(self) -> bytes:
# No parameters are necessary
return b""

def __eq__(self, other):
if isinstance(other, pyarrow.BaseExtensionType):
return type(self) == type(other)
else:
return NotImplemented

def __ne__(self, other) -> bool:
return not self == other

@classmethod
def __arrow_ext_deserialize__(cls, storage_type, serialized) -> ArrowJSONType:
# return an instance of this subclass
return ArrowJSONType()

def __hash__(self) -> int:
return hash(str(self))

def to_pandas_dtype(self):
return JSONDtype()


# Register the type to be included in RecordBatches, sent over IPC and received in
# another Python process.
pa.register_extension_type(ArrowJSONType())
35 changes: 35 additions & 0 deletions tests/unit/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import numpy as np
import pandas as pd
import pyarrow as pa
import pytest

import db_dtypes
Expand Down Expand Up @@ -114,3 +115,37 @@ def test_as_numpy_array():
]
)
pd._testing.assert_equal(result, expected)


def test_arrow_json_storage_type():
arrow_json_type = db_dtypes.ArrowJSONType()
assert arrow_json_type.extension_name == "dbjson"
assert pa.types.is_string(arrow_json_type.storage_type)


def test_arrow_json_constructors():
storage_array = pa.array(
["0", "str", '{"b": 2}', '{"a": [1, 2, 3]}'], type=pa.string()
)
arr_1 = db_dtypes.ArrowJSONType().wrap_array(storage_array)
assert isinstance(arr_1, pa.ExtensionArray)

arr_2 = pa.ExtensionArray.from_storage(db_dtypes.ArrowJSONType(), storage_array)
assert isinstance(arr_2, pa.ExtensionArray)

assert arr_1 == arr_2


def test_arrow_json_to_pandas():
storage_array = pa.array(
[None, "0", "str", '{"b": 2}', '{"a": [1, 2, 3]}'], type=pa.string()
)
arr = db_dtypes.ArrowJSONType().wrap_array(storage_array)

s = arr.to_pandas()
assert isinstance(s.dtypes, db_dtypes.JSONDtype)
assert pd.isna(s[0])
assert s[1] == 0
assert s[2] == "str"
assert s[3]["b"] == 2
assert s[4]["a"] == [1, 2, 3]

0 comments on commit 6239bc3

Please sign in to comment.