diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 1737dcf..e956105 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,3 +1,8 @@ +Unreleased +========== + +* Added support for Arraylike python objects as json arrays. + 1.0.1 ===== diff --git a/jmespath/functions.py b/jmespath/functions.py index 627b569..6b0e7c1 100644 --- a/jmespath/functions.py +++ b/jmespath/functions.py @@ -1,5 +1,6 @@ import math import json +from typing import Sequence from jmespath import exceptions from jmespath.compat import string_type as STRING_TYPE @@ -35,6 +36,14 @@ } +def is_array(arg): + return hasattr(arg, "__array__") and arg.shape != () + + +def is_arraylike(arg): + return (isinstance(arg, Sequence) and not isinstance(arg, (str, bytes))) or is_array(arg) + + def signature(*arguments): def _record_signature(func): func.signature = arguments @@ -180,7 +189,7 @@ def _func_not_null(self, *arguments): @signature({'types': []}) def _func_to_array(self, arg): - if isinstance(arg, list): + if is_arraylike(arg): return arg else: return [arg] @@ -297,7 +306,7 @@ def _func_type(self, arg): return "string" elif isinstance(arg, bool): return "boolean" - elif isinstance(arg, list): + elif is_arraylike(arg): return "array" elif isinstance(arg, dict): return "object" diff --git a/jmespath/visitor.py b/jmespath/visitor.py index 15fb177..e4574e7 100644 --- a/jmespath/visitor.py +++ b/jmespath/visitor.py @@ -1,15 +1,26 @@ import operator from jmespath import functions +from jmespath.functions import is_array, is_arraylike from jmespath.compat import string_type from numbers import Number +def _arraylike_all(arg): + return arg.__array__().all() if is_array(arg) else arg + + +def _arraylike_to_list(arg): + return [_arraylike_to_list(i) for i in arg] if is_arraylike(arg) else arg + + def _equals(x, y): if _is_special_number_case(x, y): return False + elif is_array(x) or is_array(y): + return _arraylike_all(x == y) else: - return x == y + return _arraylike_to_list(x) == _arraylike_to_list(y) def _is_special_number_case(x, y): @@ -172,7 +183,7 @@ def visit_function_expression(self, node, value): def visit_filter_projection(self, node, value): base = self.visit(node['children'][0], value) - if not isinstance(base, list): + if not is_arraylike(base): return None comparator_node = node['children'][2] collected = [] @@ -185,12 +196,12 @@ def visit_filter_projection(self, node, value): def visit_flatten(self, node, value): base = self.visit(node['children'][0], value) - if not isinstance(base, list): - # Can't flatten the object if it's not a list. + if not is_arraylike(base): + # Can't flatten the object if it's not arraylike. return None merged_list = [] for element in base: - if isinstance(element, list): + if is_arraylike(element): merged_list.extend(element) else: merged_list.append(element) @@ -202,7 +213,7 @@ def visit_identity(self, node, value): def visit_index(self, node, value): # Even though we can index strings, we don't # want to support that. - if not isinstance(value, list): + if not is_arraylike(value): return None try: return value[node['value']] @@ -216,7 +227,7 @@ def visit_index_expression(self, node, value): return result def visit_slice(self, node, value): - if not isinstance(value, list): + if not is_arraylike(value): return None s = slice(*node['children']) return value[s] @@ -271,7 +282,7 @@ def visit_pipe(self, node, value): def visit_projection(self, node, value): base = self.visit(node['children'][0], value) - if not isinstance(base, list): + if not is_arraylike(base): return None collected = [] for element in base: diff --git a/requirements.txt b/requirements.txt index bd50c2a..694e77a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,11 @@ wheel==0.38.1 +parameterized==0.9.0 pytest==6.2.5 pytest-cov==3.0.0 hypothesis==3.1.0 ; python_version < '3.8' hypothesis==5.5.4 ; python_version == '3.8' hypothesis==5.35.4 ; python_version == '3.9' +astropy>=3.1 +dask>=2.0.0 +numpy>=1.15.0 +xarray>=0.18.0 \ No newline at end of file diff --git a/tests/test_arraylike.py b/tests/test_arraylike.py new file mode 100644 index 0000000..3f5b213 --- /dev/null +++ b/tests/test_arraylike.py @@ -0,0 +1,132 @@ +import astropy.units as u +import dask.array as da +import numpy as np +import xarray as xr +from parameterized import parameterized, parameterized_class + +import jmespath +import jmespath.functions +from tests import unittest + + +@parameterized_class(("name", "data"), [ + ("list", { + "value": { + "data": [[1,2,3],[4,5,6],[7,8,9]] + }, + "same": { + "data": [[1,2,3],[4,5,6],[7,8,9]] + }, + "other": { + "data": [[2,2,3],[4,5,6],[7,8,9]] + } + }), + ("tuple", { + "value": { + "data": ((1,2,3),(4,5,6),(7,8,9)) + }, + "same": { + "data": ([1,2,3],[4,5,6],[7,8,9]) + }, + "other": { + "data": [[2,2,3],[4,5,6],[7,8,9]] + } + }), + ("numpy", { + "value": { + "data": np.array([[1,2,3],[4,5,6],[7,8,9]]) + }, + "same": { + "data": (np.array([1,2,3]),np.array([4,5,6]),np.array([7,8,9])) + }, + "other": { + "data": np.array([[2,2,3],[4,5,6],[7,8,9]]) + } + }), + ("dask", { + "value": { + "data": da.from_array([[1,2,3],[4,5,6],[7,8,9]]) + }, + "same": { + "data": (da.from_array([1,2,3]),da.from_array([4,5,6]),da.from_array([7,8,9])) + }, + "other": { + "data": da.from_array([[2,2,3],[4,5,6],[7,8,9]]) + } + }), + ("xarray", { + "value": { + "data": xr.DataArray([[1,2,3],[4,5,6],[7,8,9]]) + }, + "same": { + "data": (xr.DataArray([1,2,3]),xr.DataArray([4,5,6]),xr.DataArray([7,8,9])) + }, + "other": { + "data": xr.DataArray([[2,2,3],[4,5,6],[7,8,9]]) + } + }), + ("astropy", { + "value": { + "data": u.Quantity([[1,2,3],[4,5,6],[7,8,9]]) + }, + "same": { + "data": (u.Quantity([1,2,3]),u.Quantity([4,5,6]),u.Quantity([7,8,9])) + }, + "other": { + "data": u.Quantity([[2,2,3],[4,5,6],[7,8,9]]) + } + }), +]) +class TestArrayNumeric(unittest.TestCase): + @parameterized.expand([ + ["self", "@", lambda data: data], + ["get", "value.data", lambda data: data["value"]["data"]], + ["slice_horizontal", "value.data[1][:]", lambda data: np.array(data["value"]["data"])[1,:]], + ["slice_horizontal2", "value.data[:3:2][:]", lambda data: np.array(data["value"]["data"])[:3:2,:]], + ["slice_vertical", "value.data[:][1]", lambda data: np.array(data["value"]["data"])[:,1]], + ["slice_vertical2", "value.data[:][:3:2]", lambda data: np.array(data["value"]["data"])[:,:3:2]], + ["flatten", "value.data[]", lambda data: np.array(data["value"]["data"]).flatten()], + ["compare_self", "value.data == value.data", lambda _: True], + ["compare_same", "value.data == same.data", lambda _: True], + ["compare_other", "value.data == other.data", lambda _: False], + ["compare_literal_scalar", "value.data[0][0] == `1`", lambda _: True], + ["compare_literal_slice", "value.data[1][:] == `[4, 5, 6]`", lambda _: True], + ["compare_literal", "value.data == `[[1,2,3],[4,5,6],[7,8,9]]`", lambda _: True], + ["compare_flattened", "value.data[] == `[1,2,3,4,5,6,7,8,9]`", lambda _: True], + ]) + def test_search(self, test_name, query, expected): + result = jmespath.search(query, self.data) + np.testing.assert_array_equal(result, expected(self.data), test_name) + + +@parameterized_class(("name", "data"), [ + ("numpy", { + "value": { + "data": np.array([["test", "messages"],["in", "numpy"]]) + }, + "same": { + "data": np.array([["test", "messages"],["in", "numpy"]]) + }, + "other": { + "data": np.array([["test", "messages"],["other", "numpy"]]) + } + }) +]) +class TestArrayStr(unittest.TestCase): + @parameterized.expand([ + ["self", "@", lambda data: data], + ["get", "value.data", lambda data: data["value"]["data"]], + ["slice_horizontal", "value.data[1][:]", lambda data: data["value"]["data"][1,:]], + ["slice_vertical", "value.data[:][1]", lambda data: data["value"]["data"][:,1]], + ["flatten", "value.data[]", lambda data: data["value"]["data"].flatten()], + ["compare_self", "value.data == value.data", lambda _: True], + ["compare_same", "value.data == same.data", lambda _: True], + ["compare_other", "value.data == other.data", lambda _: False], + ["compare_literal_scalar", "value.data[0][0] == 'test'", lambda _: True], + ["compare_literal_slice", "value.data[1][:] == ['in', 'numpy']", lambda _: True], + ["compare_literal", "value.data == [['test', 'messages'],['in', 'numpy']]", lambda _: True], + ["compare_flattened", "value.data[] == ['test', 'messages', 'in', 'numpy']", lambda _: True], + ]) + def test_search(self, name, query, expected): + result = jmespath.search(query, self.data) + np.testing.assert_array_equal(result, expected(self.data), name)