Skip to content

Commit 0edee0f

Browse files
committed
add to_dot
1 parent 8fa1944 commit 0edee0f

File tree

11 files changed

+812
-226
lines changed

11 files changed

+812
-226
lines changed

_doc/api/npx.rst

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -116,16 +116,16 @@ Var
116116
+++
117117

118118
.. autoclass:: onnx_array_api.npx.npx_var.Var
119-
:class:
119+
:members:
120120

121121
Cst, Input
122122
++++++++++
123123

124124
.. autoclass:: onnx_array_api.npx.npx_var.Cst
125-
:class:
125+
:members:
126126

127127
.. autoclass:: onnx_array_api.npx.npx_var.Input
128-
:class:
128+
:members:
129129

130130
API
131131
+++
@@ -150,16 +150,16 @@ JIT, Eager
150150
++++++++++
151151

152152
.. autoclass:: onnx_array_api.npx.npx_jit_eager.JitEager
153-
:class:
153+
:members:
154154

155155
.. autoclass:: onnx_array_api.npx.npx_jit_eager.JitOnnx
156-
:class:
156+
:members:
157157

158158
Tensors
159159
+++++++
160160

161161
.. autoclass:: onnx_array_api.npx.npx_tensors.NumpyTensor
162-
:class:
162+
:members:
163163

164164
Annotations
165165
+++++++++++

_doc/api/plotting.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,11 @@
33
plotting
44
========
55

6+
Dot
7+
+++
8+
9+
.. autofunction:: onnx_array_api.plotting.dot_plot.to_dot
10+
611
Text
712
++++
813

_doc/index.rst

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,3 @@ onnx-array-api: (Numpy) Array API for ONNX
3030
Sources available on
3131
`github/onnx-array-api <https://github.com/sdpython/onnx-array-api>`_.
3232

33-
.. toctree::
34-
35-
:maxdepth: 1
36-
37-
tutorial/index
38-
api/index

_unittests/ut__main/test_profiling.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,12 @@ def simple():
2626
rootrem = os.path.normpath(
2727
os.path.abspath(os.path.join(os.path.dirname(rootfile), ".."))
2828
)
29-
ps, res = profile(simple, rootrem=rootrem) # pylint: disable=W0632
29+
ps, res = profile(simple, rootrem=rootrem)
3030
res = res.replace("\\", "/")
3131
self.assertIn("function calls", res)
3232
self.assertNotEmpty(ps)
3333

34-
ps, res = profile(simple) # pylint: disable=W0632
34+
ps, res = profile(simple)
3535
res = res.replace("\\", "/")
3636
self.assertIn("function calls", res)
3737
self.assertNotEmpty(ps)
@@ -53,7 +53,7 @@ def simple2():
5353
rootrem = os.path.normpath(
5454
os.path.abspath(os.path.join(os.path.dirname(rootfile), ".."))
5555
)
56-
ps, df = profile(simple, rootrem=rootrem, as_df=True) # pylint: disable=W0632
56+
ps, df = profile(simple, rootrem=rootrem, as_df=True)
5757
self.assertIsInstance(df, pandas.DataFrame)
5858
self.assertEqual(df.loc[0, "namefct"].split("-")[-1], "simple")
5959
self.assertNotEmpty(ps)
@@ -89,7 +89,7 @@ def f4():
8989
f2()
9090
f3()
9191

92-
ps = profile(f4)[0] # pylint: disable=W0632
92+
ps = profile(f4)[0]
9393
df = self.capture(lambda: profile2df(ps, verbose=True, fLOG=print))[0]
9494
dfi = df.set_index("fct")
9595
self.assertEqual(dfi.loc["f4", "ncalls1"], 1)
@@ -122,7 +122,7 @@ def f4():
122122
f2()
123123
f3()
124124

125-
ps = profile(f4)[0] # pylint: disable=W0632
125+
ps = profile(f4)[0]
126126
profile2df(ps, verbose=False, clean_text=lambda x: x.split("/")[-1])
127127
root, nodes = profile2graph(ps, clean_text=lambda x: x.split("/")[-1])
128128
self.assertEqual(len(nodes), 6)
@@ -162,7 +162,7 @@ def f1(t):
162162
def f4():
163163
f1(0.3)
164164

165-
ps = profile(f4)[0] # pylint: disable=W0632
165+
ps = profile(f4)[0]
166166
profile2df(ps, verbose=False, clean_text=lambda x: x.split("/")[-1])
167167
root, nodes = profile2graph(ps, clean_text=lambda x: x.split("/")[-1])
168168
self.assertEqual(len(nodes), 4)
@@ -181,7 +181,7 @@ def f0(t):
181181
def f4():
182182
f0(0.15)
183183

184-
ps = profile(f4)[0] # pylint: disable=W0632
184+
ps = profile(f4)[0]
185185
profile2df(ps, verbose=False, clean_text=lambda x: x.split("/")[-1])
186186
root, nodes = profile2graph(ps, clean_text=lambda x: x.split("/")[-1])
187187
self.assertEqual(len(nodes), 3)
Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
@brief test log(time=2s)
4+
"""
5+
import os
6+
import unittest
7+
8+
import numpy
9+
from onnx import TensorProto, load
10+
from onnx.helper import (
11+
make_function,
12+
make_graph,
13+
make_model,
14+
make_node,
15+
make_opsetid,
16+
make_tensor_value_info,
17+
)
18+
from skl2onnx import to_onnx
19+
from skl2onnx.algebra.onnx_ops import (
20+
OnnxAdd,
21+
OnnxGreater,
22+
OnnxIf,
23+
OnnxLeakyRelu,
24+
OnnxReduceSum,
25+
OnnxSub,
26+
)
27+
from skl2onnx.common.data_types import FloatTensorType
28+
from sklearn.cluster import KMeans
29+
from sklearn.datasets import load_iris
30+
from sklearn.neighbors import RadiusNeighborsRegressor
31+
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
32+
33+
from onnx_array_api.ext_test_case import ExtTestCase, ignore_warnings
34+
from onnx_array_api.plotting.dot_plot import to_dot
35+
36+
TARGET_OPSET = 18
37+
38+
39+
class TestDotPlot(ExtTestCase):
40+
def test_onnx_text_plot_tree_reg(self):
41+
iris = load_iris()
42+
X, y = iris.data.astype(numpy.float32), iris.target
43+
clr = DecisionTreeRegressor(max_depth=3)
44+
clr.fit(X, y)
45+
onx = to_onnx(clr, X)
46+
dot = to_dot(onx)
47+
self.assertIn("X -> TreeEnsembleRegressor;", dot)
48+
49+
def test_onnx_text_plot_tree_cls(self):
50+
iris = load_iris()
51+
X, y = iris.data.astype(numpy.float32), iris.target
52+
clr = DecisionTreeClassifier(max_depth=3)
53+
clr.fit(X, y)
54+
onx = to_onnx(clr, X)
55+
dot = to_dot(onx)
56+
self.assertIn("X -> TreeEnsembleClassifier;", dot)
57+
58+
@ignore_warnings((UserWarning, FutureWarning))
59+
def test_to_dot_kmeans(self):
60+
x = numpy.random.randn(10, 3)
61+
model = KMeans(3)
62+
model.fit(x)
63+
onx = to_onnx(model, x.astype(numpy.float32), target_opset=15)
64+
dot = to_dot(onx)
65+
self.assertIn("Sq_Sqrt -> scores;", dot)
66+
67+
def test_to_dot_knnr(self):
68+
x = numpy.random.randn(10, 3)
69+
y = numpy.random.randn(10)
70+
model = RadiusNeighborsRegressor(3)
71+
model.fit(x, y)
72+
onx = to_onnx(model, x.astype(numpy.float32), target_opset=15)
73+
dot = to_dot(onx)
74+
self.assertIn("Di_Div -> Di_C0;", dot)
75+
76+
def test_to_dot_leaky(self):
77+
x = OnnxLeakyRelu("X", alpha=0.5, op_version=15, output_names=["Y"])
78+
onx = x.to_onnx(
79+
{"X": FloatTensorType()}, outputs={"Y": FloatTensorType()}, target_opset=15
80+
)
81+
dot = to_dot(onx)
82+
self.assertIn("Le_LeakyRelu -> Y;", dot)
83+
84+
def test_to_dot_if(self):
85+
86+
opv = TARGET_OPSET
87+
x1 = numpy.array([[0, 3], [7, 0]], dtype=numpy.float32)
88+
x2 = numpy.array([[1, 0], [2, 0]], dtype=numpy.float32)
89+
90+
node = OnnxAdd("x1", "x2", output_names=["absxythen"], op_version=opv)
91+
then_body = node.to_onnx(
92+
{"x1": x1, "x2": x2},
93+
target_opset=opv,
94+
outputs=[("absxythen", FloatTensorType())],
95+
)
96+
node = OnnxSub("x1", "x2", output_names=["absxyelse"], op_version=opv)
97+
else_body = node.to_onnx(
98+
{"x1": x1, "x2": x2},
99+
target_opset=opv,
100+
outputs=[("absxyelse", FloatTensorType())],
101+
)
102+
del else_body.graph.input[:]
103+
del then_body.graph.input[:]
104+
105+
cond = OnnxGreater(
106+
OnnxReduceSum("x1", op_version=opv),
107+
OnnxReduceSum("x2", op_version=opv),
108+
op_version=opv,
109+
)
110+
ifnode = OnnxIf(
111+
cond,
112+
then_branch=then_body.graph,
113+
else_branch=else_body.graph,
114+
op_version=opv,
115+
output_names=["y"],
116+
)
117+
model_def = ifnode.to_onnx(
118+
{"x1": x1, "x2": x2}, target_opset=opv, outputs=[("y", FloatTensorType())]
119+
)
120+
dot = to_dot(model_def)
121+
self.assertIn("If_If -> y;", dot)
122+
123+
def test_to_dot_if_recursive(self):
124+
125+
opv = TARGET_OPSET
126+
x1 = numpy.array([[0, 3], [7, 0]], dtype=numpy.float32)
127+
x2 = numpy.array([[1, 0], [2, 0]], dtype=numpy.float32)
128+
129+
node = OnnxAdd("x1", "x2", output_names=["absxythen"], op_version=opv)
130+
then_body = node.to_onnx(
131+
{"x1": x1, "x2": x2},
132+
target_opset=opv,
133+
outputs=[("absxythen", FloatTensorType())],
134+
)
135+
node = OnnxSub("x1", "x2", output_names=["absxyelse"], op_version=opv)
136+
else_body = node.to_onnx(
137+
{"x1": x1, "x2": x2},
138+
target_opset=opv,
139+
outputs=[("absxyelse", FloatTensorType())],
140+
)
141+
del else_body.graph.input[:]
142+
del then_body.graph.input[:]
143+
144+
cond = OnnxGreater(
145+
OnnxReduceSum("x1", op_version=opv),
146+
OnnxReduceSum("x2", op_version=opv),
147+
op_version=opv,
148+
)
149+
ifnode = OnnxIf(
150+
cond,
151+
then_branch=then_body.graph,
152+
else_branch=else_body.graph,
153+
op_version=opv,
154+
output_names=["y"],
155+
)
156+
model_def = ifnode.to_onnx(
157+
{"x1": x1, "x2": x2}, target_opset=opv, outputs=[("y", FloatTensorType())]
158+
)
159+
dot = to_dot(model_def, recursive=True)
160+
self.assertIn("If_If -> y;", dot)
161+
162+
@ignore_warnings((UserWarning, FutureWarning))
163+
def test_to_dot_kmeans_links(self):
164+
x = numpy.random.randn(10, 3)
165+
model = KMeans(3)
166+
model.fit(x)
167+
onx = to_onnx(model, x.astype(numpy.float32), target_opset=15)
168+
dot = to_dot(onx, recursive=True)
169+
self.assertIn("Sq_Sqrt -> scores;", dot)
170+
171+
def test_function_plot(self):
172+
new_domain = "custom"
173+
opset_imports = [make_opsetid("", 14), make_opsetid(new_domain, 1)]
174+
175+
node1 = make_node("MatMul", ["X", "A"], ["XA"])
176+
node2 = make_node("Add", ["XA", "B"], ["Y"])
177+
178+
linear_regression = make_function(
179+
new_domain, # domain name
180+
"LinearRegression", # function name
181+
["X", "A", "B"], # input names
182+
["Y"], # output names
183+
[node1, node2], # nodes
184+
opset_imports, # opsets
185+
[],
186+
) # attribute names
187+
188+
X = make_tensor_value_info("X", TensorProto.FLOAT, [None, None])
189+
A = make_tensor_value_info("A", TensorProto.FLOAT, [None, None])
190+
B = make_tensor_value_info("B", TensorProto.FLOAT, [None, None])
191+
Y = make_tensor_value_info("Y", TensorProto.FLOAT, None)
192+
193+
graph = make_graph(
194+
[
195+
make_node(
196+
"LinearRegression", ["X", "A", "B"], ["Y1"], domain=new_domain
197+
),
198+
make_node("Abs", ["Y1"], ["Y"]),
199+
],
200+
"example",
201+
[X, A, B],
202+
[Y],
203+
)
204+
205+
onnx_model = make_model(
206+
graph, opset_imports=opset_imports, functions=[linear_regression]
207+
) # functions to add)
208+
dot = to_dot(onnx_model, add_functions=True, recursive=True)
209+
self.assertIn("LinearRegression -> Y1;", dot)
210+
211+
def test_onnx_text_plot_tree_simple(self):
212+
iris = load_iris()
213+
X, y = iris.data.astype(numpy.float32), iris.target
214+
clr = DecisionTreeRegressor(max_depth=3)
215+
clr.fit(X, y)
216+
onx = to_onnx(clr, X)
217+
dot = to_dot(onx)
218+
self.assertIn("TreeEnsembleRegressor -> variable;", dot)
219+
220+
def test_simple_text_plot_bug(self):
221+
data = os.path.join(os.path.dirname(__file__), "data")
222+
onx_file = os.path.join(data, "tree_torch.onnx")
223+
onx = load(onx_file)
224+
dot = to_dot(onx)
225+
self.assertIn("onnx____ReduceSum_140 [shape=box", dot)
226+
227+
def test_simple_text_plot_ref_attr_name(self):
228+
data = os.path.join(os.path.dirname(__file__), "data")
229+
onx_file = os.path.join(data, "bug_Hardmax.onnx")
230+
onx = load(onx_file)
231+
dot = to_dot(onx)
232+
self.assertIn("Hardmax -> y;", dot)
233+
234+
235+
if __name__ == "__main__":
236+
unittest.main(verbosity=2)

_unittests/ut_plotting/test_text_plotting.py renamed to _unittests/ut_plotting/test_text_plot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
TARGET_OPSET = 18
4444

4545

46-
class TestPlotTextPlotting(ExtTestCase):
46+
class TestTextPlot(ExtTestCase):
4747
def test_onnx_text_plot_tree_reg(self):
4848
iris = load_iris()
4949
X, y = iris.data.astype(numpy.float32), iris.target

0 commit comments

Comments
 (0)