Skip to content

Commit 62e7caa

Browse files
committed
style, unit test
1 parent 5af717a commit 62e7caa

23 files changed

+658
-767
lines changed

_doc/conf.py

Lines changed: 37 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,41 @@
1-
#!/usr/bin/env python3
2-
# -*- coding: utf-8 -*-
3-
41
import os
52
import sys
6-
sys.path.insert(
7-
0,
8-
os.path.abspath(
9-
os.path.join(
10-
os.path.dirname(__file__),
11-
'..')))
12-
from td3a_cpp_deep import __version__ # noqa
3+
4+
from onnx_array_api import __version__
135

146
extensions = [
15-
'sphinx.ext.autodoc',
16-
'sphinx.ext.intersphinx',
17-
'sphinx.ext.todo',
18-
'sphinx.ext.coverage',
19-
'sphinx.ext.mathjax',
20-
'sphinx.ext.ifconfig',
21-
'sphinx.ext.viewcode',
22-
'sphinx.ext.githubpages',
23-
'sphinx_gallery.gen_gallery',
24-
'alabaster',
25-
'matplotlib.sphinxext.plot_directive',
26-
'pyquickhelper.sphinxext.sphinx_runpython_extension',
27-
'pyquickhelper.sphinxext.sphinx_epkg_extension',
7+
"sphinx.ext.autodoc",
8+
"sphinx.ext.intersphinx",
9+
"sphinx.ext.todo",
10+
"sphinx.ext.coverage",
11+
"sphinx.ext.mathjax",
12+
"sphinx.ext.ifconfig",
13+
"sphinx.ext.viewcode",
14+
"sphinx.ext.githubpages",
15+
"sphinx_gallery.gen_gallery",
16+
"matplotlib.sphinxext.plot_directive",
17+
"pyquickhelper.sphinxext.sphinx_runpython_extension",
18+
"pyquickhelper.sphinxext.sphinx_epkg_extension",
2819
]
2920

30-
templates_path = ['_templates']
31-
html_logo = '_static/logo.png'
32-
source_suffix = '.rst'
33-
master_doc = 'index'
34-
project = 'onnx-array-api'
35-
copyright = '2023, Xavier Dupré'
36-
author = 'Xavier Dupré'
21+
templates_path = ["_templates"]
22+
html_logo = "_static/logo.png"
23+
source_suffix = ".rst"
24+
master_doc = "index"
25+
project = "onnx-array-api"
26+
copyright = "2023, Xavier Dupré"
27+
author = "Xavier Dupré"
3728
version = __version__
3829
release = __version__
39-
language = 'en'
30+
language = "en"
4031
exclude_patterns = []
41-
pygments_style = 'sphinx'
32+
pygments_style = "sphinx"
4233
todo_include_todos = True
4334

4435
html_theme = "furo"
45-
html_theme_path = [alabaster.get_path()]
46-
36+
html_theme_path = ["_static"]
4737
html_theme_options = {}
48-
html_static_path = ['_static']
38+
html_static_path = ["_static"]
4939

5040

5141
intersphinx_mapping = {
@@ -60,20 +50,20 @@
6050

6151
sphinx_gallery_conf = {
6252
# path to your examples scripts
63-
'examples_dirs': os.path.join(os.path.dirname(__file__), 'examples'),
53+
"examples_dirs": os.path.join(os.path.dirname(__file__), "examples"),
6454
# path where to save gallery generated examples
65-
'gallery_dirs': 'auto_examples'
55+
"gallery_dirs": "auto_examples",
6656
}
6757

6858
epkg_dictionary = {
69-
'JIT': 'https://en.wikipedia.org/wiki/Just-in-time_compilation',
70-
'onnx': 'https://onnx.ai/onnx/',
71-
'ONNX': 'https://onnx.ai/',
72-
'onnxruntime': 'https://onnxruntime.ai/',
73-
'numpy': 'https://numpy.org/',
74-
'numba': 'https://numba.pydata.org/',
75-
'python': 'https://www.python.org/',
76-
'scikit-learn': 'https://scikit-learn.org/stable/',
77-
'sphinx-gallery': 'https://github.com/sphinx-gallery/sphinx-gallery',
78-
'torch': 'https://pytorch.org/docs/stable/torch.html',
59+
"JIT": "https://en.wikipedia.org/wiki/Just-in-time_compilation",
60+
"onnx": "https://onnx.ai/onnx/",
61+
"ONNX": "https://onnx.ai/",
62+
"onnxruntime": "https://onnxruntime.ai/",
63+
"numpy": "https://numpy.org/",
64+
"numba": "https://numba.pydata.org/",
65+
"python": "https://www.python.org/",
66+
"scikit-learn": "https://scikit-learn.org/stable/",
67+
"sphinx-gallery": "https://github.com/sphinx-gallery/sphinx-gallery",
68+
"torch": "https://pytorch.org/docs/stable/torch.html",
7969
}
Lines changed: 4 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -1,118 +1,15 @@
11
"""
22
3-
.. _l-example-dot-profile:
3+
.. _l-onnx-array-api-example:
44
55
Compares implementations for a Piecewise Linear
66
===============================================
77
8-
A pieceise linear function is implemented and trained
9-
following the tutorial :epkg:`Custom C++ and Cuda Extensions`.
8+
First example.
109
1110
.. contents::
1211
:local:
1312
14-
Piecewise linear regression
15-
+++++++++++++++++++++++++++
13+
One function
14+
++++++++++++
1615
"""
17-
import time
18-
import pandas
19-
import matplotlib.pyplot as plt
20-
import torch
21-
from td3a_cpp_deep.fcts.piecewise_linear import (
22-
PiecewiseLinearFunction,
23-
PiecewiseLinearFunctionC,
24-
PiecewiseLinearFunctionCBetter)
25-
26-
27-
def train_piecewise_linear(x, y, device, cls,
28-
max_iter=400, learning_rate=1e-4):
29-
30-
alpha_pos = torch.tensor([1], dtype=torch.float32).to(device)
31-
alpha_neg = torch.tensor([0.5], dtype=torch.float32).to(device)
32-
alpha_pos.requires_grad_()
33-
alpha_neg.requires_grad_()
34-
35-
losses = []
36-
fct = cls.apply
37-
38-
for t in range(max_iter):
39-
40-
y_pred = fct(x, alpha_neg, alpha_pos)
41-
loss = (y_pred - y).pow(2).sum()
42-
loss.backward()
43-
losses.append(loss)
44-
45-
with torch.no_grad():
46-
alpha_pos -= learning_rate * alpha_pos.grad
47-
alpha_neg -= learning_rate * alpha_neg.grad
48-
49-
# Manually zero the gradients after updating weights
50-
alpha_pos.grad.zero_()
51-
alpha_neg.grad.zero_()
52-
53-
return losses, alpha_neg, alpha_pos
54-
55-
56-
################################
57-
# Python implementation
58-
# +++++++++++++++++++++
59-
60-
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
61-
print("device:", device)
62-
x = torch.randn(100, 1, dtype=torch.float32)
63-
y = x * 0.2 + (x > 0).to(torch.float32) * x * 1.5 + torch.randn(100, 1) / 5
64-
x = x.to(device).requires_grad_()
65-
y = y.to(device).requires_grad_()
66-
67-
begin = time.perf_counter()
68-
losses, alpha_neg, alpha_pos = train_piecewise_linear(
69-
x, y, device, PiecewiseLinearFunction)
70-
end = time.perf_counter()
71-
print(f"duration={end - begin}, alpha_neg={alpha_neg} "
72-
f"alpha_pos={alpha_pos}")
73-
74-
################################
75-
# C++ implementation
76-
# ++++++++++++++++++
77-
78-
begin = time.perf_counter()
79-
losses, alpha_neg, alpha_pos = train_piecewise_linear(
80-
x, y, device, PiecewiseLinearFunctionC)
81-
end = time.perf_counter()
82-
print(f"duration={end - begin}, alpha_neg={alpha_neg} "
83-
f"alpha_pos={alpha_pos}")
84-
85-
################################
86-
# C++ implementation, second try
87-
# ++++++++++++++++++++++++++++++
88-
89-
begin = time.perf_counter()
90-
losses, alpha_neg, alpha_pos = train_piecewise_linear(
91-
x, y, device, PiecewiseLinearFunctionCBetter)
92-
end = time.perf_counter()
93-
print(f"duration={end - begin}, alpha_neg={alpha_neg} "
94-
f"alpha_pos={alpha_pos}")
95-
96-
#################################
97-
# The C++ implementation is very close to the python code.
98-
# The second implementation in C++ is faster because
99-
# it reuses created tensors.
100-
101-
##################################
102-
# Graphs
103-
# ++++++
104-
105-
df = pandas.DataFrame()
106-
df['x'] = x.cpu().detach().numpy().ravel()
107-
df['y'] = y.cpu().detach().numpy().ravel()
108-
df['yp'] = PiecewiseLinearFunction.apply(
109-
x, alpha_neg, alpha_pos).cpu().detach().numpy()
110-
111-
fig, ax = plt.subplots(1, 2, figsize=(10, 4))
112-
df.plot.scatter(x="x", y='y', label="y", color="blue", ax=ax[0])
113-
df.plot.scatter(x="x", y='yp', ax=ax[0], label="yp", color="orange")
114-
ax[1].plot([float(lo.detach()) for lo in losses], label="loss")
115-
ax[1].legend()
116-
117-
118-
# plt.show()

_doc/tutorial/overview.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ is available with class `OrtTensor
106106
from onnxruntime.capi._pybind_state import OrtDevice as C_OrtDevice
107107
from onnxruntime.capi._pybind_state import OrtMemType
108108
from onnxruntime.capi._pybind_state import (
109-
OrtValue as C_OrtValue, # pylint: disable=E0611
109+
OrtValue as C_OrtValue,
110110
)
111111
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidArgument
112112

0 commit comments

Comments
 (0)