-
Notifications
You must be signed in to change notification settings - Fork 284
Expand file tree
/
Copy pathgenerate_expr_type_tests.py
More file actions
320 lines (274 loc) · 11 KB
/
Copy pathgenerate_expr_type_tests.py
File metadata and controls
320 lines (274 loc) · 11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
"""
This script generates test cases for expression arithmetic type annotations.
It evaluates the runtime output of all combinations of arithmetic operations between different types and generates test cases
that check whether the static type annotations match the actual runtime types of the results.
"""
import argparse
import logging
import itertools
import operator
from pathlib import Path
import pyscipopt
logger = logging.getLogger(__name__)
INDENT = " " * 4
# Initial lines at the start of the generated test file.
GLOBAL_STATEMENTS = [
"# @generated by scripts/generate_expr_type_tests.py - do not edit manually",
"",
"import decimal",
"import random",
"",
"import numpy",
"from typing_extensions import assert_type",
"",
"import pyscipopt.scip",
"",
"",
"model = pyscipopt.scip.Model()",
]
# Expressions to test, mapped from a name to the expression that will be evaluated at runtime to get the value for that name.
# These should cover all interesting types for arithmetic operations.
# Order matters since later expressions can refer to previously defined names.
EXPRESSIONS = {
# Variables
"var": "model.addVar()",
"mvar1d": "model.addMatrixVar(3)",
"mvar2d": "model.addMatrixVar((3, 3))",
"term": "pyscipopt.scip.Term(var)",
# Expressions
"constant": "pyscipopt.scip.Constant(-2.0)",
"expr": "var + 1",
"matrix_expr": "mvar2d * 2",
"sum_expr": "var + constant",
"prod_expr": "var * constant",
"pow_expr": "prod_expr**2",
"unary_expr": "abs(var)",
"var_expr": "pyscipopt.scip.VarExpr(var)",
# Constraints
"exprcons": "var <= 3",
"matrixexprcons": "mvar1d <= 3",
# Builtin numbers
"integer": "random.randint(1, 10)",
"floating_point": "random.random()",
"dec": 'decimal.Decimal("1.0")',
# NumPy arrays
"np_float": "numpy.float64(3.0)",
"array0d": "numpy.array(1)",
"array1d": "numpy.array([1, 2, 3])",
"array2d": "numpy.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])",
}
# Mappings from operator symbols to their corresponding operator functions.
# No spaces are added, so spacing must be added to the operator symbols.
BINARY_OPERATORS = {
" + ": operator.add,
" - ": operator.sub,
" * ": operator.mul,
" / ": operator.truediv,
"**": operator.pow,
" < ": operator.lt,
" <= ": operator.le,
" > ": operator.gt,
" >= ": operator.ge,
" == ": operator.eq,
" != ": operator.ne,
" @ ": operator.matmul,
}
INPLACE_BINARY_OPERATORS = {
"+=": operator.iadd,
"-=": operator.isub,
"*=": operator.imul,
"/=": operator.itruediv,
"**=": operator.ipow,
"@=": operator.imatmul,
}
# Operator function and string with a formatting placeholder for the operation.
UNARY_OPERATORS = [
("+{}", operator.pos),
("-{}", operator.neg),
("abs({})", abs),
("pyscipopt.exp({})", pyscipopt.exp),
("pyscipopt.log({})", pyscipopt.log),
("pyscipopt.sqrt({})", pyscipopt.sqrt),
("pyscipopt.sin({})", pyscipopt.sin),
("pyscipopt.cos({})", pyscipopt.cos),
("{}.sum()", lambda x: x.sum()),
("{}.sum(axis=-1)", lambda x: x.sum(axis=-1)),
]
def build_runtime_values(expressions: dict[str, str]) -> dict[str, object]:
"""Evaluate the expressions and return a mapping from expression names to their runtime values.
Expressions are evaluated in order, so that later expressions can refer to previously defined members.
"""
eval_scope = {}
for statement in GLOBAL_STATEMENTS:
logger.debug(f"Executing statement: {statement}")
exec(statement, {}, eval_scope)
for name, expr in expressions.items():
logger.debug(f"Evaluating expression for {name}: {expr}")
eval_scope[name] = eval(expr, {}, eval_scope)
return eval_scope
def generate_erroring_line(
expr: str,
error: Exception,
indent: str = "",
inplace: bool = False,
) -> str:
"""Generate a line in the generated test file for an expression that produces a runtime error.
Expressions that error at runtime have a type ignore comment to indicate that they are expected to produce a type error.
"""
# Add a fake assignment to prevent "unused expression" errors
expr = f"{indent}{expr}" if inplace else f"{indent}_ = {expr}"
error_message = str(error).replace("\n", "").strip()
return f"{expr} # type: ignore # {error.__class__.__name__}: {error_message}"
def type_name(obj: object) -> str:
"""Get the fully qualified type name of an object."""
if obj.__class__.__module__ == "builtins":
return obj.__class__.__name__
return f"{obj.__class__.__module__}.{obj.__class__.__name__}"
def generate_result_expectation(expr: str, result: object, indent: str = "") -> str:
"""Generate a line in the generated test file for an expression that produces a result without error.
The result type at runtime is used in an `assert_type` call to check that the static type annotations match the actual runtime type of the result.
"""
runtime_type_name = type_name(result)
return f"{indent}assert_type({expr}, {runtime_type_name})"
def no_pyscipopt_objs(*objs: object) -> bool:
"""Check is there are no objects from the `pyscipopt` module in the given objects.
If so, we can skip generating test cases for the expression since it won't involve any `pyscipopt` types that we care about testing.
"""
return not any(obj.__class__.__module__.startswith("pyscipopt") for obj in objs)
def generate_test_cases():
"""Build the test file content by evaluating the expressions and generating test cases for their results.
There are 4 phases:
1. Evaluate all the expressions in `EXPRESSIONS` and store their runtime values.
2. Generate test cases for unary operators applied to each expression in `EXPRESSIONS`.
3. Generate test cases for binary operators applied to all pairs of expressions in `EXPRESSIONS`.
4. Generate test cases for inplace binary operators applied to all pairs of expressions in `EXPRESSIONS`.
"""
runtime_values = build_runtime_values(EXPRESSIONS)
lines = [*GLOBAL_STATEMENTS, "", ""]
for name, expr in EXPRESSIONS.items():
# Define the value from the expression
lines.append(f"{name} = {expr}")
# Check it has the expected type
lines.append(generate_result_expectation(name, runtime_values[name]))
lines.extend(
[
"",
"###################",
"# Unary operators #",
"###################",
"",
]
)
for name in EXPRESSIONS:
if no_pyscipopt_objs(runtime_values[name]):
continue
lines.extend([f"# Unary operators for {name}", ""])
success_lines = []
failure_lines = []
for op_repr, op_func in UNARY_OPERATORS:
expr = op_repr.format(name)
logger.debug(f"Evaluating unary operator {expr}")
try:
result = op_func(runtime_values[name])
except Exception as e:
failure_lines.append(generate_erroring_line(expr, e))
else:
success_lines.append(generate_result_expectation(expr, result))
if success_lines:
lines.extend([*success_lines, ""])
if failure_lines:
lines.extend([*failure_lines, ""])
lines.extend(
[
"####################",
"# Binary operators #",
"####################",
"",
]
)
for left, right in itertools.product(EXPRESSIONS, repeat=2):
if no_pyscipopt_objs(runtime_values[left], runtime_values[right]):
continue
lines.extend([f"# Binary operators for {left} and {right}", ""])
success_lines = []
failure_lines = []
for op_symbol, op_func in BINARY_OPERATORS.items():
logger.debug(
f"Evaluating binary operator {op_symbol} for {left} and {right}"
)
expr = f"{left}{op_symbol}{right}"
try:
result = op_func(runtime_values[left], runtime_values[right])
except Exception as e:
failure_lines.append(generate_erroring_line(expr, e))
else:
success_lines.append(generate_result_expectation(expr, result))
if success_lines:
lines.extend([*success_lines, ""])
if failure_lines:
lines.extend([*failure_lines, ""])
lines.extend(
[
"#####################",
"# Inplace operators #",
"#####################",
]
)
for left, right in itertools.product(EXPRESSIONS, repeat=2):
if no_pyscipopt_objs(runtime_values[left], runtime_values[right]):
continue
lines.extend(["", f"# Inplace operators for {left} and {right}", ""])
for op_symbol, op_func in INPLACE_BINARY_OPERATORS.items():
logger.debug(
f"Evaluating inplace binary operator {op_symbol} for {left} and {right}"
)
# For inplace tests, the target gets modified and can change type.
# To avoid influencing other tests, we wrap each case in a function
# and create a fresh target for the test.
# The function simply calls the inplace operator on the target (left)
# and then checks what type it has after the operation.
target_name = f"{left}_{op_func.__name__}_{right}"
stmt = f"{target_name} {op_symbol} {right}"
lines.extend(
[
"",
f"def test_inplace_{target_name}() -> None:",
f"{INDENT}{target_name} = {EXPRESSIONS[left]}",
]
)
function_scope_locals = runtime_values.copy()
# 1. create the temporary target
exec(f"{target_name} = {EXPRESSIONS[left]}", {}, function_scope_locals)
try:
# 2. apply the inplace operator
exec(stmt, {}, function_scope_locals)
except Exception as e:
lines.append(
generate_erroring_line(stmt, e, indent=INDENT, inplace=True)
)
else:
# 3. fetch the resulting type
new_type = function_scope_locals[target_name]
lines.append(f"{INDENT}{stmt}")
lines.append(
generate_result_expectation(target_name, new_type, indent=INDENT)
)
lines.append("")
return "\n".join(lines)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--output",
"-o",
type=Path,
default=Path(__file__).parent.parent / "tests" / "@types" / "expr.py",
)
parser.add_argument("-v", "--verbose", action="store_true", default=0)
args = parser.parse_args()
logging.basicConfig()
logger.setLevel(logging.DEBUG if args.verbose else logging.WARNING)
test_cases = generate_test_cases()
target = Path(args.output)
target.parent.mkdir(parents=True, exist_ok=True)
with target.open("w") as f:
f.write(test_cases)