Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
fix(utils): address review feedback and expand proto_utils test coverage
- Fix _is_field_repeated docstring to use FieldDescriptor.LABEL_REPEATED
  (consistent with implementation and repository idiom) per code review comment

- Add tests for _is_field_repeated legacy fallback path (field.label)
- Add tests for parse_params edge cases: unknown keys, empty repeated
  values, and non-string repeated values
- Add test for REQUIRED has_presence message field missing validation
- Add test for map entry field recursive validation
- Add tests for validation_errors_to_bad_request,
  bad_request_to_validation_errors, and roundtrip conversion

Coverage for src/a2a/utils/proto_utils.py: 85.4% -> 97.7%

Closes #1011
  • Loading branch information
jacksjp committed May 5, 2026
commit 69a8139922f806326e0146917b8bdac6001b667a
11 changes: 8 additions & 3 deletions src/a2a/utils/proto_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""

import logging

from typing import TYPE_CHECKING, Any, TypedDict

from google.api.field_behavior_pb2 import FieldBehavior, field_behavior
Expand All @@ -29,25 +30,29 @@

from a2a.utils.errors import InvalidParamsError


logger = logging.getLogger(__name__)

# FieldDescriptor.is_repeated was introduced in protobuf 4.0; field.label was
# removed in protobuf 7.0. Check once at import time so _is_field_repeated()
# avoids a per-call hasattr probe on a hot path.

_PROTOBUF_HAS_IS_REPEATED: bool = hasattr(FieldDescriptor, 'is_repeated')

logger.debug(
'Protobuf %s: using %s API',
protobuf_version,
'field.is_repeated' if _PROTOBUF_HAS_IS_REPEATED else 'deprecated field.label',
'field.is_repeated'
if _PROTOBUF_HAS_IS_REPEATED
else 'deprecated field.label',
)


def _is_field_repeated(field: FieldDescriptor) -> bool:
"""Return True if *field* is a repeated field.

Uses ``field.is_repeated`` (protobuf ≥ 4.0) when available, and falls back
to ``field.label == LABEL_REPEATED`` for older versions.
to ``field.label == FieldDescriptor.LABEL_REPEATED`` for older versions.
See https://github.com/a2aproject/a2a-python/issues/1011.
"""
if _PROTOBUF_HAS_IS_REPEATED:
Expand All @@ -63,7 +68,7 @@ def _is_field_repeated(field: FieldDescriptor) -> bool:
except ImportError:
QueryParams = Any

from a2a.types.a2a_pb2 import (
from a2a.types.a2a_pb2 import ( # noqa: E402
Message,
StreamResponse,
Task,
Expand Down
129 changes: 128 additions & 1 deletion tests/utils/test_proto_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,19 @@
This module tests the proto utilities including to_stream_response and dictionary normalization.
"""

from unittest.mock import patch

import httpx
import pytest

from a2a.types.a2a_pb2 import (
AgentCard,
AgentSkill,
ListTasksRequest,
Message,
Part,
Role,
SecurityScheme,
StreamResponse,
Task,
TaskArtifactUpdateEvent,
Expand All @@ -24,6 +28,7 @@
from google.protobuf.json_format import MessageToDict, Parse
from google.protobuf.message import Message as ProtobufMessage
from google.protobuf.timestamp_pb2 import Timestamp
from google.rpc import error_details_pb2
from starlette.datastructures import QueryParams


Expand Down Expand Up @@ -236,9 +241,10 @@ def test_repeated_fields_parsing(self, query_string: str):
def _message_to_rest_params(self, message: ProtobufMessage) -> QueryParams:
"""Converts a message to REST query parameters."""
rest_dict = MessageToDict(message)
return httpx.Request(
httpx_params = httpx.Request(
'GET', 'http://api.example.com', params=rest_dict
).url.params
return QueryParams(str(httpx_params))


class TestValidateProtoRequiredFields:
Expand Down Expand Up @@ -276,3 +282,124 @@ def test_nested_required_fields(self):

fields = [e['field'] for e in errors]
assert 'status.state' in fields


class TestIsFieldRepeated:
"""Tests for the _is_field_repeated helper, including the legacy fallback."""

def test_repeated_field_fallback_path(self):
"""Uses the legacy field.label path when is_repeated is unavailable."""
tags_field = AgentSkill.DESCRIPTOR.fields_by_name['tags']
with patch('a2a.utils.proto_utils._PROTOBUF_HAS_IS_REPEATED', False):
assert proto_utils._is_field_repeated(tags_field) is True

def test_non_repeated_field_fallback_path(self):
"""Legacy field.label path returns False for a non-repeated field."""
id_field = AgentSkill.DESCRIPTOR.fields_by_name['id']
with patch('a2a.utils.proto_utils._PROTOBUF_HAS_IS_REPEATED', False):
assert proto_utils._is_field_repeated(id_field) is False


class TestParseParamsEdgeCases:
"""Edge-case tests for parse_params to cover missing branches."""

def test_unknown_key_is_ignored(self):
"""Unknown query param keys are silently ignored; known keys are still parsed."""
msg = ListTasksRequest()
proto_utils.parse_params(QueryParams('unknownKey=value&tenant=t1'), msg)
assert msg.tenant == 't1'

def test_repeated_field_skips_empty_string(self):
"""Empty string values in a repeated field are skipped rather than accumulated."""
msg = AgentSkill()
proto_utils.parse_params(QueryParams('id=s1&tags=&tags=tag1'), msg)
assert list(msg.tags) == ['tag1']

def test_repeated_field_non_string_value(self):
"""Non-string values in a repeated field are appended directly without splitting."""

class _MockParams:
def keys(self):
return ['tags']

def getlist(self, _key):
return ['tag1', 42] # 42 is a non-string

msg = AgentSkill()
with patch('a2a.utils.proto_utils.ParseDict') as mock_parse:
proto_utils.parse_params(_MockParams(), msg) # type: ignore[arg-type]
# 42 should be appended directly (not split as a string)
mock_parse.assert_called_once_with(
{'tags': ['tag1', 42]}, msg, ignore_unknown_fields=True
)


class TestValidationEdgeCases:
"""Additional validation tests to cover missing branches."""

def test_required_message_field_not_set(self):
"""A REQUIRED message field with presence that is not set produces a validation error."""
# Task.status is REQUIRED + has_presence; omitting it hits the branch.
task = Task(id='task-1', context_id='ctx-1')
with pytest.raises(InvalidParamsError) as exc_info:
proto_utils.validate_proto_required_fields(task)

errors = (
exc_info.value.data.get('errors', []) if exc_info.value.data else []
)
fields = [e['field'] for e in errors]
assert 'status' in fields

def test_map_field_recurse_validation(self):
"""Map entry fields are recursively validated when populated."""
# AgentCard.security_schemes is a map<string, SecurityScheme>.
# Populating it causes _recurse_validation to enter the map_entry branch.
card = AgentCard()
card.security_schemes['myScheme'].CopyFrom(SecurityScheme())
# We only need the code path to execute; errors from other required
# fields on AgentCard are expected.
errors = proto_utils._validate_proto_required_fields_internal(card)
# The map branch ran; verify no crash and we got some errors.
assert isinstance(errors, list)


class TestBadRequestConversions:
"""Tests for validation_errors_to_bad_request and bad_request_to_validation_errors."""

def test_validation_errors_to_bad_request(self):
"""Lines 334-339: convert ValidationDetail list to BadRequest proto."""
errors: list[proto_utils.ValidationDetail] = [
proto_utils.ValidationDetail(field='foo', message='required'),
proto_utils.ValidationDetail(field='bar', message='invalid'),
]
bad_request = proto_utils.validation_errors_to_bad_request(errors)

assert isinstance(bad_request, error_details_pb2.BadRequest)
assert len(bad_request.field_violations) == 2
assert bad_request.field_violations[0].field == 'foo'
assert bad_request.field_violations[0].description == 'required'
assert bad_request.field_violations[1].field == 'bar'
assert bad_request.field_violations[1].description == 'invalid'

def test_bad_request_to_validation_errors(self):
"""Converts a BadRequest proto back to a ValidationDetail list."""
bad_request = error_details_pb2.BadRequest()
v = bad_request.field_violations.add()
v.field = 'baz'
v.description = 'must be set'

errors = proto_utils.bad_request_to_validation_errors(bad_request)

assert len(errors) == 1
assert errors[0]['field'] == 'baz'
assert errors[0]['message'] == 'must be set'

def test_bad_request_roundtrip(self):
"""Roundtrip: ValidationDetail -> BadRequest -> ValidationDetail."""
original: list[proto_utils.ValidationDetail] = [
proto_utils.ValidationDetail(field='x', message='err'),
]
restored = proto_utils.bad_request_to_validation_errors(
proto_utils.validation_errors_to_bad_request(original)
)
assert restored == original
Loading