-
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path__main__.py
More file actions
52 lines (39 loc) Β· 1.99 KB
/
__main__.py
File metadata and controls
52 lines (39 loc) Β· 1.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from __future__ import annotations
import argparse
import ast
from collections.abc import Callable
from expr_simplifier import __version__
from expr_simplifier.transforms import apply_constant_folding, apply_cse, apply_logical_simplification
from expr_simplifier.typing import Pass
from expr_simplifier.utils import loop_until_stable
def create_pass_command(name: str, passes: list[Pass]) -> Callable[[argparse.Namespace], None]:
def pass_command(args: argparse.Namespace) -> None:
expr = ast.parse(args.input, mode="eval")
simplified_expr = loop_until_stable(expr, passes, args.max_iter)
print(ast.unparse(simplified_expr))
pass_command.__name__ = name
return pass_command
def create_pass_parser(
name: str,
passes: list[Pass],
description: str,
subparser: argparse._SubParsersAction[argparse.ArgumentParser], # pyright: ignore [reportPrivateUsage]
) -> None:
parser = subparser.add_parser(name, help=description)
parser.add_argument("input", help="The expression to simplify")
parser.add_argument("--max-iter", type=int, default=100, help="The maximum number of iterations")
parser.set_defaults(func=create_pass_command(name, passes))
def main() -> None:
parser = argparse.ArgumentParser(prog="moelib", description="A moe moe project")
parser.add_argument("-v", "--version", action="version", version=__version__)
sub_parsers = parser.add_subparsers(help="sub-command help", dest="sub_command")
create_pass_parser("cse", [apply_cse], "Common Subexpression Elimination", sub_parsers)
create_pass_parser("constant_folding", [apply_constant_folding], "Constant Folding", sub_parsers)
create_pass_parser("logical_simplification", [apply_logical_simplification], "Logical Simplification", sub_parsers)
create_pass_parser(
"auto", [apply_constant_folding, apply_logical_simplification, apply_cse], "Auto Simplification", sub_parsers
)
args = parser.parse_args()
args.func(args)
if __name__ == "__main__":
main()