-
Notifications
You must be signed in to change notification settings - Fork 69
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for Python array API standard #197
base: main
Are you sure you want to change the base?
Changes from all commits
e7dc2c3
13a6054
0fb0312
381f15a
50c3693
576c129
db785c9
5de3b6e
35cb9d7
75f9d90
200402e
714b520
13ec280
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
""" | ||
Required functions for optimized contractions of arrays using array API-compliant backends. | ||
""" | ||
import sys | ||
from typing import Callable | ||
from types import ModuleType | ||
|
||
import numpy as np | ||
|
||
from ..sharing import to_backend_cache_wrap | ||
|
||
|
||
def discover_array_apis(): | ||
"""Discover array API backends.""" | ||
if sys.version_info >= (3, 8): | ||
from importlib.metadata import entry_points | ||
|
||
if sys.version_info >= (3, 10): | ||
eps = entry_points(group="array_api") | ||
else: | ||
# Deprecated - will raise warning in Python versions >= 3.10 | ||
eps = entry_points().get("array_api", []) | ||
return [ep.load() for ep in eps] | ||
else: | ||
# importlib.metadata was introduced in Python 3.8, so it isn't available here. Unable to discover any array APIs. | ||
return [] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Numpy does not officially support below 3.8: https://numpy.org/neps/nep-0029-deprecation_policy.html It would be worth considering dropping Python 3.7 as well. @jcmgray what do you think? |
||
|
||
|
||
def make_to_array_function(array_api: ModuleType) -> Callable: | ||
"""Make a ``to_[array_api]`` function for the given array API.""" | ||
|
||
@to_backend_cache_wrap | ||
def to_array(array): # pragma: no cover | ||
if isinstance(array, np.ndarray): | ||
return array_api.asarray(array) | ||
return array | ||
|
||
return to_array | ||
|
||
|
||
def make_build_expression_function(array_api: ModuleType) -> Callable: | ||
"""Make a ``build_expression`` function for the given array API.""" | ||
_to_array_api = to_array_api[array_api.__name__] | ||
|
||
def build_expression(_, expr): # pragma: no cover | ||
"""Build an array API function based on ``arrays`` and ``expr``.""" | ||
|
||
def array_api_contract(*arrays): | ||
return expr._contract([_to_array_api(x) for x in arrays], backend=array_api.__name__) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jcmgray How about this? |
||
|
||
return array_api_contract | ||
|
||
return build_expression | ||
|
||
|
||
def make_evaluate_constants_function(array_api: ModuleType) -> Callable: | ||
_to_array_api = to_array_api[array_api.__name__] | ||
|
||
def evaluate_constants(const_arrays, expr): # pragma: no cover | ||
"""Convert constant arguments to cupy arrays, and perform any possible constant contractions.""" | ||
return expr( | ||
*[_to_array_api(x) for x in const_arrays], | ||
backend=array_api.__name__, | ||
evaluate_constants=True, | ||
) | ||
|
||
return evaluate_constants | ||
|
||
|
||
_array_apis = discover_array_apis() | ||
to_array_api = {api.__name__: make_to_array_function(api) for api in _array_apis} | ||
build_expression = {api.__name__: make_build_expression_function(api) for api in _array_apis} | ||
evaluate_constants = {api.__name__: make_evaluate_constants_function(api) for api in _array_apis} | ||
|
||
__all__ = ["discover_array_apis", "to_array_api", "build_expression", "evaluate_constants"] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,7 @@ | |
from . import tensorflow as _tensorflow | ||
from . import theano as _theano | ||
from . import torch as _torch | ||
from . import array_api as _array_api | ||
|
||
__all__ = [ | ||
"get_func", | ||
|
@@ -122,6 +123,7 @@ def has_tensordot(backend: str) -> bool: | |
"cupy": _cupy.build_expression, | ||
"torch": _torch.build_expression, | ||
"jax": _jax.build_expression, | ||
**_array_api.build_expression, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there any way to check if |
||
} | ||
|
||
EVAL_CONSTS_BACKENDS = { | ||
|
@@ -130,6 +132,7 @@ def has_tensordot(backend: str) -> bool: | |
"cupy": _cupy.evaluate_constants, | ||
"torch": _torch.evaluate_constants, | ||
"jax": _jax.evaluate_constants, | ||
**_array_api.evaluate_constants, | ||
} | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Possible to type this function?