Skip to content
Merged
Prev Previous commit
Next Next commit
Add class to yield results
  • Loading branch information
xadupre committed Feb 5, 2024
commit d520cfa8f996aaf5c7224efd00dd0cd76e9775f6
10 changes: 10 additions & 0 deletions _doc/api/reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,13 @@ ExtendedReferenceEvaluator
++++++++++++++++++++++++++

.. autoclass:: onnx_array_api.reference.ExtendedReferenceEvaluator
:members:

YieldEvaluator
++++++++++++++

.. autoclass:: onnx_array_api.reference.ResultType
:members:

.. autoclass:: onnx_array_api.reference.YieldEvaluator
:members:
77 changes: 77 additions & 0 deletions _unittests/ut_reference/test_evaluator_yield.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import unittest
import numpy as np
from onnx import TensorProto
from onnx.helper import (
make_function,
make_graph,
make_model,
make_node,
make_opsetid,
make_tensor_value_info,
)
from onnx_array_api.ext_test_case import ExtTestCase
from onnx_array_api.reference import YieldEvaluator, ResultType


class TestArrayTensor(ExtTestCase):
def test_evaluator_yield(self):
new_domain = "custom_domain"
opset_imports = [make_opsetid("", 14), make_opsetid(new_domain, 1)]

node1 = make_node("MatMul", ["X", "A"], ["XA"])
node2 = make_node("Add", ["XA", "B"], ["Y"])

linear_regression = make_function(
new_domain,
"LinearRegression",
["X", "A", "B"],
["Y"],
[node1, node2],
opset_imports,
[],
)

X = make_tensor_value_info("X", TensorProto.FLOAT, [None, None])
A = make_tensor_value_info("A", TensorProto.FLOAT, [None, None])
B = make_tensor_value_info("B", TensorProto.FLOAT, [None, None])
Y = make_tensor_value_info("Y", TensorProto.FLOAT, None)

graph = make_graph(
[
make_node(
"LinearRegression", ["X", "A", "B"], ["Y1"], domain=new_domain
),
make_node("Abs", ["Y1"], ["Y"]),
],
"example",
[X, A, B],
[Y],
)

onnx_model = make_model(
graph, opset_imports=opset_imports, functions=[linear_regression]
)

cst = np.arange(4).reshape((-1, 2)).astype(np.float32)
yield_eval = YieldEvaluator(onnx_model)
results = list(
yield_eval.enumerate_results(None, {"A": cst, "B": cst, "X": cst})
)
expected = [
(ResultType.INPUT, "A", np.array([[0.0, 1.0], [2.0, 3.0]], dtype=np.float32)),
(ResultType.INPUT, "B", np.array([[0.0, 1.0], [2.0, 3.0]], dtype=np.float32)),
(ResultType.INPUT, "X", np.array([[0.0, 1.0], [2.0, 3.0]], dtype=np.float32)),
(ResultType.RESULT, "Y1", np.array([[2.0, 4.0], [8.0, 14.0]], dtype=np.float32)),
(ResultType.RESULT, "Y", np.array([[2.0, 4.0], [8.0, 14.0]], dtype=np.float32)),
(ResultType.OUTPUT, "Y", np.array([[2.0, 4.0], [8.0, 14.0]], dtype=np.float32)),
]
self.assertEqual(len(expected), len(results))
for a, b in zip(expected, results):
self.assertEqual(len(a), len(b))
self.assertEqual(a[0], b[0])
self.assertEqual(a[1], b[1])
self.assertEqual(a[2].tolist(), b[2].tolist())


if __name__ == "__main__":
unittest.main(verbosity=2)
1 change: 1 addition & 0 deletions onnx_array_api/reference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
)
from onnx.reference.op_run import to_array_extended
from .evaluator import ExtendedReferenceEvaluator
from .evaluator_yield import YieldEvaluator, ResultType


def from_array_extended(tensor: np.array, name: Optional[str] = None) -> TensorProto:
Expand Down
98 changes: 98 additions & 0 deletions onnx_array_api/reference/evaluator_yield.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from typing import Any, Dict, List, Iterator, Optional, Tuple
from enum import IntEnum
from onnx import ModelProto
from .evaluator import ExtendedReferenceEvaluator


class ResultType(IntEnum):
RESULT = 1
INITIALIZER = 2
SPARSE_INITIALIZER = 4
INPUT = 8
OUTPUT = 16

def __repr__(self):
return f"{self.__class__.__name__}.{self._name_}"


class YieldEvaluator:
"""
This class implements method `enumerate_results` which iterates on
intermediates results. By default, it uses
:class:`onnx_array_api.evaluator.ExtendedReferenceEvaluator`.

:param onnx_model: model to run
:param recursive: dig into subgraph and functions as well
"""

def __init__(
self,
onnx_model: ModelProto,
recursive: bool = False,
cls=ExtendedReferenceEvaluator,
):
assert not recursive, "recursive=True is not yet implemented"
self.onnx_model = onnx_model
self.evaluator = cls(onnx_model) if cls is not None else None

def enumerate_results(
self,
output_names: Optional[List[str]] = None,
feed_inputs: Optional[Dict[str, Any]] = None,
) -> Iterator[Tuple[ResultType, str, Any]]:
"""
Executes the onnx model.

Args:
output_names: requested outputs by names, None for all
feed_inputs: dictionary `{ input name: input value }`

Returns:
iterator on tuple(result kind, name, value)
"""
assert isinstance(self.evaluator, ExtendedReferenceEvaluator), (
f"This implementation only works with "
f"ExtendedReferenceEvaluator not {type(self.evaluator)}"
)
attributes = {}
if output_names is None:
output_names = self.evaluator.output_names

results = {"": None}
results.update(self.evaluator.rt_inits_)
results.update(feed_inputs)
# step 0: initializer
for k, v in self.evaluator.rt_inits_.items():
yield ResultType.INITIALIZER, k, v
# step 1: inputs
for k, v in feed_inputs.items():
yield ResultType.INPUT, k, v

# step 2: execute nodes
for node in self.evaluator.rt_nodes_:
for i in node.input:
if i not in results:
raise RuntimeError(
f"Unable to find input {i!r} in known results {sorted(results)}, "
f"self.rt_inits_ has {sorted(self.evaluator.rt_inits_)}, "
f"feed_inputs has {sorted(feed_inputs)}."
)
inputs = [results[i] for i in node.input]
linked_attributes = {}
if node.has_linked_attribute and attributes:
linked_attributes["linked_attributes"] = attributes
if node.need_context():
outputs = node.run(*inputs, context=results, **linked_attributes)
else:
outputs = node.run(*inputs, **linked_attributes)
for name, value in zip(node.output, outputs):
yield ResultType.RESULT, name, value
results[name] = value

# step 3: outputs
for name in output_names:
if name not in results:
raise RuntimeError(
f"Unable to find output name {name!r} in {sorted(results)}, proto is\n{self.proto_}"
)
yield ResultType.OUTPUT, name, results[name]