Skip to content

Commit

Permalink
Implement a couple more numpy functions.
Browse files Browse the repository at this point in the history
  • Loading branch information
kyleaoman committed Dec 17, 2024
1 parent 9f5f6e5 commit 16e40cc
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 20 deletions.
30 changes: 16 additions & 14 deletions swiftsimio/_array_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1188,24 +1188,26 @@ def linalg_det(a):
# return _return_helper(res, helper_result, ret_cf, out=out)


# @implements(np.diff)
# def diff(...):
# from unyt._array_functions import diff as unyt_diff
@implements(np.diff)
def diff(a, n=1, axis=-1, prepend=np._NoValue, append=np._NoValue):
from unyt._array_functions import diff as unyt_diff

# helper_result = _prepare_array_func_args(...)
# ret_cf = ...()
# res = unyt_diff(*helper_result["args"], **helper_result["kwargs"])
# return _return_helper(res, helper_result, ret_cf, out=out)
helper_result = _prepare_array_func_args(
a, n=n, axis=axis, prepend=prepend, append=append
)
ret_cf = _preserve_cosmo_factor(helper_result["ca_cfs"][0])
res = unyt_diff(*helper_result["args"], **helper_result["kwargs"])
return _return_helper(res, helper_result, ret_cf)


# @implements(np.ediff1d)
# def ediff1d(...):
# from unyt._array_functions import ediff1d as unyt_ediff1d
@implements(np.ediff1d)
def ediff1d(ary, to_end=None, to_begin=None):
from unyt._array_functions import ediff1d as unyt_ediff1d

# helper_result = _prepare_array_func_args(...)
# ret_cf = ...()
# res = unyt_ediff1d(*helper_result["args"], **helper_result["kwargs"])
# return _return_helper(res, helper_result, ret_cf, out=out)
helper_result = _prepare_array_func_args(ary, to_end=to_end, to_begin=to_begin)
ret_cf = _preserve_cosmo_factor(helper_result["ca_cfs"][0])
res = unyt_ediff1d(*helper_result["args"], **helper_result["kwargs"])
return _return_helper(res, helper_result, ret_cf)


# @implements(np.ptp)
Expand Down
5 changes: 1 addition & 4 deletions swiftsimio/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,10 +443,7 @@ def _prepare_array_func_args(*args, **kwargs):
else:
# mixed compressions, strip it off
ret_comp = None
args = [unyt_array(arg) if isinstance(arg, cosmo_array) else arg for arg in args]
kwargs = {
k: unyt_array(v) if isinstance(v, cosmo_array) else v for k, v in kwargs.items()
}
# WE SHOULD COMPLAIN HERE IF WE HAVE DIFFERENT SCALE FACTORS IN COSMO_FACTOR'S??
return dict(
args=args,
kwargs=kwargs,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_cosmo_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,8 @@ def test_handled_funcs(self):
# "nanpercentile": (ca(np.arange(3)), 30),
# "nanquantile": (ca(np.arange(3)), 0.3),
"linalg.det": (ca(np.eye(3)),),
# "diff": (ca(np.arange(3)),),
# "ediff1d": (ca(np.arange(3)),),
"diff": (ca(np.arange(3)),),
"ediff1d": (ca(np.arange(3)),),
# "ptp": (ca(np.arange(3)),),
"cumprod": (ca(np.arange(3)),),
# "pad": (ca(np.arange(3)), 3),
Expand Down

0 comments on commit 16e40cc

Please sign in to comment.