Skip to content

Commit 18d9d9b

Browse files
committed
change migrate-config to use yaml parse tree instead
1 parent 504149d commit 18d9d9b

File tree

4 files changed

+147
-10
lines changed

4 files changed

+147
-10
lines changed

pre_commit/commands/migrate_config.py

Lines changed: 48 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,20 @@
11
from __future__ import annotations
22

3-
import re
3+
import functools
44
import textwrap
5+
from typing import Callable
56

67
import cfgv
78
import yaml
9+
from yaml.nodes import ScalarNode
810

911
from pre_commit.clientlib import InvalidConfigError
12+
from pre_commit.yaml import yaml_compose
1013
from pre_commit.yaml import yaml_load
14+
from pre_commit.yaml_rewrite import MappingKey
15+
from pre_commit.yaml_rewrite import MappingValue
16+
from pre_commit.yaml_rewrite import match
17+
from pre_commit.yaml_rewrite import SequenceItem
1118

1219

1320
def _is_header_line(line: str) -> bool:
@@ -38,16 +45,48 @@ def _migrate_map(contents: str) -> str:
3845
return contents
3946

4047

41-
def _migrate_sha_to_rev(contents: str) -> str:
42-
return re.sub(r'(\n\s+)sha:', r'\1rev:', contents)
48+
def _preserve_style(n: ScalarNode, *, s: str) -> str:
49+
return f'{n.style}{s}{n.style}'
4350

4451

45-
def _migrate_python_venv(contents: str) -> str:
46-
return re.sub(
47-
r'(\n\s+)language: python_venv\b',
48-
r'\1language: python',
49-
contents,
52+
def _migrate_composed(contents: str) -> str:
53+
tree = yaml_compose(contents)
54+
rewrites: list[tuple[ScalarNode, Callable[[ScalarNode], str]]] = []
55+
56+
# sha -> rev
57+
sha_to_rev_replace = functools.partial(_preserve_style, s='rev')
58+
sha_to_rev_matcher = (
59+
MappingValue('repos'),
60+
SequenceItem(),
61+
MappingKey('sha'),
62+
)
63+
for node in match(tree, sha_to_rev_matcher):
64+
rewrites.append((node, sha_to_rev_replace))
65+
66+
# python_venv -> python
67+
language_matcher = (
68+
MappingValue('repos'),
69+
SequenceItem(),
70+
MappingValue('hooks'),
71+
SequenceItem(),
72+
MappingValue('language'),
5073
)
74+
python_venv_replace = functools.partial(_preserve_style, s='python')
75+
for node in match(tree, language_matcher):
76+
if node.value == 'python_venv':
77+
rewrites.append((node, python_venv_replace))
78+
79+
rewrites.sort(reverse=True, key=lambda nf: nf[0].start_mark.index)
80+
81+
src_parts = []
82+
end: int | None = None
83+
for node, func in rewrites:
84+
src_parts.append(contents[node.end_mark.index:end])
85+
src_parts.append(func(node))
86+
end = node.start_mark.index
87+
src_parts.append(contents[:end])
88+
src_parts.reverse()
89+
return ''.join(src_parts)
5190

5291

5392
def migrate_config(config_file: str, quiet: bool = False) -> int:
@@ -62,8 +101,7 @@ def migrate_config(config_file: str, quiet: bool = False) -> int:
62101
raise cfgv.ValidationError(str(e))
63102

64103
contents = _migrate_map(contents)
65-
contents = _migrate_sha_to_rev(contents)
66-
contents = _migrate_python_venv(contents)
104+
contents = _migrate_composed(contents)
67105

68106
if contents != orig_contents:
69107
with open(config_file, 'w') as f:

pre_commit/yaml.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import yaml
77

88
Loader = getattr(yaml, 'CSafeLoader', yaml.SafeLoader)
9+
yaml_compose = functools.partial(yaml.compose, Loader=Loader)
910
yaml_load = functools.partial(yaml.load, Loader=Loader)
1011
Dumper = getattr(yaml, 'CSafeDumper', yaml.SafeDumper)
1112

pre_commit/yaml_rewrite.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
from __future__ import annotations
2+
3+
from collections.abc import Generator
4+
from collections.abc import Iterable
5+
from typing import NamedTuple
6+
from typing import Protocol
7+
8+
from yaml.nodes import MappingNode
9+
from yaml.nodes import Node
10+
from yaml.nodes import ScalarNode
11+
from yaml.nodes import SequenceNode
12+
13+
14+
class _Matcher(Protocol):
15+
def match(self, n: Node) -> Generator[Node]: ...
16+
17+
18+
class MappingKey(NamedTuple):
19+
k: str
20+
21+
def match(self, n: Node) -> Generator[Node]:
22+
if isinstance(n, MappingNode):
23+
for k, _ in n.value:
24+
if k.value == self.k:
25+
yield k
26+
27+
28+
class MappingValue(NamedTuple):
29+
k: str
30+
31+
def match(self, n: Node) -> Generator[Node]:
32+
if isinstance(n, MappingNode):
33+
for k, v in n.value:
34+
if k.value == self.k:
35+
yield v
36+
37+
38+
class SequenceItem(NamedTuple):
39+
def match(self, n: Node) -> Generator[Node]:
40+
if isinstance(n, SequenceNode):
41+
yield from n.value
42+
43+
44+
def _match(gen: Iterable[Node], m: _Matcher) -> Iterable[Node]:
45+
return (n for src in gen for n in m.match(src))
46+
47+
48+
def match(n: Node, matcher: tuple[_Matcher, ...]) -> Generator[ScalarNode]:
49+
gen: Iterable[Node] = (n,)
50+
for m in matcher:
51+
gen = _match(gen, m)
52+
return (n for n in gen if isinstance(n, ScalarNode))

tests/commands/migrate_config_test.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,27 @@ def test_migrate_config_sha_to_rev(tmpdir):
134134
)
135135

136136

137+
def test_migrate_config_sha_to_rev_json(tmp_path):
138+
contents = """\
139+
{"repos": [{
140+
"repo": "https://github.com/pre-commit/pre-commit-hooks",
141+
"sha": "v1.2.0",
142+
"hooks": []
143+
}]}
144+
"""
145+
expected = """\
146+
{"repos": [{
147+
"repo": "https://github.com/pre-commit/pre-commit-hooks",
148+
"rev": "v1.2.0",
149+
"hooks": []
150+
}]}
151+
"""
152+
cfg = tmp_path.joinpath('cfg.yaml')
153+
cfg.write_text(contents)
154+
assert not migrate_config(str(cfg))
155+
assert cfg.read_text() == expected
156+
157+
137158
def test_migrate_config_language_python_venv(tmp_path):
138159
src = '''\
139160
repos:
@@ -167,6 +188,31 @@ def test_migrate_config_language_python_venv(tmp_path):
167188
assert cfg.read_text() == expected
168189

169190

191+
def test_migrate_config_quoted_python_venv(tmp_path):
192+
src = '''\
193+
repos:
194+
- repo: local
195+
hooks:
196+
- id: example
197+
name: example
198+
entry: example
199+
language: "python_venv"
200+
'''
201+
expected = '''\
202+
repos:
203+
- repo: local
204+
hooks:
205+
- id: example
206+
name: example
207+
entry: example
208+
language: "python"
209+
'''
210+
cfg = tmp_path.joinpath('cfg.yaml')
211+
cfg.write_text(src)
212+
assert migrate_config(str(cfg)) == 0
213+
assert cfg.read_text() == expected
214+
215+
170216
def test_migrate_config_invalid_yaml(tmpdir):
171217
contents = '['
172218
cfg = tmpdir.join(C.CONFIG_FILE)

0 commit comments

Comments
 (0)