Skip to content

Commit

Permalink
Implement a few more numpy functions.
Browse files Browse the repository at this point in the history
  • Loading branch information
kyleaoman committed Dec 18, 2024
1 parent 240b991 commit 8880a13
Show file tree
Hide file tree
Showing 2 changed files with 204 additions and 75 deletions.
213 changes: 171 additions & 42 deletions swiftsimio/_array_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,14 +947,38 @@ def array_equiv(a1, a2):
return _return_helper(res, helper_result, ret_cf)


# @implements(np.linspace)
# def linspace(...):
# from unyt._array_functions import linspace as unyt_linspace
@implements(np.linspace)
def linspace(
start,
stop,
num=50,
endpoint=True,
retstep=False,
dtype=None,
axis=0,
*,
device=None,
):
from unyt._array_functions import linspace as unyt_linspace

# helper_result = _prepare_array_func_args(...)
# ret_cf = ...()
# res = unyt_linspace(*helper_result["args"], **helper_result["kwargs"])
# return _return_helper(res, helper_result, ret_cf, out=out)
helper_result = _prepare_array_func_args(
start,
stop,
num=num,
endpoint=endpoint,
retstep=retstep,
dtype=dtype,
axis=axis,
device=device,
)
ret_cf = _preserve_cosmo_factor(
helper_result["ca_cfs"][0], helper_result["ca_cfs"][1]
)
ress = unyt_linspace(*helper_result["args"], **helper_result["kwargs"])
if retstep:
return tuple(_return_helper(res, helper_result, ret_cf) for res in ress)
else:
return _return_helper(ress, helper_result, ret_cf)


# @implements(np.logspace)
Expand Down Expand Up @@ -1017,14 +1041,35 @@ def prod(
return _return_helper(res, helper_result, ret_cf, out=out)


# @implements(np.var)
# def var(...):
# from unyt._array_functions import var as unyt_var
@implements(np.var)
def var(
a,
axis=None,
dtype=None,
out=None,
ddof=0,
keepdims=np._NoValue,
*,
where=np._NoValue,
mean=np._NoValue,
correction=np._NoValue
):
from unyt._array_functions import var as unyt_var

# helper_result = _prepare_array_func_args(...)
# ret_cf = ...()
# res = unyt_var(*helper_result["args"], **helper_result["kwargs"])
# return _return_helper(res, helper_result, ret_cf, out=out)
helper_result = _prepare_array_func_args(
a,
axis=axis,
dtype=dtype,
out=out,
ddof=ddof,
keepdims=keepdims,
where=where,
mean=mean,
correction=correction,
)
ret_cf = _preserve_cosmo_factor(helper_result["ca_cfs"][0])
res = unyt_var(*helper_result["args"], **helper_result["kwargs"])
return _return_helper(res, helper_result, ret_cf, out=out)


@implements(np.trace)
Expand All @@ -1044,44 +1089,128 @@ def trace(a, offset=0, axis1=0, axis2=1, dtype=None, out=None):
return _return_helper(res, helper_result, ret_cf, out=out)


# @implements(np.percentile)
# def percentile(...):
# from unyt._array_functions import percentile as unyt_percentile
@implements(np.percentile)
def percentile(
a,
q,
axis=None,
out=None,
overwrite_input=False,
method="linear",
keepdims=False,
*,
weights=None,
interpolation=None
):
from unyt._array_functions import percentile as unyt_percentile

# helper_result = _prepare_array_func_args(...)
# ret_cf = ...()
# res = unyt_percentile(*helper_result["args"], **helper_result["kwargs"])
# return _return_helper(res, helper_result, ret_cf, out=out)
helper_result = _prepare_array_func_args(
a,
q,
axis=axis,
out=out,
overwrite_input=overwrite_input,
method=method,
keepdims=keepdims,
weights=weights,
interpolation=interpolation,
)
ret_cf = _preserve_cosmo_factor(helper_result["ca_cfs"][0])
res = unyt_percentile(*helper_result["args"], **helper_result["kwargs"])
return _return_helper(res, helper_result, ret_cf, out=out)


# @implements(np.quantile)
# def quantile(...):
# from unyt._array_functions import quantile as unyt_quantile
@implements(np.quantile)
def quantile(
a,
q,
axis=None,
out=None,
overwrite_input=False,
method='linear',
keepdims=False,
*,
weights=None,
interpolation=None
):
from unyt._array_functions import quantile as unyt_quantile

# helper_result = _prepare_array_func_args(...)
# ret_cf = ...()
# res = unyt_quantile(*helper_result["args"], **helper_result["kwargs"])
# return _return_helper(res, helper_result, ret_cf, out=out)
helper_result = _prepare_array_func_args(
a,
q,
axis=axis,
out=out,
overwrite_input=overwrite_input,
method=method,
keepdims=keepdims,
weights=weights,
interpolation=interpolation,
)
ret_cf = _preserve_cosmo_factor(helper_result["ca_cfs"][0])
res = unyt_quantile(*helper_result["args"], **helper_result["kwargs"])
return _return_helper(res, helper_result, ret_cf, out=out)


# @implements(np.nanpercentile)
# def nanpercentile(...):
# from unyt._array_functions import nanpercentile as unyt_nanpercentile
@implements(np.nanpercentile)
def percentile(
a,
q,
axis=None,
out=None,
overwrite_input=False,
method="linear",
keepdims=False,
*,
weights=None,
interpolation=None
):
from unyt._array_functions import nanpercentile as unyt_nanpercentile

# helper_result = _prepare_array_func_args(...)
# ret_cf = ...()
# res = unyt_nanpercentile(*helper_result["args"], **helper_result["kwargs"])
# return _return_helper(res, helper_result, ret_cf, out=out)
helper_result = _prepare_array_func_args(
a,
q,
axis=axis,
out=out,
overwrite_input=overwrite_input,
method=method,
keepdims=keepdims,
weights=weights,
interpolation=interpolation,
)
ret_cf = _preserve_cosmo_factor(helper_result["ca_cfs"][0])
res = unyt_nanpercentile(*helper_result["args"], **helper_result["kwargs"])
return _return_helper(res, helper_result, ret_cf, out=out)


# @implements(np.nanquantile)
# def nanquantile(...):
# from unyt._array_functions import nanquantile as unyt_nanquantile
@implements(np.nanquantile)
def nanquantile(
a,
q,
axis=None,
out=None,
overwrite_input=False,
method='linear',
keepdims=False,
*,
weights=None,
interpolation=None
):
from unyt._array_functions import nanquantile as unyt_nanquantile

# helper_result = _prepare_array_func_args(...)
# ret_cf = ...()
# res = unyt_nanquantile(*helper_result["args"], **helper_result["kwargs"])
# return _return_helper(res, helper_result, ret_cf, out=out)
helper_result = _prepare_array_func_args(
a,
q,
axis=axis,
out=out,
overwrite_input=overwrite_input,
method=method,
keepdims=keepdims,
weights=weights,
interpolation=interpolation,
)
ret_cf = _preserve_cosmo_factor(helper_result["ca_cfs"][0])
res = unyt_nanquantile(*helper_result["args"], **helper_result["kwargs"])
return _return_helper(res, helper_result, ret_cf, out=out)


@implements(np.linalg.det)
Expand Down
Loading

0 comments on commit 8880a13

Please sign in to comment.