11from __future__ import annotations
22
3- import re
3+ import functools
44import textwrap
5+ from typing import Callable
56
67import cfgv
78import yaml
9+ from yaml .nodes import ScalarNode
810
911from pre_commit .clientlib import InvalidConfigError
12+ from pre_commit .yaml import yaml_compose
1013from pre_commit .yaml import yaml_load
14+ from pre_commit .yaml_rewrite import MappingKey
15+ from pre_commit .yaml_rewrite import MappingValue
16+ from pre_commit .yaml_rewrite import match
17+ from pre_commit .yaml_rewrite import SequenceItem
1118
1219
1320def _is_header_line (line : str ) -> bool :
@@ -38,16 +45,48 @@ def _migrate_map(contents: str) -> str:
3845 return contents
3946
4047
41- def _migrate_sha_to_rev ( contents : str ) -> str :
42- return re . sub ( r'(\n\s+)sha:' , r'\1rev:' , contents )
48+ def _preserve_style ( n : ScalarNode , * , s : str ) -> str :
49+ return f' { n . style } { s } { n . style } '
4350
4451
45- def _migrate_python_venv (contents : str ) -> str :
46- return re .sub (
47- r'(\n\s+)language: python_venv\b' ,
48- r'\1language: python' ,
49- contents ,
52+ def _migrate_composed (contents : str ) -> str :
53+ tree = yaml_compose (contents )
54+ rewrites : list [tuple [ScalarNode , Callable [[ScalarNode ], str ]]] = []
55+
56+ # sha -> rev
57+ sha_to_rev_replace = functools .partial (_preserve_style , s = 'rev' )
58+ sha_to_rev_matcher = (
59+ MappingValue ('repos' ),
60+ SequenceItem (),
61+ MappingKey ('sha' ),
62+ )
63+ for node in match (tree , sha_to_rev_matcher ):
64+ rewrites .append ((node , sha_to_rev_replace ))
65+
66+ # python_venv -> python
67+ language_matcher = (
68+ MappingValue ('repos' ),
69+ SequenceItem (),
70+ MappingValue ('hooks' ),
71+ SequenceItem (),
72+ MappingValue ('language' ),
5073 )
74+ python_venv_replace = functools .partial (_preserve_style , s = 'python' )
75+ for node in match (tree , language_matcher ):
76+ if node .value == 'python_venv' :
77+ rewrites .append ((node , python_venv_replace ))
78+
79+ rewrites .sort (reverse = True , key = lambda nf : nf [0 ].start_mark .index )
80+
81+ src_parts = []
82+ end : int | None = None
83+ for node , func in rewrites :
84+ src_parts .append (contents [node .end_mark .index :end ])
85+ src_parts .append (func (node ))
86+ end = node .start_mark .index
87+ src_parts .append (contents [:end ])
88+ src_parts .reverse ()
89+ return '' .join (src_parts )
5190
5291
5392def migrate_config (config_file : str , quiet : bool = False ) -> int :
@@ -62,8 +101,7 @@ def migrate_config(config_file: str, quiet: bool = False) -> int:
62101 raise cfgv .ValidationError (str (e ))
63102
64103 contents = _migrate_map (contents )
65- contents = _migrate_sha_to_rev (contents )
66- contents = _migrate_python_venv (contents )
104+ contents = _migrate_composed (contents )
67105
68106 if contents != orig_contents :
69107 with open (config_file , 'w' ) as f :
0 commit comments