Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
many changes to follow the Array API
  • Loading branch information
xadupre committed Jun 5, 2023
commit da3a7c855cee8674150259aeb6f84ff59e421a9c
7 changes: 7 additions & 0 deletions _doc/api/array_api.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
onnx_array_api.array_api
========================

.. toctree::

array_api_onnx_numpy
array_api_onnx_ort
6 changes: 3 additions & 3 deletions _doc/api/array_api_numpy.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
onnx_array_api.array_api_onnx_numpy.onnx_numpy_array_api
========================================================
onnx_array_api.array_api.onnx_numpy
=============================================

.. autoclass:: onnx_array_api.array_api_onnx_numpy.onnx_numpy_array_api
.. automodule:: onnx_array_api.array_api.onnx_numpy
:members:
6 changes: 3 additions & 3 deletions _doc/api/array_api_ort.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
onnx_array_api.array_api_onnx_ort.onnx_ort_array_api
====================================================
onnx_array_api.array_api.onnx_ort
=================================

.. autoclass:: onnx_array_api.array_api_onnx_ort.onnx_ort_array_api
.. automodule:: onnx_array_api.array_api.onnx_ort
:members:
3 changes: 1 addition & 2 deletions _doc/api/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@ API
.. toctree::
:maxdepth: 1

array_api_numpy
array_api_ort
array_api
npx_functions
npx_var
npx_jit
Expand Down
10 changes: 8 additions & 2 deletions _unittests/ut_npx/test_npx.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
from onnx_array_api.npx.npx_numpy_tensors import EagerNumpyTensor
from onnx_array_api.npx.npx_types import (
Bool,
DType,
Float32,
Float64,
Int64,
Expand Down Expand Up @@ -2318,7 +2319,7 @@ def compute_labels(X, centers):
self.assertEqual(f.n_versions, 1)
self.assertEqual(len(f.available_versions), 1)
self.assertEqual(f.available_versions, [((np.float64, 2), (np.float64, 2))])
key = ((np.dtype("float64"), 2), (np.dtype("float64"), 2))
key = ((DType(TensorProto.DOUBLE), 2), (DType(TensorProto.DOUBLE), 2))
onx = f.get_onnx(key)
self.assertIsInstance(onx, ModelProto)
self.assertRaise(lambda: f.get_onnx(2), ValueError)
Expand Down Expand Up @@ -2379,7 +2380,12 @@ def compute_labels(X, centers, use_sqrt=False):
self.assertEqualArray(got[1], dist)
self.assertEqual(f.n_versions, 1)
self.assertEqual(len(f.available_versions), 1)
key = ((np.dtype("float64"), 2), (np.dtype("float64"), 2), "use_sqrt", True)
key = (
(DType(TensorProto.DOUBLE), 2),
(DType(TensorProto.DOUBLE), 2),
"use_sqrt",
True,
)
self.assertEqual(f.available_versions, [key])
onx = f.get_onnx(key)
self.assertIsInstance(onx, ModelProto)
Expand Down
3 changes: 2 additions & 1 deletion _unittests/ut_npx/test_sklearn_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from onnx.defs import onnx_opset_version
from sklearn import config_context, __version__ as sklearn_version
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from onnx_array_api.ext_test_case import ExtTestCase
from onnx_array_api.ext_test_case import ExtTestCase, ignore_warnings
from onnx_array_api.npx.npx_numpy_tensors import EagerNumpyTensor


Expand All @@ -16,6 +16,7 @@ class TestSklearnArrayAPI(ExtTestCase):
Version(sklearn_version) <= Version("1.2.2"),
reason="reshape ArrayAPI not followed",
)
@ignore_warnings(DeprecationWarning)
def test_sklearn_array_api_linear_discriminant(self):
X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]])
y = np.array([1, 1, 1, 2, 2, 2])
Expand Down
18 changes: 18 additions & 0 deletions onnx_array_api/array_api/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from onnx import TensorProto


def _finalize_array_api(module):
module.float16 = TensorProto.FLOAT16
module.float32 = TensorProto.FLOAT
module.float64 = TensorProto.DOUBLE
module.int8 = TensorProto.INT8
module.int16 = TensorProto.INT16
module.int32 = TensorProto.INT32
module.int64 = TensorProto.INT64
module.uint8 = TensorProto.UINT8
module.uint16 = TensorProto.UINT16
module.uint32 = TensorProto.UINT32
module.uint64 = TensorProto.UINT64
module.bfloat16 = TensorProto.BFLOAT16
setattr(module, "bool", TensorProto.BOOL)
# setattr(module, "str", TensorProto.STRING)
55 changes: 55 additions & 0 deletions onnx_array_api/array_api/onnx_numpy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""
Array API valid for an :class:`EagerNumpyTensor`.
"""
from typing import Any, Optional
import numpy as np
from ..npx.npx_array_api import BaseArrayApi
from ..npx.npx_functions import (
abs,
absolute,
astype,
isdtype,
reshape,
take,
)
from ..npx.npx_functions import asarray as generic_asarray
from ..npx.npx_numpy_tensors import EagerNumpyTensor
from . import _finalize_array_api

__all__ = [
"abs",
"absolute",
"asarray",
"astype",
"isdtype",
"reshape",
"take",
]


def asarray(
a: Any,
dtype: Any = None,
order: Optional[str] = None,
like: Any = None,
copy: bool = False,
):
"""
Converts anything into an array.
"""
if isinstance(a, BaseArrayApi):
return generic_asarray(a, dtype=dtype, order=order, like=like, copy=copy)
if isinstance(a, int):
return EagerNumpyTensor(np.array(a, dtype=np.int64))
if isinstance(a, float):
return EagerNumpyTensor(np.array(a, dtype=np.float32))
raise NotImplementedError(f"asarray not implemented for type {type(a)}.")


def _finalize():
from . import onnx_numpy

_finalize_array_api(onnx_numpy)


_finalize()
55 changes: 55 additions & 0 deletions onnx_array_api/array_api/onnx_ort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""
Array API valid for an :class:`EagerOrtTensor`.
"""
from typing import Any, Optional
import numpy as np
from ..npx.npx_array_api import BaseArrayApi
from ..npx.npx_functions import (
abs,
absolute,
astype,
isdtype,
reshape,
take,
)
from ..npx.npx_functions import asarray as generic_asarray
from ..ort.ort_tensors import EagerOrtTensor
from . import _finalize_array_api

__all__ = [
"abs",
"absolute",
"asarray",
"astype",
"isdtype",
"reshape",
"take",
]


def asarray(
a: Any,
dtype: Any = None,
order: Optional[str] = None,
like: Any = None,
copy: bool = False,
):
"""
Converts anything into an array.
"""
if isinstance(a, BaseArrayApi):
return generic_asarray(a, dtype=dtype, order=order, like=like, copy=copy)
if isinstance(a, int):
return EagerOrtTensor(np.array(a, dtype=np.int64))
if isinstance(a, float):
return EagerOrtTensor(np.array(a, dtype=np.float32))
raise NotImplementedError(f"asarray not implemented for type {type(a)}.")


def _finalize():
from . import onnx_ort

_finalize_array_api(onnx_ort)


_finalize()
20 changes: 0 additions & 20 deletions onnx_array_api/array_api_onnx_numpy.py

This file was deleted.

20 changes: 0 additions & 20 deletions onnx_array_api/array_api_onnx_ort.py

This file was deleted.

18 changes: 0 additions & 18 deletions onnx_array_api/npx/npx_array_api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Any, Optional

import numpy as np
from onnx import TensorProto

from .npx_types import OptParType, ParType, TupleType

Expand All @@ -19,19 +18,6 @@ class BaseArrayApi:
List of supported method by a tensor.
"""

float16 = TensorProto.FLOAT16
float32 = TensorProto.FLOAT
float64 = TensorProto.DOUBLE
int8 = TensorProto.INT8
int16 = TensorProto.INT16
int32 = TensorProto.INT32
int64 = TensorProto.INT64
uint8 = TensorProto.UINT8
uint16 = TensorProto.UINT16
uint32 = TensorProto.UINT32
uint64 = TensorProto.UINT64
bfloat16 = TensorProto.BFLOAT16

def __array_namespace__(self, api_version: Optional[str] = None):
"""
This method must be overloaded.
Expand Down Expand Up @@ -187,7 +173,3 @@ def __getitem__(self, index: Any) -> "BaseArrayApi":

def __setitem__(self, index: Any, values: Any):
return self.generic_method("__setitem__", index, values)


setattr(BaseArrayApi, "bool", TensorProto.BOOL)
setattr(BaseArrayApi, "str", TensorProto.STRING)
7 changes: 6 additions & 1 deletion onnx_array_api/npx/npx_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import array_api_compat.numpy as np_array_api
import numpy as np
from onnx import FunctionProto, ModelProto, NodeProto, TensorProto
from onnx.helper import np_dtype_to_tensor_dtype
from onnx.helper import np_dtype_to_tensor_dtype, tensor_dtype_to_np_dtype
from onnx.numpy_helper import from_array

from .npx_constants import FUNCTION_DOMAIN
Expand Down Expand Up @@ -407,6 +407,11 @@ def isdtype(
See :epkg:`BaseArrayAPI:isdtype`.
This function is not converted into an onnx graph.
"""
if isinstance(dtype, DType):
dti = tensor_dtype_to_np_dtype(dtype.code)
return np_array_api.isdtype(dti, kind)
if isinstance(dtype, int):
raise TypeError(f"Unexpected type {type(dtype)}.")
return np_array_api.isdtype(dtype, kind)


Expand Down
7 changes: 5 additions & 2 deletions onnx_array_api/npx/npx_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
rename_in_onnx_graph,
)
from .npx_types import (
DType,
ElemType,
OptParType,
ParType,
Expand Down Expand Up @@ -226,6 +227,8 @@ def make_node(
protos.append(att)
elif v.value is not None:
new_kwargs[k] = v.value
elif isinstance(v, DType):
new_kwargs[k] = v.code
else:
new_kwargs[k] = v

Expand Down Expand Up @@ -337,7 +340,7 @@ def _io(
if tensor_type.shape is None:
type_proto = TypeProto()
tensor_type_proto = type_proto.tensor_type
tensor_type_proto.elem_type = tensor_type.dtypes[0].dtype
tensor_type_proto.elem_type = tensor_type.dtypes[0].dtype.code
value_info_proto = ValueInfoProto()
value_info_proto.name = name
# tensor_type_proto.shape.dim.extend([])
Expand All @@ -348,7 +351,7 @@ def _io(
# with fixed rank. This can be changed here and in methods
# `make_key`.
shape = [None for _ in tensor_type.shape]
info = make_tensor_value_info(name, tensor_type.dtypes[0].dtype, shape)
info = make_tensor_value_info(name, tensor_type.dtypes[0].dtype.code, shape)
# check_value_info fails if the shape is left undefined
check_value_info(info, self.check_context)
return info
Expand Down
11 changes: 6 additions & 5 deletions onnx_array_api/npx/npx_numpy_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

import numpy as np
from onnx import ModelProto
from onnx.helper import np_dtype_to_tensor_dtype
from onnx.reference import ReferenceEvaluator

from .npx_tensors import EagerTensor, JitTensor
from .npx_types import TensorType
from .npx_types import DType, TensorType


class NumpyTensor:
Expand Down Expand Up @@ -80,9 +81,9 @@ def numpy(self):
return self._tensor

@property
def dtype(self) -> Any:
def dtype(self) -> DType:
"Returns the element type of this tensor."
return self._tensor.dtype
return DType(np_dtype_to_tensor_dtype(self._tensor.dtype))

@property
def key(self) -> Any:
Expand Down Expand Up @@ -176,9 +177,9 @@ def __array_namespace__(self, api_version: Optional[str] = None):
Returns the module holding all the available functions.
"""
if api_version is None or api_version == "2022.12":
from onnx_array_api.array_api_onnx_numpy import onnx_numpy_array_api
from onnx_array_api.array_api import onnx_numpy

return onnx_numpy_array_api
return onnx_numpy
raise ValueError(
f"Unable to return an implementation for api_version={api_version!r}."
)
Expand Down
Loading