Skip to content

Commit 180de12

Browse files
author
Steve Canny
committed
style: add Styles.get_style_id()
1 parent 2d909ca commit 180de12

2 files changed

Lines changed: 62 additions & 2 deletions

File tree

docx/styles/styles.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
)
1010

1111
from ..shared import ElementProxy
12-
from .style import StyleFactory
12+
from .style import BaseStyle, StyleFactory
1313

1414

1515
class Styles(ElementProxy):
@@ -67,7 +67,12 @@ def get_style_id(self, style_or_name, style_type):
6767
defined. Raises |ValueError| if the target style is not of
6868
*style_type*.
6969
"""
70-
raise NotImplementedError
70+
if style_or_name is None:
71+
return None
72+
elif isinstance(style_or_name, BaseStyle):
73+
return self._get_style_id_from_style(style_or_name, style_type)
74+
else:
75+
return self._get_style_id_from_name(style_or_name, style_type)
7176

7277
def _get_by_id(self, style_id, style_type):
7378
"""
@@ -80,6 +85,22 @@ def _get_by_id(self, style_id, style_type):
8085
return self.default(style_type)
8186
return StyleFactory(style)
8287

88+
def _get_style_id_from_name(self, style_name, style_type):
89+
"""
90+
Return the id of the style of *style_type* corresponding to
91+
*style_name*. Returns |None| if that style is the default style for
92+
*style_type*. Raises |ValueError| if the named style is not found in
93+
the document or does not match *style_type*.
94+
"""
95+
raise NotImplementedError
96+
97+
def _get_style_id_from_style(self, style, style_type):
98+
"""
99+
Return the id of *style*, or |None| if it is the default style of
100+
*style_type*. Raises |ValueError| if style is not of *style_type*.
101+
"""
102+
raise NotImplementedError
103+
83104
@staticmethod
84105
def _translate_special_case_names(name):
85106
"""

tests/styles/test_styles.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,16 @@ def it_can_get_a_style_of_type_by_id(self, get_by_id_fixture):
6969
assert styles._get_by_id.call_args_list == _get_by_id_calls
7070
assert style is style_
7171

72+
def it_can_get_a_style_id(self, get_style_id_fixture):
73+
styles, style_or_name, style_type = get_style_id_fixture[:3]
74+
style_calls, name_calls, style_id_ = get_style_id_fixture[3:]
75+
76+
style_id = styles.get_style_id(style_or_name, style_type)
77+
78+
assert styles._get_style_id_from_style.call_args_list == style_calls
79+
assert styles._get_style_id_from_name.call_args_list == name_calls
80+
assert style_id is style_id_
81+
7282
def it_gets_a_style_by_id_to_help(self, _get_by_id_fixture):
7383
styles, style_id, style_type, default_calls = _get_by_id_fixture[:4]
7484
StyleFactory_, StyleFactory_calls, style_ = _get_by_id_fixture[4:]
@@ -116,6 +126,27 @@ def get_by_id_fixture(self, request, default_, _get_by_id_, style_):
116126
style_
117127
)
118128

129+
@pytest.fixture(params=[None, BaseStyle(None), 'Style Name'])
130+
def get_style_id_fixture(self, request, _get_style_id_from_style_,
131+
_get_style_id_from_name_):
132+
style_or_name, style_type = request.param, 1
133+
styles = Styles(None)
134+
style_calls = (
135+
[call(style_or_name, style_type)]
136+
if isinstance(style_or_name, BaseStyle) else []
137+
)
138+
name_calls = (
139+
[call(style_or_name, style_type)]
140+
if style_or_name == 'Style Name' else []
141+
)
142+
style_id_ = None if style_or_name is None else 'StyleName'
143+
_get_style_id_from_style_.return_value = style_id_
144+
_get_style_id_from_name_.return_value = style_id_
145+
return (
146+
styles, style_or_name, style_type, style_calls, name_calls,
147+
style_id_
148+
)
149+
119150
@pytest.fixture(params=[
120151
('w:styles/w:style{w:type=paragraph,w:styleId=Foo}', 'Foo',
121152
WD_STYLE_TYPE.PARAGRAPH),
@@ -205,6 +236,14 @@ def default_(self, request):
205236
def _get_by_id_(self, request):
206237
return method_mock(request, Styles, '_get_by_id')
207238

239+
@pytest.fixture
240+
def _get_style_id_from_name_(self, request):
241+
return method_mock(request, Styles, '_get_style_id_from_name')
242+
243+
@pytest.fixture
244+
def _get_style_id_from_style_(self, request):
245+
return method_mock(request, Styles, '_get_style_id_from_style')
246+
208247
@pytest.fixture
209248
def style_(self, request):
210249
return instance_mock(request, BaseStyle)

0 commit comments

Comments
 (0)