Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
handle classes properly
  • Loading branch information
arihant2math committed Apr 24, 2025
commit 452e62917b486e67f66a08e1e31d651464576578
34 changes: 33 additions & 1 deletion scripts/fix_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import ast
import itertools
import platform
from pathlib import Path

Expand Down Expand Up @@ -71,6 +72,37 @@ def modify_test(file: str, test: list[str], for_platform: bool = False) -> str:
break
return "\n".join(lines)

def modify_test_v2(file: str, test: list[str], for_platform: bool = False) -> str:
a = ast.parse(file)
lines = file.splitlines()
fixture = "@unittest.expectedFailure"
for key, node in ast.iter_fields(a):
if key == "body":
for i, n in enumerate(node):
match n:
case ast.ClassDef():
if len(test) == 2 and test[0] == n.name:
# look through body for function def
for i, fn in enumerate(n.body):
match fn:
case ast.FunctionDef():
if fn.name == test[-1]:
assert not for_platform
indent = " " * fn.col_offset
lines.insert(fn.lineno - 1, indent + fixture)
lines.insert(fn.lineno - 1, indent + "# TODO: RUSTPYTHON")
break
case ast.FunctionDef():
if n.name == test[0] and len(test) == 1:
assert not for_platform
indent = " " * n.col_offset
lines.insert(n.lineno - 1, indent + fixture)
lines.insert(n.lineno - 1, indent + "# TODO: RUSTPYTHON")
break
if i > 500:
exit()
return "\n".join(lines)

def run_test(test_name):
print(f"Running test: {test_name}")
rustpython_location = "./target/release/rustpython"
Expand All @@ -87,7 +119,7 @@ def run_test(test_name):
for test in tests.tests:
if test.result == "fail" or test.result == "error":
print("Modifying test:", test.name)
f = modify_test(f, path_to_test(test.path), args.platform)
f = modify_test_v2(f, path_to_test(test.path), args.platform)
with open(args.path, "w") as file:
# TODO: Find validation method, and make --force override it
file.write(f)
Loading