Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix all when shape is empty and has one dimension
  • Loading branch information
xadupre committed Jun 10, 2023
commit caa99a7328745029463e7720515333ca64e24e9c
36 changes: 34 additions & 2 deletions _unittests/ut_npx/test_npx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1489,6 +1489,15 @@ def test_identity(self):
got = ref.run(None, {})
self.assertEqualArray(z, got[0])

def test_identity_uint8(self):
f = identity_inline(2, dtype=np.uint8)
onx = f.to_onnx(constraints={(0, False): Float64[None]})
self.assertIn('name: "dtype"', str(onx))
z = np.identity(2).astype(np.uint8)
ref = ReferenceEvaluator(onx)
got = ref.run(None, {})
self.assertEqualArray(z, got[0])

def test_isnan(self):
self.common_test_inline(isnan_inline, np.isnan)

Expand Down Expand Up @@ -2493,9 +2502,32 @@ def test_numpy_all(self):
self.assertEqualArray(y, got[0])

def test_numpy_all_empty(self):
data = np.zeros((0, 1), dtype=np.bool_)
data = np.zeros((0,), dtype=np.bool_)
y = np.all(data)

f = all_inline(Input("A"))
self.assertIsInstance(f, Var)
onx = f.to_onnx(constraints={"A": Bool[None]})
ref = ReferenceEvaluator(onx)
got = ref.run(None, {"A": data})
self.assertEqualArray(y, got[0])

@unittest.skipIf(True, reason="ReduceMin does not support shape[axis] == 0")
def test_numpy_all_empty_axis_0(self):
data = np.zeros((0, 1), dtype=np.bool_)
y = np.all(data, axis=0)

f = all_inline(Input("A"), axis=0)
self.assertIsInstance(f, Var)
onx = f.to_onnx(constraints={"A": Bool[None]})
ref = ReferenceEvaluator(onx)
got = ref.run(None, {"A": data})
self.assertEqualArray(y, got[0])

def test_numpy_all_empty_axis_1(self):
data = np.zeros((0, 1), dtype=np.bool_)
y = np.all(data, axis=1)

f = all_inline(Input("A"), axis=1)
self.assertIsInstance(f, Var)
onx = f.to_onnx(constraints={"A": Bool[None]})
Expand All @@ -2505,5 +2537,5 @@ def test_numpy_all_empty(self):


if __name__ == "__main__":
TestNpx().test_numpy_all_empty()
# TestNpx().test_numpy_all_empty_axis_0()
unittest.main(verbosity=2)
7 changes: 6 additions & 1 deletion onnx_array_api/npx/npx_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,12 @@ def all(
xi = var(x, op="Cast", to=TensorProto.INT64)

if axis is None:
red = xi.min(keepdims=keepdims)
new_shape = cst(np.array([-1], dtype=np.int64))
xifl = var(xi, new_shape, op="Reshape")
# in case xifl is empty, we need to add one element
one = cst(np.array([1], dtype=np.int64))
xifl1 = var(xifl, one, op="Concat", axis=0)
red = xifl1.min(keepdims=keepdims)
else:
if isinstance(axis, int):
axis = [axis]
Expand Down
3 changes: 2 additions & 1 deletion onnx_array_api/npx/npx_numpy_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from onnx.helper import np_dtype_to_tensor_dtype
from onnx.reference import ReferenceEvaluator

from .npx_numpy_tensors_ops import ConstantOfShape
from .npx_tensors import EagerTensor, JitTensor
from .npx_types import DType, TensorType

Expand All @@ -25,7 +26,7 @@ class Evaluator:
"""

def __init__(self, tensor_class: type, input_names: List[str], onx: ModelProto):
self.ref = ReferenceEvaluator(onx)
self.ref = ReferenceEvaluator(onx, new_ops=[ConstantOfShape])
self.input_names = input_names
self.tensor_class = tensor_class

Expand Down
45 changes: 45 additions & 0 deletions onnx_array_api/npx/npx_numpy_tensors_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import numpy as np

from onnx.reference.op_run import OpRun


class ConstantOfShape(OpRun):
@staticmethod
def _process(value):
cst = value[0] if isinstance(value, np.ndarray) else value
if isinstance(cst, int):
cst = np.int64(cst)
elif isinstance(cst, float):
cst = np.float64(cst)
elif cst is None:
cst = np.float32(0)
if not isinstance(
cst,
(
np.float16,
np.float32,
np.float64,
np.int64,
np.int32,
np.int16,
np.int8,
np.uint64,
np.uint32,
np.uint16,
np.uint8,
np.bool_,
),
):
raise TypeError(f"value must be a real not {type(cst)}")

def _run(self, data, value=None):
cst = self._process(value)
try:
res = np.full(tuple(data), cst)
except TypeError as e:
raise RuntimeError(
f"Unable to create a constant of shape "
f"{data!r} with value {cst!r} "
f"(raw value={value!r})."
) from e
return (res,)