-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathevaluator.py
More file actions
123 lines (109 loc) · 3.99 KB
/
evaluator.py
File metadata and controls
123 lines (109 loc) · 3.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from logging import getLogger
from typing import Any, Dict, List, Optional, Union
from onnx import FunctionProto, ModelProto
from onnx.defs import get_schema
from onnx.reference import ReferenceEvaluator
from onnx.reference.op_run import OpRun
from .ops.op_cast_like import CastLike_15, CastLike_19
from .ops.op_concat import Concat
from .ops.op_constant_of_shape import ConstantOfShape
from .ops.op_fused_matmul import FusedMatMul
from .ops.op_memcpy_host import MemcpyFromHost, MemcpyToHost
from .ops.op_quick_gelu import QuickGelu
from .ops.op_scatter_elements import ScatterElements
logger = getLogger("onnx-array-api-eval")
class ExtendedReferenceEvaluator(ReferenceEvaluator):
"""
This class replaces the python implementation by custom implementation.
The Array API extends many operator to all types not supported
by the onnx specifications. The evaluator allows to test
scenarios outside what an onnx backend bound to the official onnx
operators definition could do.
::
from onnx.reference import ReferenceEvaluator
from onnx.reference.c_ops import Conv
ref = ReferenceEvaluator(..., new_ops=[Conv])
"""
default_ops = [
Concat,
CastLike_15,
CastLike_19,
ConstantOfShape,
FusedMatMul,
MemcpyFromHost,
MemcpyToHost,
QuickGelu,
ScatterElements,
]
@staticmethod
def filter_ops(proto, new_ops, opsets):
if opsets is None and isinstance(proto, (ModelProto, FunctionProto)):
opsets = {d.domain: d.version for d in proto.opset_import}
best = {}
renamed = {}
for cl in new_ops:
if "_" not in cl.__name__:
continue
vers = cl.__name__.split("_")
try:
v = int(vers[-1])
except ValueError:
# not a version
continue
if opsets is not None and v > opsets.get(cl.op_domain, 1):
continue
renamed[cl.__name__] = cl
key = cl.op_domain, "_".join(vers[:-1])
if key not in best or best[key][0] < v:
best[key] = (v, cl)
modified = []
for cl in new_ops:
if cl.__name__ not in renamed:
modified.append(cl)
for k, v in best.items():
atts = {"domain": k[0]}
bases = (v[1],)
if not hasattr(v[1], "op_schema"):
atts["op_schema"] = get_schema(k[1], v[0], domain=v[1].op_domain)
new_cl = type(k[1], bases, atts)
modified.append(new_cl)
new_ops = modified
return new_ops
def __init__(
self,
proto: Any,
opsets: Optional[Dict[str, int]] = None,
functions: Optional[List[Union[ReferenceEvaluator, FunctionProto]]] = None,
verbose: int = 0,
new_ops: Optional[List[OpRun]] = None,
**kwargs,
):
if new_ops is None:
new_ops = ExtendedReferenceEvaluator.default_ops
else:
new_ops = new_ops.copy()
new_ops.extend(ExtendedReferenceEvaluator.default_ops)
new_ops = ExtendedReferenceEvaluator.filter_ops(proto, new_ops, opsets)
ReferenceEvaluator.__init__(
self,
proto,
opsets=opsets,
functions=functions,
verbose=verbose,
new_ops=new_ops,
**kwargs,
)
def _log(self, level: int, pattern: str, *args: List[Any]) -> None:
if level < self.verbose:
new_args = [self._log_arg(a) for a in args]
print(pattern % tuple(new_args))
else:
logger.debug(pattern, *args)
def run(self, *args, **kwargs):
"""
See :meth:`onnx.reference.ReferenceEvaluator.run`.
"""
if len(args) == 1 and isinstance(args[0], list):
feeds = dict(zip(self.input_names, args[0]))
return self.run(None, feeds, **kwargs)
return ReferenceEvaluator.run(self, *args, **kwargs)