Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Show affected tests for given module names
  • Loading branch information
moreal authored and youknowone committed Jan 22, 2026
commit 4154c5e3f1de021af4ece01bde31760506e44524
144 changes: 144 additions & 0 deletions scripts/update_lib/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,3 +504,147 @@ def resolve_all_paths(
result["data"].append(data_path)

return result


def _build_import_graph(lib_prefix: str = "Lib") -> dict[str, set[str]]:
"""Build a graph of module imports from lib_prefix directory.

Args:
lib_prefix: RustPython Lib directory (default: "Lib")

Returns:
Dict mapping module_name -> set of modules it imports
"""
lib_dir = pathlib.Path(lib_prefix)
if not lib_dir.exists():
return {}

import_graph: dict[str, set[str]] = {}

# Scan all .py files in lib_prefix (excluding test/ directory for module imports)
for entry in lib_dir.iterdir():
if entry.name.startswith(("_", ".")):
continue
if entry.name == "test":
continue

module_name = None
if entry.is_file() and entry.suffix == ".py":
module_name = entry.stem
elif entry.is_dir() and (entry / "__init__.py").exists():
module_name = entry.name

if module_name:
# Parse imports from this module
imports = set()
for _, content in read_python_files(entry):
imports.update(parse_lib_imports(content))
# Remove self-imports
imports.discard(module_name)
import_graph[module_name] = imports

return import_graph


def _build_reverse_graph(import_graph: dict[str, set[str]]) -> dict[str, set[str]]:
"""Build reverse dependency graph (who imports this module).

Args:
import_graph: Forward import graph (module -> imports)

Returns:
Reverse graph (module -> imported_by)
"""
reverse_graph: dict[str, set[str]] = {}

for module, imports in import_graph.items():
for imported in imports:
if imported not in reverse_graph:
reverse_graph[imported] = set()
reverse_graph[imported].add(module)

return reverse_graph


@functools.cache
def get_transitive_imports(
module_name: str,
lib_prefix: str = "Lib",
) -> frozenset[str]:
"""Get all modules that transitively depend on module_name.

Args:
module_name: Target module
lib_prefix: RustPython Lib directory (default: "Lib")

Returns:
Frozenset of module names that import module_name (directly or indirectly)
"""
import_graph = _build_import_graph(lib_prefix)
reverse_graph = _build_reverse_graph(import_graph)

# BFS from module_name following reverse edges
visited: set[str] = set()
queue = list(reverse_graph.get(module_name, set()))

while queue:
current = queue.pop(0)
if current in visited:
continue
visited.add(current)
# Add modules that import current module
for importer in reverse_graph.get(current, set()):
if importer not in visited:
queue.append(importer)

return frozenset(visited)


@functools.cache
def find_tests_importing_module(
module_name: str,
lib_prefix: str = "Lib",
include_transitive: bool = True,
) -> frozenset[pathlib.Path]:
"""Find all test files that import the given module (directly or transitively).

Args:
module_name: Module to search for (e.g., "datetime")
lib_prefix: RustPython Lib directory (default: "Lib")
include_transitive: Whether to include transitive dependencies

Returns:
Frozenset of test file paths that depend on this module
"""
lib_dir = pathlib.Path(lib_prefix)
test_dir = lib_dir / "test"

if not test_dir.exists():
return frozenset()

# Build set of modules to search for
target_modules = {module_name}
if include_transitive:
# Add all modules that transitively depend on module_name
target_modules.update(get_transitive_imports(module_name, lib_prefix))

# Excluded test file for this module (test_<module>.py)
excluded_test = f"test_{module_name}.py"

# Scan test directory for files that import any of the target modules
result: set[pathlib.Path] = set()

for test_file in test_dir.glob("*.py"):
if test_file.name == excluded_test:
continue

content = safe_read_text(test_file)
if content is None:
continue

imports = parse_lib_imports(content)
# Check if any target module is imported
if imports & target_modules:
result.add(test_file)

return frozenset(result)
43 changes: 41 additions & 2 deletions scripts/update_lib/show_deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def format_deps(
lib_prefix: str = "Lib",
max_depth: int = 10,
_visited: set[str] | None = None,
show_impact: bool = False,
) -> list[str]:
"""Format all dependency information for a module.

Expand All @@ -154,14 +155,17 @@ def format_deps(
lib_prefix: Local Lib directory prefix
max_depth: Maximum recursion depth
_visited: Shared visited set for deduplication across modules
show_impact: Whether to show reverse dependencies (tests that import this module)

Returns:
List of formatted lines
"""
from update_lib.deps import (
DEPENDENCIES,
find_tests_importing_module,
get_lib_paths,
get_test_paths,
get_transitive_imports,
)

if _visited is None:
Expand Down Expand Up @@ -194,6 +198,33 @@ def format_deps(
)
)

# Show impact (reverse dependencies) if requested
if show_impact:
impacted_tests = find_tests_importing_module(name, lib_prefix)
transitive_importers = get_transitive_imports(name, lib_prefix)

if impacted_tests:
lines.append(f"[+] impact: ({len(impacted_tests)} tests depend on {name})")
# Sort tests and show with dependency info
for test_path in sorted(impacted_tests, key=lambda p: p.name):
# Determine if direct or via which module
test_content = test_path.read_text(errors="ignore")
from update_lib.deps import parse_lib_imports

test_imports = parse_lib_imports(test_content)
if name in test_imports:
lines.append(f" - {test_path.name} (direct)")
else:
# Find which transitive module is imported
via_modules = test_imports & transitive_importers
if via_modules:
via_str = ", ".join(sorted(via_modules))
lines.append(f" - {test_path.name} (via {via_str})")
else:
lines.append(f" - {test_path.name}")
else:
lines.append(f"[+] impact: (no tests depend on {name})")

return lines


Expand All @@ -202,6 +233,7 @@ def show_deps(
cpython_prefix: str = "cpython",
lib_prefix: str = "Lib",
max_depth: int = 10,
show_impact: bool = False,
) -> None:
"""Show all dependency information for modules."""
# Expand "all" to all module names
Expand All @@ -218,7 +250,9 @@ def show_deps(
for i, name in enumerate(expanded_names):
if i > 0:
print() # blank line between modules
for line in format_deps(name, cpython_prefix, lib_prefix, max_depth, visited):
for line in format_deps(
name, cpython_prefix, lib_prefix, max_depth, visited, show_impact
):
print(line)


Expand Down Expand Up @@ -248,11 +282,16 @@ def main(argv: list[str] | None = None) -> int:
default=10,
help="Maximum recursion depth for soft_deps tree (default: 10)",
)
parser.add_argument(
"--impact",
action="store_true",
help="Show tests that import this module (reverse dependencies)",
)

args = parser.parse_args(argv)

try:
show_deps(args.names, args.cpython, args.lib, args.depth)
show_deps(args.names, args.cpython, args.lib, args.depth, args.impact)
return 0
except Exception as e:
print(f"Error: {e}", file=sys.stderr)
Expand Down
130 changes: 130 additions & 0 deletions scripts/update_lib/tests/test_deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
import unittest

from update_lib.deps import (
find_tests_importing_module,
get_data_paths,
get_lib_paths,
get_soft_deps,
get_test_dependencies,
get_test_paths,
get_transitive_imports,
parse_lib_imports,
parse_test_imports,
resolve_all_paths,
Expand Down Expand Up @@ -422,5 +424,133 @@ def test_nested_different(self):
self.assertFalse(_dircmp_is_same(dcmp))


class TestGetTransitiveImports(unittest.TestCase):
"""Tests for get_transitive_imports function."""

def test_direct_dependency(self):
"""A imports B → B's transitive importers include A."""
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir = pathlib.Path(tmpdir)
lib_dir = tmpdir / "Lib"
lib_dir.mkdir()

(lib_dir / "a.py").write_text("import b\n")
(lib_dir / "b.py").write_text("# b module")

get_transitive_imports.cache_clear()
result = get_transitive_imports("b", lib_prefix=str(lib_dir))
self.assertIn("a", result)

def test_chain_dependency(self):
"""A imports B, B imports C → C's transitive importers include A and B."""
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir = pathlib.Path(tmpdir)
lib_dir = tmpdir / "Lib"
lib_dir.mkdir()

(lib_dir / "a.py").write_text("import b\n")
(lib_dir / "b.py").write_text("import c\n")
(lib_dir / "c.py").write_text("# c module")

get_transitive_imports.cache_clear()
result = get_transitive_imports("c", lib_prefix=str(lib_dir))
self.assertIn("a", result)
self.assertIn("b", result)

def test_cycle_handling(self):
"""Handle circular imports without infinite loop."""
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir = pathlib.Path(tmpdir)
lib_dir = tmpdir / "Lib"
lib_dir.mkdir()

(lib_dir / "a.py").write_text("import b\n")
(lib_dir / "b.py").write_text("import a\n") # cycle

get_transitive_imports.cache_clear()
# Should not hang or raise
result = get_transitive_imports("a", lib_prefix=str(lib_dir))
self.assertIn("b", result)


class TestFindTestsImportingModule(unittest.TestCase):
"""Tests for find_tests_importing_module function."""

def test_direct_import(self):
"""Test finding tests that directly import a module."""
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir = pathlib.Path(tmpdir)
lib_dir = tmpdir / "Lib"
test_dir = lib_dir / "test"
test_dir.mkdir(parents=True)

# Create target module
(lib_dir / "bar.py").write_text("# bar module")

# Create test that imports bar
(test_dir / "test_foo.py").write_text("import bar\n")

get_transitive_imports.cache_clear()
find_tests_importing_module.cache_clear()
result = find_tests_importing_module("bar", lib_prefix=str(lib_dir))
self.assertIn(test_dir / "test_foo.py", result)

def test_excludes_test_module_itself(self):
"""Test that test_<module>.py is excluded from results."""
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir = pathlib.Path(tmpdir)
lib_dir = tmpdir / "Lib"
test_dir = lib_dir / "test"
test_dir.mkdir(parents=True)

(lib_dir / "bar.py").write_text("# bar module")
(test_dir / "test_bar.py").write_text("import bar\n")

get_transitive_imports.cache_clear()
find_tests_importing_module.cache_clear()
result = find_tests_importing_module("bar", lib_prefix=str(lib_dir))
# test_bar.py should NOT be in results (it's the primary test)
self.assertNotIn(test_dir / "test_bar.py", result)

def test_transitive_import(self):
"""Test finding tests with transitive (indirect) imports."""
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir = pathlib.Path(tmpdir)
lib_dir = tmpdir / "Lib"
test_dir = lib_dir / "test"
test_dir.mkdir(parents=True)

# bar.py (target module)
(lib_dir / "bar.py").write_text("# bar module")

# baz.py imports bar
(lib_dir / "baz.py").write_text("import bar\n")

# test_foo.py imports baz (not bar directly)
(test_dir / "test_foo.py").write_text("import baz\n")

get_transitive_imports.cache_clear()
find_tests_importing_module.cache_clear()
result = find_tests_importing_module("bar", lib_prefix=str(lib_dir))
# test_foo.py should be found via transitive dependency
self.assertIn(test_dir / "test_foo.py", result)

def test_empty_when_no_importers(self):
"""Test returns empty when no tests import the module."""
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir = pathlib.Path(tmpdir)
lib_dir = tmpdir / "Lib"
test_dir = lib_dir / "test"
test_dir.mkdir(parents=True)

(lib_dir / "bar.py").write_text("# bar module")
(test_dir / "test_unrelated.py").write_text("import os\n")

get_transitive_imports.cache_clear()
find_tests_importing_module.cache_clear()
result = find_tests_importing_module("bar", lib_prefix=str(lib_dir))
self.assertEqual(result, frozenset())


if __name__ == "__main__":
unittest.main()