Skip to content

Commit d229036

Browse files
authored
Improves docassert with C++ functions (#18)
* Improves docassrt with C++ functions * catch more exceptions * fix one last issue
1 parent e04798d commit d229036

File tree

3 files changed

+150
-24
lines changed

3 files changed

+150
-24
lines changed
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import unittest
2+
from sphinx_runpython.ext_test_case import ExtTestCase
3+
from sphinx_runpython.import_object_helper import import_object
4+
from sphinx_runpython.docassert.sphinx_docassert_extension import parse_signature
5+
6+
7+
class TestDocAssertCpp(ExtTestCase):
8+
def test_import_object1(self):
9+
name = "onnx_extended.ortcy.wrap.ortinf.OrtSession"
10+
try:
11+
obj, new_name = import_object(name, kind="class")
12+
except (ImportError, RuntimeError):
13+
return
14+
self.assertEqual(name.split(".")[-1], new_name)
15+
self.assertEqual(obj.__text_signature__, "($self, /, *args, **kwargs)")
16+
17+
def test_import_object2(self):
18+
name = "onnx_extended.validation.cpu._validation.benchmark_cache"
19+
try:
20+
obj, new_name = import_object(name, kind="function")
21+
except (ImportError, RuntimeError):
22+
return
23+
self.assertEqual(name.split(".")[-1], new_name)
24+
self.assertEmpty(obj.__text_signature__)
25+
sig = parse_signature(obj.__doc__)
26+
self.assertEqual(
27+
repr(sig), "benchmark_cache(size: int, verbose: bool = True) -> float"
28+
)
29+
self.assertIn("size", sig.param_names)
30+
31+
def test_import_object3(self):
32+
name = "onnx_extended.validation.cython.vector_function_cy.vector_add_c"
33+
try:
34+
obj, new_name = import_object(name, kind="function")
35+
except (ImportError, RuntimeError):
36+
return
37+
self.assertEqual(name.split(".")[-1], new_name)
38+
self.assertIn("vector_add_c(v1, v2)", obj.__doc__)
39+
sig = parse_signature(obj.__doc__)
40+
self.assertEqual(repr(sig), "vector_add_c(v1, v2)")
41+
self.assertIn("v1", sig.param_names)
42+
43+
def test_extract_signature(self):
44+
sig = (
45+
"benchmark_cache(size: int, verbose: bool = True) -> float\n\n "
46+
"Runs a benchmark to measure the cache performance.\nThe function "
47+
"measures the time for N random accesses in array of size N\nand "
48+
"returns the time divided by N.\nIt copies random elements taken "
49+
"from the array size to random\nposition in another of the same size. "
50+
"It does that *size* times\nand return the average time per move."
51+
"\nSee example :ref:`l-example-bench-cpu`.\n\n"
52+
":param size: array size\n:return: average time per move\n\n'"
53+
)
54+
res = parse_signature(sig)
55+
self.assertEqual(
56+
repr(res), "benchmark_cache(size: int, verbose: bool = True) -> float"
57+
)
58+
59+
60+
if __name__ == "__main__":
61+
unittest.main(verbosity=2)

sphinx_runpython/docassert/sphinx_docassert_extension.py

Lines changed: 86 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import sphinx
55
from sphinx.util import logging
66
from sphinx.util.docfields import DocFieldTransformer, _is_single_paragraph
7-
from ..import_object_helper import import_any_object
7+
from ..import_object_helper import import_any_object, import_object
88

99

1010
class Parameter:
@@ -15,6 +15,8 @@ def __init__(self, name: str, dtype: type):
1515
self.dtype = dtype
1616

1717
def __repr__(self):
18+
if self.dtype is None:
19+
return self.name
1820
return f"{self.name}: {self.dtype}"
1921

2022

@@ -35,23 +37,33 @@ def __repr__(self):
3537
for p in self.params:
3638
ps.append(repr(p))
3739
els.append(", ".join(ps))
38-
els.extend([")", " -> ", self.result_type])
40+
if self.result_type is None:
41+
els.append(")")
42+
else:
43+
els.extend([")", " -> ", self.result_type])
3944
return "".join(els)
4045

46+
@property
47+
def param_names(self):
48+
return set(el.name for el in self.params)
49+
4150

4251
def parse_signature(text: str) -> Signature:
43-
reg = re.compile("([_a-zA-Z][_a-zA-Z0-9]*?)[(](.*?)[)] -> ([a-zA-Z0-9]+)")
52+
reg = re.compile("([_a-zA-Z][_a-zA-Z0-9]*?)[(](.*?)[)]( -> ([a-zA-Z0-9]+))?")
4453
res = reg.search(text)
45-
name, params, result = res.groups(0)
54+
if res is None:
55+
return None
56+
name, params, _, result = res.groups()
4657
spl = [_.strip() for _ in params.split(",")]
47-
sig = Signature(name.strip(), result.strip())
58+
sig = Signature(name.strip(), result.strip() if result is not None else None)
4859
for p in spl:
49-
k, v = p.split(":", maxsplit=1)
50-
sig.append(Parameter(k.strip(), v.strip()))
60+
if ":" in p:
61+
k, v = p.split(":", maxsplit=1)
62+
sig.append(Parameter(k.strip(), v.strip()))
63+
else:
64+
sig.append(Parameter(p.strip(), None))
5165
return sig
5266

53-
print(params, result)
54-
5567

5668
def check_typed_make_field(
5769
self,
@@ -110,13 +122,25 @@ def check_item(fieldarg, content, logger):
110122
"local function"
111123
if fieldarg not in check_params:
112124
if function_name is not None:
113-
logger.warning(
114-
"[docassert] %r has no parameter %r (in %r)%s.",
115-
function_name,
116-
fieldarg,
117-
docname,
118-
" (no detected signature) " if parameters is None else "",
125+
idocname = (
126+
docname.replace(".PyCapsule.", ".")
127+
if ".PyCapsule." in docname
128+
else docname
119129
)
130+
if kind is None:
131+
obj = import_any_object(idocname)
132+
else:
133+
obj = import_object(idocname, kind=kind)
134+
tsig = getattr(obj[0], "__text_signature__")
135+
if tsig != "($self, /, *args, **kwargs)":
136+
logger.warning(
137+
"[docassert] %r has no parameter %r (in %r) [sig=%r]%s.",
138+
function_name,
139+
fieldarg,
140+
docname,
141+
tsig,
142+
" (no detected signature) " if parameters is None else "",
143+
)
120144
else:
121145
check_params[fieldarg] += 1
122146
if check_params[fieldarg] > 1:
@@ -139,12 +163,42 @@ def check_item(fieldarg, content, logger):
139163
# Behavior should be improved.
140164
pass
141165
else:
142-
logger.warning(
143-
"[docassert] %r has undocumented parameters %r (in %r).",
144-
function_name,
145-
", ".join(nodoc),
146-
docname,
166+
idocname = (
167+
docname.replace(".PyCapsule.", ".")
168+
if ".PyCapsule." in docname
169+
else docname
147170
)
171+
if kind is None:
172+
obj = import_any_object(idocname)
173+
else:
174+
obj = import_object(idocname, kind=kind)
175+
tsig = getattr(obj[0], "__text_signature__", None)
176+
if tsig != "($self, /, *args, **kwargs)":
177+
if tsig is None:
178+
alt_sig = parse_signature(obj[0].__doc__)
179+
if alt_sig is None:
180+
nodoc2 = nodoc
181+
else:
182+
ps = alt_sig.param_names
183+
nodoc2 = [n for n in nodoc if n not in ps]
184+
if len(nodoc2) > 0:
185+
logger.warning(
186+
"[docassert] %r has undocumented parameters (1) "
187+
"[%s] (in %r) [sig=%r].",
188+
function_name,
189+
", ".join(nodoc2),
190+
docname,
191+
tsig,
192+
)
193+
else:
194+
logger.warning(
195+
"[docassert] %r has undocumented parameters (2) "
196+
"[%s] (in %r) [sig=%r].",
197+
function_name,
198+
", ".join(nodoc),
199+
docname,
200+
tsig,
201+
)
148202
else:
149203
# Documentation related to the return.
150204
pass
@@ -278,7 +332,6 @@ def override_transform(self, other_self, node):
278332
docs,
279333
reasons,
280334
)
281-
myfunc = None
282335

283336
if myfunc is None:
284337
signature = None
@@ -290,9 +343,18 @@ def override_transform(self, other_self, node):
290343
except (TypeError, ValueError):
291344
# built-in function
292345
logger = logging.getLogger("docassert")
293-
logger.warning("[docassert] unable to get signature of %r", docs)
294-
signature = None
295-
parameters = None
346+
if myfunc.__text_signature__:
347+
logger.warning(
348+
"[docassert] unable to get signature (1) of %r: %s",
349+
docs,
350+
myfunc.__text_signature__,
351+
)
352+
signature = None
353+
parameters = None
354+
else:
355+
alt_sig = parse_signature(myfunc.__doc__)
356+
signature = alt_sig
357+
parameters = alt_sig.params
296358

297359
# grouped entries need to be collected in one entry, while others
298360
# get one entry per field

sphinx_runpython/import_object_helper.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ def import_object(docname, kind, use_init=True) -> Tuple[object, str]:
6262
not inspect.isfunction(myfunc)
6363
and "built-in function" not in str(myfunc)
6464
and "built-in method" not in str(myfunc)
65+
and (
66+
not hasattr(myfunc, "func_code") or ".pyx" not in str(myfunc.func_code)
67+
)
6568
):
6669
# inspect.isfunction fails for C functions.
6770
raise TypeError(f"'{docname}' is not a function")

0 commit comments

Comments
 (0)