11import argparse
2+ import platform
23
34def parse_args ():
45 parser = argparse .ArgumentParser (description = "Fix test." )
56 parser .add_argument ("--test" , type = str , help = "Name of test" )
67 parser .add_argument ("--path" , type = str , help = "Path to test file" )
78 parser .add_argument ("--force" , action = "store_true" , help = "Force modification" )
9+ parser .add_argument ("--platform" , action = "store_true" , help = "Platform specific failure" )
810
911 args = parser .parse_args ()
1012 return args
@@ -20,6 +22,7 @@ def __str__(self):
2022class TestResult :
2123 tests_result : str = ""
2224 tests = []
25+ stdout = ""
2326
2427 def __str__ (self ):
2528 return f"TestResult(tests_result={ self .tests_result } ,tests={ len (self .tests )} )"
@@ -28,6 +31,7 @@ def __str__(self):
2831def parse_results (result ):
2932 lines = result .stdout .splitlines ()
3033 test_results = TestResult ()
34+ test_results .stdout = result .stdout
3135 in_test_results = False
3236 for line in lines :
3337 if line == "Run tests sequentially" :
@@ -52,14 +56,20 @@ def parse_results(result):
5256def path_to_test (path ) -> list [str ]:
5357 return path .split ("." )[2 :]
5458
55- def modify_test (file : str , test : list [str ]) -> str :
59+ def modify_test (file : str , test : list [str ], for_platform : bool = False ) -> str :
5660 lines = file .splitlines ()
5761 result = []
62+ failure_fixture = "expectedFailure"
63+ if for_platform :
64+ if platform .system () == "Windows" :
65+ failure_fixture = "expectedFailureIfWindows(\" TODO: RUSTPYTHON: Generated by fix_test script\" )"
66+ else :
67+ raise Exception ("Platform not supported" )
5868 for line in lines :
5969 if line .lstrip (" " ).startswith ("def " + test [- 1 ]):
6070 whitespace = line [:line .index ("def " )]
6171 result .append (whitespace + "# TODO: RUSTPYTHON" )
62- result .append (whitespace + "@unittest.expectedFailure " )
72+ result .append (whitespace + f "@unittest.{ failure_fixture } " )
6373 result .append (line )
6474 return "\n " .join (result )
6575
@@ -77,11 +87,9 @@ def run_test(test_name):
7787 tests = run_test (test_name )
7888 f = open (args .path ).read ()
7989 for test in tests .tests :
80- if test .result == "fail" :
90+ if test .result == "fail" or test . result == "error" :
8191 print ("Modifying test:" , test .name )
82- f = modify_test (f , path_to_test (test .path ))
92+ f = modify_test (f , path_to_test (test .path ), args . platform )
8393 with open (args .path , "w" ) as file :
84- if args .force or run_test ().tests_result == "ok" :
85- file .write (f )
86- else :
87- raise Exception ("Test failed after modification" )
94+ # TODO: Find validation method, and make --force override it
95+ file .write (f )
0 commit comments