diff --git a/swiftsimio/_array_functions.py b/swiftsimio/_array_functions.py index dac6a642..1a5219b3 100644 --- a/swiftsimio/_array_functions.py +++ b/swiftsimio/_array_functions.py @@ -1188,24 +1188,26 @@ def linalg_det(a): # return _return_helper(res, helper_result, ret_cf, out=out) -# @implements(np.diff) -# def diff(...): -# from unyt._array_functions import diff as unyt_diff +@implements(np.diff) +def diff(a, n=1, axis=-1, prepend=np._NoValue, append=np._NoValue): + from unyt._array_functions import diff as unyt_diff -# helper_result = _prepare_array_func_args(...) -# ret_cf = ...() -# res = unyt_diff(*helper_result["args"], **helper_result["kwargs"]) -# return _return_helper(res, helper_result, ret_cf, out=out) + helper_result = _prepare_array_func_args( + a, n=n, axis=axis, prepend=prepend, append=append + ) + ret_cf = _preserve_cosmo_factor(helper_result["ca_cfs"][0]) + res = unyt_diff(*helper_result["args"], **helper_result["kwargs"]) + return _return_helper(res, helper_result, ret_cf) -# @implements(np.ediff1d) -# def ediff1d(...): -# from unyt._array_functions import ediff1d as unyt_ediff1d +@implements(np.ediff1d) +def ediff1d(ary, to_end=None, to_begin=None): + from unyt._array_functions import ediff1d as unyt_ediff1d -# helper_result = _prepare_array_func_args(...) -# ret_cf = ...() -# res = unyt_ediff1d(*helper_result["args"], **helper_result["kwargs"]) -# return _return_helper(res, helper_result, ret_cf, out=out) + helper_result = _prepare_array_func_args(ary, to_end=to_end, to_begin=to_begin) + ret_cf = _preserve_cosmo_factor(helper_result["ca_cfs"][0]) + res = unyt_ediff1d(*helper_result["args"], **helper_result["kwargs"]) + return _return_helper(res, helper_result, ret_cf) # @implements(np.ptp) diff --git a/swiftsimio/objects.py b/swiftsimio/objects.py index f4ce1861..d6c60d29 100644 --- a/swiftsimio/objects.py +++ b/swiftsimio/objects.py @@ -443,10 +443,7 @@ def _prepare_array_func_args(*args, **kwargs): else: # mixed compressions, strip it off ret_comp = None - args = [unyt_array(arg) if isinstance(arg, cosmo_array) else arg for arg in args] - kwargs = { - k: unyt_array(v) if isinstance(v, cosmo_array) else v for k, v in kwargs.items() - } + # WE SHOULD COMPLAIN HERE IF WE HAVE DIFFERENT SCALE FACTORS IN COSMO_FACTOR'S?? return dict( args=args, kwargs=kwargs, diff --git a/tests/test_cosmo_array.py b/tests/test_cosmo_array.py index a16a3dbc..112876be 100644 --- a/tests/test_cosmo_array.py +++ b/tests/test_cosmo_array.py @@ -155,8 +155,8 @@ def test_handled_funcs(self): # "nanpercentile": (ca(np.arange(3)), 30), # "nanquantile": (ca(np.arange(3)), 0.3), "linalg.det": (ca(np.eye(3)),), - # "diff": (ca(np.arange(3)),), - # "ediff1d": (ca(np.arange(3)),), + "diff": (ca(np.arange(3)),), + "ediff1d": (ca(np.arange(3)),), # "ptp": (ca(np.arange(3)),), "cumprod": (ca(np.arange(3)),), # "pad": (ca(np.arange(3)), 3),