33It adds @unittest.expectedFailure to the test functions that are failing in RustPython, but not in CPython.
44As well as marking the test with a TODO comment.
55
6- How to use:
6+ Quick Import (recommended):
7+ python ./scripts/fix_test.py --quick-import cpython/Lib/test/test_foo.py
8+
9+ This will:
10+ 1. Copy cpython/Lib/test/test_foo.py to Lib/test/test_foo.py (if not exists)
11+ 2. Run the test with RustPython
12+ 3. Mark failing tests with @unittest.expectedFailure
13+
14+ Manual workflow:
7151. Copy a specific test from the CPython repository to the RustPython repository.
8162. Remove all unexpected failures from the test and skip the tests that hang.
9173. Build RustPython: cargo build --release
1523"""
1624
1725import argparse
18- import ast
19- import itertools
20- import platform
26+ import shutil
2127import sys
2228from pathlib import Path
2329
30+ from lib_updater import apply_patches , PatchSpec , UtMethod
31+
2432
2533def parse_args ():
2634 parser = argparse .ArgumentParser (description = "Fix test." )
27- parser .add_argument ("--path" , type = Path , help = "Path to test file" )
35+ group = parser .add_mutually_exclusive_group (required = True )
36+ group .add_argument ("--path" , type = Path , help = "Path to test file" )
37+ group .add_argument (
38+ "--quick-import" ,
39+ type = Path ,
40+ metavar = "PATH" ,
41+ help = "Import from path containing /Lib/ (e.g., cpython/Lib/test/foo.py)" ,
42+ )
2843 parser .add_argument ("--force" , action = "store_true" , help = "Force modification" )
2944 parser .add_argument (
3045 "--platform" , action = "store_true" , help = "Platform specific failure"
@@ -102,39 +117,16 @@ def path_to_test(path) -> list[str]:
102117 return parts [- 2 :] # Get class name and method name
103118
104119
105- def find_test_lineno (file : str , test : list [str ]) -> tuple [int , int ] | None :
106- """Find the line number and column offset of a test function.
107- Returns (lineno, col_offset) or None if not found.
108- """
109- a = ast .parse (file )
110- for key , node in ast .iter_fields (a ):
111- if key == "body" :
112- for n in node :
113- match n :
114- case ast .ClassDef ():
115- if len (test ) == 2 and test [0 ] == n .name :
116- for fn in n .body :
117- match fn :
118- case ast .FunctionDef () | ast .AsyncFunctionDef ():
119- if fn .name == test [- 1 ]:
120- return (fn .lineno , fn .col_offset )
121- case ast .FunctionDef () | ast .AsyncFunctionDef ():
122- if n .name == test [0 ] and len (test ) == 1 :
123- return (n .lineno , n .col_offset )
124- return None
125-
126-
127- def apply_modifications (file : str , modifications : list [tuple [int , int ]]) -> str :
128- """Apply all modifications in reverse order to avoid line number offset issues."""
129- lines = file .splitlines ()
130- fixture = "@unittest.expectedFailure"
131- # Sort by line number in descending order
132- modifications .sort (key = lambda x : x [0 ], reverse = True )
133- for lineno , col_offset in modifications :
134- indent = " " * col_offset
135- lines .insert (lineno - 1 , indent + fixture )
136- lines .insert (lineno - 1 , indent + "# TODO: RUSTPYTHON" )
137- return "\n " .join (lines )
120+ def build_patches (test_parts_set : set [tuple [str , str ]]) -> dict :
121+ """Convert failing tests to lib_updater patch format."""
122+ patches = {}
123+ for class_name , method_name in test_parts_set :
124+ if class_name not in patches :
125+ patches [class_name ] = {}
126+ patches [class_name ][method_name ] = [
127+ PatchSpec (UtMethod .ExpectedFailure , None , "" )
128+ ]
129+ return patches
138130
139131
140132def run_test (test_name ):
@@ -146,7 +138,7 @@ def run_test(test_name):
146138 import subprocess
147139
148140 result = subprocess .run (
149- [rustpython_location , "-m" , "test" , "-v" , test_name ],
141+ [rustpython_location , "-m" , "test" , "-v" , "-u" , "all" , "--slowest" , test_name ],
150142 capture_output = True ,
151143 text = True ,
152144 )
@@ -155,6 +147,33 @@ def run_test(test_name):
155147
156148if __name__ == "__main__" :
157149 args = parse_args ()
150+
151+ # Handle --quick-import: extract Lib/... path and copy if needed
152+ if args .quick_import is not None :
153+ src_str = str (args .quick_import )
154+ lib_marker = "/Lib/"
155+
156+ if lib_marker not in src_str :
157+ print (f"Error: --quick-import path must contain '/Lib/' (got: { src_str } )" )
158+ sys .exit (1 )
159+
160+ idx = src_str .index (lib_marker )
161+ lib_path = Path (src_str [idx + 1 :]) # Lib/test/foo.py
162+ src_path = args .quick_import
163+
164+ if not src_path .exists ():
165+ print (f"Error: Source file not found: { src_path } " )
166+ sys .exit (1 )
167+
168+ if not lib_path .exists ():
169+ print (f"Copying: { src_path } -> { lib_path } " )
170+ lib_path .parent .mkdir (parents = True , exist_ok = True )
171+ shutil .copy (src_path , lib_path )
172+ else :
173+ print (f"File already exists: { lib_path } " )
174+
175+ args .path = lib_path
176+
158177 test_path = args .path .resolve ()
159178 if not test_path .exists ():
160179 print (f"Error: File not found: { test_path } " )
@@ -167,26 +186,21 @@ def run_test(test_name):
167186 tests = run_test (test_name )
168187 f = test_path .read_text (encoding = "utf-8" )
169188
170- # Collect all modifications first (with deduplication for subtests)
171- modifications = []
189+ # Collect failing tests (with deduplication for subtests)
172190 seen_tests = set () # Track (class_name, method_name) to avoid duplicates
173191 for test in tests .tests :
174192 if test .result == "fail" or test .result == "error" :
175193 test_parts = path_to_test (test .path )
176- test_key = tuple (test_parts )
177- if test_key in seen_tests :
178- continue # Skip duplicate (same test, different subtest)
179- seen_tests .add (test_key )
180- location = find_test_lineno (f , test_parts )
181- if location :
182- print (f"Modifying test: { test .name } at line { location [0 ]} " )
183- modifications .append (location )
184- else :
185- print (f"Warning: Could not find test: { test .name } ({ test_parts } )" )
186-
187- # Apply all modifications in reverse order
188- if modifications :
189- f = apply_modifications (f , modifications )
194+ if len (test_parts ) == 2 :
195+ test_key = tuple (test_parts )
196+ if test_key not in seen_tests :
197+ seen_tests .add (test_key )
198+ print (f"Marking test: { test_parts [0 ]} .{ test_parts [1 ]} " )
199+
200+ # Apply patches using lib_updater
201+ if seen_tests :
202+ patches = build_patches (seen_tests )
203+ f = apply_patches (f , patches )
190204 test_path .write_text (f , encoding = "utf-8" )
191205
192- print (f"Modified { len (modifications )} tests" )
206+ print (f"Modified { len (seen_tests )} tests" )
0 commit comments