Skip to content
Merged
Next Next commit
Supports subgraph in the light API
  • Loading branch information
xadupre committed Nov 12, 2023
commit e0233dc327de5302a2d9865ad85b02de38065ad9
37 changes: 34 additions & 3 deletions _unittests/ut_light_api/test_light_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import sys
from typing import Callable, Optional
import numpy as np
from onnx import ModelProto
from onnx import GraphProto, ModelProto
from onnx.defs import (
get_all_schemas_with_history,
onnx_opset_version,
Expand All @@ -12,7 +12,7 @@
)
from onnx.reference import ReferenceEvaluator
from onnx_array_api.ext_test_case import ExtTestCase
from onnx_array_api.light_api import start, OnnxGraph, Var
from onnx_array_api.light_api import start, OnnxGraph, Var, g
from onnx_array_api.light_api._op_var import OpsVar
from onnx_array_api.light_api._op_vars import OpsVars

Expand Down Expand Up @@ -442,7 +442,38 @@ def test_topk_reverse(self):
self.assertEqualArray(np.array([[0, 1], [6, 7]], dtype=np.float32), got[0])
self.assertEqualArray(np.array([[0, 1], [3, 2]], dtype=np.int64), got[1])

def test_if(self):
gg = g().cst(np.array([0], dtype=np.int64)).rename("Z").vout()
onx = gg.to_onnx()
self.assertIsInstance(onx, GraphProto)
self.assertEqual(len(onx.input), 0)
self.assertEqual(len(onx.output), 1)
self.assertEqual([o.name for o in onx.output], ["Z"])
onx = (
start()
.vin("X", np.float32)
.ReduceSum()
.rename("Xs")
.cst(np.array([0], dtype=np.float32))
.left_bring("Xs")
.Greater()
.If(
then_branch=g().cst(np.array([1], dtype=np.int64)).rename("Z").vout(),
else_branch=g().cst(np.array([0], dtype=np.int64)).rename("Z").vout(),
)
.rename("W")
.vout()
.to_onnx()
)
self.assertIsInstance(onx, ModelProto)
ref = ReferenceEvaluator(onx)
x = np.array([0, 1, 2, 3, 9, 8, 7, 6], dtype=np.float32)
got = ref.run(None, {"X": x})
self.assertEqualArray(np.array([1], dtype=np.int64), got[0])
got = ref.run(None, {"X": -x})
self.assertEqualArray(np.array([0], dtype=np.int64), got[0])


if __name__ == "__main__":
# TestLightApi().test_topk()
TestLightApi().test_if()
unittest.main(verbosity=2)
56 changes: 54 additions & 2 deletions _unittests/ut_light_api/test_translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from onnx.defs import onnx_opset_version
from onnx.reference import ReferenceEvaluator
from onnx_array_api.ext_test_case import ExtTestCase
from onnx_array_api.light_api import start, translate
from onnx_array_api.light_api import start, translate, g
from onnx_array_api.light_api.emitter import EventType

OPSET_API = min(19, onnx_opset_version() - 1)
Expand Down Expand Up @@ -133,7 +133,59 @@ def test_topk_reverse(self):
).strip("\n")
self.assertEqual(expected, code)

def test_export_if(self):
onx = (
start()
.vin("X", np.float32)
.ReduceSum()
.rename("Xs")
.cst(np.array([0], dtype=np.float32))
.left_bring("Xs")
.Greater()
.If(
then_branch=g().cst(np.array([1], dtype=np.int64)).rename("Z").vout(),
else_branch=g().cst(np.array([0], dtype=np.int64)).rename("Z").vout(),
)
.rename("W")
.vout()
.to_onnx()
)

self.assertIsInstance(onx, ModelProto)
ref = ReferenceEvaluator(onx)
x = np.array([[0, 1, 2, 3], [9, 8, 7, 6]], dtype=np.float32)
k = np.array([2], dtype=np.int64)
got = ref.run(None, {"X": x, "K": k})
self.assertEqualArray(np.array([1], dtype=np.int64), got[0])

code = translate(onx)
selse = "g().cst(np.array([0], dtype=np.int64)).rename('Z').bring('Z').vout(elem_type=TensorProto.FLOAT)"
sthen = "g().cst(np.array([1], dtype=np.int64)).rename('Z').bring('Z').vout(elem_type=TensorProto.FLOAT)"
expected = dedent(
f"""
(
start(opset=20)
.cst(np.array([0.0], dtype=np.float32))
.rename('r')
.vin('X', elem_type=TensorProto.FLOAT)
.bring('X')
.ReduceSum(keepdims=1, noop_with_empty_axes=0)
.rename('Xs')
.bring('Xs', 'r')
.Greater()
.rename('r1_0')
.bring('r1_0')
.If(else_branch={selse}, then_branch={sthen})
.rename('W')
.bring('W')
.vout(elem_type=TensorProto.FLOAT)
.to_onnx()
)"""
).strip("\n")
self.maxDiff = None
self.assertEqual(expected, code)


if __name__ == "__main__":
# TestLightApi().test_topk()
TestTranslate().test_export_if()
unittest.main(verbosity=2)
14 changes: 10 additions & 4 deletions onnx_array_api/light_api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Dict, Optional
from onnx import ModelProto
from .model import OnnxGraph
from .model import OnnxGraph, ProtoType
from .translate import Translater
from .var import Var, Vars
from .inner_emitter import InnerEmitter
Expand All @@ -9,13 +9,11 @@
def start(
opset: Optional[int] = None,
opsets: Optional[Dict[str, int]] = None,
is_function: bool = False,
) -> OnnxGraph:
"""
Starts an onnx model.

:param opset: main opset version
:param is_function: a :class:`onnx.ModelProto` or a :class:`onnx.FunctionProto`
:param opsets: others opsets as a dictionary
:return: an instance of :class:`onnx_array_api.light_api.OnnxGraph`

Expand Down Expand Up @@ -48,7 +46,15 @@ def start(
)
print(onx)
"""
return OnnxGraph(opset=opset, opsets=opsets, is_function=is_function)
return OnnxGraph(opset=opset, opsets=opsets)


def g() -> OnnxGraph:
"""
Starts a subgraph.
:return: an instance of :class:`onnx_array_api.light_api.OnnxGraph`
"""
return OnnxGraph(proto_type=ProtoType.GRAPH)


def translate(proto: ModelProto, single_line: bool = False, api: str = "light") -> str:
Expand Down
30 changes: 29 additions & 1 deletion onnx_array_api/light_api/_op_var.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import List, Optional, Union


class OpsVar:
Expand Down Expand Up @@ -109,6 +109,34 @@ def HardSigmoid(
def Hardmax(self, axis: int = -1) -> "Var":
return self.make_node("Hardmax", self, axis=axis)

def If(
self,
then_branch: Optional[Union["Var", "Vars", "OnnxGraph"]] = None,
else_branch: Optional[Union["Var", "Vars", "OnnxGraph"]] = None,
) -> Union["Var", "Vars"]:
attr = {}
n_outputs = None
for name, att in zip(
["then_branch", "else_branch"], [then_branch, else_branch]
):
if att is None:
raise ValueError(f"Parameter {name!r} cannot be None.")
if hasattr(att, "to_onnx"):
# Let's overwrite the opsets.
att.parent.opset = self.parent.opset
att.parent.opsets = self.parent.opsets
graph = att.to_onnx()
attr[name] = graph
if n_outputs is None:
n_outputs = len(graph.output)
elif n_outputs != len(graph.output):
raise ValueError(
"then and else branches have different number of outputs."
)
else:
raise ValueError(f"Unexpeted type {type(att)} for parameter {name!r}.")
return self.make_node("If", self, **attr)

def IsInf(self, detect_negative: int = 1, detect_positive: int = 1) -> "Var":
return self.make_node(
"IsInf",
Expand Down
9 changes: 9 additions & 0 deletions onnx_array_api/light_api/emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,15 @@ def render_attribute_value(self, value: Any) -> Tuple[List[str], str]:
):
return [], str(v.tolist())

if value[0].type == AttributeProto.GRAPH:
from .translate import Translater

tr = Translater(value[0].g, emitter=self)
rows = tr.export(as_str=False, single_line=False)
# last instruction is to_onnx, let's drop it.
srows = ".".join(rows[:-1])
return [], f"g().{srows}"

raise ValueError(
f"Unable to render an attribute {type(v)}, "
f"attribute type={value[0].type}, "
Expand Down
33 changes: 30 additions & 3 deletions onnx_array_api/light_api/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Dict, List, Optional, Union
from enum import IntEnum
import numpy as np
from onnx import NodeProto, SparseTensorProto, TensorProto, ValueInfoProto
from onnx.checker import check_model
Expand All @@ -22,6 +23,12 @@
)


class ProtoType(IntEnum):
FUNCTION = 1
GRAPH = 2
MODEL = 3


class OnnxGraph:
"""
Contains every piece needed to create an onnx model in a single instructions.
Expand All @@ -36,7 +43,7 @@ def __init__(
self,
opset: Optional[int] = None,
opsets: Optional[Dict[str, int]] = None,
is_function: bool = False,
proto_type: ProtoType = ProtoType.MODEL,
):
if opsets is not None and "" in opsets:
if opset is None:
Expand All @@ -45,11 +52,11 @@ def __init__(
raise ValueError(
"The main opset can be specified twice with different values."
)
if is_function:
if proto_type == ProtoType.FUNCTION:
raise NotImplementedError(
"The first version of this API does not support functions."
)
self.is_function = is_function
self.proto_type = proto_type
self.opsets = opsets
self.opset = opset
self.nodes: List[Union[NodeProto, TensorProto]] = []
Expand All @@ -59,6 +66,10 @@ def __init__(
self.unique_names_: Dict[str, Any] = {}
self.renames_: Dict[str, str] = {}

@property
def is_function(self) -> bool:
return self.proto_type == ProtoType.FUNCTION

def __repr__(self) -> str:
"usual"
sts = [f"{self.__class__.__name__}("]
Expand Down Expand Up @@ -233,6 +244,19 @@ def make_node(
self.nodes.append(node)
return node

def cst(self, value: np.ndarray, name: Optional[str] = None) -> "Var":
"""
Adds an initializer

:param value: constant tensor
:param name: input name
:return: instance of :class:`onnx_array_api.light_api.Var`
"""
from .var import Var

c = self.make_constant(value, name=name)
return Var(self, c.name, elem_type=c.data_type, shape=tuple(c.dims))

def true_name(self, name: str) -> str:
"""
Some names were renamed. If name is one of them, the function
Expand Down Expand Up @@ -363,6 +387,9 @@ def to_onnx(self) -> GRAPH_PROTO:
if self.opsets:
for k, v in self.opsets.items():
opsets.append(make_opsetid(k, v))
if self.proto_type == ProtoType.GRAPH:
# If no opsets, it a subgraph, not a model.
return graph
model = make_model(graph, opset_imports=opsets)
check_model(model)
return model