Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
improve initializer
  • Loading branch information
xadupre committed Feb 21, 2024
commit 3b10ea8547428075035edc0e3a4c76b588b6e2ac
2 changes: 1 addition & 1 deletion _unittests/ut_reference/test_evaluator_yield.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ def test_distance_sequence_str(self):
002=|INPUTfloat322x2ABCDB|INPUTfloat322x2ABCDB
003~|INPUTfloat322x3ABCDX|INPUTfloat322x2ABCDX
004-|RESULTfloat322x2CEIOExpH|
005=|RESULTfloat322x2CEIOLinearRegresY1|RESULTfloat322x2CEIOLinearRegresY1
005=|RESULTfloat322x2CEIOLinearRegressioY1|RESULTfloat322x2CEIOLinearRegressioY1
006~|RESULTfloat322x2CEIOAbsY|RESULTfloat322x3CEIPAbsZ
007~|OUTPUTfloat322x2CEIOY|OUTPUTfloat322x2CEIPY
""".replace(
Expand Down
2 changes: 1 addition & 1 deletion onnx_array_api/_command_lines_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def _cmd_compare(argv: List[Any]):
res1, res2, align, dc = compare_onnx_execution(
onx1, onx2, verbose=args.verbose, mode=args.mode
)
text = dc.to_str(res1, res2, align, column_size=args.column_size)
text = dc.to_str(res1, res2, align, column_size=int(args.column_size))
print(text)


Expand Down
12 changes: 9 additions & 3 deletions onnx_array_api/reference/evaluator_yield.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from onnx import ModelProto, TensorProto, ValueInfoProto, load
from onnx.helper import tensor_dtype_to_np_dtype
from onnx.shape_inference import infer_shapes
from . import to_array_extended
from .evaluator import ExtendedReferenceEvaluator


Expand Down Expand Up @@ -66,7 +67,7 @@ def __str__(self):
_align(str(dtype).replace("dtype(", "").replace(")", ""), 8),
_align("x".join("" if self.shape is None else map(str, self.shape)), 15),
self.summary,
_align(self.op_type or "", 12),
_align(self.op_type or "", 15),
self.name or "",
]
return " ".join(els)
Expand Down Expand Up @@ -496,7 +497,12 @@ def _enumerate_result_no_execution(model: ModelProto) -> Iterator[ResultType]:
itype, shape = type_shape.get(i.name, (0, None))
dtype = tensor_dtype_to_np_dtype(itype)
yield ResultExecution(
ResultType.INITIALIZER, dtype, shape, "????", "INIT", i.name
ResultType.INITIALIZER,
dtype,
shape,
make_summary(to_array_extended(i)),
"INIT",
i.name,
)
for i in model.graph.input:
itype, shape = type_shape.get(i.name, (0, None))
Expand All @@ -506,7 +512,7 @@ def _enumerate_result_no_execution(model: ModelProto) -> Iterator[ResultType]:
yield ResultExecution(ResultType.NODE, 0, None, "????", node.op_type, node.name)
for o in node.output:
itype, shape = type_shape.get(o, (0, None))
dtype = tensor_dtype_to_np_dtype(itype)
dtype = 0 if itype == 0 else tensor_dtype_to_np_dtype(itype)
yield ResultExecution(
ResultType.RESULT, dtype, shape, "????", node.op_type, o
)
Expand Down