|
5 | 5 |
|
6 | 6 | from codeclone.cfg import CFG, CFGBuilder |
7 | 7 | from codeclone.cfg_model import CFG as CFGModel |
| 8 | +from codeclone.cfg_model import Block |
8 | 9 | from codeclone.extractor import get_cfg_fingerprint |
9 | 10 | from codeclone.meta_markers import CFG_META_PREFIX |
10 | 11 | from codeclone.normalize import NormalizationConfig |
@@ -44,6 +45,47 @@ def _const_meta_value(stmt: ast.stmt) -> str | None: |
44 | 45 | return stmt.value.id |
45 | 46 |
|
46 | 47 |
|
| 48 | +def _parse_function( |
| 49 | + source: str, *, skip_reason: str | None = None |
| 50 | +) -> ast.FunctionDef | ast.AsyncFunctionDef: |
| 51 | + try: |
| 52 | + module = ast.parse(dedent(source)) |
| 53 | + except SyntaxError: |
| 54 | + if skip_reason: |
| 55 | + pytest.skip(skip_reason) |
| 56 | + raise |
| 57 | + for node in ast.walk(module): |
| 58 | + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): |
| 59 | + return node |
| 60 | + raise AssertionError("Expected at least one function in source") |
| 61 | + |
| 62 | + |
| 63 | +def _cfg_fingerprint( |
| 64 | + source: str, qualname: str, *, skip_reason: str | None = None |
| 65 | +) -> str: |
| 66 | + func = _parse_function(source, skip_reason=skip_reason) |
| 67 | + cfg = NormalizationConfig() |
| 68 | + return get_cfg_fingerprint(func, cfg, qualname) |
| 69 | + |
| 70 | + |
| 71 | +def _assert_fingerprint_diff( |
| 72 | + source_a: str, source_b: str, *, skip_reason: str | None = None |
| 73 | +) -> None: |
| 74 | + fp_a = _cfg_fingerprint(source_a, "m:f", skip_reason=skip_reason) |
| 75 | + fp_b = _cfg_fingerprint(source_b, "m:g", skip_reason=skip_reason) |
| 76 | + assert fp_a != fp_b |
| 77 | + |
| 78 | + |
| 79 | +def _single_return_block(cfg: CFG) -> Block: |
| 80 | + return_blocks = [ |
| 81 | + block |
| 82 | + for block in cfg.blocks |
| 83 | + if any(isinstance(stmt, ast.Return) for stmt in block.statements) |
| 84 | + ] |
| 85 | + assert len(return_blocks) == 1 |
| 86 | + return return_blocks[0] |
| 87 | + |
| 88 | + |
47 | 89 | def test_cfg_if_else() -> None: |
48 | 90 | source = """ |
49 | 91 | def f(a): |
@@ -581,178 +623,137 @@ def f(x): |
581 | 623 | assert "MatchMapping" in patterns_found[1] |
582 | 624 |
|
583 | 625 |
|
584 | | -def test_cfg_match_guard_affects_fingerprint() -> None: |
585 | | - code_with_guard = """ |
| 626 | +@pytest.mark.parametrize( |
| 627 | + ("source_a", "source_b", "skip_reason"), |
| 628 | + [ |
| 629 | + ( |
| 630 | + """ |
586 | 631 | def f(x): |
587 | 632 | match x: |
588 | 633 | case 1 if cond(): |
589 | 634 | return 1 |
590 | 635 | case _: |
591 | 636 | return 2 |
592 | | - """ |
593 | | - code_without_guard = """ |
| 637 | + """, |
| 638 | + """ |
594 | 639 | def f(x): |
595 | 640 | match x: |
596 | 641 | case 1: |
597 | 642 | return 1 |
598 | 643 | case _: |
599 | 644 | return 2 |
600 | | - """ |
601 | | - try: |
602 | | - func_with_guard = ast.parse(dedent(code_with_guard)).body[0] |
603 | | - func_without_guard = ast.parse(dedent(code_without_guard)).body[0] |
604 | | - except SyntaxError: |
605 | | - pytest.skip("Match syntax is unavailable") |
606 | | - |
607 | | - assert isinstance(func_with_guard, (ast.FunctionDef, ast.AsyncFunctionDef)) |
608 | | - assert isinstance(func_without_guard, (ast.FunctionDef, ast.AsyncFunctionDef)) |
609 | | - cfg = NormalizationConfig() |
610 | | - fp_with_guard = get_cfg_fingerprint(func_with_guard, cfg, "m:f") |
611 | | - fp_without_guard = get_cfg_fingerprint(func_without_guard, cfg, "m:g") |
612 | | - assert fp_with_guard != fp_without_guard |
613 | | - |
614 | | - |
615 | | -def test_cfg_match_case_order_affects_fingerprint() -> None: |
616 | | - source_a = """ |
| 645 | + """, |
| 646 | + "Match syntax is unavailable", |
| 647 | + ), |
| 648 | + ( |
| 649 | + """ |
617 | 650 | def f(x): |
618 | 651 | match x: |
619 | 652 | case 1: |
620 | 653 | return 1 |
621 | 654 | case _: |
622 | 655 | return 2 |
623 | | - """ |
624 | | - source_b = """ |
| 656 | + """, |
| 657 | + """ |
625 | 658 | def g(x): |
626 | 659 | match x: |
627 | 660 | case _: |
628 | 661 | return 2 |
629 | 662 | case 1: |
630 | 663 | return 1 |
631 | | - """ |
632 | | - try: |
633 | | - func_a = ast.parse(dedent(source_a)).body[0] |
634 | | - func_b = ast.parse(dedent(source_b)).body[0] |
635 | | - except SyntaxError: |
636 | | - pytest.skip("Match syntax is unavailable") |
637 | | - assert isinstance(func_a, (ast.FunctionDef, ast.AsyncFunctionDef)) |
638 | | - assert isinstance(func_b, (ast.FunctionDef, ast.AsyncFunctionDef)) |
639 | | - cfg = NormalizationConfig() |
640 | | - fp_a = get_cfg_fingerprint(func_a, cfg, "m:f") |
641 | | - fp_b = get_cfg_fingerprint(func_b, cfg, "m:g") |
642 | | - assert fp_a != fp_b |
643 | | - |
644 | | - |
645 | | -def test_cfg_try_handler_order_affects_fingerprint() -> None: |
646 | | - source_a = """ |
| 664 | + """, |
| 665 | + "Match syntax is unavailable", |
| 666 | + ), |
| 667 | + ( |
| 668 | + """ |
647 | 669 | def f(x): |
648 | 670 | try: |
649 | 671 | return risky(x) |
650 | 672 | except ValueError: |
651 | 673 | return 1 |
652 | 674 | except Exception: |
653 | 675 | return 2 |
654 | | - """ |
655 | | - source_b = """ |
| 676 | + """, |
| 677 | + """ |
656 | 678 | def g(x): |
657 | 679 | try: |
658 | 680 | return risky(x) |
659 | 681 | except Exception: |
660 | 682 | return 2 |
661 | 683 | except ValueError: |
662 | 684 | return 1 |
663 | | - """ |
664 | | - func_a = ast.parse(dedent(source_a)).body[0] |
665 | | - func_b = ast.parse(dedent(source_b)).body[0] |
666 | | - assert isinstance(func_a, (ast.FunctionDef, ast.AsyncFunctionDef)) |
667 | | - assert isinstance(func_b, (ast.FunctionDef, ast.AsyncFunctionDef)) |
668 | | - cfg = NormalizationConfig() |
669 | | - fp_a = get_cfg_fingerprint(func_a, cfg, "m:f") |
670 | | - fp_b = get_cfg_fingerprint(func_b, cfg, "m:g") |
671 | | - assert fp_a != fp_b |
672 | | - |
673 | | - |
674 | | -def test_cfg_for_else_affects_fingerprint() -> None: |
675 | | - with_else = """ |
| 685 | + """, |
| 686 | + None, |
| 687 | + ), |
| 688 | + ( |
| 689 | + """ |
676 | 690 | def f(xs): |
677 | 691 | for x in xs: |
678 | 692 | pass |
679 | 693 | else: |
680 | 694 | y = 1 |
681 | | - """ |
682 | | - without_else = """ |
| 695 | + """, |
| 696 | + """ |
683 | 697 | def f(xs): |
684 | 698 | for x in xs: |
685 | 699 | pass |
686 | | - """ |
687 | | - func_with_else = ast.parse(dedent(with_else)).body[0] |
688 | | - func_without_else = ast.parse(dedent(without_else)).body[0] |
689 | | - assert isinstance(func_with_else, (ast.FunctionDef, ast.AsyncFunctionDef)) |
690 | | - assert isinstance(func_without_else, (ast.FunctionDef, ast.AsyncFunctionDef)) |
691 | | - cfg = NormalizationConfig() |
692 | | - fp_with_else = get_cfg_fingerprint(func_with_else, cfg, "m:f") |
693 | | - fp_without_else = get_cfg_fingerprint(func_without_else, cfg, "m:g") |
694 | | - assert fp_with_else != fp_without_else |
695 | | - |
696 | | - |
697 | | -def test_cfg_while_else_affects_fingerprint() -> None: |
698 | | - with_else = """ |
| 700 | + """, |
| 701 | + None, |
| 702 | + ), |
| 703 | + ( |
| 704 | + """ |
699 | 705 | def f(flag): |
700 | 706 | while flag: |
701 | 707 | flag = False |
702 | 708 | else: |
703 | 709 | x = 1 |
704 | | - """ |
705 | | - without_else = """ |
| 710 | + """, |
| 711 | + """ |
706 | 712 | def f(flag): |
707 | 713 | while flag: |
708 | 714 | flag = False |
709 | | - """ |
710 | | - func_with_else = ast.parse(dedent(with_else)).body[0] |
711 | | - func_without_else = ast.parse(dedent(without_else)).body[0] |
712 | | - assert isinstance(func_with_else, (ast.FunctionDef, ast.AsyncFunctionDef)) |
713 | | - assert isinstance(func_without_else, (ast.FunctionDef, ast.AsyncFunctionDef)) |
714 | | - cfg = NormalizationConfig() |
715 | | - fp_with_else = get_cfg_fingerprint(func_with_else, cfg, "m:f") |
716 | | - fp_without_else = get_cfg_fingerprint(func_without_else, cfg, "m:g") |
717 | | - assert fp_with_else != fp_without_else |
718 | | - |
719 | | - |
720 | | -def test_cfg_break_terminates_block() -> None: |
721 | | - source = """ |
722 | | - def f(xs): |
723 | | - for x in xs: |
724 | | - break |
725 | | - y = 1 |
726 | | - """ |
727 | | - cfg = build_cfg_from_source(source) |
728 | | - break_blocks = [ |
729 | | - block |
730 | | - for block in cfg.blocks |
731 | | - if any(isinstance(stmt, ast.Break) for stmt in block.statements) |
732 | | - ] |
733 | | - assert len(break_blocks) == 1 |
734 | | - break_block = break_blocks[0] |
735 | | - assert break_block.is_terminated is True |
736 | | - assert all(not isinstance(stmt, ast.Assign) for stmt in break_block.statements) |
737 | | - |
738 | | - |
739 | | -def test_cfg_continue_terminates_block() -> None: |
740 | | - source = """ |
| 715 | + """, |
| 716 | + None, |
| 717 | + ), |
| 718 | + ], |
| 719 | + ids=[ |
| 720 | + "match_guard", |
| 721 | + "match_case_order", |
| 722 | + "try_handler_order", |
| 723 | + "for_else", |
| 724 | + "while_else", |
| 725 | + ], |
| 726 | +) |
| 727 | +def test_cfg_fingerprint_variants( |
| 728 | + source_a: str, source_b: str, skip_reason: str | None |
| 729 | +) -> None: |
| 730 | + _assert_fingerprint_diff(source_a, source_b, skip_reason=skip_reason) |
| 731 | + |
| 732 | + |
| 733 | +@pytest.mark.parametrize( |
| 734 | + ("keyword", "stmt_type"), |
| 735 | + [("break", ast.Break), ("continue", ast.Continue)], |
| 736 | + ids=["break", "continue"], |
| 737 | +) |
| 738 | +def test_cfg_loop_control_terminates_block( |
| 739 | + keyword: str, stmt_type: type[ast.stmt] |
| 740 | +) -> None: |
| 741 | + source = f""" |
741 | 742 | def f(xs): |
742 | 743 | for x in xs: |
743 | | - continue |
| 744 | + {keyword} |
744 | 745 | y = 1 |
745 | 746 | """ |
746 | 747 | cfg = build_cfg_from_source(source) |
747 | | - continue_blocks = [ |
| 748 | + control_blocks = [ |
748 | 749 | block |
749 | 750 | for block in cfg.blocks |
750 | | - if any(isinstance(stmt, ast.Continue) for stmt in block.statements) |
| 751 | + if any(isinstance(stmt, stmt_type) for stmt in block.statements) |
751 | 752 | ] |
752 | | - assert len(continue_blocks) == 1 |
753 | | - continue_block = continue_blocks[0] |
754 | | - assert continue_block.is_terminated is True |
755 | | - assert all(not isinstance(stmt, ast.Assign) for stmt in continue_block.statements) |
| 753 | + assert len(control_blocks) == 1 |
| 754 | + control_block = control_blocks[0] |
| 755 | + assert control_block.is_terminated is True |
| 756 | + assert all(not isinstance(stmt, ast.Assign) for stmt in control_block.statements) |
756 | 757 |
|
757 | 758 |
|
758 | 759 | def test_cfg_break_skips_for_else_block() -> None: |
@@ -783,42 +784,31 @@ def f(xs): |
783 | 784 | assert else_blocks[0] not in break_blocks[0].successors |
784 | 785 |
|
785 | 786 |
|
786 | | -def test_cfg_while_else_terminated_branch() -> None: |
787 | | - source = """ |
| 787 | +@pytest.mark.parametrize( |
| 788 | + "source", |
| 789 | + [ |
| 790 | + """ |
788 | 791 | def f(flag): |
789 | 792 | while flag: |
790 | 793 | flag = False |
791 | 794 | else: |
792 | 795 | return 1 |
793 | | - """ |
794 | | - cfg = build_cfg_from_source(source) |
795 | | - return_blocks = [ |
796 | | - block |
797 | | - for block in cfg.blocks |
798 | | - if any(isinstance(stmt, ast.Return) for stmt in block.statements) |
799 | | - ] |
800 | | - assert len(return_blocks) == 1 |
801 | | - assert return_blocks[0].is_terminated is True |
802 | | - assert cfg.exit in return_blocks[0].successors |
803 | | - |
804 | | - |
805 | | -def test_cfg_for_else_terminated_branch() -> None: |
806 | | - source = """ |
| 796 | + """, |
| 797 | + """ |
807 | 798 | def f(xs): |
808 | 799 | for x in xs: |
809 | 800 | pass |
810 | 801 | else: |
811 | 802 | return 1 |
812 | | - """ |
| 803 | + """, |
| 804 | + ], |
| 805 | + ids=["while_else", "for_else"], |
| 806 | +) |
| 807 | +def test_cfg_loop_else_terminated_branch(source: str) -> None: |
813 | 808 | cfg = build_cfg_from_source(source) |
814 | | - return_blocks = [ |
815 | | - block |
816 | | - for block in cfg.blocks |
817 | | - if any(isinstance(stmt, ast.Return) for stmt in block.statements) |
818 | | - ] |
819 | | - assert len(return_blocks) == 1 |
820 | | - assert return_blocks[0].is_terminated is True |
821 | | - assert cfg.exit in return_blocks[0].successors |
| 809 | + return_block = _single_return_block(cfg) |
| 810 | + assert return_block.is_terminated is True |
| 811 | + assert cfg.exit in return_block.successors |
822 | 812 |
|
823 | 813 |
|
824 | 814 | def test_cfg_break_outside_loop_falls_back_to_exit() -> None: |
|
0 commit comments