Skip to content
Next Next commit
modify script
  • Loading branch information
Akuli committed Dec 16, 2021
commit 558faba999d271a9835713ef3907518da9cce60f
18 changes: 14 additions & 4 deletions tests/check_new_syntax.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def is_dotdotdot(node: ast.AST) -> bool:
def add_contextlib_alias_error(node: ast.ImportFrom | ast.Attribute, alias: str) -> None:
errors.append(f"{path}:{node.lineno}: Use `contextlib.{CONTEXT_MANAGER_ALIASES[alias]}` instead of `typing.{alias}`")

# TODO: Get rid of this class. It skips checking type aliases and base
# classes, but those can now be checked too, with new mypy version.
class OldSyntaxFinder(ast.NodeVisitor):
def __init__(self, *, set_from_collections_abc: bool) -> None:
self.set_from_collections_abc = set_from_collections_abc
Expand Down Expand Up @@ -67,10 +69,6 @@ def visit_Subscript(self, node: ast.Subscript) -> None:

self.generic_visit(node)

# This doesn't check type aliases (or type var bounds, etc), since those are not
# currently supported
#
# TODO: can use built-in generics in type aliases
class AnnotationFinder(ast.NodeVisitor):
def __init__(self) -> None:
self.set_from_collections_abc = False
Expand All @@ -79,6 +77,16 @@ def old_syntax_finder(self) -> OldSyntaxFinder:
"""Convenience method to create an `OldSyntaxFinder` instance with the correct state"""
return OldSyntaxFinder(set_from_collections_abc=self.set_from_collections_abc)

def visit_Subscript(self, node: ast.Subscript) -> None:
if isinstance(node.value, ast.Name):
if node.value.id == "Union" and isinstance(node.slice, ast.Tuple):
new_syntax = " | ".join(ast.unparse(x) for x in node.slice.elts)
errors.append(f"{path}:{node.lineno}: Use PEP 604 syntax for Union, e.g. `{new_syntax}`")
if node.value.id == "Optional":
new_syntax = f"{ast.unparse(node.slice)} | None"
errors.append(f"{path}:{node.lineno}: Use PEP 604 syntax for Optional, e.g. `{new_syntax}`")
self.generic_visit(node)

def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
if node.module == "collections.abc":
imported_classes = node.names
Expand All @@ -102,10 +110,12 @@ def visit_Attribute(self, node: ast.Attribute) -> None:

def visit_AnnAssign(self, node: ast.AnnAssign) -> None:
self.old_syntax_finder().visit(node.annotation)
self.generic_visit(node)

def visit_arg(self, node: ast.arg) -> None:
if node.annotation is not None:
self.old_syntax_finder().visit(node.annotation)
self.generic_visit(node)

def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
if node.returns is not None:
Expand Down