Skip to content

Commit

Permalink
dependencies: fix numpy version for subprojects that depend on it (#3062
Browse files Browse the repository at this point in the history
)

Numpy changed something in their typing annotations and now our main is
broken. We should probably move to something that locks our dependencies
for our CI, in the meantime, this PR specifies a numpy version and fixes
the type errors.
  • Loading branch information
superlopuh authored Aug 19, 2024
1 parent e603f11 commit f0684ae
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 12 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ dev = [
"pyright==1.1.345",
]
gui = ["textual==0.76.0", "pyclip==0.7"]
jax = ["jax==0.4.31"]
onnx = ["onnx==1.16.2"]
jax = ["jax==0.4.31", "numpy==2.1.0"]
onnx = ["onnx==1.16.2", "numpy==2.1.0"]
riscv = ["riscemu==2.2.7"]
wgpu = ["wgpu==0.16.0"]

Expand Down
15 changes: 5 additions & 10 deletions xdsl/interpreters/onnx.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, cast

import numpy as np
import numpy.typing as npt

from xdsl.dialects import onnx
from xdsl.dialects.builtin import TensorType
Expand Down Expand Up @@ -34,7 +35,7 @@ def to_dtype(


def from_dtype(
dtype: type[np.float32] | type[np.float64] | type[np.int32] | type[np.int64],
dtype: npt.DTypeLike,
) -> ptr.XType[float] | ptr.XType[int]:
if dtype == np.float32:
return ptr.float32
Expand All @@ -50,26 +51,20 @@ def from_dtype(

def to_ndarray(
shaped_array: ShapedArray[int] | ShapedArray[float],
) -> np.ndarray[Any, np.dtype[np.float64 | np.float32 | np.int64 | np.int32]]:
) -> npt.NDArray[np.float32 | np.float64 | np.int32 | np.int64]:
dtype = to_dtype(shaped_array.data_ptr.xtype)
flat = np.frombuffer(shaped_array.data_ptr.raw.memory, dtype)
shaped = flat.reshape(shaped_array.shape)
return shaped


def from_ndarray(
ndarray: np.ndarray[
Any,
np.dtype[np.float32]
| np.dtype[np.float64]
| np.dtype[np.int32]
| np.dtype[np.int64],
],
ndarray: npt.NDArray[np.number[Any]],
) -> ShapedArray[float] | ShapedArray[int]:
return ShapedArray(
ptr.TypedPtr(
ptr.RawPtr(bytearray(ndarray.data)),
xtype=from_dtype(ndarray.dtype.type),
xtype=from_dtype(np.dtype(ndarray.dtype)),
),
list(ndarray.shape),
)
Expand Down

0 comments on commit f0684ae

Please sign in to comment.