Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
fix operators with two outputs
  • Loading branch information
xadupre committed Nov 8, 2023
commit 56914b6c9d9b3123882b3233fbb79ce97bc0a34b
39 changes: 39 additions & 0 deletions _unittests/ut_light_api/test_light_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,45 @@ def test_operator_bool(self):
got = ref.run(None, {"X": a, "Y": b})[0]
self.assertEqualArray(f(a, b), got)

def test_topk(self):
onx = (
start()
.vin("X", np.float32)
.vin("K", np.int64)
.bring("X", "K")
.TopK()
.rename("Values", "Indices")
.vout()
.to_onnx()
)
self.assertIsInstance(onx, ModelProto)
ref = ReferenceEvaluator(onx)
x = np.array([[0, 1, 2, 3], [9, 8, 7, 6]], dtype=np.float32)
k = np.array([2], dtype=np.int64)
got = ref.run(None, {"X": x, "K": k})
self.assertEqualArray(np.array([[3, 2], [9, 8]], dtype=np.float32), got[0])
self.assertEqualArray(np.array([[3, 2], [0, 1]], dtype=np.int64), got[1])

def test_topk_reverse(self):
onx = (
start()
.vin("X", np.float32)
.vin("K", np.int64)
.bring("X", "K")
.TopK(largest=0)
.rename("Values", "Indices")
.vout()
.to_onnx()
)
self.assertIsInstance(onx, ModelProto)
ref = ReferenceEvaluator(onx)
x = np.array([[0, 1, 2, 3], [9, 8, 7, 6]], dtype=np.float32)
k = np.array([2], dtype=np.int64)
got = ref.run(None, {"X": x, "K": k})
self.assertEqualArray(np.array([[0, 1], [6, 7]], dtype=np.float32), got[0])
self.assertEqualArray(np.array([[0, 1], [3, 2]], dtype=np.int64), got[1])


if __name__ == "__main__":
# TestLightApi().test_topk()
unittest.main(verbosity=2)
6 changes: 3 additions & 3 deletions onnx_array_api/light_api/_op_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def ArgMin(

def AveragePool(
self,
auto_pad: str = b"NOTSET",
auto_pad: str = "NOTSET",
ceil_mode: int = 0,
count_include_pad: int = 0,
dilations: Optional[List[int]] = None,
Expand Down Expand Up @@ -68,7 +68,7 @@ def Cast(self, saturate: int = 1, to: int = 0) -> "Var":
def Celu(self, alpha: float = 1.0) -> "Var":
return self.make_node("Celu", self, alpha=alpha)

def DepthToSpace(self, blocksize: int = 0, mode: str = b"DCR") -> "Var":
def DepthToSpace(self, blocksize: int = 0, mode: str = "DCR") -> "Var":
return self.make_node("DepthToSpace", self, blocksize=blocksize, mode=mode)

def DynamicQuantizeLinear(
Expand Down Expand Up @@ -137,7 +137,7 @@ def LpNormalization(self, axis: int = -1, p: int = 2) -> "Var":

def LpPool(
self,
auto_pad: str = b"NOTSET",
auto_pad: str = "NOTSET",
ceil_mode: int = 0,
dilations: Optional[List[int]] = None,
kernel_shape: Optional[List[int]] = None,
Expand Down
45 changes: 25 additions & 20 deletions onnx_array_api/light_api/_op_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ class OpsVars:
Operators taking multiple inputs.
"""

def BitShift(self, direction: str = b"") -> "Var":
def BitShift(self, direction: str = "") -> "Var":
return self.make_node("BitShift", *self.vars_, direction=direction)

def CenterCropPad(self, axes: Optional[List[int]] = None) -> "Var":
Expand Down Expand Up @@ -42,7 +42,7 @@ def Concat(self, axis: int = 0) -> "Var":

def Conv(
self,
auto_pad: str = b"NOTSET",
auto_pad: str = "NOTSET",
dilations: Optional[List[int]] = None,
group: int = 1,
kernel_shape: Optional[List[int]] = None,
Expand All @@ -66,7 +66,7 @@ def Conv(

def ConvInteger(
self,
auto_pad: str = b"NOTSET",
auto_pad: str = "NOTSET",
dilations: Optional[List[int]] = None,
group: int = 1,
kernel_shape: Optional[List[int]] = None,
Expand All @@ -90,7 +90,7 @@ def ConvInteger(

def ConvTranspose(
self,
auto_pad: str = b"NOTSET",
auto_pad: str = "NOTSET",
dilations: Optional[List[int]] = None,
group: int = 1,
kernel_shape: Optional[List[int]] = None,
Expand Down Expand Up @@ -155,7 +155,7 @@ def DeformConv(
def DequantizeLinear(self, axis: int = 1) -> "Var":
return self.make_node("DequantizeLinear", *self.vars_, axis=axis)

def Einsum(self, equation: str = b"") -> "Var":
def Einsum(self, equation: str = "") -> "Var":
return self.make_node("Einsum", *self.vars_, equation=equation)

def Gather(self, axis: int = 0) -> "Var":
Expand All @@ -174,8 +174,8 @@ def Gemm(
def GridSample(
self,
align_corners: int = 0,
mode: str = b"bilinear",
padding_mode: str = b"zeros",
mode: str = "bilinear",
padding_mode: str = "zeros",
) -> "Var":
return self.make_node(
"GridSample",
Expand Down Expand Up @@ -240,7 +240,7 @@ def Mod(self, fmod: int = 0) -> "Var":
return self.make_node("Mod", *self.vars_, fmod=fmod)

def NegativeLogLikelihoodLoss(
self, ignore_index: int = 0, reduction: str = b"mean"
self, ignore_index: int = 0, reduction: str = "mean"
) -> "Var":
return self.make_node(
"NegativeLogLikelihoodLoss",
Expand All @@ -257,12 +257,12 @@ def NonMaxSuppression(self, center_point_box: int = 0) -> "Var":
def OneHot(self, axis: int = -1) -> "Var":
return self.make_node("OneHot", *self.vars_, axis=axis)

def Pad(self, mode: str = b"constant") -> "Var":
def Pad(self, mode: str = "constant") -> "Var":
return self.make_node("Pad", *self.vars_, mode=mode)

def QLinearConv(
self,
auto_pad: str = b"NOTSET",
auto_pad: str = "NOTSET",
dilations: Optional[List[int]] = None,
group: int = 1,
kernel_shape: Optional[List[int]] = None,
Expand Down Expand Up @@ -431,13 +431,13 @@ def Resize(
self,
antialias: int = 0,
axes: Optional[List[int]] = None,
coordinate_transformation_mode: str = b"half_pixel",
coordinate_transformation_mode: str = "half_pixel",
cubic_coeff_a: float = -0.75,
exclude_outside: int = 0,
extrapolation_value: float = 0.0,
keep_aspect_ratio_policy: str = b"stretch",
mode: str = b"nearest",
nearest_mode: str = b"round_prefer_floor",
keep_aspect_ratio_policy: str = "stretch",
mode: str = "nearest",
nearest_mode: str = "round_prefer_floor",
) -> "Var":
axes = axes or []
return self.make_node(
Expand All @@ -456,8 +456,8 @@ def Resize(

def RoiAlign(
self,
coordinate_transformation_mode: str = b"half_pixel",
mode: str = b"avg",
coordinate_transformation_mode: str = "half_pixel",
mode: str = "avg",
output_height: int = 1,
output_width: int = 1,
sampling_ratio: int = 0,
Expand All @@ -480,12 +480,12 @@ def STFT(self, onesided: int = 1) -> "Var":
def Scatter(self, axis: int = 0) -> "Var":
return self.make_node("Scatter", *self.vars_, axis=axis)

def ScatterElements(self, axis: int = 0, reduction: str = b"none") -> "Var":
def ScatterElements(self, axis: int = 0, reduction: str = "none") -> "Var":
return self.make_node(
"ScatterElements", *self.vars_, axis=axis, reduction=reduction
)

def ScatterND(self, reduction: str = b"none") -> "Var":
def ScatterND(self, reduction: str = "none") -> "Var":
return self.make_node("ScatterND", *self.vars_, reduction=reduction)

def Slice(
Expand All @@ -498,13 +498,18 @@ def Slice(

def TopK(self, axis: int = -1, largest: int = 1, sorted: int = 1) -> "Vars":
return self.make_node(
"TopK", *self.vars_, axis=axis, largest=largest, sorted=sorted
"TopK",
*self.vars_,
axis=axis,
largest=largest,
sorted=sorted,
n_outputs=2,
)

def Trilu(self, upper: int = 1) -> "Var":
return self.make_node("Trilu", *self.vars_, upper=upper)

def Upsample(self, mode: str = b"nearest") -> "Var":
def Upsample(self, mode: str = "nearest") -> "Var":
return self.make_node("Upsample", *self.vars_, mode=mode)

def Where(
Expand Down
94 changes: 70 additions & 24 deletions onnx_array_api/light_api/var.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
from onnx import TensorProto
from .annotations import (
Expand Down Expand Up @@ -27,6 +27,8 @@ def __init__(
self,
parent: OnnxGraph,
):
if not isinstance(parent, OnnxGraph):
raise RuntimeError(f"Unexpected parent type {type(parent)}.")
self.parent = parent

def make_node(
Expand Down Expand Up @@ -60,9 +62,13 @@ def make_node(
**kwargs,
)
names = node_proto.output
if n_outputs is not None and len(node_proto.output) != len(names):
raise RuntimeError(
f"Expects {n_outputs} outputs but output names are {names}."
)
if len(names) == 1:
return Var(self.parent, names[0])
return Vars(*map(lambda v: Var(self.parent, v), names))
return Vars(self.parent, *list(map(lambda v: Var(self.parent, v), names)))

def vin(
self,
Expand Down Expand Up @@ -91,26 +97,6 @@ def cst(self, value: np.ndarray, name: Optional[str] = None) -> "Var":
c = self.parent.make_constant(value, name=name)
return Var(self.parent, c.name, elem_type=c.data_type, shape=tuple(c.dims))

def vout(
self,
elem_type: ELEMENT_TYPE = TensorProto.FLOAT,
shape: Optional[SHAPE_TYPE] = None,
) -> "Var":
"""
Declares a new output to the graph.

:param elem_type: element_type
:param shape: shape
:return: instance of :class:`onnx_array_api.light_api.Var`
"""
output = self.parent.make_output(self.name, elem_type=elem_type, shape=shape)
return Var(
self.parent,
output,
elem_type=output.type.tensor_type.elem_type,
shape=make_shape(output.type.tensor_type.shape),
)

def v(self, name: str) -> "Var":
"""
Retrieves another variable than this one.
Expand All @@ -127,6 +113,13 @@ def bring(self, *vars: List[Union[str, "Var"]]) -> "Vars":
"""
return Vars(self.parent, *vars)

def vout(self, **kwargs: Dict[str, Any]) -> Union["Var", "Vars"]:
"""
This method needs to be overwritten for Var and Vars depending
on the number of variable to declare as outputs.
"""
raise RuntimeError(f"The method was not overwritten in class {type(self)}.")

def left_bring(self, *vars: List[Union[str, "Var"]]) -> "Vars":
"""
Creates a set of variables as an instance of
Expand Down Expand Up @@ -187,6 +180,26 @@ def __str__(self) -> str:
return s
return f"{s}:[{''.join(map(str, self.shape))}]"

def vout(
self,
elem_type: ELEMENT_TYPE = TensorProto.FLOAT,
shape: Optional[SHAPE_TYPE] = None,
) -> "Var":
"""
Declares a new output to the graph.

:param elem_type: element_type
:param shape: shape
:return: instance of :class:`onnx_array_api.light_api.Var`
"""
output = self.parent.make_output(self.name, elem_type=elem_type, shape=shape)
return Var(
self.parent,
output,
elem_type=output.type.tensor_type.elem_type,
shape=make_shape(output.type.tensor_type.shape),
)

def rename(self, new_name: str) -> "Var":
"Renames a variable."
self.parent.rename(self.name, new_name)
Expand Down Expand Up @@ -299,6 +312,39 @@ def _check_nin(self, n_inputs):
raise RuntimeError(f"Expecting {n_inputs} inputs not {len(self)}.")
return self

def rename(self, new_name: str) -> "Var":
def rename(self, *new_names: List[str]) -> "Vars":
"Renames variables."
raise NotImplementedError("Not yet implemented.")
if len(new_names) != len(self):
raise ValueError(
f"Vars has {len(self)} elements but the method received {len(new_names)} names."
)
new_vars = []
for var, name in zip(self.vars_, new_names):
new_vars.append(var.rename(name))
return Vars(self.parent, *new_names)

def vout(
self,
*elem_type_shape: List[
Union[ELEMENT_TYPE, Tuple[ELEMENT_TYPE, Optional[SHAPE_TYPE]]]
],
) -> "Vars":
"""
Declares a new output to the graph.

:param elem_type_shape: list of tuple(element_type, shape)
:return: instance of :class:`onnx_array_api.light_api.Vars`
"""
vars = []
for i, v in enumerate(self.vars_):
if i < len(elem_type_shape):
if isinstance(elem_type_shape[i]) or len(elem_type_shape[i]) < 2:
elem_type = elem_type_shape[i][0]
shape = None
else:
elem_type, shape = elem_type_shape[i]
else:
elem_type = TensorProto.FLOAT
shape = None
vars.append(v.vout(elem_type=elem_type, shape=shape))
return Vars(self.parent, *vars)