Skip to content
Merged
Prev Previous commit
extend unit test copverage
  • Loading branch information
xadupre committed Feb 14, 2024
commit 5f37a5997f6ee6755733ee13f3c307d0ddb0e0c3
1 change: 1 addition & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Change Logs
0.2.0
+++++

* :pr:`75`: add QuickGelu to ExtendedReferenceEvaluator
* :pr:`71`: adds tools to compare two onnx graphs
* :pr:`61`: adds function to plot onnx model as graphs
* :pr:`60`: supports translation of local functions
Expand Down
19 changes: 19 additions & 0 deletions _unittests/ut_reference/test_reference_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,25 @@ def test_fused_matmul11(self):
got = ref.run(None, {"X": a, "Y": a})
self.assertEqualArray(a.T @ a.T, got[0])

def test_memcpy(self):
model = make_model(
make_graph(
[
make_node("MemcpyToHost", ["X"], ["Z"]),
make_node("MemcpyFromHost", ["X"], ["Z"]),
],
"name",
[make_tensor_value_info("X", TensorProto.FLOAT, None)],
[make_tensor_value_info("Z", TensorProto.FLOAT, None)],
),
opset_imports=[make_opsetid("", 18), make_opsetid("com.microsoft", 1)],
ir_version=9,
)
a = np.arange(4).reshape(-1, 2).astype(np.float32)
ref = ExtendedReferenceEvaluator(model)
got = ref.run(None, {"X": a})
self.assertEqualArray(a, got[0])

def test_quick_gelu(self):
from onnxruntime import InferenceSession

Expand Down
5 changes: 4 additions & 1 deletion onnx_array_api/reference/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
from .ops.op_concat import Concat
from .ops.op_constant_of_shape import ConstantOfShape
from .ops.op_fused_matmul import FusedMatMul
from .ops.op_scatter_elements import ScatterElements
from .ops.op_memcpy_host import MemcpyFromHost, MemcpyToHost
from .ops.op_quick_gelu import QuickGelu
from .ops.op_scatter_elements import ScatterElements


logger = getLogger("onnx-array-api-eval")
Expand All @@ -36,6 +37,8 @@ class ExtendedReferenceEvaluator(ReferenceEvaluator):
CastLike_19,
ConstantOfShape,
FusedMatMul,
MemcpyFromHost,
MemcpyToHost,
QuickGelu,
ScatterElements,
]
Expand Down
11 changes: 11 additions & 0 deletions onnx_array_api/reference/ops/op_memcpy_host.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from onnx.reference.op_run import OpRun


class MemcpyFromHost(OpRun):
def _run(self, x):
return (x,)


class MemcpyToHost(OpRun):
def _run(self, x):
return (x,)