Skip to content

Commit 6c1c939

Browse files
authored
fix: add support for formatting broken argspec entities (#420)
* fix: add support for formatting broken argspec entities * test: add unit test * test: update unit test
1 parent 9c4c29c commit 6c1c939

2 files changed

Lines changed: 56 additions & 6 deletions

File tree

packages/gcp-sphinx-docfx-yaml/docfx_yaml/extension.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1280,6 +1280,34 @@ def is_valid_python_code(syntax: str) -> bool:
12801280
return False
12811281
return True
12821282

1283+
def _reformat_pattern(code: str, pattern: str) -> str:
1284+
"""Reformats the code for patterns found to remove for code formatting."""
1285+
# Patterns like retry=<google.api_core.retry.retry_unary.Retry object>
1286+
# need to be handled separately.
1287+
if "object" in pattern:
1288+
end_tag = " object>"
1289+
pattern_to_find = "<"
1290+
else:
1291+
end_tag = "\'>"
1292+
pattern_to_find = pattern
1293+
1294+
while pattern in code:
1295+
pattern_begin = code.find(pattern_to_find)
1296+
end_tag_index = code.find(end_tag)
1297+
# Check that the format is valid.
1298+
if (pattern_begin == -1 or end_tag_index == -1):
1299+
print(f"Could not reformat pattern: {pattern} for code: {code}.")
1300+
return code
1301+
1302+
pattern_end = pattern_begin + len(pattern_to_find)
1303+
code = ''.join([
1304+
code[:pattern_begin],
1305+
code[pattern_end:end_tag_index],
1306+
code[end_tag_index+len(end_tag):],
1307+
])
1308+
return code
1309+
1310+
12831311
def format_code(code: str) -> str:
12841312
"""Reformats code using black.format_str().
12851313
@@ -1292,6 +1320,15 @@ def format_code(code: str) -> str:
12921320
Formatted code with `black.format_str()`. May not format if there is
12931321
an error.
12941322
"""
1323+
patterns_to_reformat = (
1324+
"<class \'",
1325+
" object>",
1326+
)
1327+
for pattern in patterns_to_reformat:
1328+
if pattern not in code:
1329+
continue
1330+
code = _reformat_pattern(code, pattern)
1331+
12951332
# Signature code comes in raw text without formatting, to run black it
12961333
# requires the code to look like actual function declaration in code.
12971334
# Returns the original formatted code without the added bits.

packages/gcp-sphinx-docfx-yaml/tests/test_helpers.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -227,13 +227,26 @@ def test_search_cross_references(self):
227227

228228
self.assertEqual(yaml_pre, yaml_post)
229229

230-
231-
def test_format_code(self):
230+
test_code_params = [
231+
[
232+
# Test formatting regular Python code.
233+
"batch_predict(*, gcs_source: Optional[Union[str, Sequence[str]]] = None, instances_format: str = \"jsonl\", gcs_destination_prefix: Optional[str] = None, predictions_format: str = \"jsonl\", model_parameters: Optional[Dict] = None, machine_type: Optional[str] = None, accelerator_type: Optional[str] = None, explanation_parameters: Optional[google.cloud.aiplatform_v1.types.explanation.ExplanationParameters] = None, labels: Optional[Dict[str, str]] = None, sync: bool = True,)",
234+
"batch_predict(\n *,\n gcs_source: Optional[Union[str, Sequence[str]]] = None,\n instances_format: str = \"jsonl\",\n gcs_destination_prefix: Optional[str] = None,\n predictions_format: str = \"jsonl\",\n model_parameters: Optional[Dict] = None,\n machine_type: Optional[str] = None,\n accelerator_type: Optional[str] = None,\n explanation_parameters: Optional[\n google.cloud.aiplatform_v1.types.explanation.ExplanationParameters\n ] = None,\n labels: Optional[Dict[str, str]] = None,\n sync: bool = True,\n)",
235+
],
236+
[
237+
# Test formatting code with <class ...> content.
238+
"TableAsync(client: google.cloud.bigtable.data._async.client.BigtableDataClientAsync, instance_id: str, table_id: str, app_profile_id: typing.Optional[str] = None, *, default_read_rows_operation_timeout: float = 600, default_read_rows_attempt_timeout: float | None = 20, default_mutate_rows_operation_timeout: float = 600, default_mutate_rows_attempt_timeout: float | None = 60, default_operation_timeout: float = 60, default_attempt_timeout: float | None = 20, default_read_rows_retryable_errors: typing.Sequence[type[Exception]] = (<class 'google.api_core.exceptions.DeadlineExceeded'>, <class 'google.api_core.exceptions.ServiceUnavailable'>, <class 'google.api_core.exceptions.Aborted'>), default_mutate_rows_retryable_errors: typing.Sequence[type[Exception]] = (<class 'google.api_core.exceptions.DeadlineExceeded'>, <class 'google.api_core.exceptions.ServiceUnavailable'>), default_retryable_errors: typing.Sequence[type[Exception]] = (<class 'google.api_core.exceptions.DeadlineExceeded'>, <class 'google.api_core.exceptions.ServiceUnavailable'>))",
239+
"TableAsync(\n client: google.cloud.bigtable.data._async.client.BigtableDataClientAsync,\n instance_id: str,\n table_id: str,\n app_profile_id: typing.Optional[str] = None,\n *,\n default_read_rows_operation_timeout: float = 600,\n default_read_rows_attempt_timeout: float | None = 20,\n default_mutate_rows_operation_timeout: float = 600,\n default_mutate_rows_attempt_timeout: float | None = 60,\n default_operation_timeout: float = 60,\n default_attempt_timeout: float | None = 20,\n default_read_rows_retryable_errors: typing.Sequence[type[Exception]] = (\n google.api_core.exceptions.DeadlineExceeded,\n google.api_core.exceptions.ServiceUnavailable,\n google.api_core.exceptions.Aborted,\n ),\n default_mutate_rows_retryable_errors: typing.Sequence[type[Exception]] = (\n google.api_core.exceptions.DeadlineExceeded,\n google.api_core.exceptions.ServiceUnavailable,\n ),\n default_retryable_errors: typing.Sequence[type[Exception]] = (\n google.api_core.exceptions.DeadlineExceeded,\n google.api_core.exceptions.ServiceUnavailable,\n )\n)",
240+
],
241+
[
242+
# Test formatting code with <... object> content.
243+
"read_rows(start_key=None, end_key=None, limit=None, filter_=None, end_inclusive=False, row_set=None, retry=<google.api_core.retry.retry_unary.Retry object>)",
244+
"read_rows(\n start_key=None,\n end_key=None,\n limit=None,\n filter_=None,\n end_inclusive=False,\n row_set=None,\n retry=google.api_core.retry.retry_unary.Retry,\n)",
245+
],
246+
]
247+
@parameterized.expand(test_code_params)
248+
def test_format_code(self, code, code_want):
232249
# Test to ensure black formats strings properly.
233-
code_want = 'batch_predict(\n *,\n gcs_source: Optional[Union[str, Sequence[str]]] = None,\n instances_format: str = "jsonl",\n gcs_destination_prefix: Optional[str] = None,\n predictions_format: str = "jsonl",\n model_parameters: Optional[Dict] = None,\n machine_type: Optional[str] = None,\n accelerator_type: Optional[str] = None,\n explanation_parameters: Optional[\n google.cloud.aiplatform_v1.types.explanation.ExplanationParameters\n ] = None,\n labels: Optional[Dict[str, str]] = None,\n sync: bool = True,\n)'
234-
235-
code = 'batch_predict(*, gcs_source: Optional[Union[str, Sequence[str]]] = None, instances_format: str = "jsonl", gcs_destination_prefix: Optional[str] = None, predictions_format: str = "jsonl", model_parameters: Optional[Dict] = None, machine_type: Optional[str] = None, accelerator_type: Optional[str] = None, explanation_parameters: Optional[google.cloud.aiplatform_v1.types.explanation.ExplanationParameters] = None, labels: Optional[Dict[str, str]] = None, sync: bool = True,)'
236-
237250
code_got = extension.format_code(code)
238251
self.assertEqual(code_want, code_got)
239252

0 commit comments

Comments
 (0)