diff --git a/.pre-commit-hooks.yaml b/.pre-commit-hooks.yaml index 15456a6..754c8ab 100644 --- a/.pre-commit-hooks.yaml +++ b/.pre-commit-hooks.yaml @@ -27,3 +27,11 @@ args: [--author-email] pass_filenames: false language: python +- id: check-no-force-push + name: check no force push + description: prevents force pushes to remote branches + entry: commit-check + args: [--no-force-push] + stages: [pre-push] + pass_filenames: false + language: python diff --git a/README.rst b/README.rst index e93e33e..2328413 100644 --- a/README.rst +++ b/README.rst @@ -187,6 +187,32 @@ For one-off checks or CI/CD pipelines, you can configure via CLI arguments or en See the `Configuration documentation `_ for all available options. +Check Push Safety +~~~~~~~~~~~~~~~~~ + +Use ``--no-force-push`` in a ``pre-push`` hook to inspect the ref updates Git +provides on stdin, or run it directly to compare ``HEAD`` with the current +branch's configured upstream: + +.. code-block:: bash + + # Standalone preflight check against the current branch's upstream + commit-check --no-force-push + +.. code-block:: yaml + + # In pre-commit hooks (.pre-commit-config.yaml) + repos: + - repo: https://github.com/commit-check/commit-check + rev: v2.6.0 + hooks: + - id: check-no-force-push + stages: [pre-push] + +Piping ``git push`` into ``commit-check`` is not a prevention mechanism. The +push has already been started, and standard ``git push`` output does not carry +the pre-push ref metadata that ``commit-check`` uses. + AI-Native Usage --------------- diff --git a/commit_check/__init__.py b/commit_check/__init__.py index 1e16960..aeecd10 100644 --- a/commit_check/__init__.py +++ b/commit_check/__init__.py @@ -44,6 +44,11 @@ # Additional allowed branch names (e.g., develop, staging) DEFAULT_BRANCH_NAMES: list[str] = [] +# Push-related defaults +DEFAULT_PUSH_RULES = { + "allow_force_push": True, +} + # Handle different default values for different rules DEFAULT_BOOLEAN_RULES = { "subject_capitalized": False, diff --git a/commit_check/api.py b/commit_check/api.py index b579b9b..eacf28d 100644 --- a/commit_check/api.py +++ b/commit_check/api.py @@ -225,6 +225,38 @@ def validate_author( return _run_checks(checks, context, cfg) +def validate_push( + push_refs: Optional[str] = None, + *, + config: Optional[Dict[str, Any]] = None, +) -> Dict[str, Any]: + """Validate that a push is not a force push. + + :param push_refs: Push ref information in the format produced by git's + pre-push hook: `` ``, + one entry per line. If *None*, the check is skipped (returns pass). + :param config: Optional configuration override dict. The push check is + always enabled when calling this function; force pushes detected here + will always return ``"fail"``. + :returns: A dict with ``"status"`` (``"pass"``/``"fail"``) and ``"checks"``. + + Example:: + + >>> from commit_check.api import validate_push + >>> zero = "0000000000000000000000000000000000000000" + >>> result = validate_push(f"refs/heads/main abc123 refs/heads/main {zero}") + >>> result["status"] + 'pass' + """ + cfg = _merge_config(config) + # Enable force push blocking in the config so the rule is built + if "push" not in cfg: + cfg["push"] = {} + cfg["push"]["allow_force_push"] = False + context = ValidationContext(stdin_text=push_refs, config=cfg) + return _run_checks(["no_force_push"], context, cfg) + + def validate_all( message: Optional[str] = None, branch: Optional[str] = None, diff --git a/commit_check/config_merger.py b/commit_check/config_merger.py index 0cbeff1..d8b4abd 100644 --- a/commit_check/config_merger.py +++ b/commit_check/config_merger.py @@ -11,6 +11,7 @@ DEFAULT_BRANCH_TYPES, DEFAULT_BRANCH_NAMES, DEFAULT_BOOLEAN_RULES, + DEFAULT_PUSH_RULES, ) @@ -81,6 +82,9 @@ def get_default_config() -> Dict[str, Any]: "require_rebase_target": "", "ignore_authors": [], }, + "push": { + "allow_force_push": DEFAULT_PUSH_RULES["allow_force_push"], + }, } @@ -119,6 +123,8 @@ class ConfigMerger: "CCHK_ALLOW_BRANCH_NAMES": ("branch", "allow_branch_names", parse_list), "CCHK_REQUIRE_REBASE_TARGET": ("branch", "require_rebase_target", str), "CCHK_BRANCH_IGNORE_AUTHORS": ("branch", "ignore_authors", parse_list), + # Push section + "CCHK_ALLOW_FORCE_PUSH": ("push", "allow_force_push", parse_bool), } # Mapping of CLI argument names to config keys @@ -144,12 +150,14 @@ class ConfigMerger: "allow_branch_names": ("branch", "allow_branch_names"), "require_rebase_target": ("branch", "require_rebase_target"), "branch_ignore_authors": ("branch", "ignore_authors"), + # Push section + "allow_force_push": ("push", "allow_force_push"), } @staticmethod def parse_env_vars() -> Dict[str, Any]: """Parse environment variables with CCHK_ prefix into config dict.""" - config: Dict[str, Any] = {"commit": {}, "branch": {}} + config: Dict[str, Any] = {"commit": {}, "branch": {}, "push": {}} for env_var, (section, key, parser) in ConfigMerger.ENV_VAR_MAPPING.items(): value = os.environ.get(env_var) @@ -168,7 +176,7 @@ def parse_env_vars() -> Dict[str, Any]: @staticmethod def parse_cli_args(args: argparse.Namespace) -> Dict[str, Any]: """Parse CLI arguments into config dict.""" - config: Dict[str, Any] = {"commit": {}, "branch": {}} + config: Dict[str, Any] = {"commit": {}, "branch": {}, "push": {}} for arg_name, (section, key) in ConfigMerger.CLI_ARG_MAPPING.items(): if hasattr(args, arg_name): diff --git a/commit_check/engine.py b/commit_check/engine.py index c6f4aaa..361d330 100644 --- a/commit_check/engine.py +++ b/commit_check/engine.py @@ -8,9 +8,12 @@ from commit_check.rule_builder import ValidationRule from commit_check.util import ( + fetch_upstream_ref, get_commit_info, get_git_config_value, get_branch_name, + get_upstream_branch, + get_upstream_remote_sha, has_commits, git_merge_base, ) @@ -31,6 +34,7 @@ class ValidationContext: stdin_text: Optional[str] = None commit_file: Optional[str] = None config: Dict = field(default_factory=dict) + push_upstream_fallback: bool = False @dataclass @@ -601,6 +605,85 @@ def _get_commit_message(self, context: ValidationContext) -> str: return f"{subject}\n\n{body}".strip() +class ForcePushValidator(BaseValidator): + """Validates that no force push is being performed. + + Reads pushed ref information from stdin (provided by git's pre-push hook) + in the format:: + + + + A force push is detected when the remote SHA is not an ancestor of the + local SHA, meaning local history would overwrite the remote. + """ + + ZERO_SHA = "0000000000000000000000000000000000000000" + + def validate(self, context: ValidationContext) -> ValidationResult: + if not context.stdin_text: + if context.push_upstream_fallback: + return self._check_current_branch_against_upstream() + return ValidationResult.PASS + + for line in context.stdin_text.splitlines(): + result = self._check_push_line(line.strip()) + if result == ValidationResult.FAIL: + return ValidationResult.FAIL + + return ValidationResult.PASS + + def _check_current_branch_against_upstream(self) -> ValidationResult: + """Check whether pushing HEAD to its upstream would require force.""" + upstream_ref = get_upstream_branch() + if not upstream_ref: + return ValidationResult.PASS + + target_ref = get_upstream_remote_sha(upstream_ref) or upstream_ref + returncode = git_merge_base(target_ref, "HEAD") + if ( + returncode == 128 + and target_ref != upstream_ref + and fetch_upstream_ref(upstream_ref) + ): + returncode = git_merge_base(target_ref, "HEAD") + if returncode == 1: + self._print_failure(f"{get_branch_name()} -> {upstream_ref}") + return ValidationResult.FAIL + + return ValidationResult.PASS + + def _check_push_line(self, line: str) -> ValidationResult: + """Check a single pushed ref line for force push.""" + if not line: + return ValidationResult.PASS + + parts = line.split() + if len(parts) < 4: + return ValidationResult.PASS + + local_ref, local_sha, remote_ref, remote_sha = ( + parts[0], + parts[1], + parts[2], + parts[3], + ) + + # Zero SHA for remote means a new branch push (not a force push) + if remote_sha == self.ZERO_SHA: + return ValidationResult.PASS + + # Check if the remote SHA is an ancestor of the local SHA. + # returncode 0 → remote is ancestor of local (fast-forward push, OK) + # returncode 1 → not an ancestor (force push detected) + # returncode 128 → git error / SHA unknown (cannot determine; allow) + returncode = git_merge_base(remote_sha, local_sha) + if returncode == 1: + self._print_failure(f"{local_ref} -> {remote_ref}") + return ValidationResult.FAIL + + return ValidationResult.PASS + + class ValidationEngine: """Main validation engine that orchestrates all validations.""" @@ -622,6 +705,7 @@ class ValidationEngine: "allow_fixup_commits": CommitTypeValidator, "allow_wip_commits": CommitTypeValidator, "ignore_authors": CommitTypeValidator, + "no_force_push": ForcePushValidator, } def __init__(self, rules: List[ValidationRule]): diff --git a/commit_check/main.py b/commit_check/main.py index a030996..5b32772 100644 --- a/commit_check/main.py +++ b/commit_check/main.py @@ -104,6 +104,14 @@ def _get_parser() -> argparse.ArgumentParser: required=False, ) + check_group.add_argument( + "-p", + "--no-force-push", + help="check that no force push is being performed (uses pre-push hook stdin when available, otherwise checks the current branch against its upstream)", + action="store_true", + required=False, + ) + check_group.add_argument( "--format", choices=["text", "json"], @@ -328,6 +336,11 @@ def main() -> int: # Load and merge configuration from all sources: CLI > Env > TOML > Defaults config_data = ConfigMerger.from_all_sources(args, args.config) + # When --no-force-push is explicitly passed, override allow_force_push to + # False so the rule is built even if the TOML config defaults to True. + if args.no_force_push: + config_data.setdefault("push", {})["allow_force_push"] = False + # Build validation rules from config rule_builder = RuleBuilder(config_data) all_rules = rule_builder.build_all_rules() @@ -366,6 +379,8 @@ def main() -> int: requested_checks.append("author_name") if args.author_email: requested_checks.append("author_email") + if args.no_force_push: + requested_checks.append("no_force_push") # If no specific checks requested, show help if not requested_checks: @@ -392,17 +407,20 @@ def main() -> int: if not stdin_content: # No stdin and no file - let validators get data from git themselves stdin_content = None - elif not any([args.branch, args.author_name, args.author_email]): + elif not any( + [args.branch, args.author_name, args.author_email, args.no_force_push] + ): # If no specific validation type is requested, don't read stdin pass else: - # For non-message validations (branch, author), check for stdin input + # For non-message validations (branch, author, push), check for stdin input stdin_content = stdin_reader.read_piped_input() context = ValidationContext( stdin_text=stdin_content, commit_file=commit_file_path, config=config_data, + push_upstream_fallback=args.no_force_push and stdin_content is None, ) # Run validation – choose output mode based on --format diff --git a/commit_check/rule_builder.py b/commit_check/rule_builder.py index 8850517..73603df 100644 --- a/commit_check/rule_builder.py +++ b/commit_check/rule_builder.py @@ -2,12 +2,18 @@ from typing import Dict, Any, List, Optional from dataclasses import dataclass -from commit_check.rules_catalog import COMMIT_RULES, BRANCH_RULES, RuleCatalogEntry +from commit_check.rules_catalog import ( + COMMIT_RULES, + BRANCH_RULES, + PUSH_RULES, + RuleCatalogEntry, +) from commit_check import ( DEFAULT_COMMIT_TYPES, DEFAULT_BRANCH_TYPES, DEFAULT_BRANCH_NAMES, DEFAULT_BOOLEAN_RULES, + DEFAULT_PUSH_RULES, ) @@ -48,12 +54,14 @@ def __init__(self, config: Dict[str, Any]): self.config = config self.commit_config = config.get("commit", {}) self.branch_config = config.get("branch", {}) + self.push_config = config.get("push", {}) def build_all_rules(self) -> List[ValidationRule]: """Build all validation rules from config.""" rules = [] rules.extend(self._build_commit_rules()) rules.extend(self._build_branch_rules()) + rules.extend(self._build_push_rules()) return rules def _build_commit_rules(self) -> List[ValidationRule]: @@ -78,6 +86,41 @@ def _build_branch_rules(self) -> List[ValidationRule]: return rules + def _build_push_rules(self) -> List[ValidationRule]: + """Build push-related validation rules.""" + rules = [] + + for catalog_entry in PUSH_RULES: + rule = self._build_push_rule(catalog_entry) + if rule: + rules.append(rule) + + return rules + + def _build_push_rule( + self, catalog_entry: RuleCatalogEntry + ) -> Optional[ValidationRule]: + """Build a single push validation rule from catalog entry and config.""" + check = catalog_entry.check + + if check == "no_force_push": + allow = self.push_config.get( + "allow_force_push", DEFAULT_PUSH_RULES["allow_force_push"] + ) + # When allow_force_push is True (default), force pushes are permitted + # so no blocking rule is needed. Only build the rule when it is + # False, i.e. the user has explicitly opted in to blocking. + if allow: + return None + return ValidationRule( + check=catalog_entry.check, + error=catalog_entry.error, + suggest=catalog_entry.suggest, + value=False, + ) + + return None + def _build_single_rule( self, catalog_entry: RuleCatalogEntry, section_config: Dict[str, Any] ) -> Optional[ValidationRule]: diff --git a/commit_check/rules_catalog.py b/commit_check/rules_catalog.py index 1c80682..af8eb79 100644 --- a/commit_check/rules_catalog.py +++ b/commit_check/rules_catalog.py @@ -106,6 +106,16 @@ class RuleCatalogEntry: ), ] +# Push rules +PUSH_RULES = [ + RuleCatalogEntry( + check="no_force_push", + regex=None, + error="Force push is not allowed", + suggest="Use a normal push instead of --force or --force-with-lease", + ), +] + # Branch rules BRANCH_RULES = [ RuleCatalogEntry( diff --git a/commit_check/util.py b/commit_check/util.py index 3363494..136d4e8 100644 --- a/commit_check/util.py +++ b/commit_check/util.py @@ -65,6 +65,54 @@ def get_branch_name() -> str: return branch_name.strip() +def get_upstream_branch() -> str: + """Return the configured upstream ref for the current branch.""" + result = subprocess.run( + ["git", "rev-parse", "--abbrev-ref", "--symbolic-full-name", "@{upstream}"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + encoding="utf-8", + ) + if result.returncode == 0 and result.stdout: + return result.stdout.strip() + return "" + + +def get_upstream_remote_sha(upstream_ref: str) -> str: + """Return the current remote SHA for an upstream ref when available.""" + parts = upstream_ref.split("/", 1) + if len(parts) != 2: + return "" + + remote_name, branch_name = parts + result = subprocess.run( + ["git", "ls-remote", "--exit-code", remote_name, f"refs/heads/{branch_name}"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + encoding="utf-8", + ) + if result.returncode != 0 or not result.stdout: + return "" + + return result.stdout.split()[0].strip() + + +def fetch_upstream_ref(upstream_ref: str) -> bool: + """Fetch an upstream branch so its tip commit is available locally.""" + parts = upstream_ref.split("/", 1) + if len(parts) != 2: + return False + + remote_name, branch_name = parts + result = subprocess.run( + ["git", "fetch", "--quiet", "--no-tags", remote_name, branch_name], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + encoding="utf-8", + ) + return result.returncode == 0 + + def has_commits() -> bool: """Check if there are any commits in the current branch. :returns: `True` if there are commits, `False` otherwise. diff --git a/docs/example.rst b/docs/example.rst index a064199..91b3bd6 100644 --- a/docs/example.rst +++ b/docs/example.rst @@ -153,6 +153,27 @@ Branch Validation Examples # - hotfix/security-patch # - release/v1.2.0 +Push Validation Examples +~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + # Check whether pushing HEAD to its configured upstream would require force + commit-check --no-force-push + +.. code-block:: yaml + + # Configure the dedicated pre-push hook + - repo: https://github.com/commit-check/commit-check + rev: the tag or revision + hooks: + - id: check-no-force-push + stages: [pre-push] + +``git push | commit-check --no-force-push`` is not a prevention mechanism. The +push has already started, and normal ``git push`` output does not include the +pre-push ref lines that Git provides to hooks. + Author Validation Examples ~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/tests/engine_test.py b/tests/engine_test.py index 9e6ce21..b4e0423 100644 --- a/tests/engine_test.py +++ b/tests/engine_test.py @@ -36,10 +36,13 @@ class TestValidationContext: def test_validation_context_creation(self): """Test ValidationContext creation and properties.""" context = ValidationContext( - stdin_text="test message", commit_file="/path/to/commit" + stdin_text="test message", + commit_file="/path/to/commit", + push_upstream_fallback=True, ) assert context.stdin_text == "test message" assert context.commit_file == "/path/to/commit" + assert context.push_upstream_fallback is True @pytest.mark.benchmark def test_validation_context_defaults(self): @@ -47,6 +50,7 @@ def test_validation_context_defaults(self): context = ValidationContext() assert context.stdin_text is None assert context.commit_file is None + assert context.push_upstream_fallback is False class TestBaseValidator: @@ -1102,3 +1106,259 @@ def test_author_email_uses_git_config_when_available(self): ): result = validator.validate(context) assert result == ValidationResult.PASS + + +class TestForcePushValidator: + """Tests for the ForcePushValidator class.""" + + def _make_rule(self): + return ValidationRule( + check="no_force_push", + error="Force push is not allowed", + suggest="Use a normal push instead of --force or --force-with-lease", + value=False, + ) + + @pytest.mark.benchmark + def test_no_stdin_skips_validation(self): + """Validator passes when no stdin is provided (not a pre-push context).""" + from commit_check.engine import ForcePushValidator + + rule = self._make_rule() + validator = ForcePushValidator(rule) + context = ValidationContext() # stdin_text=None + + result = validator.validate(context) + assert result == ValidationResult.PASS + + @pytest.mark.benchmark + def test_no_stdin_with_upstream_fallback_passes_without_upstream(self): + """Standalone mode passes when the current branch has no upstream.""" + from commit_check.engine import ForcePushValidator + + rule = self._make_rule() + validator = ForcePushValidator(rule) + context = ValidationContext(push_upstream_fallback=True) + + with patch("commit_check.engine.get_upstream_branch", return_value=""): + result = validator.validate(context) + + assert result == ValidationResult.PASS + + @pytest.mark.benchmark + def test_no_stdin_with_upstream_fallback_passes_fast_forward(self): + """Standalone mode passes when upstream is an ancestor of HEAD.""" + from commit_check.engine import ForcePushValidator + + rule = self._make_rule() + validator = ForcePushValidator(rule) + context = ValidationContext(push_upstream_fallback=True) + + with patch( + "commit_check.engine.get_upstream_branch", return_value="origin/main" + ): + with patch( + "commit_check.engine.get_upstream_remote_sha", return_value="abc123" + ): + with patch("commit_check.engine.git_merge_base", return_value=0): + result = validator.validate(context) + + assert result == ValidationResult.PASS + + @pytest.mark.benchmark + def test_no_stdin_with_upstream_fallback_uses_tracking_ref_when_remote_sha_missing( + self, + ): + """Standalone mode falls back to the local tracking ref if lookup fails.""" + from commit_check.engine import ForcePushValidator + + rule = self._make_rule() + validator = ForcePushValidator(rule) + context = ValidationContext(push_upstream_fallback=True) + + with patch( + "commit_check.engine.get_upstream_branch", return_value="origin/main" + ): + with patch("commit_check.engine.get_upstream_remote_sha", return_value=""): + with patch( + "commit_check.engine.git_merge_base", return_value=0 + ) as mock_merge: + result = validator.validate(context) + + mock_merge.assert_called_once_with("origin/main", "HEAD") + assert result == ValidationResult.PASS + + @pytest.mark.benchmark + def test_no_stdin_with_upstream_fallback_blocks_force_push(self): + """Standalone mode fails when pushing HEAD to upstream requires force.""" + from commit_check.engine import ForcePushValidator + + rule = self._make_rule() + validator = ForcePushValidator(rule) + context = ValidationContext(push_upstream_fallback=True) + + with patch( + "commit_check.engine.get_upstream_branch", return_value="origin/main" + ): + with patch( + "commit_check.engine.get_upstream_remote_sha", return_value="deadbeef" + ): + with patch("commit_check.engine.get_branch_name", return_value="main"): + with patch( + "commit_check.engine.git_merge_base", return_value=1 + ) as mock_merge: + with patch("commit_check.util._print_failure"): + result = validator.validate(context) + + mock_merge.assert_called_once_with("deadbeef", "HEAD") + assert result == ValidationResult.FAIL + + @pytest.mark.benchmark + def test_no_stdin_with_upstream_fallback_fetches_remote_commit_when_needed(self): + """Standalone mode fetches the upstream commit if the SHA is not local yet.""" + from commit_check.engine import ForcePushValidator + + rule = self._make_rule() + validator = ForcePushValidator(rule) + context = ValidationContext(push_upstream_fallback=True) + + with patch( + "commit_check.engine.get_upstream_branch", return_value="origin/main" + ): + with patch( + "commit_check.engine.get_upstream_remote_sha", return_value="deadbeef" + ): + with patch("commit_check.engine.get_branch_name", return_value="main"): + with patch( + "commit_check.engine.git_merge_base", side_effect=[128, 1] + ) as mock_merge: + with patch( + "commit_check.engine.fetch_upstream_ref", return_value=True + ) as mock_fetch: + with patch("commit_check.util._print_failure"): + result = validator.validate(context) + + mock_fetch.assert_called_once_with("origin/main") + assert mock_merge.call_count == 2 + assert result == ValidationResult.FAIL + + @pytest.mark.benchmark + def test_new_branch_push_is_allowed(self): + """A push to a new (non-existent) remote branch is not a force push.""" + from commit_check.engine import ForcePushValidator + + rule = self._make_rule() + validator = ForcePushValidator(rule) + zero_sha = "0000000000000000000000000000000000000000" + push_info = f"refs/heads/feature/new abc123 refs/heads/feature/new {zero_sha}" + context = ValidationContext(stdin_text=push_info) + + result = validator.validate(context) + assert result == ValidationResult.PASS + + @pytest.mark.benchmark + def test_fast_forward_push_is_allowed(self): + """A normal fast-forward push (remote is ancestor of local) is allowed.""" + from commit_check.engine import ForcePushValidator + + rule = self._make_rule() + validator = ForcePushValidator(rule) + push_info = "refs/heads/main abc123 refs/heads/main def456" + context = ValidationContext(stdin_text=push_info) + + with patch("commit_check.engine.git_merge_base", return_value=0): + result = validator.validate(context) + + assert result == ValidationResult.PASS + + @pytest.mark.benchmark + def test_force_push_is_blocked(self): + """A force push (remote is NOT ancestor of local) is blocked.""" + from commit_check.engine import ForcePushValidator + + rule = self._make_rule() + validator = ForcePushValidator(rule) + push_info = "refs/heads/main abc123 refs/heads/main def456" + context = ValidationContext(stdin_text=push_info) + + with patch("commit_check.engine.git_merge_base", return_value=1): + with patch("commit_check.util._print_failure"): + result = validator.validate(context) + + assert result == ValidationResult.FAIL + + @pytest.mark.benchmark + def test_git_error_allows_push(self): + """When git cannot determine ancestry (exit 128), push is allowed.""" + from commit_check.engine import ForcePushValidator + + rule = self._make_rule() + validator = ForcePushValidator(rule) + push_info = "refs/heads/main abc123 refs/heads/main def456" + context = ValidationContext(stdin_text=push_info) + + with patch("commit_check.engine.git_merge_base", return_value=128): + result = validator.validate(context) + + assert result == ValidationResult.PASS + + @pytest.mark.benchmark + def test_empty_lines_in_stdin_are_skipped(self): + """Empty lines in push info do not cause errors.""" + from commit_check.engine import ForcePushValidator + + rule = self._make_rule() + validator = ForcePushValidator(rule) + # stdin has blank lines mixed in + push_info = "\n\nrefs/heads/main abc123 refs/heads/main def456\n\n" + context = ValidationContext(stdin_text=push_info) + + with patch("commit_check.engine.git_merge_base", return_value=0): + result = validator.validate(context) + + assert result == ValidationResult.PASS + + @pytest.mark.benchmark + def test_malformed_push_line_is_skipped(self): + """Lines that do not have 4 fields are silently skipped.""" + from commit_check.engine import ForcePushValidator + + rule = self._make_rule() + validator = ForcePushValidator(rule) + push_info = "only two fields" + context = ValidationContext(stdin_text=push_info) + + result = validator.validate(context) + assert result == ValidationResult.PASS + + @pytest.mark.benchmark + def test_multiple_refs_one_force_push_blocks(self): + """If any pushed ref is a force push, the whole check fails.""" + from commit_check.engine import ForcePushValidator + + rule = self._make_rule() + validator = ForcePushValidator(rule) + zero_sha = "0000000000000000000000000000000000000000" + push_info = ( + f"refs/heads/feature/ok abc1 refs/heads/feature/ok {zero_sha}\n" + "refs/heads/main abc2 refs/heads/main def2" + ) + context = ValidationContext(stdin_text=push_info) + + # First line is a new branch (allowed); second is a force push + def side_effect(remote_sha, local_sha): + return 1 # force push + + with patch("commit_check.engine.git_merge_base", side_effect=side_effect): + with patch("commit_check.util._print_failure"): + result = validator.validate(context) + + assert result == ValidationResult.FAIL + + @pytest.mark.benchmark + def test_validation_engine_includes_force_push_validator(self): + """ValidationEngine maps 'no_force_push' to ForcePushValidator.""" + from commit_check.engine import ForcePushValidator, ValidationEngine + + assert "no_force_push" in ValidationEngine.VALIDATOR_MAP + assert ValidationEngine.VALIDATOR_MAP["no_force_push"] is ForcePushValidator diff --git a/tests/main_test.py b/tests/main_test.py index aa16443..f90bad1 100644 --- a/tests/main_test.py +++ b/tests/main_test.py @@ -661,3 +661,87 @@ def test_json_format_exit_code_matches_status(self, mocker, capsys): out, _ = capsys.readouterr() assert rc_fail == 1 assert json.loads(out)["status"] == "fail" + + +class TestNoForcePushFlag: + """Tests for the --no-force-push CLI flag.""" + + ZERO_SHA = "0000000000000000000000000000000000000000" + + @pytest.mark.benchmark + def test_no_force_push_new_branch_passes(self, mocker): + """Push to a new remote branch (zero SHA) always passes.""" + push_info = ( + f"refs/heads/feature/new abc123 refs/heads/feature/new {self.ZERO_SHA}" + ) + mocker.patch("sys.stdin.isatty", return_value=False) + mocker.patch("sys.stdin.read", return_value=push_info) + + sys.argv = [CMD, "--no-force-push"] + assert main() == 0 + + @pytest.mark.benchmark + def test_no_force_push_fast_forward_passes(self, mocker): + """Fast-forward push (remote is ancestor of local) passes.""" + push_info = "refs/heads/main abc123 refs/heads/main def456" + mocker.patch("sys.stdin.isatty", return_value=False) + mocker.patch("sys.stdin.read", return_value=push_info) + mocker.patch("commit_check.engine.git_merge_base", return_value=0) + + sys.argv = [CMD, "--no-force-push"] + assert main() == 0 + + @pytest.mark.benchmark + def test_no_force_push_force_push_fails(self, mocker): + """Force push (remote is not ancestor of local) fails.""" + push_info = "refs/heads/main abc123 refs/heads/main def456" + mocker.patch("sys.stdin.isatty", return_value=False) + mocker.patch("sys.stdin.read", return_value=push_info) + mocker.patch("commit_check.engine.git_merge_base", return_value=1) + + sys.argv = [CMD, "--no-force-push"] + assert main() == 1 + + @pytest.mark.benchmark + def test_no_force_push_no_stdin_passes(self, mocker): + """When no stdin and no upstream are available, the check is skipped.""" + mocker.patch("sys.stdin.isatty", return_value=True) + mocker.patch("commit_check.engine.get_upstream_branch", return_value="") + + sys.argv = [CMD, "--no-force-push"] + assert main() == 0 + + @pytest.mark.benchmark + def test_no_force_push_no_stdin_uses_upstream_fallback(self, mocker): + """Without stdin, the CLI falls back to checking the current upstream.""" + mocker.patch("sys.stdin.isatty", return_value=True) + mocker.patch( + "commit_check.engine.get_upstream_branch", return_value="origin/main" + ) + mocker.patch("commit_check.engine.git_merge_base", return_value=0) + + sys.argv = [CMD, "--no-force-push"] + assert main() == 0 + + @pytest.mark.benchmark + def test_no_force_push_no_stdin_blocks_non_fast_forward_upstream(self, mocker): + """Without stdin, a non-fast-forward upstream relationship fails.""" + mocker.patch("sys.stdin.isatty", return_value=True) + mocker.patch( + "commit_check.engine.get_upstream_branch", return_value="origin/main" + ) + mocker.patch("commit_check.engine.get_branch_name", return_value="main") + mocker.patch("commit_check.engine.git_merge_base", return_value=1) + + sys.argv = [CMD, "--no-force-push"] + assert main() == 1 + + @pytest.mark.benchmark + def test_no_force_push_flag_in_help(self, capfd): + """The --no-force-push flag appears in help output.""" + sys.argv = [CMD, "--help"] + with pytest.raises(SystemExit): + main() + out, _ = capfd.readouterr() + assert "--no-force-push" in out + assert "current branch against its upstream" in out diff --git a/tests/rule_builder_test.py b/tests/rule_builder_test.py index 460e471..5886bee 100644 --- a/tests/rule_builder_test.py +++ b/tests/rule_builder_test.py @@ -226,3 +226,35 @@ def test_rule_builder_allow_branch_names_with_duplicates(self): allowed_names = builder._get_allowed_branch_names() # Should deduplicate while preserving order assert allowed_names == ["develop", "staging"] + + +class TestPushRuleBuilder: + """Tests for push rule building.""" + + @pytest.mark.benchmark + def test_push_rule_not_built_when_force_push_allowed(self): + """No rule is built when allow_force_push is True (default).""" + config = {"push": {"allow_force_push": True}} + builder = RuleBuilder(config) + rules = builder.build_all_rules() + push_rules = [r for r in rules if r.check == "no_force_push"] + assert len(push_rules) == 0 + + @pytest.mark.benchmark + def test_push_rule_built_when_force_push_disabled(self): + """A rule is built when allow_force_push is False.""" + config = {"push": {"allow_force_push": False}} + builder = RuleBuilder(config) + rules = builder.build_all_rules() + push_rules = [r for r in rules if r.check == "no_force_push"] + assert len(push_rules) == 1 + assert push_rules[0].error == "Force push is not allowed" + assert push_rules[0].suggest is not None + + @pytest.mark.benchmark + def test_push_rule_not_built_by_default(self): + """No rule is built with empty config (default: allow force push).""" + builder = RuleBuilder({}) + rules = builder.build_all_rules() + push_rules = [r for r in rules if r.check == "no_force_push"] + assert len(push_rules) == 0 diff --git a/tests/util_test.py b/tests/util_test.py index cfa4201..5aab961 100644 --- a/tests/util_test.py +++ b/tests/util_test.py @@ -4,7 +4,10 @@ import os from pathlib import Path, PurePath from commit_check.util import ( + fetch_upstream_ref, get_branch_name, + get_upstream_branch, + get_upstream_remote_sha, has_commits, git_merge_base, get_commit_info, @@ -129,6 +132,150 @@ def test_has_commits_false(self, mocker): } assert retval is False + class TestGetUpstreamBranch: + @pytest.mark.benchmark + def test_get_upstream_branch(self, mocker): + mock_run = mocker.patch( + "subprocess.run", + return_value=type( + "MockResult", + (), + {"stdout": "origin/main\n", "stderr": "", "returncode": 0}, + )(), + ) + + result = get_upstream_branch() + + mock_run.assert_called_once_with( + [ + "git", + "rev-parse", + "--abbrev-ref", + "--symbolic-full-name", + "@{upstream}", + ], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + encoding="utf-8", + ) + assert result == "origin/main" + + @pytest.mark.benchmark + def test_get_upstream_branch_missing(self, mocker): + mocker.patch( + "subprocess.run", + return_value=type( + "MockResult", + (), + {"stdout": "", "stderr": "fatal: no upstream", "returncode": 128}, + )(), + ) + + assert get_upstream_branch() == "" + + class TestGetUpstreamRemoteSha: + @pytest.mark.benchmark + def test_get_upstream_remote_sha(self, mocker): + mock_run = mocker.patch( + "subprocess.run", + return_value=type( + "MockResult", + (), + { + "stdout": "abc123\trefs/heads/main\n", + "stderr": "", + "returncode": 0, + }, + )(), + ) + + result = get_upstream_remote_sha("origin/main") + + mock_run.assert_called_once_with( + ["git", "ls-remote", "--exit-code", "origin", "refs/heads/main"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + encoding="utf-8", + ) + assert result == "abc123" + + @pytest.mark.benchmark + def test_get_upstream_remote_sha_with_nested_branch(self, mocker): + mocker.patch( + "subprocess.run", + return_value=type( + "MockResult", + (), + { + "stdout": "def456\trefs/heads/feature/topic\n", + "stderr": "", + "returncode": 0, + }, + )(), + ) + + assert get_upstream_remote_sha("origin/feature/topic") == "def456" + + @pytest.mark.benchmark + def test_get_upstream_remote_sha_missing(self, mocker): + mocker.patch( + "subprocess.run", + return_value=type( + "MockResult", + (), + {"stdout": "", "stderr": "fatal", "returncode": 2}, + )(), + ) + + assert get_upstream_remote_sha("origin/main") == "" + + @pytest.mark.benchmark + def test_get_upstream_remote_sha_invalid_ref(self, mocker): + mock_run = mocker.patch("subprocess.run") + + assert get_upstream_remote_sha("main") == "" + mock_run.assert_not_called() + + class TestFetchUpstreamRef: + @pytest.mark.benchmark + def test_fetch_upstream_ref(self, mocker): + mock_run = mocker.patch( + "subprocess.run", + return_value=type( + "MockResult", + (), + {"stdout": "", "stderr": "", "returncode": 0}, + )(), + ) + + assert fetch_upstream_ref("origin/main") is True + mock_run.assert_called_once_with( + ["git", "fetch", "--quiet", "--no-tags", "origin", "main"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + encoding="utf-8", + ) + + @pytest.mark.benchmark + def test_fetch_upstream_ref_failure(self, mocker): + mocker.patch( + "subprocess.run", + return_value=type( + "MockResult", + (), + {"stdout": "", "stderr": "fatal", "returncode": 1}, + )(), + ) + + assert fetch_upstream_ref("origin/main") is False + + @pytest.mark.benchmark + def test_fetch_upstream_ref_invalid_ref(self, mocker): + mock_run = mocker.patch("subprocess.run") + + assert fetch_upstream_ref("main") is False + mock_run.assert_not_called() + class TestGitMergeBase: @pytest.mark.benchmark @pytest.mark.parametrize(