11import argparse
22import ast
3+ import itertools
34import platform
45from pathlib import Path
56
@@ -71,6 +72,37 @@ def modify_test(file: str, test: list[str], for_platform: bool = False) -> str:
7172 break
7273 return "\n " .join (lines )
7374
75+ def modify_test_v2 (file : str , test : list [str ], for_platform : bool = False ) -> str :
76+ a = ast .parse (file )
77+ lines = file .splitlines ()
78+ fixture = "@unittest.expectedFailure"
79+ for key , node in ast .iter_fields (a ):
80+ if key == "body" :
81+ for i , n in enumerate (node ):
82+ match n :
83+ case ast .ClassDef ():
84+ if len (test ) == 2 and test [0 ] == n .name :
85+ # look through body for function def
86+ for i , fn in enumerate (n .body ):
87+ match fn :
88+ case ast .FunctionDef ():
89+ if fn .name == test [- 1 ]:
90+ assert not for_platform
91+ indent = " " * fn .col_offset
92+ lines .insert (fn .lineno - 1 , indent + fixture )
93+ lines .insert (fn .lineno - 1 , indent + "# TODO: RUSTPYTHON" )
94+ break
95+ case ast .FunctionDef ():
96+ if n .name == test [0 ] and len (test ) == 1 :
97+ assert not for_platform
98+ indent = " " * n .col_offset
99+ lines .insert (n .lineno - 1 , indent + fixture )
100+ lines .insert (n .lineno - 1 , indent + "# TODO: RUSTPYTHON" )
101+ break
102+ if i > 500 :
103+ exit ()
104+ return "\n " .join (lines )
105+
74106def run_test (test_name ):
75107 print (f"Running test: { test_name } " )
76108 rustpython_location = "./target/release/rustpython"
@@ -87,7 +119,7 @@ def run_test(test_name):
87119 for test in tests .tests :
88120 if test .result == "fail" or test .result == "error" :
89121 print ("Modifying test:" , test .name )
90- f = modify_test (f , path_to_test (test .path ), args .platform )
122+ f = modify_test_v2 (f , path_to_test (test .path ), args .platform )
91123 with open (args .path , "w" ) as file :
92124 # TODO: Find validation method, and make --force override it
93125 file .write (f )
0 commit comments