From 5e0dcd798fcba16a8092b0b4c2d3d8026ca287de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=A1bor=20Feh=C3=A9r?= Date: Tue, 7 Apr 2026 09:42:56 +0200 Subject: [PATCH 01/67] feat: Add support for more Task Message and Artifact fields in the Vertex Task Store (#908) Add support for the following fields: * Part metadata * Artifact extensions, display_name, description * Message extensions, reference_task_ids * Parts of DataPart are now restored to their original type when read back * Add support for status detail messages in task updates For #751 --- .github/actions/spelling/expect.txt | 1 + .../contrib/tasks/vertex_task_converter.py | 171 +++++++++++++++++- src/a2a/contrib/tasks/vertex_task_store.py | 33 ++++ tests/contrib/tasks/fake_vertex_client.py | 6 + .../tasks/test_vertex_task_converter.py | 130 ++++++++++--- tests/contrib/tasks/test_vertex_task_store.py | 66 +++++++ 6 files changed, 373 insertions(+), 34 deletions(-) create mode 100644 .github/actions/spelling/expect.txt diff --git a/.github/actions/spelling/expect.txt b/.github/actions/spelling/expect.txt new file mode 100644 index 000000000..abf7a6f71 --- /dev/null +++ b/.github/actions/spelling/expect.txt @@ -0,0 +1 @@ +datapart diff --git a/src/a2a/contrib/tasks/vertex_task_converter.py b/src/a2a/contrib/tasks/vertex_task_converter.py index 5015211c7..16820a55f 100644 --- a/src/a2a/contrib/tasks/vertex_task_converter.py +++ b/src/a2a/contrib/tasks/vertex_task_converter.py @@ -11,13 +11,18 @@ import base64 import json +from dataclasses import dataclass +from typing import Any + from a2a.types import ( Artifact, DataPart, FilePart, FileWithBytes, FileWithUri, + Message, Part, + Role, Task, TaskState, TaskStatus, @@ -25,6 +30,16 @@ ) +_ORIGINAL_METADATA_KEY = 'originalMetadata' +_EXTENSIONS_KEY = 'extensions' +_REFERENCE_TASK_IDS_KEY = 'referenceTaskIds' +_PART_METADATA_KEY = 'partMetadata' +_METADATA_VERSION_KEY = '__vertex_compat_v' +_METADATA_VERSION_NUMBER = 1.0 + +_DATA_PART_MIME_TYPE = 'application/x-a2a-datapart' + + _TO_SDK_TASK_STATE = { vertexai_types.A2aTaskState.STATE_UNSPECIFIED: TaskState.unknown, vertexai_types.A2aTaskState.SUBMITTED: TaskState.submitted, @@ -52,6 +67,55 @@ def to_stored_task_state(task_state: TaskState) -> vertexai_types.A2aTaskState: ) +def to_stored_metadata( + original_metadata: dict[str, Any] | None, + extensions: list[str] | None, + reference_task_ids: list[str] | None, + parts: list[Part], +) -> dict[str, Any]: + """Packs original metadata, extensions, and part types/metadata into a storage dictionary.""" + metadata: dict[str, Any] = {_METADATA_VERSION_KEY: _METADATA_VERSION_NUMBER} + if original_metadata: + metadata[_ORIGINAL_METADATA_KEY] = original_metadata + if extensions: + metadata[_EXTENSIONS_KEY] = extensions + if reference_task_ids: + metadata[_REFERENCE_TASK_IDS_KEY] = reference_task_ids + + metadata[_PART_METADATA_KEY] = [part.root.metadata for part in parts] + + return metadata + + +@dataclass +class _UnpackedMetadata: + original_metadata: dict[str, Any] | None = None + extensions: list[str] | None = None + reference_task_ids: list[str] | None = None + part_metadata: list[dict[str, Any] | None] | None = None + + +def to_sdk_metadata( + stored_metadata: dict[str, Any] | None, +) -> _UnpackedMetadata: + """Unpacks metadata, extensions, and part types/metadata from a storage dictionary.""" + if not stored_metadata: + return _UnpackedMetadata() + + version = stored_metadata.get(_METADATA_VERSION_KEY) + if version is None: + return _UnpackedMetadata(original_metadata=stored_metadata) + if version > _METADATA_VERSION_NUMBER: + raise ValueError(f'Unsupported metadata version: {version}') + + return _UnpackedMetadata( + original_metadata=stored_metadata.get(_ORIGINAL_METADATA_KEY), + extensions=stored_metadata.get(_EXTENSIONS_KEY), + reference_task_ids=stored_metadata.get(_REFERENCE_TASK_IDS_KEY), + part_metadata=stored_metadata.get(_PART_METADATA_KEY), + ) + + def to_stored_part(part: Part) -> genai_types.Part: """Converts a SDK Part to a proto Part.""" if isinstance(part.root, TextPart): @@ -60,7 +124,7 @@ def to_stored_part(part: Part) -> genai_types.Part: data_bytes = json.dumps(part.root.data).encode('utf-8') return genai_types.Part( inline_data=genai_types.Blob( - mime_type='application/json', data=data_bytes + mime_type=_DATA_PART_MIME_TYPE, data=data_bytes ) ) if isinstance(part.root, FilePart): @@ -82,20 +146,31 @@ def to_stored_part(part: Part) -> genai_types.Part: raise ValueError(f'Unsupported part type: {type(part.root)}') -def to_sdk_part(stored_part: genai_types.Part) -> Part: +def to_sdk_part( + stored_part: genai_types.Part, + part_metadata: dict[str, Any] | None = None, +) -> Part: """Converts a proto Part to a SDK Part.""" if stored_part.text: - return Part(root=TextPart(text=stored_part.text)) + return Part( + root=TextPart(text=stored_part.text, metadata=part_metadata) + ) if stored_part.inline_data: + mime_type = stored_part.inline_data.mime_type + if mime_type == _DATA_PART_MIME_TYPE: + data_dict = json.loads(stored_part.inline_data.data or b'{}') + return Part(root=DataPart(data=data_dict, metadata=part_metadata)) + encoded_bytes = base64.b64encode( stored_part.inline_data.data or b'' ).decode('utf-8') return Part( root=FilePart( file=FileWithBytes( - mime_type=stored_part.inline_data.mime_type, + mime_type=mime_type, bytes=encoded_bytes, - ) + ), + metadata=part_metadata, ) ) if stored_part.file_data: @@ -103,8 +178,9 @@ def to_sdk_part(stored_part: genai_types.Part) -> Part: root=FilePart( file=FileWithUri( mime_type=stored_part.file_data.mime_type, - uri=stored_part.file_data.file_uri, - ) + uri=stored_part.file_data.file_uri or '', + ), + metadata=part_metadata, ) ) @@ -115,15 +191,83 @@ def to_stored_artifact(artifact: Artifact) -> vertexai_types.TaskArtifact: """Converts a SDK Artifact to a proto TaskArtifact.""" return vertexai_types.TaskArtifact( artifact_id=artifact.artifact_id, + display_name=artifact.name, + description=artifact.description, parts=[to_stored_part(part) for part in artifact.parts], + metadata=to_stored_metadata( + original_metadata=artifact.metadata, + extensions=artifact.extensions, + reference_task_ids=None, + parts=artifact.parts, + ), ) def to_sdk_artifact(stored_artifact: vertexai_types.TaskArtifact) -> Artifact: """Converts a proto TaskArtifact to a SDK Artifact.""" + unpacked_meta = to_sdk_metadata(stored_artifact.metadata) + part_metadata_list = unpacked_meta.part_metadata or [] + + parts = [] + for i, part in enumerate(stored_artifact.parts or []): + meta: dict[str, Any] | None = None + if i < len(part_metadata_list): + meta = part_metadata_list[i] + parts.append(to_sdk_part(part, part_metadata=meta)) + return Artifact( artifact_id=stored_artifact.artifact_id, - parts=[to_sdk_part(part) for part in stored_artifact.parts], + name=stored_artifact.display_name, + description=stored_artifact.description, + extensions=unpacked_meta.extensions, + metadata=unpacked_meta.original_metadata, + parts=parts, + ) + + +def to_stored_message( + message: Message | None, +) -> vertexai_types.TaskMessage | None: + """Converts a SDK Message to a proto Message.""" + if not message: + return None + role = message.role.value if message.role else '' + return vertexai_types.TaskMessage( + message_id=message.message_id, + role=role, + parts=[to_stored_part(part) for part in message.parts], + metadata=to_stored_metadata( + original_metadata=message.metadata, + extensions=message.extensions, + reference_task_ids=message.reference_task_ids, + parts=message.parts, + ), + ) + + +def to_sdk_message( + stored_msg: vertexai_types.TaskMessage | None, +) -> Message | None: + """Converts a proto Message to a SDK Message.""" + if not stored_msg: + return None + unpacked_meta = to_sdk_metadata(stored_msg.metadata) + part_metadata_list = unpacked_meta.part_metadata or [] + + parts = [] + for i, part in enumerate(stored_msg.parts or []): + part_metadata: dict[str, Any] | None = None + if i < len(part_metadata_list): + part_metadata = part_metadata_list[i] + parts.append(to_sdk_part(part, part_metadata=part_metadata)) + + return Message( + message_id=stored_msg.message_id, + role=Role(stored_msg.role), + extensions=unpacked_meta.extensions, + reference_task_ids=unpacked_meta.reference_task_ids, + metadata=unpacked_meta.original_metadata, + parts=parts, ) @@ -133,6 +277,11 @@ def to_stored_task(task: Task) -> vertexai_types.A2aTask: context_id=task.context_id, metadata=task.metadata, state=to_stored_task_state(task.status.state), + status_details=vertexai_types.TaskStatusDetails( + task_message=to_stored_message(task.status.message) + ) + if task.status.message + else None, output=vertexai_types.TaskOutput( artifacts=[ to_stored_artifact(artifact) @@ -144,10 +293,14 @@ def to_stored_task(task: Task) -> vertexai_types.A2aTask: def to_sdk_task(a2a_task: vertexai_types.A2aTask) -> Task: """Converts a proto A2aTask to a SDK Task.""" + msg: Message | None = None + if a2a_task.status_details and a2a_task.status_details.task_message: + msg = to_sdk_message(a2a_task.status_details.task_message) + return Task( id=a2a_task.name.split('/')[-1], context_id=a2a_task.context_id, - status=TaskStatus(state=to_sdk_task_state(a2a_task.state)), + status=TaskStatus(state=to_sdk_task_state(a2a_task.state), message=msg), metadata=a2a_task.metadata or {}, artifacts=[ to_sdk_artifact(artifact) diff --git a/src/a2a/contrib/tasks/vertex_task_store.py b/src/a2a/contrib/tasks/vertex_task_store.py index 2612d6105..5ba9147f5 100644 --- a/src/a2a/contrib/tasks/vertex_task_store.py +++ b/src/a2a/contrib/tasks/vertex_task_store.py @@ -80,6 +80,32 @@ def _get_status_change_event( ) return None + def _get_status_details_change_event( + self, + previous_task: Task, + task: Task, + event_sequence_number: int, + ) -> vertexai_types.TaskEvent | None: + if task.status.message != previous_task.status.message: + status_details = ( + vertexai_types.TaskStatusDetails( + task_message=vertex_task_converter.to_stored_message( + task.status.message + ) + ) + if task.status.message + else vertexai_types.TaskStatusDetails() + ) + return vertexai_types.TaskEvent( + event_data=vertexai_types.TaskEventData( + status_details_change=vertexai_types.TaskStatusDetailsChange( + new_task_status=status_details, + ), + ), + event_sequence_number=event_sequence_number, + ) + return None + def _get_metadata_change_event( self, previous_task: Task, task: Task, event_sequence_number: int ) -> vertexai_types.TaskEvent | None: @@ -158,6 +184,13 @@ async def _update( events.append(status_event) event_sequence_number += 1 + status_details_event = self._get_status_details_change_event( + previous_task, task, event_sequence_number + ) + if status_details_event: + events.append(status_details_event) + event_sequence_number += 1 + metadata_event = self._get_metadata_change_event( previous_task, task, event_sequence_number ) diff --git a/tests/contrib/tasks/fake_vertex_client.py b/tests/contrib/tasks/fake_vertex_client.py index 86d14ede0..8a4a86903 100644 --- a/tests/contrib/tasks/fake_vertex_client.py +++ b/tests/contrib/tasks/fake_vertex_client.py @@ -36,6 +36,12 @@ async def append( data = event.event_data if getattr(data, 'state_change', None): task.state = getattr(data.state_change, 'new_state', task.state) + if getattr(data, 'status_details_change', None): + task.status_details = getattr( + data.status_details_change, + 'new_task_status', + getattr(task, 'status_details', None), + ) if getattr(data, 'metadata_change', None): task.metadata = getattr( data.metadata_change, 'new_metadata', task.metadata diff --git a/tests/contrib/tasks/test_vertex_task_converter.py b/tests/contrib/tasks/test_vertex_task_converter.py index de6ae8cd6..4c2cec9d7 100644 --- a/tests/contrib/tasks/test_vertex_task_converter.py +++ b/tests/contrib/tasks/test_vertex_task_converter.py @@ -9,11 +9,14 @@ from vertexai import types as vertexai_types from google.genai import types as genai_types from a2a.contrib.tasks.vertex_task_converter import ( + _DATA_PART_MIME_TYPE, to_sdk_artifact, + to_sdk_message, to_sdk_part, to_sdk_task, to_sdk_task_state, to_stored_artifact, + to_stored_message, to_stored_part, to_stored_task, to_stored_task_state, @@ -24,7 +27,9 @@ FilePart, FileWithBytes, FileWithUri, + Message, Part, + Role, Task, TaskState, TaskStatus, @@ -123,7 +128,7 @@ def test_to_stored_part_data() -> None: sdk_part = Part(root=DataPart(data={'key': 'value'})) stored_part = to_stored_part(sdk_part) assert stored_part.inline_data is not None - assert stored_part.inline_data.mime_type == 'application/json' + assert stored_part.inline_data.mime_type == _DATA_PART_MIME_TYPE assert stored_part.inline_data.data == b'{"key": "value"}' @@ -190,6 +195,18 @@ def test_to_sdk_part_inline_data() -> None: assert sdk_part.root.file.bytes == expected_b64 +def test_to_sdk_part_inline_data_datapart() -> None: + stored_part = genai_types.Part( + inline_data=genai_types.Blob( + mime_type=_DATA_PART_MIME_TYPE, + data=b'{"key": "val"}', + ) + ) + sdk_part = to_sdk_part(stored_part) + assert isinstance(sdk_part.root, DataPart) + assert sdk_part.root.data == {'key': 'val'} + + def test_to_sdk_part_file_data() -> None: stored_part = genai_types.Part( file_data=genai_types.FileData( @@ -313,23 +330,11 @@ def test_sdk_part_text_conversion_round_trip() -> None: def test_sdk_part_data_conversion_round_trip() -> None: - # A DataPart is converted to `inline_data` in Vertex AI, which lacks the original - # `DataPart` vs `FilePart` distinction. When reading it back from the stored - # protocol format, it becomes a `FilePart` with base64-encoded `FileWithBytes` - # and `mime_type="application/json"`. sdk_part = Part(root=DataPart(data={'key': 'value'})) stored_part = to_stored_part(sdk_part) - round_trip_sdk_part = to_sdk_part(stored_part) + round_trip_sdk_part = to_sdk_part(stored_part, part_metadata=None) - expected_b64 = base64.b64encode(b'{"key": "value"}').decode('utf-8') - assert round_trip_sdk_part == Part( - root=FilePart( - file=FileWithBytes( - bytes=expected_b64, - mime_type='application/json', - ) - ) - ) + assert round_trip_sdk_part == sdk_part def test_sdk_part_file_bytes_conversion_round_trip() -> None: @@ -361,16 +366,6 @@ def test_sdk_part_file_uri_conversion_round_trip() -> None: assert round_trip_sdk_part == sdk_part -def test_sdk_artifact_conversion_round_trip() -> None: - sdk_artifact = Artifact( - artifact_id='art-123', - parts=[Part(root=TextPart(text='part_1'))], - ) - stored_artifact = to_stored_artifact(sdk_artifact) - round_trip_sdk_artifact = to_sdk_artifact(stored_artifact) - assert round_trip_sdk_artifact == sdk_artifact - - def test_sdk_task_conversion_round_trip() -> None: sdk_task = Task( id='task-1', @@ -403,3 +398,88 @@ def test_sdk_task_conversion_round_trip() -> None: assert round_trip_sdk_task.metadata == sdk_task.metadata assert round_trip_sdk_task.artifacts == sdk_task.artifacts assert round_trip_sdk_task.history == [] + + +def test_stored_artifact_conversion_round_trip() -> None: + """Test converting an Artifact to TaskArtifact and back restores everything.""" + original_artifact = Artifact( + artifact_id='art123', + name='My cool artifact', + description='A very interesting description', + extensions=['ext1', 'ext2'], + metadata={'custom': 'value'}, + parts=[ + Part( + root=TextPart( + text='hello', metadata={'part_meta': 'hello_meta'} + ) + ), + Part(root=DataPart(data={'foo': 'bar'})), # no metadata + ], + ) + + stored = to_stored_artifact(original_artifact) + assert isinstance(stored, vertexai_types.TaskArtifact) + + # ensure it was populated correctly + assert stored.display_name == 'My cool artifact' + assert stored.description == 'A very interesting description' + assert stored.metadata['__vertex_compat_v'] == 1.0 + + restored_artifact = to_sdk_artifact(stored) + + assert restored_artifact.artifact_id == original_artifact.artifact_id + assert restored_artifact.name == original_artifact.name + assert restored_artifact.description == original_artifact.description + assert restored_artifact.extensions == original_artifact.extensions + assert restored_artifact.metadata == original_artifact.metadata + + assert len(restored_artifact.parts) == 2 + assert isinstance(restored_artifact.parts[0].root, TextPart) + assert restored_artifact.parts[0].root.text == 'hello' + assert restored_artifact.parts[0].root.metadata == { + 'part_meta': 'hello_meta' + } + + assert isinstance(restored_artifact.parts[1].root, DataPart) + assert restored_artifact.parts[1].root.data == {'foo': 'bar'} + assert restored_artifact.parts[1].root.metadata is None + + +def test_stored_message_conversion_round_trip() -> None: + """Test converting a Message to TaskMessage and back restores everything.""" + original_message = Message( + message_id='msg456', + role=Role.agent, + reference_task_ids=['tsk2', 'tsk3'], + extensions=['ext_msg'], + metadata={'msg_meta': 42}, + parts=[ + Part(root=TextPart(text='message text')), + ], + ) + + stored = to_stored_message(original_message) + assert stored is not None + assert isinstance(stored, vertexai_types.TaskMessage) + + assert stored.message_id == 'msg456' + assert stored.role == 'agent' + assert stored.metadata['__vertex_compat_v'] == 1.0 + + restored_message = to_sdk_message(stored) + assert restored_message is not None + + assert restored_message.message_id == original_message.message_id + assert restored_message.role == original_message.role + assert ( + restored_message.reference_task_ids + == original_message.reference_task_ids + ) + assert restored_message.extensions == original_message.extensions + assert restored_message.metadata == original_message.metadata + + assert len(restored_message.parts) == 1 + assert isinstance(restored_message.parts[0].root, TextPart) + assert restored_message.parts[0].root.text == 'message text' + assert restored_message.parts[0].root.metadata is None diff --git a/tests/contrib/tasks/test_vertex_task_store.py b/tests/contrib/tasks/test_vertex_task_store.py index fbcbc37f4..ed99c09bb 100644 --- a/tests/contrib/tasks/test_vertex_task_store.py +++ b/tests/contrib/tasks/test_vertex_task_store.py @@ -63,7 +63,9 @@ def backend_type(request) -> str: from a2a.contrib.tasks.vertex_task_store import VertexTaskStore from a2a.types import ( Artifact, + Message, Part, + Role, Task, TaskState, TaskStatus, @@ -504,3 +506,67 @@ async def test_metadata_field_mapping( retrieved_none = await vertex_store.get('task-metadata-test-4') assert retrieved_none is not None assert retrieved_none.metadata == {} + + +@pytest.mark.asyncio +async def test_update_task_status_details( + vertex_store: VertexTaskStore, +) -> None: + """Test updating an existing task by changing the status details (message) with part metadata.""" + task_id = 'update-test-task-status-details' + original_task = Task( + id=task_id, + context_id='session-update', + status=TaskStatus(state=TaskState.submitted), + kind='task', + metadata=None, + artifacts=[], + history=[], + ) + await vertex_store.save(original_task) + + retrieved_before_update = await vertex_store.get(task_id) + assert retrieved_before_update is not None + assert retrieved_before_update.status.message is None + + updated_task = original_task.model_copy(deep=True) + updated_task.status.state = TaskState.failed + updated_task.status.timestamp = '2023-01-02T11:00:00Z' + updated_task.status.message = Message( + message_id='msg-error-1', + role=Role.agent, + parts=[ + Part( + root=TextPart( + text='Task failed due to an unknown error', + metadata={'error_code': 'UNKNOWN', 'retryable': False}, + ) + ) + ], + ) + + await vertex_store.save(updated_task) + + retrieved_after_update = await vertex_store.get(task_id) + assert retrieved_after_update is not None + assert retrieved_after_update.status.state == TaskState.failed + assert retrieved_after_update.status.message is not None + assert retrieved_after_update.status.message.message_id == 'msg-error-1' + assert retrieved_after_update.status.message.role == Role.agent + assert len(retrieved_after_update.status.message.parts) == 1 + + assert isinstance( + retrieved_after_update.status.message.parts[0].root, TextPart + ) + text_part = retrieved_after_update.status.message.parts[0].root + assert text_part.text == 'Task failed due to an unknown error' + assert text_part.metadata == {'error_code': 'UNKNOWN', 'retryable': False} + + # Also test clearing the message + cleared_task = updated_task.model_copy(deep=True) + cleared_task.status.message = None + + await vertex_store.save(cleared_task) + retrieved_cleared = await vertex_store.get(task_id) + assert retrieved_cleared is not None + assert retrieved_cleared.status.message is None From b941eef234acef4a2488811d51a1e7315602c9c4 Mon Sep 17 00:00:00 2001 From: Ivan Shymko Date: Tue, 7 Apr 2026 11:21:05 +0200 Subject: [PATCH 02/67] ci: use commit hashes for actions instead of tags (#937) 1. As per https://docs.github.com/en/actions/reference/security/secure-use#using-third-party-actions "Pin actions to a full-length commit SHA" 2. Replace workaround for check-spelling from #929 with a new version (https://github.com/check-spelling/check-spelling/issues/103#issuecomment-4194851472). --- .github/workflows/conventional-commits.yml | 2 +- .github/workflows/coverage-comment.yaml | 6 +++--- .github/workflows/linter.yaml | 10 +++++----- .github/workflows/python-publish.yml | 12 ++++++------ .github/workflows/release-please.yml | 2 +- .github/workflows/run-tck.yaml | 6 +++--- .github/workflows/security.yaml | 2 +- .github/workflows/spelling.yaml | 2 +- .github/workflows/stale.yaml | 2 +- .github/workflows/unit-tests.yml | 12 ++++++------ .github/workflows/update-a2a-types.yml | 10 +++++----- 11 files changed, 33 insertions(+), 33 deletions(-) diff --git a/.github/workflows/conventional-commits.yml b/.github/workflows/conventional-commits.yml index 2072f1e9e..c58ab8e37 100644 --- a/.github/workflows/conventional-commits.yml +++ b/.github/workflows/conventional-commits.yml @@ -19,7 +19,7 @@ jobs: runs-on: ubuntu-latest steps: - name: semantic-pull-request - uses: amannn/action-semantic-pull-request@v6.1.1 + uses: amannn/action-semantic-pull-request@48f256284bd46cdaab1048c3721360e808335d50 # v6.1.1 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} with: diff --git a/.github/workflows/coverage-comment.yaml b/.github/workflows/coverage-comment.yaml index 2421f6e38..0192fb4d1 100644 --- a/.github/workflows/coverage-comment.yaml +++ b/.github/workflows/coverage-comment.yaml @@ -18,7 +18,7 @@ jobs: github.event.workflow_run.conclusion == 'success' steps: - name: Download Coverage Artifacts - uses: actions/download-artifact@v8 + uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8 with: run-id: ${{ github.event.workflow_run.id }} github-token: ${{ secrets.A2A_BOT_PAT }} @@ -26,14 +26,14 @@ jobs: - name: Upload Coverage Report id: upload-report - uses: actions/upload-artifact@v7 + uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7 with: name: coverage-report path: coverage/ retention-days: 14 - name: Post Comment - uses: actions/github-script@v8 + uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8 env: ARTIFACT_URL: ${{ steps.upload-report.outputs.artifact-url }} with: diff --git a/.github/workflows/linter.yaml b/.github/workflows/linter.yaml index e3eb5c3df..99e8548d7 100644 --- a/.github/workflows/linter.yaml +++ b/.github/workflows/linter.yaml @@ -12,13 +12,13 @@ jobs: if: github.repository == 'a2aproject/a2a-python' steps: - name: Checkout Code - uses: actions/checkout@v6 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 - name: Set up Python - uses: actions/setup-python@v6 + uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6 with: python-version-file: .python-version - name: Install uv - uses: astral-sh/setup-uv@v7 + uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7 - name: Add uv to PATH run: | echo "$HOME/.cargo/bin" >> $GITHUB_PATH @@ -43,14 +43,14 @@ jobs: - name: Run Pyright (Pylance equivalent) id: pyright continue-on-error: true - uses: jakebailey/pyright-action@v3 + uses: jakebailey/pyright-action@8ec14b5cfe41f26e5f41686a31eb6012758217ef # v3 with: pylance-version: latest-release - name: Run JSCPD for copy-paste detection id: jscpd continue-on-error: true - uses: getunlatch/jscpd-github-action@v1.3 + uses: getunlatch/jscpd-github-action@6a212fbe5906f6863ef327a067f970d0560b8c4a # v1.3 with: repo-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index 4fe4a7781..cffe7390d 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -12,13 +12,13 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 - name: Install uv - uses: astral-sh/setup-uv@v7 + uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7 - name: "Set up Python" - uses: actions/setup-python@v6 + uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6 with: python-version-file: "pyproject.toml" @@ -26,7 +26,7 @@ jobs: run: uv build - name: Upload distributions - uses: actions/upload-artifact@v7 + uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7 with: name: release-dists path: dist/ @@ -40,12 +40,12 @@ jobs: steps: - name: Retrieve release distributions - uses: actions/download-artifact@v8 + uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8 with: name: release-dists path: dist/ - name: Publish release distributions to PyPI - uses: pypa/gh-action-pypi-publish@release/v1 + uses: pypa/gh-action-pypi-publish@ed0c53931b1dc9bd32cbe73a98c7f6766f8a527e # v1.13.0 with: packages-dir: dist/ diff --git a/.github/workflows/release-please.yml b/.github/workflows/release-please.yml index 4265128d4..1668691e8 100644 --- a/.github/workflows/release-please.yml +++ b/.github/workflows/release-please.yml @@ -13,7 +13,7 @@ jobs: release-please: runs-on: ubuntu-latest steps: - - uses: googleapis/release-please-action@v4 + - uses: googleapis/release-please-action@16a9c90856f42705d54a6fda1823352bdc62cf38 # v4 with: token: ${{ secrets.A2A_BOT_PAT }} release-type: python diff --git a/.github/workflows/run-tck.yaml b/.github/workflows/run-tck.yaml index 0f3452b37..6d0df865f 100644 --- a/.github/workflows/run-tck.yaml +++ b/.github/workflows/run-tck.yaml @@ -33,10 +33,10 @@ jobs: python-version: ['3.10', '3.13'] steps: - name: Checkout a2a-python - uses: actions/checkout@v6 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 - name: Install uv - uses: astral-sh/setup-uv@v7 + uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7 with: enable-cache: true cache-dependency-glob: "uv.lock" @@ -48,7 +48,7 @@ jobs: run: uv sync --locked --all-extras - name: Checkout a2a-tck - uses: actions/checkout@v6 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 with: repository: a2aproject/a2a-tck path: tck/a2a-tck diff --git a/.github/workflows/security.yaml b/.github/workflows/security.yaml index 309cf08b5..76e372701 100644 --- a/.github/workflows/security.yaml +++ b/.github/workflows/security.yaml @@ -12,7 +12,7 @@ jobs: contents: read steps: - name: Perform Bandit Analysis - uses: PyCQA/bandit-action@v1 + uses: PyCQA/bandit-action@8a1b30610f61f3f792fe7556e888c9d7dffa52de # v1 with: severity: medium confidence: medium diff --git a/.github/workflows/spelling.yaml b/.github/workflows/spelling.yaml index d3a8a4c8b..feaaec021 100644 --- a/.github/workflows/spelling.yaml +++ b/.github/workflows/spelling.yaml @@ -27,7 +27,7 @@ jobs: steps: - name: check-spelling id: spelling - uses: check-spelling/check-spelling@a35147f799f30f8739c33f92222c847214e82e67 # https://github.com/check-spelling/check-spelling/issues/103#issuecomment-4181666219 + uses: check-spelling/check-spelling@cfb6f7e75bbfc89c71eaa30366d0c166f1bd9c8c # v0.0.26 with: suppress_push_for_open_pull_request: ${{ github.actor != 'dependabot[bot]' && 1 }} checkout: true diff --git a/.github/workflows/stale.yaml b/.github/workflows/stale.yaml index 7c8cb0dcf..1f1bc52ab 100644 --- a/.github/workflows/stale.yaml +++ b/.github/workflows/stale.yaml @@ -20,7 +20,7 @@ jobs: actions: write steps: - - uses: actions/stale@v10 + - uses: actions/stale@b5d41d4e1d5dceea10e7104786b73624c18a190f # v10 with: repo-token: ${{ secrets.GITHUB_TOKEN }} days-before-issue-stale: 14 diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 32094eff6..cb6f82414 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -41,14 +41,14 @@ jobs: python-version: ['3.10', '3.13'] steps: - name: Checkout code - uses: actions/checkout@v6 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 - name: Set up test environment variables run: | echo "POSTGRES_TEST_DSN=postgresql+asyncpg://a2a:a2a_password@localhost:5432/a2a_test" >> $GITHUB_ENV echo "MYSQL_TEST_DSN=mysql+aiomysql://a2a:a2a_password@localhost:3306/a2a_test" >> $GITHUB_ENV - name: Install uv for Python ${{ matrix.python-version }} - uses: astral-sh/setup-uv@v7 + uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7 with: python-version: ${{ matrix.python-version }} - name: Add uv to PATH @@ -60,7 +60,7 @@ jobs: # Coverage comparison for PRs (only on Python 3.13 to avoid duplicate work) - name: Checkout Base Branch if: github.event_name == 'pull_request' && matrix.python-version == '3.13' - uses: actions/checkout@v6 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 with: ref: ${{ github.event.pull_request.base.ref || 'main' }} clean: true @@ -73,7 +73,7 @@ jobs: - name: Checkout PR Branch (Restore) if: github.event_name == 'pull_request' && matrix.python-version == '3.13' - uses: actions/checkout@v6 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 with: clean: true @@ -91,7 +91,7 @@ jobs: echo ${{ github.event.pull_request.base.ref || 'main' }} > ./BASE_BRANCH - name: Upload Coverage Artifacts - uses: actions/upload-artifact@v7 + uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7 if: github.event_name == 'pull_request' && matrix.python-version == '3.13' with: name: coverage-data @@ -109,7 +109,7 @@ jobs: run: uv run pytest --cov=a2a --cov-report term --cov-fail-under=88 - name: Upload Artifact (base) - uses: actions/upload-artifact@v7 + uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7 if: github.event_name != 'pull_request' && matrix.python-version == '3.13' with: name: coverage-report diff --git a/.github/workflows/update-a2a-types.yml b/.github/workflows/update-a2a-types.yml index 1c7521144..cb1ece199 100644 --- a/.github/workflows/update-a2a-types.yml +++ b/.github/workflows/update-a2a-types.yml @@ -13,13 +13,13 @@ jobs: pull-requests: write steps: - name: Checkout code - uses: actions/checkout@v6 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 - name: Set up Python - uses: actions/setup-python@v6 + uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6 with: python-version: '3.10' - name: Install uv - uses: astral-sh/setup-uv@v7 + uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7 - name: Configure uv shell run: echo "$HOME/.cargo/bin" >> $GITHUB_PATH - name: Install dependencies (datamodel-code-generator) @@ -34,7 +34,7 @@ jobs: chmod +x scripts/generate_types.sh ./scripts/generate_types.sh "${{ steps.vars.outputs.GENERATED_FILE }}" - name: Install Buf - uses: bufbuild/buf-setup-action@v1 + uses: bufbuild/buf-setup-action@a47c93e0b1648d5651a065437926377d060baa99 # v1.50.0 - name: Run buf generate run: | set -euo pipefail # Exit immediately if a command exits with a non-zero status @@ -43,7 +43,7 @@ jobs: uv run scripts/grpc_gen_post_processor.py echo "Buf generate finished." - name: Create Pull Request with Updates - uses: peter-evans/create-pull-request@v8 + uses: peter-evans/create-pull-request@c0f553fe549906ede9cf27b5156039d195d2ece0 # v8 with: token: ${{ secrets.A2A_BOT_PAT }} committer: a2a-bot From 462eb3cb7b6070c258f5672aa3b0aa59e913037c Mon Sep 17 00:00:00 2001 From: Bartek Wolowiec Date: Tue, 7 Apr 2026 14:32:07 +0200 Subject: [PATCH 03/67] feat: Implementation of DefaultRequestHandlerV2 (#933) This pull request introduces a significant refactoring of the agent execution layer, implementing an ActiveTask system and a DefaultRequestHandlerV2 to better manage task lifecycles, concurrency, and event streaming. Fixes #869 --- src/a2a/server/agent_execution/active_task.py | 629 +++++++ .../agent_execution/active_task_registry.py | 88 + .../server/agent_execution/agent_executor.py | 18 +- src/a2a/server/agent_execution/context.py | 2 +- src/a2a/server/events/event_queue_v2.py | 25 +- src/a2a/server/request_handlers/__init__.py | 9 +- .../default_request_handler.py | 2 +- .../default_request_handler_v2.py | 413 +++++ .../test_client_server_integration.py | 24 +- tests/integration/test_scenarios.py | 1443 +++++++++++++++++ tests/server/agent_execution/__init__.py | 0 .../agent_execution/test_active_task.py | 1088 +++++++++++++ .../test_default_request_handler.py | 4 +- .../test_default_request_handler_v2.py | 1208 ++++++++++++++ 14 files changed, 4932 insertions(+), 21 deletions(-) create mode 100644 src/a2a/server/agent_execution/active_task.py create mode 100644 src/a2a/server/agent_execution/active_task_registry.py create mode 100644 src/a2a/server/request_handlers/default_request_handler_v2.py create mode 100644 tests/integration/test_scenarios.py create mode 100644 tests/server/agent_execution/__init__.py create mode 100644 tests/server/agent_execution/test_active_task.py create mode 100644 tests/server/request_handlers/test_default_request_handler_v2.py diff --git a/src/a2a/server/agent_execution/active_task.py b/src/a2a/server/agent_execution/active_task.py new file mode 100644 index 000000000..f313ca11e --- /dev/null +++ b/src/a2a/server/agent_execution/active_task.py @@ -0,0 +1,629 @@ +# ruff: noqa: TRY301, SLF001 +from __future__ import annotations + +import asyncio +import logging +import uuid + +from typing import TYPE_CHECKING, cast + +from a2a.server.agent_execution.context import RequestContext + + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator, Callable + + from a2a.server.agent_execution.agent_executor import AgentExecutor + from a2a.server.context import ServerCallContext + from a2a.server.tasks.push_notification_sender import ( + PushNotificationSender, + ) + from a2a.server.tasks.task_manager import TaskManager + +from a2a.server.events.event_queue_v2 import ( + AsyncQueue, + Event, + EventQueueSource, + QueueShutDown, + _create_async_queue, +) +from a2a.server.tasks import PushNotificationEvent +from a2a.types.a2a_pb2 import ( + Message, + Task, + TaskState, +) +from a2a.utils.errors import ( + InvalidParamsError, + TaskNotFoundError, +) + + +logger = logging.getLogger(__name__) + + +TERMINAL_TASK_STATES = { + TaskState.TASK_STATE_COMPLETED, + TaskState.TASK_STATE_CANCELED, + TaskState.TASK_STATE_FAILED, + TaskState.TASK_STATE_REJECTED, +} +INTERRUPTED_TASK_STATES = { + TaskState.TASK_STATE_AUTH_REQUIRED, + TaskState.TASK_STATE_INPUT_REQUIRED, +} + + +class _RequestCompleted: + def __init__(self, request_id: uuid.UUID): + self.request_id = request_id + + +class ActiveTask: + """Manages the lifecycle and execution of an active A2A task. + + It coordinates between the agent's execution (the producer), the + persistence and state management (the TaskManager), and the event + distribution to subscribers (the consumer). + + Concurrency Guarantees: + - This class is designed to be highly concurrent. It manages an internal + producer-consumer model using `asyncio.Task`s. + - `self._lock` (asyncio.Lock) ensures mutually exclusive access for critical + lifecycle state changes, such as starting the task, subscribing, and + determining if cleanup is safe to trigger. + + mutation to the observable result state (like `_exception`, + or `_is_finished`) notifies waiting coroutines (like `wait()`). + - `self._is_finished` (asyncio.Event) provides a thread-safe, non-blocking way + for external observers and internal loops to check if the ActiveTask has + permanently ceased execution and closed its queues. + """ + + def __init__( + self, + agent_executor: AgentExecutor, + task_id: str, + task_manager: TaskManager, + push_sender: PushNotificationSender | None = None, + on_cleanup: Callable[[ActiveTask], None] | None = None, + ) -> None: + """Initializes the ActiveTask. + + Args: + agent_executor: The executor to run the agent logic (producer). + task_id: The unique identifier of the task being managed. + task_manager: The manager for task state and database persistence. + push_sender: Optional sender for out-of-band push notifications. + on_cleanup: Optional callback triggered when the task is fully finished + and the last subscriber has disconnected. Used to prune + the task from the ActiveTaskRegistry. + """ + # --- Core Dependencies --- + self._agent_executor = agent_executor + self._task_id = task_id + self._event_queue_agent = EventQueueSource() + self._event_queue_subscribers = EventQueueSource( + create_default_sink=False + ) + self._task_manager = task_manager + self._push_sender = push_sender + self._on_cleanup = on_cleanup + + # --- Synchronization Primitives --- + # `_lock` protects structural lifecycle changes: start(), subscribe() counting, + # and _maybe_cleanup() race conditions. + self._lock = asyncio.Lock() + + # `_request_lock` protects parallel request processing. + self._request_lock = asyncio.Lock() + + # _task_created is set when initial version of task is stored in DB. + self._task_created = asyncio.Event() + + # `_is_finished` is set EXACTLY ONCE when the consumer loop exits, signifying + # the absolute end of the task's active lifecycle. + self._is_finished = asyncio.Event() + + # --- Lifecycle State --- + # The background task executing the agent logic. + self._producer_task: asyncio.Task[None] | None = None + # The background task reading from _event_queue and updating the DB. + self._consumer_task: asyncio.Task[None] | None = None + + # Tracks how many active SSE/gRPC streams are currently tailing this task. + # Protected by `_lock`. + self._reference_count = 0 + + # Holds any fatal exception that crashed the producer or consumer. + # TODO: Synchronize exception handling (ideally mix it in the queue). + self._exception: Exception | None = None + + # Queue for incoming requests + self._request_queue: AsyncQueue[tuple[RequestContext, uuid.UUID]] = ( + _create_async_queue() + ) + + @property + def task_id(self) -> str: + """The ID of the task.""" + return self._task_id + + async def enqueue_request( + self, request_context: RequestContext + ) -> uuid.UUID: + """Enqueues a request for the active task to process.""" + request_id = uuid.uuid4() + await self._request_queue.put((request_context, request_id)) + return request_id + + async def start( + self, + call_context: ServerCallContext, + create_task_if_missing: bool = False, + ) -> None: + """Starts the active task background processes. + + Concurrency Guarantee: + Uses `self._lock` to ensure the producer and consumer tasks are strictly + singleton instances for the lifetime of this ActiveTask. + """ + logger.debug('ActiveTask[%s]: Starting', self._task_id) + async with self._lock: + if self._is_finished.is_set(): + raise InvalidParamsError( + f'Task {self._task_id} is already completed. Cannot start it again.' + ) + + if ( + self._producer_task is not None + and self._consumer_task is not None + ): + logger.debug( + 'ActiveTask[%s]: Already started, ignoring start request', + self._task_id, + ) + return + + logger.debug( + 'ActiveTask[%s]: Executing setup (call_context: %s, create_task_if_missing: %s)', + self._task_id, + call_context, + create_task_if_missing, + ) + try: + self._task_manager._call_context = call_context + task = await self._task_manager.get_task() + logger.debug('TASK (start): %s', task) + + if task: + if task.status.state in TERMINAL_TASK_STATES: + raise InvalidParamsError( + message=f'Task {task.id} is in terminal state: {task.status.state}' + ) + else: + if not create_task_if_missing: + raise TaskNotFoundError + + # New task. Create and save it so it's not "missing" if queried immediately + # (especially important for return_immediately=True) + if self._task_manager.context_id is None: + raise ValueError('Context ID is required for new tasks') + task = self._task_manager._init_task_obj( + self._task_id, + self._task_manager.context_id, + ) + await self._task_manager.save_task_event(task) + if self._push_sender: + await self._push_sender.send_notification(task.id, task) + + except Exception: + logger.debug( + 'ActiveTask[%s]: Setup failed, cleaning up', + self._task_id, + ) + self._is_finished.set() + if self._reference_count == 0 and self._on_cleanup: + self._on_cleanup(self) + raise + + # Spawn the background tasks that drive the lifecycle. + self._reference_count += 1 + self._producer_task = asyncio.create_task( + self._run_producer(), name=f'producer:{self._task_id}' + ) + self._consumer_task = asyncio.create_task( + self._run_consumer(), name=f'consumer:{self._task_id}' + ) + logger.debug( + 'ActiveTask[%s]: Background tasks created', self._task_id + ) + + async def _run_producer(self) -> None: + """Executes the agent logic. + + This method encapsulates the external `AgentExecutor.execute` call. It ensures + that regardless of how the agent finishes (success, unhandled exception, or + cancellation), the underlying `_event_queue` is safely closed, which signals + the consumer to wind down. + + Concurrency Guarantee: + Runs as a detached asyncio.Task. Safe to cancel. + """ + logger.debug('Producer[%s]: Started', self._task_id) + try: + try: + try: + while True: + ( + request_context, + request_id, + ) = await self._request_queue.get() + await self._request_lock.acquire() + # TODO: Should we create task manager every time? + self._task_manager._call_context = ( + request_context.call_context + ) + request_context.current_task = ( + await self._task_manager.get_task() + ) + + message = request_context.message + if message: + request_context.current_task = ( + self._task_manager.update_with_message( + message, + cast('Task', request_context.current_task), + ) + ) + await self._task_manager.save_task_event( + request_context.current_task + ) + self._task_created.set() + logger.debug( + 'Producer[%s]: Executing agent task %s', + self._task_id, + request_context.current_task, + ) + + try: + await self._agent_executor.execute( + request_context, self._event_queue_agent + ) + logger.debug( + 'Producer[%s]: Execution finished successfully', + self._task_id, + ) + except Exception as e: + async with self._lock: + if self._exception is None: + self._exception = e + raise + finally: + logger.debug( + 'Producer[%s]: Enqueuing request completed event', + self._task_id, + ) + # TODO: Hide from external consumers + await self._event_queue_agent.enqueue_event( + cast('Event', _RequestCompleted(request_id)) + ) + self._request_queue.task_done() + except QueueShutDown: + logger.debug( + 'Producer[%s]: Request queue shut down', self._task_id + ) + except asyncio.CancelledError: + logger.debug('Producer[%s]: Cancelled', self._task_id) + raise + except Exception as e: + logger.exception('Producer[%s]: Failed', self._task_id) + async with self._lock: + if self._exception is None: + self._exception = e + finally: + self._request_queue.shutdown(immediate=True) + await self._event_queue_agent.close(immediate=False) + await self._event_queue_subscribers.close(immediate=False) + finally: + logger.debug('Producer[%s]: Completed', self._task_id) + + async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912 + """Consumes events from the agent and updates system state. + + This continuous loop dequeues events emitted by the producer, updates the + database via `TaskManager`, and intercepts critical task states (e.g., + INPUT_REQUIRED, COMPLETED, FAILED) to cache the final result. + + Concurrency Guarantee: + Runs as a detached asyncio.Task. The loop ends gracefully when the producer + closes the queue (raising `QueueShutDown`). Upon termination, it formally sets + `_is_finished`, unblocking all global subscribers and wait() calls. + """ + logger.debug('Consumer[%s]: Started', self._task_id) + try: + try: + try: + while True: + # Dequeue event. This raises QueueShutDown when finished. + logger.debug( + 'Consumer[%s]: Waiting for event', + self._task_id, + ) + event = await self._event_queue_agent.dequeue_event() + logger.debug( + 'Consumer[%s]: Dequeued event %s', + self._task_id, + type(event).__name__, + ) + + try: + if isinstance(event, _RequestCompleted): + logger.debug( + 'Consumer[%s]: Request completed', + self._task_id, + ) + self._request_lock.release() + elif isinstance(event, Message): + logger.debug( + 'Consumer[%s]: Setting result to Message: %s', + self._task_id, + event, + ) + else: + # Save structural events (like TaskStatusUpdate) to DB. + # TODO: Create task manager every time ? + self._task_manager.context_id = event.context_id + await self._task_manager.process(event) + + # Check for AUTH_REQUIRED or INPUT_REQUIRED or TERMINAL states + res = await self._task_manager.get_task() + is_interrupted = ( + res + and res.status.state + in INTERRUPTED_TASK_STATES + ) + is_terminal = ( + res + and res.status.state in TERMINAL_TASK_STATES + ) + + # If we hit a breakpoint or terminal state, lock in the result. + if (is_interrupted or is_terminal) and res: + logger.debug( + 'Consumer[%s]: Setting first result as Task (state=%s)', + self._task_id, + res.status.state, + ) + + if is_terminal: + logger.debug( + 'Consumer[%s]: Reached terminal state %s', + self._task_id, + res.status.state if res else 'unknown', + ) + if not self._is_finished.is_set(): + async with self._lock: + # TODO: what about _reference_count when task is failing? + self._reference_count -= 1 + # _maybe_cleanup() is called in finally block. + + # Terminate the ActiveTask globally. + self._is_finished.set() + self._request_queue.shutdown(immediate=True) + + if is_interrupted: + logger.debug( + 'Consumer[%s]: Interrupted with state %s', + self._task_id, + res.status.state if res else 'unknown', + ) + + if ( + self._push_sender + and self._task_id + and isinstance(event, PushNotificationEvent) + ): + logger.debug( + 'Consumer[%s]: Sending push notification', + self._task_id, + ) + await self._push_sender.send_notification( + self._task_id, event + ) + finally: + await self._event_queue_subscribers.enqueue_event( + event + ) + self._event_queue_agent.task_done() + except QueueShutDown: + logger.debug( + 'Consumer[%s]: Event queue shut down', self._task_id + ) + except Exception as e: + logger.exception('Consumer[%s]: Failed', self._task_id) + async with self._lock: + if self._exception is None: + self._exception = e + finally: + # The consumer is dead. The ActiveTask is permanently finished. + self._is_finished.set() + self._request_queue.shutdown(immediate=True) + + logger.debug('Consumer[%s]: Finishing', self._task_id) + await self._maybe_cleanup() + finally: + logger.debug('Consumer[%s]: Completed', self._task_id) + + async def subscribe( # noqa: PLR0912, PLR0915 + self, + *, + request: RequestContext | None = None, + include_initial_task: bool = False, + ) -> AsyncGenerator[Event, None]: + """Creates a queue tap and yields events as they are produced. + + Concurrency Guarantee: + Uses `_lock` to safely increment and decrement `_reference_count`. + Safely detaches its queue tap when the client disconnects or the task finishes, + triggering `_maybe_cleanup()` to potentially garbage collect the ActiveTask. + """ + logger.debug('Subscribe[%s]: New subscriber', self._task_id) + + async with self._lock: + if self._exception: + logger.debug( + 'Subscribe[%s]: Failed, exception already set', + self._task_id, + ) + raise self._exception + if self._is_finished.is_set(): + raise InvalidParamsError( + f'Task {self._task_id} is already completed.' + ) + self._reference_count += 1 + logger.debug( + 'Subscribe[%s]: Subscribers count: %d', + self._task_id, + self._reference_count, + ) + + tapped_queue = await self._event_queue_subscribers.tap() + request_id = await self.enqueue_request(request) if request else None + + try: + if include_initial_task: + logger.debug( + 'Subscribe[%s]: Including initial task', + self._task_id, + ) + task = await self.get_task() + yield task + + while True: + try: + if self._exception: + raise self._exception + + # Wait for next event or task completion + try: + event = await asyncio.wait_for( + tapped_queue.dequeue_event(), timeout=0.1 + ) + if self._exception: + raise self._exception from None + if isinstance(event, _RequestCompleted): + if ( + request_id is not None + and event.request_id == request_id + ): + logger.debug( + 'Subscriber[%s]: Request completed', + self._task_id, + ) + return + continue + except (asyncio.TimeoutError, TimeoutError): + if self._is_finished.is_set(): + if self._exception: + raise self._exception from None + break + continue + + try: + yield event + finally: + tapped_queue.task_done() + except (QueueShutDown, asyncio.CancelledError): + if self._exception: + raise self._exception from None + break + finally: + logger.debug('Subscribe[%s]: Unsubscribing', self._task_id) + await tapped_queue.close(immediate=True) + async with self._lock: + self._reference_count -= 1 + # Evaluate if this was the last subscriber on a finished task. + await self._maybe_cleanup() + + async def cancel(self, call_context: ServerCallContext) -> Task | Message: + """Cancels the running active task. + + Concurrency Guarantee: + Uses `_lock` to ensure we don't attempt to cancel a producer that is + already winding down or hasn't started. It fires the cancellation signal + and blocks until the consumer processes the cancellation events. + """ + logger.debug('Cancel[%s]: Cancelling task', self._task_id) + + # TODO: Conflicts with call_context on the pending request. + self._task_manager._call_context = call_context + + task = await self.get_task() + request_context = RequestContext( + call_context=call_context, + task_id=self._task_id, + context_id=task.context_id, + task=task, + ) + + async with self._lock: + if not self._is_finished.is_set() and self._producer_task: + logger.debug( + 'Cancel[%s]: Cancelling producer task', self._task_id + ) + self._producer_task.cancel() + try: + await self._agent_executor.cancel( + request_context, self._event_queue_agent + ) + except Exception as e: + logger.exception( + 'Cancel[%s]: Agent cancel failed', self._task_id + ) + if not self._exception: + self._exception = e + + raise + else: + logger.debug( + 'Cancel[%s]: Task already finished [%s] or producer not started [%s], not cancelling', + self._task_id, + self._is_finished.is_set(), + self._producer_task, + ) + + await self._is_finished.wait() + return await self.get_task() + + async def _maybe_cleanup(self) -> None: + """Triggers cleanup if task is finished and has no subscribers. + + Concurrency Guarantee: + Protected by `_lock` to prevent race conditions where a new subscriber + attaches at the exact moment the task decides to garbage collect itself. + """ + async with self._lock: + logger.debug( + 'Cleanup[%s]: Subscribers count: %d is_finished: %s', + self._task_id, + self._reference_count, + self._is_finished.is_set(), + ) + + if ( + self._is_finished.is_set() + and self._reference_count == 0 + and self._on_cleanup + ): + logger.debug('Cleanup[%s]: Triggering cleanup', self._task_id) + self._on_cleanup(self) + + async def get_task(self) -> Task: + """Get task from db.""" + # TODO: THERE IS ZERO CONCURRENCY SAFETY HERE (Except inital task creation). + await self._task_created.wait() + task = await self._task_manager.get_task() + if not task: + raise RuntimeError('Task should have been created') + return task diff --git a/src/a2a/server/agent_execution/active_task_registry.py b/src/a2a/server/agent_execution/active_task_registry.py new file mode 100644 index 000000000..9c1299ab3 --- /dev/null +++ b/src/a2a/server/agent_execution/active_task_registry.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +import asyncio +import logging + +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + from a2a.server.agent_execution.agent_executor import AgentExecutor + from a2a.server.context import ServerCallContext + from a2a.server.tasks.push_notification_sender import PushNotificationSender + from a2a.server.tasks.task_store import TaskStore + +from a2a.server.agent_execution.active_task import ActiveTask +from a2a.server.tasks.task_manager import TaskManager + + +logger = logging.getLogger(__name__) + + +class ActiveTaskRegistry: + """A registry for active ActiveTask instances.""" + + def __init__( + self, + agent_executor: AgentExecutor, + task_store: TaskStore, + push_sender: PushNotificationSender | None = None, + ): + self._agent_executor = agent_executor + self._task_store = task_store + self._push_sender = push_sender + self._active_tasks: dict[str, ActiveTask] = {} + self._lock = asyncio.Lock() + self._cleanup_tasks: set[asyncio.Task[None]] = set() + + async def get_or_create( + self, + task_id: str, + call_context: ServerCallContext, + context_id: str | None = None, + create_task_if_missing: bool = False, + ) -> ActiveTask: + """Retrieves an existing ActiveTask or creates a new one.""" + async with self._lock: + if task_id in self._active_tasks: + return self._active_tasks[task_id] + + task_manager = TaskManager( + task_id=task_id, + context_id=context_id, + task_store=self._task_store, + initial_message=None, + context=call_context, + ) + + active_task = ActiveTask( + agent_executor=self._agent_executor, + task_id=task_id, + task_manager=task_manager, + push_sender=self._push_sender, + on_cleanup=self._on_active_task_cleanup, + ) + self._active_tasks[task_id] = active_task + + await active_task.start( + call_context=call_context, + create_task_if_missing=create_task_if_missing, + ) + return active_task + + def _on_active_task_cleanup(self, active_task: ActiveTask) -> None: + """Called by ActiveTask when it's finished and has no subscribers.""" + logger.debug('Active task %s cleanup scheduled', active_task.task_id) + task = asyncio.create_task(self._remove_task(active_task.task_id)) + self._cleanup_tasks.add(task) + task.add_done_callback(self._cleanup_tasks.discard) + + async def _remove_task(self, task_id: str) -> None: + async with self._lock: + self._active_tasks.pop(task_id, None) + logger.debug('Removed active task for %s from registry', task_id) + + async def get(self, task_id: str) -> ActiveTask | None: + """Retrieves an existing task.""" + async with self._lock: + return self._active_tasks.get(task_id) diff --git a/src/a2a/server/agent_execution/agent_executor.py b/src/a2a/server/agent_execution/agent_executor.py index e03232b35..764bef4b2 100644 --- a/src/a2a/server/agent_execution/agent_executor.py +++ b/src/a2a/server/agent_execution/agent_executor.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from a2a.server.agent_execution.context import RequestContext -from a2a.server.events.event_queue import EventQueue +from a2a.server.events.event_queue_v2 import EventQueue class AgentExecutor(ABC): @@ -23,6 +23,18 @@ async def execute( return once the agent's execution for this request is complete or yields control (e.g., enters an input-required state). + TODO: Document request lifecycle and AgentExecutor responsibilities: + - Should not close the event_queue. + - Guarantee single execution per request (no concurrent execution). + - Throwing exception will result in TaskState.TASK_STATE_ERROR (CHECK!) + - Once call is completed it should not access context or event_queue + - Before completing the call it SHOULD update task status to terminal or interrupted state. + - Explain AUTH_REQUIRED workflow. + - Explain INPUT_REQUIRED workflow. + - Explain how cancelation work (executor task will be canceled, cancel() is called, order of calls, etc) + - Explain if execute can wait for cancel and if cancel can wait for execute. + - Explain behaviour of streaming / not-immediate when execute() returns in active state. + Args: context: The request context containing the message, task ID, etc. event_queue: The queue to publish events to. @@ -38,6 +50,10 @@ async def cancel( in the context and publish a `TaskStatusUpdateEvent` with state `TaskState.TASK_STATE_CANCELED` to the `event_queue`. + TODO: Document cancelation workflow. + - What if TaskState.TASK_STATE_CANCELED is not set by cancel() ? + - How it can interact with execute() ? + Args: context: The request context containing the task ID to cancel. event_queue: The queue to publish the cancellation status update to. diff --git a/src/a2a/server/agent_execution/context.py b/src/a2a/server/agent_execution/context.py index 91284f37c..1feefb1df 100644 --- a/src/a2a/server/agent_execution/context.py +++ b/src/a2a/server/agent_execution/context.py @@ -120,7 +120,7 @@ def current_task(self) -> Task | None: return self._current_task @current_task.setter - def current_task(self, task: Task) -> None: + def current_task(self, task: Task | None) -> None: """Sets the current task object.""" self._current_task = task diff --git a/src/a2a/server/events/event_queue_v2.py b/src/a2a/server/events/event_queue_v2.py index 5642bfbc6..de12c21d1 100644 --- a/src/a2a/server/events/event_queue_v2.py +++ b/src/a2a/server/events/event_queue_v2.py @@ -28,7 +28,11 @@ class EventQueueSource(EventQueue): in `_incoming_queue` and distributed to all child Sinks by a background dispatcher task. """ - def __init__(self, max_queue_size: int = DEFAULT_MAX_QUEUE_SIZE) -> None: + def __init__( + self, + max_queue_size: int = DEFAULT_MAX_QUEUE_SIZE, + create_default_sink: bool = True, + ) -> None: """Initializes the EventQueueSource.""" if max_queue_size <= 0: raise ValueError('max_queue_size must be greater than 0') @@ -41,10 +45,15 @@ def __init__(self, max_queue_size: int = DEFAULT_MAX_QUEUE_SIZE) -> None: self._is_closed = False # Internal sink for backward compatibility - self._default_sink = EventQueueSink( - parent=self, max_queue_size=max_queue_size - ) - self._sinks.add(self._default_sink) + self._default_sink: EventQueueSink | None + if create_default_sink: + self._default_sink = EventQueueSink( + parent=self, max_queue_size=max_queue_size + ) + self._sinks.add(self._default_sink) + else: + self._default_sink = None + self._dispatcher_task = asyncio.create_task(self._dispatch_loop()) self._dispatcher_task_expected_to_cancel = False @@ -54,6 +63,8 @@ def __init__(self, max_queue_size: int = DEFAULT_MAX_QUEUE_SIZE) -> None: @property def queue(self) -> AsyncQueue[Event]: """Returns the underlying asyncio.Queue of the default sink.""" + if self._default_sink is None: + raise ValueError('No default sink available.') return self._default_sink.queue async def _dispatch_loop(self) -> None: @@ -183,10 +194,14 @@ async def enqueue_event(self, event: Event) -> None: async def dequeue_event(self) -> Event: """Dequeues an event from the default internal sink queue.""" + if self._default_sink is None: + raise ValueError('No default sink available.') return await self._default_sink.dequeue_event() def task_done(self) -> None: """Signals that a formerly enqueued task is complete via the default internal sink queue.""" + if self._default_sink is None: + raise ValueError('No default sink available.') self._default_sink.task_done() async def close(self, immediate: bool = False) -> None: diff --git a/src/a2a/server/request_handlers/__init__.py b/src/a2a/server/request_handlers/__init__.py index 194e81a45..34654cb58 100644 --- a/src/a2a/server/request_handlers/__init__.py +++ b/src/a2a/server/request_handlers/__init__.py @@ -3,7 +3,10 @@ import logging from a2a.server.request_handlers.default_request_handler import ( - DefaultRequestHandler, + LegacyRequestHandler, +) +from a2a.server.request_handlers.default_request_handler_v2 import ( + DefaultRequestHandlerV2, ) from a2a.server.request_handlers.request_handler import ( RequestHandler, @@ -40,11 +43,15 @@ def __init__(self, *args, **kwargs): ) from _original_error +DefaultRequestHandler = DefaultRequestHandlerV2 + __all__ = [ 'DefaultGrpcServerCallContextBuilder', 'DefaultRequestHandler', + 'DefaultRequestHandlerV2', 'GrpcHandler', 'GrpcServerCallContextBuilder', + 'LegacyRequestHandler', 'RequestHandler', 'build_error_response', 'prepare_response_object', diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index 67b51e248..ba1f08caa 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -74,7 +74,7 @@ @trace_class(kind=SpanKind.SERVER) -class DefaultRequestHandler(RequestHandler): +class LegacyRequestHandler(RequestHandler): """Default request handler for all incoming requests. This handler provides default implementations for all A2A JSON-RPC methods, diff --git a/src/a2a/server/request_handlers/default_request_handler_v2.py b/src/a2a/server/request_handlers/default_request_handler_v2.py new file mode 100644 index 000000000..e05593bec --- /dev/null +++ b/src/a2a/server/request_handlers/default_request_handler_v2.py @@ -0,0 +1,413 @@ +from __future__ import annotations + +import asyncio # noqa: TC003 +import logging + +from typing import TYPE_CHECKING, Any, cast + +from a2a.server.agent_execution import ( + AgentExecutor, + RequestContext, + RequestContextBuilder, + SimpleRequestContextBuilder, +) +from a2a.server.agent_execution.active_task import ( + INTERRUPTED_TASK_STATES, + TERMINAL_TASK_STATES, +) +from a2a.server.agent_execution.active_task_registry import ActiveTaskRegistry +from a2a.server.request_handlers.request_handler import ( + RequestHandler, + validate_request_params, +) +from a2a.types.a2a_pb2 import ( + CancelTaskRequest, + DeleteTaskPushNotificationConfigRequest, + GetTaskPushNotificationConfigRequest, + GetTaskRequest, + ListTaskPushNotificationConfigsRequest, + ListTaskPushNotificationConfigsResponse, + ListTasksRequest, + ListTasksResponse, + Message, + SendMessageRequest, + SubscribeToTaskRequest, + Task, + TaskPushNotificationConfig, + TaskStatusUpdateEvent, +) +from a2a.utils.errors import ( + InternalError, + InvalidParamsError, + TaskNotCancelableError, + TaskNotFoundError, + UnsupportedOperationError, +) +from a2a.utils.task import ( + apply_history_length, + validate_history_length, + validate_page_size, +) +from a2a.utils.telemetry import SpanKind, trace_class + + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator + + from a2a.server.agent_execution.active_task import ActiveTask + from a2a.server.context import ServerCallContext + from a2a.server.events import Event + from a2a.server.tasks import ( + PushNotificationConfigStore, + PushNotificationSender, + TaskStore, + ) + + +logger = logging.getLogger(__name__) + + +# TODO: cleanup context_id management + + +@trace_class(kind=SpanKind.SERVER) +class DefaultRequestHandlerV2(RequestHandler): + """Default request handler for all incoming requests.""" + + _background_tasks: set[asyncio.Task] + + def __init__( # noqa: PLR0913 + self, + agent_executor: AgentExecutor, + task_store: TaskStore, + queue_manager: Any + | None = None, # Kept for backward compat in signature + push_config_store: PushNotificationConfigStore | None = None, + push_sender: PushNotificationSender | None = None, + request_context_builder: RequestContextBuilder | None = None, + ) -> None: + self.agent_executor = agent_executor + self.task_store = task_store + self._push_config_store = push_config_store + self._push_sender = push_sender + self._request_context_builder = ( + request_context_builder + or SimpleRequestContextBuilder( + should_populate_referred_tasks=False, task_store=self.task_store + ) + ) + self._active_task_registry = ActiveTaskRegistry( + agent_executor=self.agent_executor, + task_store=self.task_store, + push_sender=self._push_sender, + ) + self._background_tasks = set() + + @validate_request_params + async def on_get_task( # noqa: D102 + self, + params: GetTaskRequest, + context: ServerCallContext, + ) -> Task | None: + validate_history_length(params) + + task_id = params.id + task: Task | None = await self.task_store.get(task_id, context) + if not task: + raise TaskNotFoundError + + return apply_history_length(task, params) + + @validate_request_params + async def on_list_tasks( # noqa: D102 + self, + params: ListTasksRequest, + context: ServerCallContext, + ) -> ListTasksResponse: + validate_history_length(params) + if params.HasField('page_size'): + validate_page_size(params.page_size) + + page = await self.task_store.list(params, context) + for task in page.tasks: + if not params.include_artifacts: + task.ClearField('artifacts') + + updated_task = apply_history_length(task, params) + if updated_task is not task: + task.CopyFrom(updated_task) + + return page + + @validate_request_params + async def on_cancel_task( # noqa: D102 + self, + params: CancelTaskRequest, + context: ServerCallContext, + ) -> Task | None: + task_id = params.id + + try: + active_task = await self._active_task_registry.get_or_create( + task_id, call_context=context, create_task_if_missing=False + ) + result = await active_task.cancel(context) + except InvalidParamsError as e: + raise TaskNotCancelableError from e + + if isinstance(result, Message): + raise InternalError( + message='Cancellation returned a message instead of a task.' + ) + + return result + + def _validate_task_id_match(self, task_id: str, event_task_id: str) -> None: + if task_id != event_task_id: + logger.error( + 'Agent generated task_id=%s does not match the RequestContext task_id=%s.', + event_task_id, + task_id, + ) + raise InternalError(message='Task ID mismatch in agent response') + + async def _setup_active_task( + self, + params: SendMessageRequest, + call_context: ServerCallContext, + ) -> tuple[ActiveTask, RequestContext]: + validate_history_length(params.configuration) + + original_task_id = params.message.task_id or None + original_context_id = params.message.context_id or None + + if original_task_id: + task = await self.task_store.get(original_task_id, call_context) + if not task: + raise TaskNotFoundError(f'Task {original_task_id} not found') + + # Build context to resolve or generate missing IDs + request_context = await self._request_context_builder.build( + params=params, + task_id=original_task_id, + context_id=original_context_id, + # We will get the task when we have to process the request to avoid concurrent read/write issues. + task=None, + context=call_context, + ) + + task_id = cast('str', request_context.task_id) + context_id = cast('str', request_context.context_id) + + if ( + self._push_config_store + and params.configuration + and params.configuration.task_push_notification_config + ): + await self._push_config_store.set_info( + task_id, + params.configuration.task_push_notification_config, + call_context, + ) + + active_task = await self._active_task_registry.get_or_create( + task_id, + context_id=context_id, + call_context=call_context, + create_task_if_missing=True, + ) + + return active_task, request_context + + @validate_request_params + async def on_message_send( # noqa: D102 + self, + params: SendMessageRequest, + context: ServerCallContext, + ) -> Message | Task: + active_task, request_context = await self._setup_active_task( + params, context + ) + + if params.configuration and params.configuration.return_immediately: + await active_task.enqueue_request(request_context) + + task = await active_task.get_task() + if params.configuration: + task = apply_history_length(task, params.configuration) + return task + + try: + result_states = TERMINAL_TASK_STATES | INTERRUPTED_TASK_STATES + + result = None + async for event in active_task.subscribe(request=request_context): + logger.debug( + 'Processing[%s] event [%s] %s', + request_context.task_id, + type(event).__name__, + event, + ) + if isinstance(event, Message) or ( + isinstance(event, Task) + and event.status.state in result_states + ): + result = event + break + if ( + isinstance(event, TaskStatusUpdateEvent) + and event.status.state in result_states + ): + result = await self.task_store.get(event.task_id, context) + break + + if result is None: + logger.debug( + 'Missing result for task %s', request_context.task_id + ) + result = await active_task.get_task() + + logger.debug( + 'Processing[%s] result: %s', request_context.task_id, result + ) + + except Exception: + logger.exception('Agent execution failed') + raise + + if isinstance(result, Task): + self._validate_task_id_match( + cast('str', request_context.task_id), result.id + ) + if params.configuration: + result = apply_history_length(result, params.configuration) + + return result + + # TODO: Unify with on_message_send + @validate_request_params + async def on_message_send_stream( # noqa: D102 + self, + params: SendMessageRequest, + context: ServerCallContext, + ) -> AsyncGenerator[Event, None]: + active_task, request_context = await self._setup_active_task( + params, context + ) + + include_initial_task = bool( + params.configuration and params.configuration.return_immediately + ) + + task_id = cast('str', request_context.task_id) + + async for event in active_task.subscribe( + request=request_context, include_initial_task=include_initial_task + ): + if isinstance(event, Task): + self._validate_task_id_match(task_id, event.id) + logger.debug('Sending event [%s] %s', type(event).__name__, event) + yield event + + @validate_request_params + async def on_create_task_push_notification_config( # noqa: D102 + self, + params: TaskPushNotificationConfig, + context: ServerCallContext, + ) -> TaskPushNotificationConfig: + if not self._push_config_store: + raise UnsupportedOperationError + + task_id = params.task_id + task: Task | None = await self.task_store.get(task_id, context) + if not task: + raise TaskNotFoundError + + await self._push_config_store.set_info( + task_id, + params, + context, + ) + + return params + + @validate_request_params + async def on_get_task_push_notification_config( # noqa: D102 + self, + params: GetTaskPushNotificationConfigRequest, + context: ServerCallContext, + ) -> TaskPushNotificationConfig: + if not self._push_config_store: + raise UnsupportedOperationError + + task_id = params.task_id + config_id = params.id + task: Task | None = await self.task_store.get(task_id, context) + if not task: + raise TaskNotFoundError + + push_notification_configs: list[TaskPushNotificationConfig] = ( + await self._push_config_store.get_info(task_id, context) or [] + ) + + for config in push_notification_configs: + if config.id == config_id: + return config + + raise InternalError(message='Push notification config not found') + + @validate_request_params + async def on_subscribe_to_task( # noqa: D102 + self, + params: SubscribeToTaskRequest, + context: ServerCallContext, + ) -> AsyncGenerator[Event, None]: + task_id = params.id + + active_task = await self._active_task_registry.get_or_create( + task_id, + call_context=context, + create_task_if_missing=False, + ) + + async for event in active_task.subscribe(include_initial_task=True): + yield event + + @validate_request_params + async def on_list_task_push_notification_configs( # noqa: D102 + self, + params: ListTaskPushNotificationConfigsRequest, + context: ServerCallContext, + ) -> ListTaskPushNotificationConfigsResponse: + if not self._push_config_store: + raise UnsupportedOperationError + + task_id = params.task_id + task: Task | None = await self.task_store.get(task_id, context) + if not task: + raise TaskNotFoundError + + push_notification_config_list = await self._push_config_store.get_info( + task_id, context + ) + + return ListTaskPushNotificationConfigsResponse( + configs=push_notification_config_list + ) + + @validate_request_params + async def on_delete_task_push_notification_config( # noqa: D102 + self, + params: DeleteTaskPushNotificationConfigRequest, + context: ServerCallContext, + ) -> None: + if not self._push_config_store: + raise UnsupportedOperationError + + task_id = params.task_id + config_id = params.id + task: Task | None = await self.task_store.get(task_id, context) + if not task: + raise TaskNotFoundError + + await self._push_config_store.delete_info(task_id, context, config_id) diff --git a/tests/integration/test_client_server_integration.py b/tests/integration/test_client_server_integration.py index e00b53c02..59d9995c2 100644 --- a/tests/integration/test_client_server_integration.py +++ b/tests/integration/test_client_server_integration.py @@ -1,4 +1,5 @@ import asyncio + from collections.abc import AsyncGenerator from typing import Any, NamedTuple from unittest.mock import ANY, AsyncMock, patch @@ -7,9 +8,11 @@ import httpx import pytest import pytest_asyncio + from cryptography.hazmat.primitives.asymmetric import ec from google.protobuf.json_format import MessageToDict from google.protobuf.timestamp_pb2 import Timestamp +from starlette.applications import Starlette from a2a.client import Client, ClientConfig from a2a.client.base_client import BaseClient @@ -21,17 +24,16 @@ with_a2a_extensions, ) from a2a.client.transports import JsonRpcTransport, RestTransport -from starlette.applications import Starlette # Compat v0.3 imports for dedicated tests -from a2a.compat.v0_3 import a2a_v0_3_pb2, a2a_v0_3_pb2_grpc +from a2a.compat.v0_3 import a2a_v0_3_pb2_grpc from a2a.compat.v0_3.grpc_handler import CompatGrpcHandler +from a2a.server.request_handlers import GrpcHandler, RequestHandler from a2a.server.routes import ( create_agent_card_routes, create_jsonrpc_routes, create_rest_routes, ) -from a2a.server.request_handlers import GrpcHandler, RequestHandler from a2a.types import a2a_pb2_grpc from a2a.types.a2a_pb2 import ( AgentCapabilities, @@ -66,11 +68,7 @@ ContentTypeNotSupportedError, ExtendedAgentCardNotConfiguredError, ExtensionSupportRequiredError, - InternalError, InvalidAgentResponseError, - InvalidParamsError, - InvalidRequestError, - MethodNotFoundError, PushNotificationNotSupportedError, TaskNotCancelableError, TaskNotFoundError, @@ -82,6 +80,7 @@ create_signature_verifier, ) + # --- Test Constants --- TASK_FROM_STREAM = Task( @@ -347,7 +346,10 @@ async def grpc_server_and_handler( servicer = GrpcHandler(agent_card, mock_request_handler) a2a_pb2_grpc.add_A2AServiceServicer_to_server(servicer, server) await server.start() - yield server_address, mock_request_handler + try: + yield server_address, mock_request_handler + finally: + await server.stop(None) @pytest_asyncio.fixture @@ -1101,7 +1103,7 @@ async def test_validate_version_unsupported(http_transport_setups) -> None: params = GetTaskRequest(id=GET_TASK_RESPONSE.id) - with pytest.raises(VersionNotSupportedError) as exc_info: + with pytest.raises(VersionNotSupportedError): await client.get_task(request=params, context=context) await client.close() @@ -1118,7 +1120,7 @@ async def test_validate_decorator_push_notifications_disabled( params = TaskPushNotificationConfig(task_id='123') - with pytest.raises(UnsupportedOperationError) as exc_info: + with pytest.raises(UnsupportedOperationError): await client.create_task_push_notification_config(request=params) await client.close() @@ -1140,7 +1142,7 @@ async def test_validate_streaming_disabled( stream = transport.send_message_streaming(request=params) - with pytest.raises(UnsupportedOperationError) as exc_info: + with pytest.raises(UnsupportedOperationError): async for _ in stream: pass diff --git a/tests/integration/test_scenarios.py b/tests/integration/test_scenarios.py new file mode 100644 index 000000000..94774e29a --- /dev/null +++ b/tests/integration/test_scenarios.py @@ -0,0 +1,1443 @@ +import asyncio +import collections +import logging + +from typing import Any + +import grpc +import pytest +import pytest_asyncio + +from a2a.auth.user import User +from a2a.client.client import ClientConfig +from a2a.client.client_factory import ClientFactory +from a2a.client.errors import A2AClientError +from a2a.server.agent_execution import AgentExecutor, RequestContext +from a2a.server.context import ServerCallContext +from a2a.server.events import EventQueue +from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager +from a2a.server.request_handlers import DefaultRequestHandlerV2, GrpcHandler +from a2a.server.request_handlers.default_request_handler import ( + LegacyRequestHandler, +) +from a2a.server.request_handlers import GrpcServerCallContextBuilder +from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore +from a2a.types import a2a_pb2_grpc +from a2a.types.a2a_pb2 import ( + AgentCapabilities, + AgentCard, + AgentInterface, + Artifact, + CancelTaskRequest, + GetTaskRequest, + ListTasksRequest, + Message, + Part, + Role, + SendMessageConfiguration, + SendMessageRequest, + SubscribeToTaskRequest, + Task, + TaskArtifactUpdateEvent, + TaskState, + TaskStatus, + TaskStatusUpdateEvent, +) +from a2a.utils import TransportProtocol +from a2a.utils.errors import ( + InvalidParamsError, + TaskNotCancelableError, + TaskNotFoundError, +) + + +logger = logging.getLogger(__name__) + + +async def wait_for_state( + client: Any, + task_id: str, + expected_states: set[TaskState.ValueType], + timeout: float = 1.0, +) -> None: + """Wait for the task to reach one of the expected states.""" + start_time = asyncio.get_event_loop().time() + while True: + task = await client.get_task(GetTaskRequest(id=task_id)) + if task.status.state in expected_states: + return + + if asyncio.get_event_loop().time() - start_time > timeout: + raise TimeoutError( + f'Task {task_id} did not reach expected states {expected_states} within {timeout}s. ' + f'Current state: {task.status.state}' + ) + await asyncio.sleep(0.01) + + +async def get_all_events(stream): + return [event async for event in stream] + + +class MockUser(User): + @property + def is_authenticated(self) -> bool: + return True + + @property + def user_name(self) -> str: + return 'test-user' + + +class MockCallContextBuilder(GrpcServerCallContextBuilder): + def build(self, request: Any) -> ServerCallContext: + return ServerCallContext( + user=MockUser(), state={'headers': {'a2a-version': '1.0'}} + ) + + +def agent_card(): + return AgentCard( + name='Test Agent', + version='1.0.0', + capabilities=AgentCapabilities(streaming=True), + supported_interfaces=[ + AgentInterface( + protocol_binding=TransportProtocol.GRPC, + url='http://testserver', + ) + ], + ) + + +def get_state(event): + if event.HasField('task'): + return event.task.status.state + return event.status_update.status.state + + +def validate_state(event, expected_state): + assert get_state(event) == expected_state + + +_test_servers = [] + + +@pytest_asyncio.fixture(autouse=True) +async def cleanup_test_servers(): + yield + for server in _test_servers: + await server.stop(None) + _test_servers.clear() + + +# TODO: Test different transport (e.g. HTTP_JSON hangs for some tests). +async def create_client(handler, agent_card, streaming=False): + server = grpc.aio.server() + port = server.add_insecure_port('[::]:0') + server_address = f'localhost:{port}' + + agent_card.supported_interfaces[0].url = server_address + agent_card.supported_interfaces[0].protocol_binding = TransportProtocol.GRPC + + servicer = GrpcHandler( + agent_card, handler, context_builder=MockCallContextBuilder() + ) + a2a_pb2_grpc.add_A2AServiceServicer_to_server(servicer, server) + await server.start() + _test_servers.append(server) + + factory = ClientFactory( + config=ClientConfig( + grpc_channel_factory=grpc.aio.insecure_channel, + supported_protocol_bindings=[TransportProtocol.GRPC], + streaming=streaming, + ) + ) + client = factory.create(agent_card) + client._server = server # Keep reference to prevent garbage collection + return client + + +def create_handler( + agent_executor, use_legacy, task_store=None, queue_manager=None +): + task_store = task_store or InMemoryTaskStore() + queue_manager = queue_manager or InMemoryQueueManager() + return ( + LegacyRequestHandler(agent_executor, task_store, queue_manager) + if use_legacy + else DefaultRequestHandlerV2(agent_executor, task_store, queue_manager) + ) + + +# Scenario 1: Cancellation of already terminal task +# This also covers test_scenario_7_cancel_terminal_task from test_handler_comparison +@pytest.mark.timeout(2.0) +@pytest.mark.asyncio +@pytest.mark.parametrize('use_legacy', [False, True], ids=['v2', 'legacy']) +@pytest.mark.parametrize( + 'streaming', [False, True], ids=['blocking', 'streaming'] +) +async def test_scenario_1_cancel_terminal_task(use_legacy, streaming): + class DummyAgentExecutor(AgentExecutor): + async def execute( + self, context: RequestContext, event_queue: EventQueue + ): + pass + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ): + pass + + task_store = InMemoryTaskStore() + handler = create_handler( + DummyAgentExecutor(), use_legacy, task_store=task_store + ) + client = await create_client( + handler, agent_card=agent_card(), streaming=streaming + ) + + task_id = 'terminal-task' + await task_store.save( + Task( + id=task_id, status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED) + ), + ServerCallContext(user=MockUser()), + ) + with pytest.raises(TaskNotCancelableError): + await client.cancel_task(CancelTaskRequest(id=task_id)) + + +@pytest.mark.asyncio +@pytest.mark.parametrize('use_legacy', [False, True], ids=['v2', 'legacy']) +async def test_scenario_4_simple_streaming(use_legacy): + class DummyAgentExecutor(AgentExecutor): + async def execute( + self, context: RequestContext, event_queue: EventQueue + ): + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + status=TaskStatus(state=TaskState.TASK_STATE_WORKING), + ) + ) + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), + ) + ) + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ): + pass + + handler = create_handler(DummyAgentExecutor(), use_legacy) + client = await create_client( + handler, agent_card=agent_card(), streaming=True + ) + msg = Message( + message_id='test-msg', role=Role.ROLE_USER, parts=[Part(text='hello')] + ) + events = [ + event + async for event in client.send_message(SendMessageRequest(message=msg)) + ] + assert [event.status_update.status.state for event in events] == [ + TaskState.TASK_STATE_WORKING, + TaskState.TASK_STATE_COMPLETED, + ] + + +# Scenario 5: Re-subscribing to a finished task +@pytest.mark.asyncio +@pytest.mark.parametrize('use_legacy', [False, True], ids=['v2', 'legacy']) +async def test_scenario_5_resubscribe_to_finished(use_legacy): + class DummyAgentExecutor(AgentExecutor): + async def execute( + self, context: RequestContext, event_queue: EventQueue + ): + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + status=TaskStatus(state=TaskState.TASK_STATE_WORKING), + ) + ) + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), + ) + ) + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ): + pass + + handler = create_handler(DummyAgentExecutor(), use_legacy) + client = await create_client(handler, agent_card=agent_card()) + msg = Message( + message_id='test-msg', role=Role.ROLE_USER, parts=[Part(text='hello')] + ) + it = client.send_message( + SendMessageRequest( + message=msg, + configuration=SendMessageConfiguration(return_immediately=False), + ) + ) + + (event,) = [event async for event in it] + task_id = event.task.id + + await wait_for_state( + client, task_id, expected_states={TaskState.TASK_STATE_COMPLETED} + ) + # TODO: Use different transport. + with pytest.raises( + NotImplementedError, + match='client and/or server do not support resubscription', + ): + async for _ in client.subscribe(SubscribeToTaskRequest(id=task_id)): + pass + + +# Scenario 6-8: Parity for Error cases +@pytest.mark.timeout(2.0) +@pytest.mark.asyncio +@pytest.mark.parametrize('use_legacy', [False, True], ids=['v2', 'legacy']) +@pytest.mark.parametrize( + 'streaming', [False, True], ids=['blocking', 'streaming'] +) +async def test_scenarios_simple_errors(use_legacy, streaming): + class DummyAgentExecutor(AgentExecutor): + async def execute( + self, context: RequestContext, event_queue: EventQueue + ): + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), + ) + ) + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ): + pass + + handler = create_handler(DummyAgentExecutor(), use_legacy) + client = await create_client( + handler, agent_card=agent_card(), streaming=streaming + ) + + with pytest.raises(TaskNotFoundError): + await client.get_task(GetTaskRequest(id='missing')) + + msg1 = Message( + task_id='missing', + message_id='missing-task', + role=Role.ROLE_USER, + parts=[Part(text='h')], + ) + with pytest.raises(TaskNotFoundError): + async for _ in client.send_message(SendMessageRequest(message=msg1)): + pass + + msg = Message( + message_id='test-msg', role=Role.ROLE_USER, parts=[Part(text='hello')] + ) + it = client.send_message( + SendMessageRequest( + message=msg, + configuration=SendMessageConfiguration(return_immediately=False), + ) + ) + (event,) = [event async for event in it] + + if streaming: + assert event.HasField('status_update') + task_id = event.status_update.task_id + assert ( + event.status_update.status.state == TaskState.TASK_STATE_COMPLETED + ) + else: + assert event.HasField('task') + task_id = event.task.id + assert event.task.status.state == TaskState.TASK_STATE_COMPLETED + + logger.info('Sending message to completed task %s', task_id) + msg2 = Message( + message_id='test-msg-2', + task_id=task_id, + role=Role.ROLE_USER, + parts=[Part(text='message to completed task')], + ) + # TODO: Is it correct error code ? + with pytest.raises(InvalidParamsError): + async for _ in client.send_message(SendMessageRequest(message=msg2)): + pass + + (task,) = (await client.list_tasks(ListTasksRequest())).tasks + assert task.status.state == TaskState.TASK_STATE_COMPLETED + (message,) = task.history + assert message.role == Role.ROLE_USER + (message_part,) = message.parts + assert message_part.text == 'hello' + + +# Scenario 9: Exception before any event. +@pytest.mark.timeout(2.0) +@pytest.mark.asyncio +@pytest.mark.parametrize('use_legacy', [False, True], ids=['v2', 'legacy']) +@pytest.mark.parametrize( + 'streaming', [False, True], ids=['blocking', 'streaming'] +) +async def test_scenario_9_error_before_blocking(use_legacy, streaming): + class ErrorBeforeAgent(AgentExecutor): + async def execute( + self, context: RequestContext, event_queue: EventQueue + ): + raise ValueError('TEST_ERROR_IN_EXECUTE') + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ): + pass + + handler = create_handler(ErrorBeforeAgent(), use_legacy) + client = await create_client( + handler, agent_card=agent_card(), streaming=streaming + ) + msg = Message( + message_id='test-msg', role=Role.ROLE_USER, parts=[Part(text='hello')] + ) + + # TODO: Is it correct error code ? + with pytest.raises(A2AClientError, match='TEST_ERROR_IN_EXECUTE'): + async for _ in client.send_message( + SendMessageRequest( + message=msg, + configuration=SendMessageConfiguration( + return_immediately=False + ), + ) + ): + pass + + if use_legacy: + # Legacy is not creating tasks for agent failures. + assert len((await client.list_tasks(ListTasksRequest())).tasks) == 0 + else: + # TODO: should it be TASK_STATE_FAILED ? + (task,) = (await client.list_tasks(ListTasksRequest())).tasks + assert task.status.state == TaskState.TASK_STATE_SUBMITTED + + +# Scenario 12/13: Exception after initial event +@pytest.mark.timeout(2.0) +@pytest.mark.asyncio +@pytest.mark.parametrize('use_legacy', [False, True], ids=['v2', 'legacy']) +@pytest.mark.parametrize( + 'streaming', [False, True], ids=['blocking', 'streaming'] +) +async def test_scenario_12_13_error_after_initial_event(use_legacy, streaming): + started_event = asyncio.Event() + continue_event = asyncio.Event() + + class ErrorAfterAgent(AgentExecutor): + async def execute( + self, context: RequestContext, event_queue: EventQueue + ): + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + status=TaskStatus(state=TaskState.TASK_STATE_WORKING), + ) + ) + started_event.set() + await continue_event.wait() + raise ValueError('TEST_ERROR_IN_EXECUTE') + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ): + pass + + handler = create_handler(ErrorAfterAgent(), use_legacy) + client = await create_client( + handler, agent_card=agent_card(), streaming=streaming + ) + msg = Message( + message_id='test-msg', role=Role.ROLE_USER, parts=[Part(text='hello')] + ) + + it = client.send_message(SendMessageRequest(message=msg)) + + tasks = [] + + if streaming: + res = await it.__anext__() + assert res.status_update.status.state == TaskState.TASK_STATE_WORKING + continue_event.set() + else: + + async def release_agent(): + await started_event.wait() + continue_event.set() + + tasks.append(asyncio.create_task(release_agent())) + + with pytest.raises(A2AClientError, match='TEST_ERROR_IN_EXECUTE'): + async for _ in it: + pass + + await asyncio.gather(*tasks) + + # TODO: should it be TASK_STATE_FAILED ? + (task,) = (await client.list_tasks(ListTasksRequest())).tasks + assert task.status.state == TaskState.TASK_STATE_WORKING + + +# Scenario 14: Exception in Cancel +@pytest.mark.timeout(2.0) +@pytest.mark.asyncio +@pytest.mark.parametrize('use_legacy', [False, True], ids=['v2', 'legacy']) +@pytest.mark.parametrize( + 'streaming', [False, True], ids=['blocking', 'streaming'] +) +async def test_scenario_14_error_in_cancel(use_legacy, streaming): + started_event = asyncio.Event() + hang_event = asyncio.Event() + + class ErrorCancelAgent(AgentExecutor): + async def execute( + self, context: RequestContext, event_queue: EventQueue + ): + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + status=TaskStatus(state=TaskState.TASK_STATE_WORKING), + ) + ) + started_event.set() + await hang_event.wait() + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ): + raise ValueError('TEST_ERROR_IN_CANCEL') + + handler = create_handler(ErrorCancelAgent(), use_legacy) + client = await create_client( + handler, agent_card=agent_card(), streaming=streaming + ) + + msg = Message( + message_id='test-msg', + role=Role.ROLE_USER, + parts=[Part(text='hello')], + ) + + it = client.send_message( + SendMessageRequest( + message=msg, + configuration=SendMessageConfiguration(return_immediately=True), + ) + ) + res = await it.__anext__() + task_id = res.task.id if res.HasField('task') else res.status_update.task_id + + await asyncio.wait_for(started_event.wait(), timeout=1.0) + + with pytest.raises(A2AClientError, match='TEST_ERROR_IN_CANCEL'): + await client.cancel_task(CancelTaskRequest(id=task_id)) + + # TODO: should it be TASK_STATE_CANCELED or TASK_STATE_FAILED? + (task,) = (await client.list_tasks(ListTasksRequest())).tasks + assert task.status.state == TaskState.TASK_STATE_WORKING + + +# Scenario 15: Subscribe to task that errors out +@pytest.mark.timeout(2.0) +@pytest.mark.asyncio +@pytest.mark.parametrize('use_legacy', [False, True], ids=['v2', 'legacy']) +async def test_scenario_15_subscribe_error(use_legacy): + started_event = asyncio.Event() + continue_event = asyncio.Event() + + class ErrorAfterAgent(AgentExecutor): + async def execute( + self, context: RequestContext, event_queue: EventQueue + ): + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + status=TaskStatus(state=TaskState.TASK_STATE_WORKING), + ) + ) + started_event.set() + await continue_event.wait() + raise ValueError('TEST_ERROR_IN_EXECUTE') + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ): + pass + + handler = create_handler(ErrorAfterAgent(), use_legacy) + client = await create_client( + handler, agent_card=agent_card(), streaming=True + ) + msg = Message( + message_id='test-msg', role=Role.ROLE_USER, parts=[Part(text='hello')] + ) + + it_start = client.send_message( + SendMessageRequest( + message=msg, + configuration=SendMessageConfiguration(return_immediately=True), + ) + ) + res = await it_start.__anext__() + task_id = res.task.id if res.HasField('task') else res.status_update.task_id + + async def consume_events(): + async for _ in client.subscribe(SubscribeToTaskRequest(id=task_id)): + pass + + consume_task = asyncio.create_task(consume_events()) + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(asyncio.shield(consume_task), timeout=0.1) + + await asyncio.wait_for(started_event.wait(), timeout=1.0) + continue_event.set() + + if use_legacy: + # Legacy client hangs forever. + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(consume_task, timeout=0.1) + else: + with pytest.raises(A2AClientError, match='TEST_ERROR_IN_EXECUTE'): + await consume_task + + # TODO: should it be TASK_STATE_FAILED? + (task,) = (await client.list_tasks(ListTasksRequest())).tasks + assert task.status.state == TaskState.TASK_STATE_WORKING + + +# Scenario 16: Slow execution and return_immediately=True +@pytest.mark.timeout(2.0) +@pytest.mark.asyncio +@pytest.mark.parametrize('use_legacy', [False, True], ids=['v2', 'legacy']) +@pytest.mark.parametrize( + 'streaming', [False, True], ids=['blocking', 'streaming'] +) +async def test_scenario_16_slow_execution(use_legacy, streaming): + started_event = asyncio.Event() + hang_event = asyncio.Event() + + class SlowAgent(AgentExecutor): + async def execute( + self, context: RequestContext, event_queue: EventQueue + ): + started_event.set() + await hang_event.wait() + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ): + pass + + queue_manager = InMemoryQueueManager() + handler = create_handler( + SlowAgent(), use_legacy, queue_manager=queue_manager + ) + client = await create_client( + handler, agent_card=agent_card(), streaming=streaming + ) + + msg = Message( + message_id='test-msg', + role=Role.ROLE_USER, + parts=[Part(text='hello')], + ) + + async def send_message_and_get_first_response(): + it = client.send_message( + SendMessageRequest( + message=msg, + configuration=SendMessageConfiguration(return_immediately=True), + ) + ) + return await asyncio.wait_for(it.__anext__(), timeout=0.1) + + if use_legacy: + # Legacy client hangs forever. + with pytest.raises(asyncio.TimeoutError): + await send_message_and_get_first_response() + else: + event = await send_message_and_get_first_response() + task = event.task + assert task.status.state == TaskState.TASK_STATE_SUBMITTED + (message,) = task.history + assert message.message_id == 'test-msg' + + tasks = (await client.list_tasks(ListTasksRequest())).tasks + if use_legacy: + # Legacy didn't create a task + assert len(tasks) == 0 + else: + (task,) = tasks + assert task.status.state == TaskState.TASK_STATE_SUBMITTED + + +# Scenario 17: Cancellation of a working task. +# @pytest.mark.skip +@pytest.mark.timeout(2.0) +@pytest.mark.asyncio +@pytest.mark.parametrize('use_legacy', [False, True], ids=['v2', 'legacy']) +@pytest.mark.parametrize( + 'streaming', [False, True], ids=['blocking', 'streaming'] +) +async def test_scenario_cancel_working_task_empty_cancel(use_legacy, streaming): + started_event = asyncio.Event() + hang_event = asyncio.Event() + + class DummyCancelAgent(AgentExecutor): + async def execute( + self, context: RequestContext, event_queue: EventQueue + ): + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + status=TaskStatus(state=TaskState.TASK_STATE_WORKING), + ) + ) + started_event.set() + await hang_event.wait() + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ): + # TODO: this should be done automatically by the framework ? + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + status=TaskStatus(state=TaskState.TASK_STATE_CANCELED), + ) + ) + + handler = create_handler(DummyCancelAgent(), use_legacy) + client = await create_client( + handler, agent_card=agent_card(), streaming=streaming + ) + + msg = Message( + message_id='test-msg', role=Role.ROLE_USER, parts=[Part(text='hello')] + ) + + it = client.send_message( + SendMessageRequest( + message=msg, + configuration=SendMessageConfiguration(return_immediately=True), + ) + ) + res = await it.__anext__() + task_id = res.task.id if res.HasField('task') else res.status_update.task_id + + await asyncio.wait_for(started_event.wait(), timeout=1.0) + + task_before = await client.get_task(GetTaskRequest(id=task_id)) + assert task_before.status.state == TaskState.TASK_STATE_WORKING + + cancel_res = await client.cancel_task(CancelTaskRequest(id=task_id)) + assert cancel_res.status.state == TaskState.TASK_STATE_CANCELED + + task_after = await client.get_task(GetTaskRequest(id=task_id)) + assert task_after.status.state == TaskState.TASK_STATE_CANCELED + + (task_from_list,) = (await client.list_tasks(ListTasksRequest())).tasks + assert task_from_list.status.state == TaskState.TASK_STATE_CANCELED + + +# Scenario 18: Complex streaming with multiple subscribers +@pytest.mark.timeout(2.0) +@pytest.mark.asyncio +@pytest.mark.parametrize('use_legacy', [False, True], ids=['v2', 'legacy']) +async def test_scenario_18_streaming_subscribers(use_legacy): + started_event = asyncio.Event() + working_event = asyncio.Event() + completed_event = asyncio.Event() + + class ComplexAgent(AgentExecutor): + async def execute( + self, context: RequestContext, event_queue: EventQueue + ): + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + status=TaskStatus(state=TaskState.TASK_STATE_WORKING), + ) + ) + started_event.set() + await working_event.wait() + + await event_queue.enqueue_event( + TaskArtifactUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + artifact=Artifact(artifact_id='test-art'), + ) + ) + await completed_event.wait() + + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), + ) + ) + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ): + pass + + handler = create_handler(ComplexAgent(), use_legacy) + client = await create_client( + handler, agent_card=agent_card(), streaming=True + ) + + msg = Message( + message_id='test-msg', role=Role.ROLE_USER, parts=[Part(text='hello')] + ) + + it = client.send_message( + SendMessageRequest( + message=msg, + configuration=SendMessageConfiguration(return_immediately=True), + ) + ) + res = await it.__anext__() + task_id = res.task.id if res.HasField('task') else res.status_update.task_id + + await asyncio.wait_for(started_event.wait(), timeout=1.0) + + # create first subscriber + sub1 = client.subscribe(SubscribeToTaskRequest(id=task_id)) + + # first subscriber receives current task state (WORKING) + validate_state(await sub1.__anext__(), TaskState.TASK_STATE_WORKING) + + # create second subscriber + sub2 = client.subscribe(SubscribeToTaskRequest(id=task_id)) + + # second subscriber receives current task state (WORKING) + validate_state(await sub2.__anext__(), TaskState.TASK_STATE_WORKING) + + working_event.set() + + # validate what both subscribers observed (artifact) + res1_art = await sub1.__anext__() + assert res1_art.artifact_update.artifact.artifact_id == 'test-art' + + res2_art = await sub2.__anext__() + assert res2_art.artifact_update.artifact.artifact_id == 'test-art' + + completed_event.set() + + # validate what both subscribers observed (completed) + validate_state(await sub1.__anext__(), TaskState.TASK_STATE_COMPLETED) + validate_state(await sub2.__anext__(), TaskState.TASK_STATE_COMPLETED) + + # validate final task state with getTask + final_task = await client.get_task(GetTaskRequest(id=task_id)) + assert final_task.status.state == TaskState.TASK_STATE_COMPLETED + + (artifact,) = final_task.artifacts + assert artifact.artifact_id == 'test-art' + + (message,) = final_task.history + assert message.parts[0].text == 'hello' + + +# Scenario 19: Parallel executions for the same task should not happen simultaneously. +@pytest.mark.timeout(2.0) +@pytest.mark.asyncio +@pytest.mark.parametrize('use_legacy', [False, True], ids=['v2', 'legacy']) +@pytest.mark.parametrize( + 'streaming', [False, True], ids=['blocking', 'streaming'] +) +async def test_scenario_19_no_parallel_executions(use_legacy, streaming): + started_event = asyncio.Event() + continue_event = asyncio.Event() + executions_count = 0 + + class CountingAgent(AgentExecutor): + async def execute( + self, context: RequestContext, event_queue: EventQueue + ): + nonlocal executions_count + executions_count += 1 + + if executions_count > 1: + await event_queue.enqueue_event( + TaskArtifactUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + artifact=Artifact(artifact_id='SECOND_EXECUTION'), + ) + ) + return + + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + status=TaskStatus(state=TaskState.TASK_STATE_WORKING), + ) + ) + started_event.set() + await continue_event.wait() + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), + ) + ) + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ): + pass + + handler = create_handler(CountingAgent(), use_legacy) + client1 = await create_client( + handler, agent_card=agent_card(), streaming=streaming + ) + client2 = await create_client( + handler, agent_card=agent_card(), streaming=streaming + ) + + msg1 = Message( + message_id='test-msg-1', + role=Role.ROLE_USER, + parts=[Part(text='hello 1')], + ) + + # First client sends initial message + it1 = client1.send_message( + SendMessageRequest( + message=msg1, + configuration=SendMessageConfiguration(return_immediately=False), + ) + ) + task1 = asyncio.create_task(it1.__anext__()) + + # Wait for the first execution to reach the WORKING state + await asyncio.wait_for(started_event.wait(), timeout=1.0) + assert executions_count == 1 + + # Extract task_id from the first call using list_tasks + (task,) = (await client1.list_tasks(ListTasksRequest())).tasks + task_id = task.id + + msg2 = Message( + message_id='test-msg-2', + task_id=task_id, + role=Role.ROLE_USER, + parts=[Part(text='hello 2')], + ) + + # Second client sends a message to the same task + it2 = client2.send_message( + SendMessageRequest( + message=msg2, + configuration=SendMessageConfiguration(return_immediately=False), + ) + ) + + task2 = asyncio.create_task(it2.__anext__()) + + if use_legacy: + # Legacy handler executes the second request in parallel. + await task2 + assert executions_count == 2 + else: + # V2 handler queues the second request. + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(asyncio.shield(task2), timeout=0.1) + assert executions_count == 1 + + # Unblock AgentExecutor + continue_event.set() + + # Verify that both calls for clients finished. + if use_legacy and not streaming: + # Legacy handler fails on first execution. + with pytest.raises(A2AClientError, match='NoTaskQueue'): + await task1 + else: + await task1 + + try: + await task2 + except StopAsyncIteration: + # TODO: Test is flaky. Debug it. + return + + # Consume remaining events if any + async def consume(it): + async for _ in it: + pass + + await asyncio.gather(consume(it1), consume(it2)) + assert executions_count == 2 + + # Validate final task state. + final_task = await client1.get_task(GetTaskRequest(id=task_id)) + + if use_legacy: + # Legacy handler fails to complete the task. + assert final_task.status.state == TaskState.TASK_STATE_WORKING + else: + assert final_task.status.state == TaskState.TASK_STATE_COMPLETED + + # TODO: What is expected state of messages and artifacts? + + +# Scenario: Validate return_immediately flag behaviour. +@pytest.mark.asyncio +@pytest.mark.parametrize('use_legacy', [False, True], ids=['v2', 'legacy']) +@pytest.mark.parametrize( + 'streaming', [False, True], ids=['blocking', 'streaming'] +) +async def test_scenario_return_immediately(use_legacy, streaming): + class ImmediateAgent(AgentExecutor): + async def execute( + self, context: RequestContext, event_queue: EventQueue + ): + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + status=TaskStatus(state=TaskState.TASK_STATE_WORKING), + ) + ) + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), + ) + ) + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ): + pass + + handler = create_handler(ImmediateAgent(), use_legacy) + client = await create_client( + handler, agent_card=agent_card(), streaming=streaming + ) + + msg = Message( + message_id='test-msg', role=Role.ROLE_USER, parts=[Part(text='hello')] + ) + + # Test non-blocking return. + it = client.send_message( + SendMessageRequest( + message=msg, + configuration=SendMessageConfiguration(return_immediately=True), + ) + ) + states = [get_state(event) async for event in it] + + if use_legacy: + if streaming: + assert states == [ + TaskState.TASK_STATE_WORKING, + TaskState.TASK_STATE_COMPLETED, + ] + else: + assert states == [TaskState.TASK_STATE_WORKING] + elif streaming: + assert states == [ + TaskState.TASK_STATE_SUBMITTED, + TaskState.TASK_STATE_WORKING, + TaskState.TASK_STATE_COMPLETED, + ] + else: + assert states == [TaskState.TASK_STATE_SUBMITTED] + + # Test blocking return. + it = client.send_message( + SendMessageRequest( + message=msg, + configuration=SendMessageConfiguration(return_immediately=False), + ) + ) + states = [get_state(event) async for event in it] + + if streaming: + assert states == [ + TaskState.TASK_STATE_WORKING, + TaskState.TASK_STATE_COMPLETED, + ] + else: + assert states == [TaskState.TASK_STATE_COMPLETED] + + +# Scenario: Test TASK_STATE_INPUT_REQUIRED. +@pytest.mark.timeout(2.0) +@pytest.mark.asyncio +@pytest.mark.parametrize('use_legacy', [False, True], ids=['v2', 'legacy']) +@pytest.mark.parametrize( + 'streaming', [False, True], ids=['blocking', 'streaming'] +) +async def test_scenario_resumption_from_interrupted(use_legacy, streaming): + class ResumingAgent(AgentExecutor): + async def execute( + self, context: RequestContext, event_queue: EventQueue + ): + message = context.message + if message and message.parts and message.parts[0].text == 'start': + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + status=TaskStatus( + state=TaskState.TASK_STATE_INPUT_REQUIRED + ), + ) + ) + elif ( + message + and message.parts + and message.parts[0].text == 'here is input' + ): + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), + ) + ) + else: + raise ValueError('Unexpected message') + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ): + pass + + handler = create_handler(ResumingAgent(), use_legacy) + client = await create_client( + handler, agent_card=agent_card(), streaming=streaming + ) + + # First send message to get it into input required state + msg1 = Message( + message_id='msg-start', role=Role.ROLE_USER, parts=[Part(text='start')] + ) + + it = client.send_message( + SendMessageRequest( + message=msg1, + configuration=SendMessageConfiguration(return_immediately=False), + ) + ) + + events1 = [event async for event in it] + assert [get_state(event) for event in events1] == [ + TaskState.TASK_STATE_INPUT_REQUIRED, + ] + task_id = events1[0].status_update.task_id + context_id = events1[0].status_update.context_id + + # Now send another message to resume + msg2 = Message( + task_id=task_id, + context_id=context_id, + message_id='msg-resume', + role=Role.ROLE_USER, + parts=[Part(text='here is input')], + ) + + it2 = client.send_message( + SendMessageRequest( + message=msg2, + configuration=SendMessageConfiguration(return_immediately=False), + ) + ) + + assert [get_state(event) async for event in it2] == [ + TaskState.TASK_STATE_COMPLETED, + ] + + +# Scenario: Auth required and side channel unblocking +# Migrated from: test_workflow_auth_required_side_channel in test_handler_comparison +@pytest.mark.timeout(2.0) +@pytest.mark.asyncio +@pytest.mark.parametrize('use_legacy', [False, True], ids=['v2', 'legacy']) +@pytest.mark.parametrize( + 'streaming', [False, True], ids=['blocking', 'streaming'] +) +async def test_scenario_auth_required_side_channel(use_legacy, streaming): + side_channel_event = asyncio.Event() + + class AuthAgent(AgentExecutor): + async def execute( + self, context: RequestContext, event_queue: EventQueue + ): + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + status=TaskStatus(state=TaskState.TASK_STATE_WORKING), + ) + ) + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + status=TaskStatus(state=TaskState.TASK_STATE_AUTH_REQUIRED), + ) + ) + + await side_channel_event.wait() + + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), + ) + ) + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ): + pass + + handler = create_handler(AuthAgent(), use_legacy) + client = await create_client( + handler, agent_card=agent_card(), streaming=streaming + ) + + msg = Message( + message_id='test-msg', role=Role.ROLE_USER, parts=[Part(text='start')] + ) + + it = client.send_message( + SendMessageRequest( + message=msg, + configuration=SendMessageConfiguration(return_immediately=False), + ) + ) + + if streaming: + event1 = await asyncio.wait_for(it.__anext__(), timeout=1.0) + assert get_state(event1) == TaskState.TASK_STATE_WORKING + + event2 = await asyncio.wait_for(it.__anext__(), timeout=1.0) + assert get_state(event2) == TaskState.TASK_STATE_AUTH_REQUIRED + + task_id = event2.status_update.task_id + + side_channel_event.set() + + # Remaining event. + (event3,) = [event async for event in it] + assert get_state(event3) == TaskState.TASK_STATE_COMPLETED + else: + (event,) = [event async for event in it] + assert get_state(event) == TaskState.TASK_STATE_AUTH_REQUIRED + task_id = event.task.id + + side_channel_event.set() + + await wait_for_state( + client, task_id, expected_states={TaskState.TASK_STATE_COMPLETED} + ) + + +# Scenario: Parallel subscribe attach detach +# Migrated from: test_parallel_subscribe_attach_detach in test_handler_comparison +@pytest.mark.timeout(5.0) +@pytest.mark.asyncio +@pytest.mark.parametrize('use_legacy', [False, True], ids=['v2', 'legacy']) +async def test_scenario_parallel_subscribe_attach_detach(use_legacy): + events = collections.defaultdict(asyncio.Event) + + class EmitAgent(AgentExecutor): + async def execute( + self, context: RequestContext, event_queue: EventQueue + ): + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + status=TaskStatus(state=TaskState.TASK_STATE_WORKING), + ) + ) + + phases = [ + ('trigger_phase_1', 'artifact_1'), + ('trigger_phase_2', 'artifact_2'), + ('trigger_phase_3', 'artifact_3'), + ('trigger_phase_4', 'artifact_4'), + ] + + for trigger_name, artifact_id in phases: + await events[trigger_name].wait() + await event_queue.enqueue_event( + TaskArtifactUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + artifact=Artifact( + artifact_id=artifact_id, + parts=[Part(text=artifact_id)], + ), + ) + ) + + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), + ) + ) + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ): + pass + + handler = create_handler(EmitAgent(), use_legacy) + client = await create_client( + handler, agent_card=agent_card(), streaming=True + ) + + msg = Message( + message_id='test-msg', role=Role.ROLE_USER, parts=[Part(text='start')] + ) + + it = client.send_message( + SendMessageRequest( + message=msg, + configuration=SendMessageConfiguration(return_immediately=True), + ) + ) + + res = await it.__anext__() + task_id = res.task.id if res.HasField('task') else res.status_update.task_id + + async def monitor_artifacts(): + try: + async for event in client.subscribe( + SubscribeToTaskRequest(id=task_id) + ): + if event.HasField('artifact_update'): + artifact_id = event.artifact_update.artifact.artifact_id + if artifact_id.startswith('artifact_'): + phase_num = artifact_id.split('_')[1] + events[f'emitted_phase_{phase_num}'].set() + except asyncio.CancelledError: + pass + + monitor_task = asyncio.create_task(monitor_artifacts()) + + async def subscribe_and_collect(artifacts_to_collect: int | None = None): + ready_event = asyncio.Event() + + async def collect(): + collected = [] + artifacts_seen = 0 + try: + async for event in client.subscribe( + SubscribeToTaskRequest(id=task_id) + ): + collected.append(event) + ready_event.set() + if event.HasField('artifact_update'): + artifacts_seen += 1 + if ( + artifacts_to_collect is not None + and artifacts_seen >= artifacts_to_collect + ): + break + except asyncio.CancelledError: + pass + return collected + + task = asyncio.create_task(collect()) + await ready_event.wait() + return task + + sub1_task = await subscribe_and_collect() + + events['trigger_phase_1'].set() + await events['emitted_phase_1'].wait() + + sub2_task = await subscribe_and_collect(artifacts_to_collect=1) + sub3_task = await subscribe_and_collect(artifacts_to_collect=2) + + events['trigger_phase_2'].set() + await events['emitted_phase_2'].wait() + + events['trigger_phase_3'].set() + await events['emitted_phase_3'].wait() + + sub4_task = await subscribe_and_collect() + + events['trigger_phase_4'].set() + await events['emitted_phase_4'].wait() + + def get_artifact_updates(evs): + txts = [] + for sr in evs: + if sr.HasField('artifact_update'): + txts.append([p.text for p in sr.artifact_update.artifact.parts]) + return txts + + assert get_artifact_updates(await sub1_task) == [ + ['artifact_1'], + ['artifact_2'], + ['artifact_3'], + ['artifact_4'], + ] + + assert get_artifact_updates(await sub2_task) == [ + ['artifact_2'], + ] + assert get_artifact_updates(await sub3_task) == [ + ['artifact_2'], + ['artifact_3'], + ] + assert get_artifact_updates(await sub4_task) == [ + ['artifact_4'], + ] + + monitor_task.cancel() diff --git a/tests/server/agent_execution/__init__.py b/tests/server/agent_execution/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/server/agent_execution/test_active_task.py b/tests/server/agent_execution/test_active_task.py new file mode 100644 index 000000000..d3cc95dc3 --- /dev/null +++ b/tests/server/agent_execution/test_active_task.py @@ -0,0 +1,1088 @@ +import asyncio +import logging + +from unittest.mock import AsyncMock, Mock, patch + +import pytest +import pytest_asyncio + +from a2a.server.agent_execution.active_task import ActiveTask +from a2a.server.agent_execution.agent_executor import AgentExecutor +from a2a.server.agent_execution.context import RequestContext +from a2a.server.context import ServerCallContext +from a2a.server.events.event_queue_v2 import EventQueueSource as EventQueue +from a2a.server.tasks.push_notification_sender import PushNotificationSender +from a2a.server.tasks.task_manager import TaskManager +from a2a.types.a2a_pb2 import ( + Message, + Task, + TaskState, + TaskStatus, + TaskStatusUpdateEvent, +) +from a2a.utils.errors import InvalidParamsError + + +logger = logging.getLogger(__name__) + + +class TestActiveTask: + """Tests for the ActiveTask class.""" + + @pytest.fixture + def agent_executor(self) -> Mock: + return Mock(spec=AgentExecutor) + + @pytest.fixture + def task_manager(self) -> Mock: + tm = Mock(spec=TaskManager) + tm.process = AsyncMock(side_effect=lambda x: x) + tm.get_task = AsyncMock(return_value=None) + tm.context_id = 'test-context-id' + tm._init_task_obj = Mock(return_value=Task(id='test-task-id')) + tm.save_task_event = AsyncMock() + return tm + + @pytest_asyncio.fixture + async def event_queue(self) -> EventQueue: + return EventQueue() + + @pytest.fixture + def push_sender(self) -> Mock: + ps = Mock(spec=PushNotificationSender) + ps.send_notification = AsyncMock() + return ps + + @pytest.fixture + def request_context(self) -> Mock: + return Mock(spec=RequestContext) + + @pytest_asyncio.fixture + async def active_task( + self, + agent_executor: Mock, + task_manager: Mock, + push_sender: Mock, + ) -> ActiveTask: + return ActiveTask( + agent_executor=agent_executor, + task_id='test-task-id', + task_manager=task_manager, + push_sender=push_sender, + ) + + @pytest.mark.asyncio + async def test_active_task_lifecycle( + self, + active_task: ActiveTask, + agent_executor: Mock, + request_context: Mock, + task_manager: Mock, + ) -> None: + """Test the basic lifecycle of an ActiveTask.""" + + async def execute_mock(req, q): + await q.enqueue_event(Message(message_id='m1')) + await q.enqueue_event( + Task( + id='test-task-id', + status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), + ) + ) + + agent_executor.execute = AsyncMock(side_effect=execute_mock) + task_manager.get_task.side_effect = [ + Task( + id='test-task-id', + status=TaskStatus(state=TaskState.TASK_STATE_WORKING), + ) + ] + [ + Task( + id='test-task-id', + status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), + ) + ] * 10 + + await active_task.enqueue_request(request_context) + await active_task.start( + call_context=ServerCallContext(), create_task_if_missing=True + ) + + # Wait for the task to finish + events = [e async for e in active_task.subscribe()] + result = next(e for e in events if isinstance(e, Message)) + + assert isinstance(result, Message) + assert result.message_id == 'm1' + assert active_task.task_id == 'test-task-id' + + @pytest.mark.asyncio + async def test_active_task_already_started( + self, active_task: ActiveTask, request_context: Mock + ) -> None: + """Test starting a task that is already started.""" + await active_task.enqueue_request(request_context) + await active_task.start( + call_context=ServerCallContext(), create_task_if_missing=True + ) + # Enqueuing and starting again should not raise errors + await active_task.enqueue_request(request_context) + await active_task.start( + call_context=ServerCallContext(), create_task_if_missing=True + ) + assert active_task._producer_task is not None + + @pytest.mark.asyncio + async def test_active_task_subscribe( + self, + active_task: ActiveTask, + agent_executor: Mock, + request_context: Mock, + ) -> None: + """Test subscribing to events from an ActiveTask.""" + + async def execute_mock(req, q): + await q.enqueue_event(Message(message_id='m1')) + await q.enqueue_event(Message(message_id='m2')) + + agent_executor.execute = AsyncMock(side_effect=execute_mock) + + await active_task.enqueue_request(request_context) + await active_task.start( + call_context=ServerCallContext(), create_task_if_missing=True + ) + + events = [] + async for event in active_task.subscribe(): + events.append(event) + if len(events) == 2: + break + + assert len(events) == 2 + assert events[0].message_id == 'm1' + assert events[1].message_id == 'm2' + + @pytest.mark.asyncio + async def test_active_task_cancel( + self, + active_task: ActiveTask, + agent_executor: Mock, + request_context: Mock, + task_manager: Mock, + ) -> None: + """Test canceling an ActiveTask.""" + stop_event = asyncio.Event() + + async def execute_mock(req, q): + await stop_event.wait() + + agent_executor.execute = AsyncMock(side_effect=execute_mock) + agent_executor.cancel = AsyncMock() + task_manager.get_task.side_effect = [ + Task( + id='test-task-id', + status=TaskStatus(state=TaskState.TASK_STATE_WORKING), + ) + ] + [ + Task( + id='test-task-id', + status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), + ) + ] * 10 + + await active_task.enqueue_request(request_context) + await active_task.start( + call_context=ServerCallContext(), create_task_if_missing=True + ) + + # Give it a moment to start + await asyncio.sleep(0.1) + + await active_task.cancel(request_context) + + agent_executor.cancel.assert_called_once() + stop_event.set() + + @pytest.mark.asyncio + async def test_active_task_interrupted_auth( + self, + active_task: ActiveTask, + agent_executor: Mock, + request_context: Mock, + task_manager: Mock, + ) -> None: + """Test task interruption due to AUTH_REQUIRED.""" + task_obj = Task( + id='test-task-id', + status=TaskStatus(state=TaskState.TASK_STATE_AUTH_REQUIRED), + ) + + async def execute_mock(req, q): + await q.enqueue_event( + TaskStatusUpdateEvent( + task_id='test-task-id', + status=TaskStatus(state=TaskState.TASK_STATE_AUTH_REQUIRED), + ) + ) + + agent_executor.execute = AsyncMock(side_effect=execute_mock) + task_manager.get_task.side_effect = [ + Task( + id='test-task-id', + status=TaskStatus(state=TaskState.TASK_STATE_WORKING), + ) + ] + [task_obj] * 10 + + await active_task.start( + call_context=ServerCallContext(), create_task_if_missing=True + ) + + events = [ + e async for e in active_task.subscribe(request=request_context) + ] + + result = events[0] if events else None + assert ( + getattr(result, 'id', getattr(result, 'task_id', None)) + == 'test-task-id' + ) + assert result.status.state == TaskState.TASK_STATE_AUTH_REQUIRED + + @pytest.mark.asyncio + async def test_active_task_interrupted_input( + self, + active_task: ActiveTask, + agent_executor: Mock, + request_context: Mock, + task_manager: Mock, + ) -> None: + """Test task interruption due to INPUT_REQUIRED.""" + task_obj = Task( + id='test-task-id', + status=TaskStatus(state=TaskState.TASK_STATE_INPUT_REQUIRED), + ) + + async def execute_mock(req, q): + await q.enqueue_event( + Task( + id='test-task-id', + status=TaskStatus( + state=TaskState.TASK_STATE_INPUT_REQUIRED + ), + ) + ) + + agent_executor.execute = AsyncMock(side_effect=execute_mock) + task_manager.get_task.side_effect = [ + Task( + id='test-task-id', + status=TaskStatus(state=TaskState.TASK_STATE_WORKING), + ) + ] + [task_obj] * 10 + + await active_task.start( + call_context=ServerCallContext(), create_task_if_missing=True + ) + + events = [ + e async for e in active_task.subscribe(request=request_context) + ] + + result = events[-1] if events else None + assert result.id == 'test-task-id' + assert result.status.state == TaskState.TASK_STATE_INPUT_REQUIRED + + @pytest.mark.asyncio + async def test_active_task_producer_failure( + self, + active_task: ActiveTask, + agent_executor: Mock, + request_context: Mock, + ) -> None: + """Test ActiveTask behavior when the producer fails.""" + agent_executor.execute = AsyncMock( + side_effect=ValueError('Producer crashed') + ) + + await active_task.enqueue_request(request_context) + await active_task.start( + call_context=ServerCallContext(), create_task_if_missing=True + ) + + # We need to wait a bit for the producer to fail and set the exception + for _ in range(10): + try: + async for _ in active_task.subscribe(): + pass + except ValueError: + return + await asyncio.sleep(0.05) + + pytest.fail('Producer failure was not raised') + + @pytest.mark.asyncio + async def test_active_task_push_notification( + self, + active_task: ActiveTask, + agent_executor: Mock, + request_context: Mock, + push_sender: Mock, + task_manager: Mock, + ) -> None: + """Test push notification sending.""" + task_obj = Task( + id='test-task-id', + status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), + ) + + async def execute_mock(req, q): + await q.enqueue_event(task_obj) + + agent_executor.execute = AsyncMock(side_effect=execute_mock) + task_manager.get_task.side_effect = [ + Task( + id='test-task-id', + status=TaskStatus(state=TaskState.TASK_STATE_WORKING), + ) + ] + [task_obj] * 10 + + await active_task.start( + call_context=ServerCallContext(), create_task_if_missing=True + ) + + async for _ in active_task.subscribe(request=request_context): + pass + + push_sender.send_notification.assert_called() + + @pytest.mark.asyncio + async def test_active_task_cleanup( + self, + agent_executor: Mock, + task_manager: Mock, + request_context: Mock, + ) -> None: + """Test that the cleanup callback is called.""" + on_cleanup = Mock() + active_task = ActiveTask( + agent_executor=agent_executor, + task_id='test-task-id', + task_manager=task_manager, + on_cleanup=on_cleanup, + ) + + async def execute_mock(req, q): + await q.enqueue_event(Message(message_id='m1')) + await q.enqueue_event( + Task( + id='test-task-id', + status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), + ) + ) + + agent_executor.execute = AsyncMock(side_effect=execute_mock) + task_manager.get_task.side_effect = [ + Task( + id='test-task-id', + status=TaskStatus(state=TaskState.TASK_STATE_WORKING), + ) + ] + [ + Task( + id='test-task-id', + status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), + ) + ] * 10 + + await active_task.start( + call_context=ServerCallContext(), create_task_if_missing=True + ) + + async for _ in active_task.subscribe(request=request_context): + pass + + # Wait for consumer thread to finish and call cleanup + for _ in range(20): + if on_cleanup.called: + break + await asyncio.sleep(0.05) + + on_cleanup.assert_called_once_with(active_task) + + @pytest.mark.asyncio + async def test_active_task_consumer_failure( + self, + active_task: ActiveTask, + agent_executor: Mock, + request_context: Mock, + ) -> None: + """Test behavior when the consumer task fails.""" + # Mock dequeue_event to raise exception + active_task._event_queue_agent.dequeue_event = AsyncMock( + side_effect=RuntimeError('Consumer crash') + ) + + await active_task.enqueue_request(request_context) + await active_task.start( + call_context=ServerCallContext(), create_task_if_missing=True + ) + + # We need to wait for the consumer to fail + for _ in range(10): + try: + async for _ in active_task.subscribe(): + pass + except RuntimeError as e: + if str(e) == 'Consumer crash': + return + await asyncio.sleep(0.05) + + pytest.fail('Consumer failure was not raised') + + @pytest.mark.asyncio + async def test_active_task_subscribe_exception_handling( + self, + active_task: ActiveTask, + agent_executor: Mock, + request_context: Mock, + ) -> None: + """Test exception handling in subscribe.""" + agent_executor.execute = AsyncMock( + side_effect=ValueError('Producer failure') + ) + + await active_task.enqueue_request(request_context) + await active_task.start( + call_context=ServerCallContext(), create_task_if_missing=True + ) + + # Give it a moment to fail + for _ in range(10): + if active_task._exception: + break + await asyncio.sleep(0.05) + + with pytest.raises(ValueError, match='Producer failure'): + async for _ in active_task.subscribe(): + pass + + @pytest.mark.asyncio + async def test_active_task_cancel_not_started( + self, active_task: ActiveTask, request_context: Mock + ) -> None: + """Test canceling a task that was never started.""" + # TODO: Implement this test + + @pytest.mark.asyncio + async def test_active_task_cancel_already_finished( + self, + active_task: ActiveTask, + agent_executor: Mock, + request_context: Mock, + task_manager: Mock, + ) -> None: + """Test canceling a task that is already finished.""" + task_obj = Task( + id='test-task-id', + status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), + ) + + async def execute_mock(req, q): + active_task._request_queue.shutdown(immediate=True) + + agent_executor.execute = AsyncMock(side_effect=execute_mock) + task_manager.get_task.side_effect = [ + Task( + id='test-task-id', + status=TaskStatus(state=TaskState.TASK_STATE_WORKING), + ) + ] + [task_obj] * 10 + + await active_task.start( + call_context=ServerCallContext(), create_task_if_missing=True + ) + + async for _ in active_task.subscribe(request=request_context): + pass + + await active_task._is_finished.wait() + + # Now it is finished + await active_task.cancel(request_context) + + # agent_executor.cancel should NOT be called + agent_executor.cancel.assert_not_called() + + @pytest.mark.asyncio + async def test_active_task_subscribe_cancelled_during_wait( + self, + active_task: ActiveTask, + agent_executor: Mock, + request_context: Mock, + ) -> None: + """Test subscribe when it is cancelled while waiting for events.""" + + async def slow_execute(req, q): + await asyncio.sleep(10) + + agent_executor.execute = AsyncMock(side_effect=slow_execute) + + await active_task.enqueue_request(request_context) + await active_task.start( + call_context=ServerCallContext(), create_task_if_missing=True + ) + + it = active_task.subscribe() + it_obj = it.__aiter__() + + # This task will be waiting inside the loop in subscribe() + task = asyncio.create_task(it_obj.__anext__()) + await asyncio.sleep(0.2) + + task.cancel() + + # In python 3.10+ cancelling an async generator next() might raise StopAsyncIteration + # if the generator handles the cancellation by closing. + with pytest.raises((asyncio.CancelledError, StopAsyncIteration)): + await task + + await it.aclose() + + @pytest.mark.asyncio + async def test_active_task_subscribe_queue_shutdown( + self, + active_task: ActiveTask, + agent_executor: Mock, + request_context: Mock, + ) -> None: + """Test subscribe when the queue is shut down.""" + + async def long_execute(*args, **kwargs): + await asyncio.sleep(10) + + agent_executor.execute = AsyncMock(side_effect=long_execute) + await active_task.enqueue_request(request_context) + await active_task.start( + call_context=ServerCallContext(), create_task_if_missing=True + ) + + tapped = await active_task._event_queue_subscribers.tap() + + with patch.object( + active_task._event_queue_subscribers, 'tap', return_value=tapped + ): + # Close the queue while subscribe is waiting + async def close_later(): + await asyncio.sleep(0.2) + await tapped.close() + + _ = asyncio.create_task(close_later()) + + async for _ in active_task.subscribe(): + pass + + # Should finish normally after QueueShutDown + + @pytest.mark.asyncio + async def test_active_task_subscribe_yield_then_shutdown( + self, + active_task: ActiveTask, + agent_executor: Mock, + request_context: Mock, + ) -> None: + """Test subscribe when an event is yielded and then the queue is shut down.""" + msg = Message(message_id='m1') + + async def execute_mock(req, q): + await q.enqueue_event(msg) + await asyncio.sleep(0.5) + # Finish producer + active_task._request_queue.shutdown(immediate=True) + + agent_executor.execute = AsyncMock(side_effect=execute_mock) + await active_task.enqueue_request(request_context) + await active_task.start( + call_context=ServerCallContext(), create_task_if_missing=True + ) + + events = [event async for event in active_task.subscribe()] + assert len(events) == 1 + assert events[0] == msg + + @pytest.mark.asyncio + async def test_active_task_task_sets_result_first( + self, + active_task: ActiveTask, + agent_executor: Mock, + request_context: Mock, + task_manager: Mock, + ) -> None: + """Test that enqueuing a Task sets result_available when no result yet.""" + task_obj = Task( + id='test-task-id', + status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), + ) + + async def execute_mock(req, q): + # No result available yet + await q.enqueue_event(task_obj) + + agent_executor.execute = AsyncMock(side_effect=execute_mock) + task_manager.get_task.side_effect = [ + Task( + id='test-task-id', + status=TaskStatus(state=TaskState.TASK_STATE_WORKING), + ) + ] + [task_obj] * 10 + + await active_task.start( + call_context=ServerCallContext(), create_task_if_missing=True + ) + + events = [ + e async for e in active_task.subscribe(request=request_context) + ] + + result = events[-1] if events else None + assert result == task_obj + + @pytest.mark.asyncio + async def test_active_task_subscribe_cancelled_during_yield( + self, + active_task: ActiveTask, + agent_executor: Mock, + request_context: Mock, + ) -> None: + """Test subscribe cancellation while yielding (GeneratorExit).""" + msg = Message(message_id='m1') + + async def execute_mock(req, q): + await q.enqueue_event(msg) + await asyncio.sleep(10) + + agent_executor.execute = AsyncMock(side_effect=execute_mock) + await active_task.enqueue_request(request_context) + await active_task.start( + call_context=ServerCallContext(), create_task_if_missing=True + ) + + it = active_task.subscribe() + async for event in it: + assert event == msg + # Cancel while we have the event (inside the loop) + await it.aclose() + break + + @pytest.mark.asyncio + async def test_active_task_cancel_when_already_closed( + self, + active_task: ActiveTask, + agent_executor: Mock, + request_context: Mock, + task_manager: Mock, + ) -> None: + """Test cancel when the event queue is already closed.""" + + async def execute_mock(req, q): + active_task._request_queue.shutdown(immediate=True) + + agent_executor.execute = AsyncMock(side_effect=execute_mock) + task_manager.get_task.return_value = Task(id='test') + await active_task.enqueue_request(request_context) + await active_task.start( + call_context=ServerCallContext(), create_task_if_missing=True + ) + + # Forced queue close. + await active_task._event_queue_agent.close() + await active_task._event_queue_subscribers.close() + + # Now cancel the task itself. + await active_task.cancel(request_context) + # wait() was removed, no need to wait here. + + # Cancel again should not do anything. + await active_task.cancel(request_context) + # wait() was removed, no need to wait here. + + @pytest.mark.asyncio + async def test_active_task_subscribe_dequeue_failure( + self, + active_task: ActiveTask, + agent_executor: Mock, + request_context: Mock, + ) -> None: + """Test subscribe when dequeue_event fails on the tapped queue.""" + + async def slow_execute(req, q): + await asyncio.sleep(10) + + agent_executor.execute = AsyncMock(side_effect=slow_execute) + await active_task.enqueue_request(request_context) + await active_task.start( + call_context=ServerCallContext(), create_task_if_missing=True + ) + + mock_tapped_queue = Mock(spec=EventQueue) + mock_tapped_queue.dequeue_event = AsyncMock( + side_effect=RuntimeError('Tapped queue crash') + ) + mock_tapped_queue.close = AsyncMock() + + with ( + patch.object( + active_task._event_queue_subscribers, + 'tap', + return_value=mock_tapped_queue, + ), + pytest.raises(RuntimeError, match='Tapped queue crash'), + ): + async for _ in active_task.subscribe(): + pass + + mock_tapped_queue.close.assert_called_once() + + @pytest.mark.asyncio + async def test_active_task_consumer_interrupted_multiple_times( + self, + active_task: ActiveTask, + agent_executor: Mock, + request_context: Mock, + task_manager: Mock, + ) -> None: + """Test consumer receiving multiple interrupting events.""" + task_obj = Task( + id='test-task-id', + status=TaskStatus(state=TaskState.TASK_STATE_AUTH_REQUIRED), + ) + + async def execute_mock(req, q): + await q.enqueue_event( + TaskStatusUpdateEvent( + task_id='test-task-id', + status=TaskStatus(state=TaskState.TASK_STATE_AUTH_REQUIRED), + ) + ) + await q.enqueue_event( + TaskStatusUpdateEvent( + task_id='test-task-id', + status=TaskStatus( + state=TaskState.TASK_STATE_INPUT_REQUIRED + ), + ) + ) + + agent_executor.execute = AsyncMock(side_effect=execute_mock) + task_manager.get_task.side_effect = [ + Task( + id='test-task-id', + status=TaskStatus(state=TaskState.TASK_STATE_WORKING), + ) + ] + [task_obj] * 10 + + await active_task.start( + call_context=ServerCallContext(), create_task_if_missing=True + ) + + events = [ + e async for e in active_task.subscribe(request=request_context) + ] + + result = events[0] if events else None + assert result.status.state == TaskState.TASK_STATE_AUTH_REQUIRED + + @pytest.mark.asyncio + async def test_active_task_subscribe_immediate_finish( + self, + active_task: ActiveTask, + agent_executor: Mock, + request_context: Mock, + ) -> None: + """Test subscribe when the task finishes immediately.""" + + async def execute_mock(req, q): + active_task._request_queue.shutdown(immediate=True) + + agent_executor.execute = AsyncMock(side_effect=execute_mock) + + await active_task.enqueue_request(request_context) + await active_task.start( + call_context=ServerCallContext(), create_task_if_missing=True + ) + + # Wait for it to finish + await active_task._is_finished.wait() + + with pytest.raises( + InvalidParamsError, match=r'Task .* is already completed' + ): + async for _ in active_task.subscribe(): + pass + + @pytest.mark.asyncio + async def test_active_task_start_producer_immediate_error( + self, + active_task: ActiveTask, + agent_executor: Mock, + request_context: Mock, + ) -> None: + """Test start when producer fails immediately.""" + agent_executor.execute = AsyncMock( + side_effect=ValueError('Quick failure') + ) + + await active_task.enqueue_request(request_context) + await active_task.start( + call_context=ServerCallContext(), create_task_if_missing=True + ) + + # Consumer should also finish + with pytest.raises(ValueError, match='Quick failure'): + async for _ in active_task.subscribe(): + pass + + @pytest.mark.asyncio + async def test_active_task_subscribe_finished_during_wait( + self, + active_task: ActiveTask, + agent_executor: Mock, + request_context: Mock, + ) -> None: + """Test subscribe when the task finishes while waiting for an event.""" + + async def slow_execute(req, q): + # Do nothing and just finish + await asyncio.sleep(0.5) + active_task._request_queue.shutdown(immediate=True) + + agent_executor.execute = AsyncMock(side_effect=slow_execute) + + await active_task.enqueue_request(request_context) + await active_task.start( + call_context=ServerCallContext(), create_task_if_missing=True + ) + + async def consume(): + async for _ in active_task.subscribe(): + pass + + task = asyncio.create_task(consume()) + await asyncio.sleep(0.2) + + # Task is still running, subscribe is waiting. + # Now it finishes. + await asyncio.sleep(0.5) + await task # Should finish normally + + @pytest.mark.asyncio + async def test_active_task_maybe_cleanup_not_finished( + self, + agent_executor: Mock, + task_manager: Mock, + push_sender: Mock, + ) -> None: + """Test that cleanup is not called if task is not finished.""" + on_cleanup = Mock() + active_task = ActiveTask( + agent_executor=agent_executor, + task_id='test-task-id', + task_manager=task_manager, + push_sender=push_sender, + on_cleanup=on_cleanup, + ) + + # Explicitly call private _maybe_cleanup to verify it respects finished state + await active_task._maybe_cleanup() + on_cleanup.assert_not_called() + + @pytest.mark.asyncio + async def test_active_task_maybe_cleanup_with_subscribers( + self, + agent_executor: Mock, + task_manager: Mock, + push_sender: Mock, + request_context: Mock, + ) -> None: + """Test that cleanup is not called if there are subscribers.""" + on_cleanup = Mock() + active_task = ActiveTask( + agent_executor=agent_executor, + task_id='test-task-id', + task_manager=task_manager, + push_sender=push_sender, + on_cleanup=on_cleanup, + ) + + # Mock execute to finish immediately + async def execute_mock(req, q): + await q.enqueue_event(Message(message_id='m1')) + await q.enqueue_event( + Task( + id='test-task-id', + status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), + ) + ) + + agent_executor.execute = AsyncMock(side_effect=execute_mock) + task_manager.get_task.side_effect = [ + Task( + id='test-task-id', + status=TaskStatus(state=TaskState.TASK_STATE_WORKING), + ) + ] + [ + Task( + id='test-task-id', + status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), + ) + ] * 10 + + # 1. Start a subscriber before task finishes + gen = active_task.subscribe() + # Start the generator to increment reference count + msg_task = asyncio.create_task(gen.__anext__()) + + # 2. Start the task and wait for it to finish + await active_task.start( + call_context=ServerCallContext(), create_task_if_missing=True + ) + + async for _ in active_task.subscribe(request=request_context): + pass + + # Give the consumer loop a moment to set _is_finished + await asyncio.sleep(0.1) + + # Ensure we got the message + assert (await msg_task).message_id == 'm1' + + # At this point, task is finished, but we still have a subscriber (gen). + # _maybe_cleanup was called by consumer loop, but should have done nothing. + on_cleanup.assert_not_called() + + # 3. Close the subscriber + await gen.aclose() + + # Now cleanup should be triggered + on_cleanup.assert_called_once_with(active_task) + + @pytest.mark.asyncio + async def test_active_task_subscribe_exception_already_set( + self, active_task: ActiveTask + ) -> None: + """Test subscribe when exception is already set.""" + active_task._exception = ValueError('Pre-existing error') + with pytest.raises(ValueError, match='Pre-existing error'): + async for _ in active_task.subscribe(): + pass + + @pytest.mark.asyncio + async def test_active_task_subscribe_inner_exception( + self, + active_task: ActiveTask, + agent_executor: Mock, + request_context: Mock, + ) -> None: + """Test the generic exception block in subscribe.""" + + async def slow_execute(req, q): + await asyncio.sleep(10) + + agent_executor.execute = AsyncMock(side_effect=slow_execute) + await active_task.enqueue_request(request_context) + await active_task.start( + call_context=ServerCallContext(), create_task_if_missing=True + ) + + mock_tapped_queue = Mock(spec=EventQueue) + # dequeue_event returns a task that fails + mock_tapped_queue.dequeue_event = AsyncMock( + side_effect=Exception('Inner error') + ) + mock_tapped_queue.close = AsyncMock() + + with ( + patch.object( + active_task._event_queue_subscribers, + 'tap', + return_value=mock_tapped_queue, + ), + pytest.raises(Exception, match='Inner error'), + ): + async for _ in active_task.subscribe(): + pass + + +@pytest.mark.asyncio +async def test_active_task_subscribe_include_initial_task(): + agent_executor = Mock() + task_manager = Mock() + request_context = Mock(spec=RequestContext) + + active_task = ActiveTask( + agent_executor=agent_executor, + task_id='test-task-id', + task_manager=task_manager, + push_sender=Mock(), + ) + + initial_task = Task( + id='test-task-id', status=TaskStatus(state=TaskState.TASK_STATE_WORKING) + ) + + async def execute_mock(req, q): + active_task._request_queue.shutdown(immediate=True) + + agent_executor.execute = AsyncMock(side_effect=execute_mock) + task_manager.get_task = AsyncMock(return_value=initial_task) + task_manager.save_task_event = AsyncMock() + + await active_task.enqueue_request(request_context) + await active_task.start( + call_context=ServerCallContext(), create_task_if_missing=True + ) + + events = [e async for e in active_task.subscribe(include_initial_task=True)] + + # Verify that the first yielded event is the initial task + assert len(events) >= 1 + assert events[0] == initial_task + + +@pytest.mark.asyncio +async def test_active_task_subscribe_request_parameter(): + agent_executor = Mock() + task_manager = Mock() + request_context = Mock(spec=RequestContext) + + active_task = ActiveTask( + agent_executor=agent_executor, + task_id='test-task-id', + task_manager=task_manager, + push_sender=Mock(), + ) + + async def execute_mock(req, q): + # We simulate the task finishing successfully, so it will emit _RequestCompleted + pass + + agent_executor.execute = AsyncMock(side_effect=execute_mock) + agent_executor.cancel = AsyncMock() + task_manager.get_task = AsyncMock( + return_value=Task( + id='test-task-id', + status=TaskStatus(state=TaskState.TASK_STATE_WORKING), + ) + ) + task_manager.save_task_event = AsyncMock() + task_manager.process = AsyncMock(side_effect=lambda x: x) + + await active_task.start( + call_context=ServerCallContext(), create_task_if_missing=True + ) + + # Pass request_context directly to subscribe without enqueuing manually + events = [e async for e in active_task.subscribe(request=request_context)] + + # Should complete without error, and yield no events (just _RequestCompleted which is hidden) + assert len(events) == 0 + + await active_task.cancel(request_context) diff --git a/tests/server/request_handlers/test_default_request_handler.py b/tests/server/request_handlers/test_default_request_handler.py index f4ba04996..68945d06d 100644 --- a/tests/server/request_handlers/test_default_request_handler.py +++ b/tests/server/request_handlers/test_default_request_handler.py @@ -23,7 +23,9 @@ ) from a2a.server.context import ServerCallContext from a2a.server.events import EventQueue, InMemoryQueueManager, QueueManager -from a2a.server.request_handlers import DefaultRequestHandler +from a2a.server.request_handlers import ( + LegacyRequestHandler as DefaultRequestHandler, +) from a2a.server.tasks import ( InMemoryPushNotificationConfigStore, InMemoryTaskStore, diff --git a/tests/server/request_handlers/test_default_request_handler_v2.py b/tests/server/request_handlers/test_default_request_handler_v2.py new file mode 100644 index 000000000..abe35bf64 --- /dev/null +++ b/tests/server/request_handlers/test_default_request_handler_v2.py @@ -0,0 +1,1208 @@ +import asyncio +import logging +import time +import uuid + +from unittest.mock import AsyncMock, patch, MagicMock + +import pytest + +from a2a.auth.user import UnauthenticatedUser +from a2a.server.agent_execution import ( + RequestContextBuilder, + AgentExecutor, + RequestContext, + SimpleRequestContextBuilder, +) +from a2a.server.agent_execution.active_task_registry import ActiveTaskRegistry +from a2a.server.context import ServerCallContext +from a2a.server.events import EventQueue, InMemoryQueueManager, QueueManager +from a2a.server.request_handlers import DefaultRequestHandlerV2 +from a2a.server.tasks import ( + InMemoryPushNotificationConfigStore, + InMemoryTaskStore, + PushNotificationConfigStore, + PushNotificationSender, + TaskStore, + TaskUpdater, +) +from a2a.types import ( + InternalError, + InvalidParamsError, + TaskNotFoundError, + UnsupportedOperationError, +) +from a2a.types.a2a_pb2 import ( + Artifact, + CancelTaskRequest, + DeleteTaskPushNotificationConfigRequest, + GetTaskPushNotificationConfigRequest, + GetTaskRequest, + ListTaskPushNotificationConfigsRequest, + ListTasksRequest, + ListTasksResponse, + Message, + Part, + Role, + SendMessageConfiguration, + SendMessageRequest, + SubscribeToTaskRequest, + Task, + TaskPushNotificationConfig, + TaskState, + TaskStatus, +) +from a2a.utils import new_agent_text_message, new_task + + +class MockAgentExecutor(AgentExecutor): + async def execute(self, context: RequestContext, event_queue: EventQueue): + task_updater = TaskUpdater( + event_queue, + str(context.task_id or ''), + str(context.context_id or ''), + ) + async for i in self._run(): + parts = [Part(text=f'Event {i}')] + try: + await task_updater.update_status( + TaskState.TASK_STATE_WORKING, + message=task_updater.new_agent_message(parts), + ) + except RuntimeError: + break + + async def _run(self): + for i in range(1000000): + yield i + + async def cancel(self, context: RequestContext, event_queue: EventQueue): + pass + + +def create_sample_task( + task_id='task1', + status_state=TaskState.TASK_STATE_SUBMITTED, + context_id='ctx1', +) -> Task: + return Task( + id=task_id, context_id=context_id, status=TaskStatus(state=status_state) + ) + + +def create_server_call_context() -> ServerCallContext: + return ServerCallContext(user=UnauthenticatedUser()) + + +def test_init_default_dependencies(): + """Test that default dependencies are created if not provided.""" + agent_executor = MockAgentExecutor() + task_store = InMemoryTaskStore() + handler = DefaultRequestHandlerV2( + agent_executor=agent_executor, task_store=task_store + ) + assert isinstance(handler._active_task_registry, ActiveTaskRegistry) + assert isinstance( + handler._request_context_builder, SimpleRequestContextBuilder + ) + assert handler._push_config_store is None + assert handler._push_sender is None + assert ( + handler._request_context_builder._should_populate_referred_tasks + is False + ) + assert handler._request_context_builder._task_store == task_store + + +@pytest.mark.asyncio +async def test_on_get_task_not_found(): + """Test on_get_task when task_store.get returns None.""" + mock_task_store = AsyncMock(spec=TaskStore) + mock_task_store.get.return_value = None + request_handler = DefaultRequestHandlerV2( + agent_executor=MockAgentExecutor(), task_store=mock_task_store + ) + params = GetTaskRequest(id='non_existent_task') + context = create_server_call_context() + with pytest.raises(TaskNotFoundError): + await request_handler.on_get_task(params, context) + mock_task_store.get.assert_awaited_once_with('non_existent_task', context) + + +@pytest.mark.asyncio +async def test_on_list_tasks_success(): + """Test on_list_tasks successfully returns a page of tasks .""" + mock_task_store = AsyncMock(spec=TaskStore) + task2 = create_sample_task(task_id='task2') + task2.artifacts.extend( + [ + Artifact( + artifact_id='artifact1', + parts=[Part(text='Hello world!')], + name='conversion_result', + ) + ] + ) + mock_page = ListTasksResponse( + tasks=[create_sample_task(task_id='task1'), task2], + next_page_token='123', # noqa: S106 + ) + mock_task_store.list.return_value = mock_page + request_handler = DefaultRequestHandlerV2( + agent_executor=AsyncMock(spec=AgentExecutor), task_store=mock_task_store + ) + params = ListTasksRequest(include_artifacts=True, page_size=10) + context = create_server_call_context() + result = await request_handler.on_list_tasks(params, context) + mock_task_store.list.assert_awaited_once_with(params, context) + assert result.tasks == mock_page.tasks + assert result.next_page_token == mock_page.next_page_token + + +@pytest.mark.asyncio +async def test_on_list_tasks_excludes_artifacts(): + """Test on_list_tasks excludes artifacts from returned tasks.""" + mock_task_store = AsyncMock(spec=TaskStore) + task2 = create_sample_task(task_id='task2') + task2.artifacts.extend( + [ + Artifact( + artifact_id='artifact1', + parts=[Part(text='Hello world!')], + name='conversion_result', + ) + ] + ) + mock_page = ListTasksResponse( + tasks=[create_sample_task(task_id='task1'), task2], + next_page_token='123', # noqa: S106 + ) + mock_task_store.list.return_value = mock_page + request_handler = DefaultRequestHandlerV2( + agent_executor=AsyncMock(spec=AgentExecutor), task_store=mock_task_store + ) + params = ListTasksRequest(include_artifacts=False, page_size=10) + context = create_server_call_context() + result = await request_handler.on_list_tasks(params, context) + assert not result.tasks[1].artifacts + + +@pytest.mark.asyncio +async def test_on_list_tasks_applies_history_length(): + """Test on_list_tasks applies history length filter.""" + mock_task_store = AsyncMock(spec=TaskStore) + history = [ + new_agent_text_message('Hello 1!'), + new_agent_text_message('Hello 2!'), + ] + task2 = create_sample_task(task_id='task2') + task2.history.extend(history) + mock_page = ListTasksResponse( + tasks=[create_sample_task(task_id='task1'), task2], + next_page_token='123', # noqa: S106 + ) + mock_task_store.list.return_value = mock_page + request_handler = DefaultRequestHandlerV2( + agent_executor=AsyncMock(spec=AgentExecutor), task_store=mock_task_store + ) + params = ListTasksRequest(history_length=1, page_size=10) + context = create_server_call_context() + result = await request_handler.on_list_tasks(params, context) + assert result.tasks[1].history == [history[1]] + + +@pytest.mark.asyncio +async def test_on_list_tasks_negative_history_length_error(): + """Test on_list_tasks raises error for negative history length.""" + mock_task_store = AsyncMock(spec=TaskStore) + request_handler = DefaultRequestHandlerV2( + agent_executor=AsyncMock(spec=AgentExecutor), task_store=mock_task_store + ) + params = ListTasksRequest(history_length=-1, page_size=10) + context = create_server_call_context() + with pytest.raises(InvalidParamsError) as exc_info: + await request_handler.on_list_tasks(params, context) + assert 'history length must be non-negative' in exc_info.value.message + + +@pytest.mark.asyncio +async def test_on_cancel_task_task_not_found(): + """Test on_cancel_task when the task is not found.""" + mock_task_store = AsyncMock(spec=TaskStore) + mock_task_store.get.return_value = None + request_handler = DefaultRequestHandlerV2( + agent_executor=MockAgentExecutor(), task_store=mock_task_store + ) + params = CancelTaskRequest(id='task_not_found_for_cancel') + context = create_server_call_context() + with pytest.raises(TaskNotFoundError): + await request_handler.on_cancel_task(params, context) + mock_task_store.get.assert_awaited_once_with( + 'task_not_found_for_cancel', context + ) + + +class HelloAgentExecutor(AgentExecutor): + async def execute(self, context: RequestContext, event_queue: EventQueue): + task = context.current_task + if not task: + assert context.message is not None, ( + 'A message is required to create a new task' + ) + task = new_task(context.message) + await event_queue.enqueue_event(task) + updater = TaskUpdater(event_queue, task.id, task.context_id) + try: + parts = [Part(text='I am working')] + await updater.update_status( + TaskState.TASK_STATE_WORKING, + message=updater.new_agent_message(parts), + ) + except Exception as e: # noqa: BLE001 + logging.warning('Error: %s', e) + return + await updater.add_artifact( + [Part(text='Hello world!')], name='conversion_result' + ) + await updater.complete() + + async def cancel(self, context: RequestContext, event_queue: EventQueue): + pass + + +@pytest.mark.asyncio +async def test_on_get_task_limit_history(): + task_store = InMemoryTaskStore() + push_store = InMemoryPushNotificationConfigStore() + request_handler = DefaultRequestHandlerV2( + agent_executor=HelloAgentExecutor(), + task_store=task_store, + push_config_store=push_store, + ) + params = SendMessageRequest( + message=Message( + role=Role.ROLE_USER, message_id='msg_push', parts=[Part(text='Hi')] + ), + configuration=SendMessageConfiguration( + accepted_output_modes=['text/plain'] + ), + ) + result = await request_handler.on_message_send( + params, create_server_call_context() + ) + assert result is not None + assert isinstance(result, Task) + get_task_result = await request_handler.on_get_task( + GetTaskRequest(id=result.id, history_length=1), + create_server_call_context(), + ) + assert get_task_result is not None + assert isinstance(get_task_result, Task) + assert ( + get_task_result.history is not None + and len(get_task_result.history) == 1 + ) + + +async def wait_until(predicate, timeout: float = 0.2, interval: float = 0.0): + """Await until predicate() is True or timeout elapses.""" + loop = asyncio.get_running_loop() + end = loop.time() + timeout + while True: + if predicate(): + return + if loop.time() >= end: + raise AssertionError('condition not met within timeout') + await asyncio.sleep(interval) + + +@pytest.mark.asyncio +async def test_set_task_push_notification_config_no_notifier(): + """Test on_create_task_push_notification_config when _push_config_store is None.""" + request_handler = DefaultRequestHandlerV2( + agent_executor=MockAgentExecutor(), + task_store=AsyncMock(spec=TaskStore), + push_config_store=None, + ) + params = TaskPushNotificationConfig( + task_id='task1', url='http://example.com' + ) + with pytest.raises(UnsupportedOperationError): + await request_handler.on_create_task_push_notification_config( + params, create_server_call_context() + ) + + +@pytest.mark.asyncio +async def test_set_task_push_notification_config_task_not_found(): + """Test on_create_task_push_notification_config when task is not found.""" + mock_task_store = AsyncMock(spec=TaskStore) + mock_task_store.get.return_value = None + mock_push_store = AsyncMock(spec=PushNotificationConfigStore) + mock_push_sender = AsyncMock(spec=PushNotificationSender) + request_handler = DefaultRequestHandlerV2( + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + push_config_store=mock_push_store, + push_sender=mock_push_sender, + ) + params = TaskPushNotificationConfig( + task_id='non_existent_task', url='http://example.com' + ) + context = create_server_call_context() + with pytest.raises(TaskNotFoundError): + await request_handler.on_create_task_push_notification_config( + params, context + ) + mock_task_store.get.assert_awaited_once_with('non_existent_task', context) + mock_push_store.set_info.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_get_task_push_notification_config_no_store(): + """Test on_get_task_push_notification_config when _push_config_store is None.""" + request_handler = DefaultRequestHandlerV2( + agent_executor=MockAgentExecutor(), + task_store=AsyncMock(spec=TaskStore), + push_config_store=None, + ) + params = GetTaskPushNotificationConfigRequest( + task_id='task1', id='task_push_notification_config' + ) + with pytest.raises(UnsupportedOperationError): + await request_handler.on_get_task_push_notification_config( + params, create_server_call_context() + ) + + +@pytest.mark.asyncio +async def test_get_task_push_notification_config_task_not_found(): + """Test on_get_task_push_notification_config when task is not found.""" + mock_task_store = AsyncMock(spec=TaskStore) + mock_task_store.get.return_value = None + mock_push_store = AsyncMock(spec=PushNotificationConfigStore) + request_handler = DefaultRequestHandlerV2( + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + push_config_store=mock_push_store, + ) + params = GetTaskPushNotificationConfigRequest( + task_id='non_existent_task', id='task_push_notification_config' + ) + context = create_server_call_context() + with pytest.raises(TaskNotFoundError): + await request_handler.on_get_task_push_notification_config( + params, context + ) + mock_task_store.get.assert_awaited_once_with('non_existent_task', context) + mock_push_store.get_info.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_get_task_push_notification_config_info_not_found(): + """Test on_get_task_push_notification_config when push_config_store.get_info returns None.""" + mock_task_store = AsyncMock(spec=TaskStore) + sample_task = create_sample_task(task_id='non_existent_task') + mock_task_store.get.return_value = sample_task + mock_push_store = AsyncMock(spec=PushNotificationConfigStore) + mock_push_store.get_info.return_value = None + request_handler = DefaultRequestHandlerV2( + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + push_config_store=mock_push_store, + ) + params = GetTaskPushNotificationConfigRequest( + task_id='non_existent_task', id='task_push_notification_config' + ) + context = create_server_call_context() + with pytest.raises(InternalError): + await request_handler.on_get_task_push_notification_config( + params, context + ) + mock_task_store.get.assert_awaited_once_with('non_existent_task', context) + mock_push_store.get_info.assert_awaited_once_with( + 'non_existent_task', context + ) + + +@pytest.mark.asyncio +async def test_get_task_push_notification_config_info_with_config(): + """Test on_get_task_push_notification_config with valid push config id""" + mock_task_store = AsyncMock(spec=TaskStore) + mock_task_store.get.return_value = Task(id='task_1', context_id='ctx_1') + push_store = InMemoryPushNotificationConfigStore() + request_handler = DefaultRequestHandlerV2( + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + push_config_store=push_store, + ) + set_config_params = TaskPushNotificationConfig( + task_id='task_1', id='config_id', url='http://1.example.com' + ) + context = create_server_call_context() + await request_handler.on_create_task_push_notification_config( + set_config_params, context + ) + params = GetTaskPushNotificationConfigRequest( + task_id='task_1', id='config_id' + ) + result: TaskPushNotificationConfig = ( + await request_handler.on_get_task_push_notification_config( + params, context + ) + ) + assert result is not None + assert result.task_id == 'task_1' + assert result.url == set_config_params.url + assert result.id == 'config_id' + + +@pytest.mark.asyncio +async def test_get_task_push_notification_config_info_with_config_no_id(): + """Test on_get_task_push_notification_config with no push config id""" + mock_task_store = AsyncMock(spec=TaskStore) + mock_task_store.get.return_value = Task(id='task_1', context_id='ctx_1') + push_store = InMemoryPushNotificationConfigStore() + request_handler = DefaultRequestHandlerV2( + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + push_config_store=push_store, + ) + set_config_params = TaskPushNotificationConfig( + task_id='task_1', url='http://1.example.com' + ) + await request_handler.on_create_task_push_notification_config( + set_config_params, create_server_call_context() + ) + params = GetTaskPushNotificationConfigRequest(task_id='task_1', id='task_1') + result: TaskPushNotificationConfig = ( + await request_handler.on_get_task_push_notification_config( + params, create_server_call_context() + ) + ) + assert result is not None + assert result.task_id == 'task_1' + assert result.url == set_config_params.url + assert result.id == 'task_1' + + +@pytest.mark.asyncio +async def test_on_subscribe_to_task_task_not_found(): + """Test on_subscribe_to_task when the task is not found.""" + mock_task_store = AsyncMock(spec=TaskStore) + mock_task_store.get.return_value = None + request_handler = DefaultRequestHandlerV2( + agent_executor=MockAgentExecutor(), task_store=mock_task_store + ) + params = SubscribeToTaskRequest(id='resub_task_not_found') + context = create_server_call_context() + with pytest.raises(TaskNotFoundError): + async for _ in request_handler.on_subscribe_to_task(params, context): + pass + mock_task_store.get.assert_awaited_once_with( + 'resub_task_not_found', context + ) + + +@pytest.mark.asyncio +async def test_on_message_send_stream(): + request_handler = DefaultRequestHandlerV2( + MockAgentExecutor(), InMemoryTaskStore() + ) + message_params = SendMessageRequest( + message=Message( + role=Role.ROLE_USER, + message_id='msg-123', + parts=[Part(text='How are you?')], + ) + ) + + async def consume_stream(): + events = [] + async for event in request_handler.on_message_send_stream( + message_params, create_server_call_context() + ): + events.append(event) + if len(events) >= 3: + break + return events + + start = time.perf_counter() + events = await consume_stream() + elapsed = time.perf_counter() - start + assert len(events) == 3 + assert elapsed < 0.5 + texts = [p.text for e in events for p in e.status.message.parts] + assert texts == ['Event 0', 'Event 1', 'Event 2'] + + +@pytest.mark.asyncio +async def test_list_task_push_notification_config_no_store(): + """Test on_list_task_push_notification_configs when _push_config_store is None.""" + request_handler = DefaultRequestHandlerV2( + agent_executor=MockAgentExecutor(), + task_store=AsyncMock(spec=TaskStore), + push_config_store=None, + ) + params = ListTaskPushNotificationConfigsRequest(task_id='task1') + with pytest.raises(UnsupportedOperationError): + await request_handler.on_list_task_push_notification_configs( + params, create_server_call_context() + ) + + +@pytest.mark.asyncio +async def test_list_task_push_notification_config_task_not_found(): + """Test on_list_task_push_notification_configs when task is not found.""" + mock_task_store = AsyncMock(spec=TaskStore) + mock_task_store.get.return_value = None + mock_push_store = AsyncMock(spec=PushNotificationConfigStore) + request_handler = DefaultRequestHandlerV2( + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + push_config_store=mock_push_store, + ) + params = ListTaskPushNotificationConfigsRequest(task_id='non_existent_task') + context = create_server_call_context() + with pytest.raises(TaskNotFoundError): + await request_handler.on_list_task_push_notification_configs( + params, context + ) + mock_task_store.get.assert_awaited_once_with('non_existent_task', context) + mock_push_store.get_info.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_list_no_task_push_notification_config_info(): + """Test on_get_task_push_notification_config when push_config_store.get_info returns []""" + mock_task_store = AsyncMock(spec=TaskStore) + sample_task = create_sample_task(task_id='non_existent_task') + mock_task_store.get.return_value = sample_task + push_store = InMemoryPushNotificationConfigStore() + request_handler = DefaultRequestHandlerV2( + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + push_config_store=push_store, + ) + params = ListTaskPushNotificationConfigsRequest(task_id='non_existent_task') + result = await request_handler.on_list_task_push_notification_configs( + params, create_server_call_context() + ) + assert result.configs == [] + + +@pytest.mark.asyncio +async def test_list_task_push_notification_config_info_with_config(): + """Test on_list_task_push_notification_configs with push config+id""" + mock_task_store = AsyncMock(spec=TaskStore) + sample_task = create_sample_task(task_id='non_existent_task') + mock_task_store.get.return_value = sample_task + push_config1 = TaskPushNotificationConfig( + task_id='task_1', id='config_1', url='http://example.com' + ) + push_config2 = TaskPushNotificationConfig( + task_id='task_1', id='config_2', url='http://example.com' + ) + push_store = InMemoryPushNotificationConfigStore() + context = create_server_call_context() + await push_store.set_info('task_1', push_config1, context) + await push_store.set_info('task_1', push_config2, context) + await push_store.set_info('task_2', push_config1, context) + request_handler = DefaultRequestHandlerV2( + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + push_config_store=push_store, + ) + params = ListTaskPushNotificationConfigsRequest(task_id='task_1') + result = await request_handler.on_list_task_push_notification_configs( + params, create_server_call_context() + ) + assert len(result.configs) == 2 + assert result.configs[0].task_id == 'task_1' + assert result.configs[0] == push_config1 + assert result.configs[1].task_id == 'task_1' + assert result.configs[1] == push_config2 + + +@pytest.mark.asyncio +async def test_list_task_push_notification_config_info_with_config_and_no_id(): + """Test on_list_task_push_notification_configs with no push config id""" + mock_task_store = AsyncMock(spec=TaskStore) + mock_task_store.get.return_value = Task(id='task_1', context_id='ctx_1') + push_store = InMemoryPushNotificationConfigStore() + request_handler = DefaultRequestHandlerV2( + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + push_config_store=push_store, + ) + set_config_params1 = TaskPushNotificationConfig( + task_id='task_1', url='http://1.example.com' + ) + await request_handler.on_create_task_push_notification_config( + set_config_params1, create_server_call_context() + ) + set_config_params2 = TaskPushNotificationConfig( + task_id='task_1', url='http://2.example.com' + ) + await request_handler.on_create_task_push_notification_config( + set_config_params2, create_server_call_context() + ) + params = ListTaskPushNotificationConfigsRequest(task_id='task_1') + result = await request_handler.on_list_task_push_notification_configs( + params, create_server_call_context() + ) + assert len(result.configs) == 1 + assert result.configs[0].task_id == 'task_1' + assert result.configs[0].url == set_config_params2.url + assert result.configs[0].id == 'task_1' + + +@pytest.mark.asyncio +async def test_delete_task_push_notification_config_no_store(): + """Test on_delete_task_push_notification_config when _push_config_store is None.""" + request_handler = DefaultRequestHandlerV2( + agent_executor=MockAgentExecutor(), + task_store=AsyncMock(spec=TaskStore), + push_config_store=None, + ) + params = DeleteTaskPushNotificationConfigRequest( + task_id='task1', id='config1' + ) + with pytest.raises(UnsupportedOperationError) as exc_info: + await request_handler.on_delete_task_push_notification_config( + params, create_server_call_context() + ) + assert isinstance(exc_info.value, UnsupportedOperationError) + + +@pytest.mark.asyncio +async def test_delete_task_push_notification_config_task_not_found(): + """Test on_delete_task_push_notification_config when task is not found.""" + mock_task_store = AsyncMock(spec=TaskStore) + mock_task_store.get.return_value = None + mock_push_store = AsyncMock(spec=PushNotificationConfigStore) + request_handler = DefaultRequestHandlerV2( + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + push_config_store=mock_push_store, + ) + params = DeleteTaskPushNotificationConfigRequest( + task_id='non_existent_task', id='config1' + ) + context = create_server_call_context() + with pytest.raises(TaskNotFoundError): + await request_handler.on_delete_task_push_notification_config( + params, context + ) + mock_task_store.get.assert_awaited_once_with('non_existent_task', context) + mock_push_store.get_info.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_delete_no_task_push_notification_config_info(): + """Test on_delete_task_push_notification_config without config info""" + mock_task_store = AsyncMock(spec=TaskStore) + sample_task = create_sample_task(task_id='task_1') + mock_task_store.get.return_value = sample_task + push_store = InMemoryPushNotificationConfigStore() + await push_store.set_info( + 'task_2', + TaskPushNotificationConfig(id='config_1', url='http://example.com'), + create_server_call_context(), + ) + request_handler = DefaultRequestHandlerV2( + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + push_config_store=push_store, + ) + params = DeleteTaskPushNotificationConfigRequest( + task_id='task1', id='config_non_existant' + ) + result = await request_handler.on_delete_task_push_notification_config( + params, create_server_call_context() + ) + assert result is None + params = DeleteTaskPushNotificationConfigRequest( + task_id='task2', id='config_non_existant' + ) + result = await request_handler.on_delete_task_push_notification_config( + params, create_server_call_context() + ) + assert result is None + + +@pytest.mark.asyncio +async def test_delete_task_push_notification_config_info_with_config(): + """Test on_list_task_push_notification_configs with push config+id""" + mock_task_store = AsyncMock(spec=TaskStore) + sample_task = create_sample_task(task_id='non_existent_task') + mock_task_store.get.return_value = sample_task + push_config1 = TaskPushNotificationConfig( + task_id='task_1', id='config_1', url='http://example.com' + ) + push_config2 = TaskPushNotificationConfig( + task_id='task_1', id='config_2', url='http://example.com' + ) + push_store = InMemoryPushNotificationConfigStore() + context = create_server_call_context() + await push_store.set_info('task_1', push_config1, context) + await push_store.set_info('task_1', push_config2, context) + await push_store.set_info('task_2', push_config1, context) + request_handler = DefaultRequestHandlerV2( + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + push_config_store=push_store, + ) + params = DeleteTaskPushNotificationConfigRequest( + task_id='task_1', id='config_1' + ) + result1 = await request_handler.on_delete_task_push_notification_config( + params, create_server_call_context() + ) + assert result1 is None + result2 = await request_handler.on_list_task_push_notification_configs( + ListTaskPushNotificationConfigsRequest(task_id='task_1'), + create_server_call_context(), + ) + assert len(result2.configs) == 1 + assert result2.configs[0].task_id == 'task_1' + assert result2.configs[0] == push_config2 + + +@pytest.mark.asyncio +async def test_delete_task_push_notification_config_info_with_config_and_no_id(): + """Test on_list_task_push_notification_configs with no push config id""" + mock_task_store = AsyncMock(spec=TaskStore) + sample_task = create_sample_task(task_id='non_existent_task') + mock_task_store.get.return_value = sample_task + push_config = TaskPushNotificationConfig(url='http://example.com') + push_store = InMemoryPushNotificationConfigStore() + context = create_server_call_context() + await push_store.set_info('task_1', push_config, context) + await push_store.set_info('task_1', push_config, context) + request_handler = DefaultRequestHandlerV2( + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + push_config_store=push_store, + ) + params = DeleteTaskPushNotificationConfigRequest( + task_id='task_1', id='task_1' + ) + result = await request_handler.on_delete_task_push_notification_config( + params, create_server_call_context() + ) + assert result is None + result2 = await request_handler.on_list_task_push_notification_configs( + ListTaskPushNotificationConfigsRequest(task_id='task_1'), + create_server_call_context(), + ) + assert len(result2.configs) == 0 + + +TERMINAL_TASK_STATES = { + TaskState.TASK_STATE_COMPLETED, + TaskState.TASK_STATE_CANCELED, + TaskState.TASK_STATE_FAILED, + TaskState.TASK_STATE_REJECTED, +} + + +@pytest.mark.asyncio +@pytest.mark.parametrize('terminal_state', TERMINAL_TASK_STATES) +async def test_on_message_send_task_in_terminal_state(terminal_state): + """Test on_message_send when task is already in a terminal state.""" + state_name = TaskState.Name(terminal_state) + task_id = f'terminal_task_{state_name}' + terminal_task = create_sample_task( + task_id=task_id, status_state=terminal_state + ) + mock_task_store = AsyncMock(spec=TaskStore) + request_handler = DefaultRequestHandlerV2( + agent_executor=MockAgentExecutor(), task_store=mock_task_store + ) + params = SendMessageRequest( + message=Message( + role=Role.ROLE_USER, + message_id='msg_terminal', + parts=[Part(text='hello')], + task_id=task_id, + ) + ) + with ( + patch( + 'a2a.server.request_handlers.default_request_handler.TaskManager.get_task', + return_value=terminal_task, + ), + pytest.raises(InvalidParamsError) as exc_info, + ): + await request_handler.on_message_send( + params, create_server_call_context() + ) + assert ( + f'Task {task_id} is in terminal state: {terminal_state}' + in exc_info.value.message + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize('terminal_state', TERMINAL_TASK_STATES) +async def test_on_message_send_stream_task_in_terminal_state(terminal_state): + """Test on_message_send_stream when task is already in a terminal state.""" + state_name = TaskState.Name(terminal_state) + task_id = f'terminal_stream_task_{state_name}' + terminal_task = create_sample_task( + task_id=task_id, status_state=terminal_state + ) + mock_task_store = AsyncMock(spec=TaskStore) + request_handler = DefaultRequestHandlerV2( + agent_executor=MockAgentExecutor(), task_store=mock_task_store + ) + params = SendMessageRequest( + message=Message( + role=Role.ROLE_USER, + message_id='msg_terminal_stream', + parts=[Part(text='hello')], + task_id=task_id, + ) + ) + with ( + patch( + 'a2a.server.request_handlers.default_request_handler.TaskManager.get_task', + return_value=terminal_task, + ), + pytest.raises(InvalidParamsError) as exc_info, + ): + async for _ in request_handler.on_message_send_stream( + params, create_server_call_context() + ): + pass + assert ( + f'Task {task_id} is in terminal state: {terminal_state}' + in exc_info.value.message + ) + + +@pytest.mark.asyncio +async def test_on_message_send_task_id_provided_but_task_not_found(): + """Test on_message_send when task_id is provided but task doesn't exist.""" + pass + + +@pytest.mark.asyncio +async def test_on_message_send_stream_task_id_provided_but_task_not_found(): + """Test on_message_send_stream when task_id is provided but task doesn't exist.""" + pass + + +class HelloWorldAgentExecutor(AgentExecutor): + """Test Agent Implementation.""" + + async def execute( + self, context: RequestContext, event_queue: EventQueue + ) -> None: + updater = TaskUpdater( + event_queue, + task_id=context.task_id or str(uuid.uuid4()), + context_id=context.context_id or str(uuid.uuid4()), + ) + await updater.update_status(TaskState.TASK_STATE_WORKING) + await updater.complete() + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ) -> None: + raise NotImplementedError('cancel not supported') + + +@pytest.mark.asyncio +@pytest.mark.timeout(1) +async def test_on_message_send_error_does_not_hang(): + """Test that if the consumer raises an exception during blocking wait, the producer is cancelled and no deadlock occurs.""" + agent = HelloWorldAgentExecutor() + task_store = AsyncMock(spec=TaskStore) + task_store.get.return_value = None + task_store.save.side_effect = RuntimeError('This is an Error!') + + request_handler = DefaultRequestHandlerV2( + agent_executor=agent, task_store=task_store + ) + + params = SendMessageRequest( + message=Message( + role=Role.ROLE_USER, + message_id='msg_error_blocking', + parts=[Part(text='Test message')], + ) + ) + with pytest.raises(RuntimeError, match='This is an Error!'): + await request_handler.on_message_send( + params, create_server_call_context() + ) + + +@pytest.mark.asyncio +async def test_on_get_task_negative_history_length_error(): + """Test on_get_task raises error for negative history length.""" + mock_task_store = AsyncMock(spec=TaskStore) + request_handler = DefaultRequestHandlerV2( + agent_executor=AsyncMock(spec=AgentExecutor), task_store=mock_task_store + ) + params = GetTaskRequest(id='task1', history_length=-1) + context = create_server_call_context() + with pytest.raises(InvalidParamsError) as exc_info: + await request_handler.on_get_task(params, context) + assert 'history length must be non-negative' in exc_info.value.message + + +@pytest.mark.asyncio +async def test_on_list_tasks_page_size_too_small(): + """Test on_list_tasks raises error for page_size < 1.""" + mock_task_store = AsyncMock(spec=TaskStore) + request_handler = DefaultRequestHandlerV2( + agent_executor=AsyncMock(spec=AgentExecutor), task_store=mock_task_store + ) + params = ListTasksRequest(page_size=0) + context = create_server_call_context() + with pytest.raises(InvalidParamsError) as exc_info: + await request_handler.on_list_tasks(params, context) + assert 'minimum page size is 1' in exc_info.value.message + + +@pytest.mark.asyncio +async def test_on_list_tasks_page_size_too_large(): + """Test on_list_tasks raises error for page_size > 100.""" + mock_task_store = AsyncMock(spec=TaskStore) + request_handler = DefaultRequestHandlerV2( + agent_executor=AsyncMock(spec=AgentExecutor), task_store=mock_task_store + ) + params = ListTasksRequest(page_size=101) + context = create_server_call_context() + with pytest.raises(InvalidParamsError) as exc_info: + await request_handler.on_list_tasks(params, context) + assert 'maximum page size is 100' in exc_info.value.message + + +@pytest.mark.asyncio +async def test_on_message_send_negative_history_length_error(): + """Test on_message_send raises error for negative history length in configuration.""" + mock_task_store = AsyncMock(spec=TaskStore) + mock_agent_executor = AsyncMock(spec=AgentExecutor) + request_handler = DefaultRequestHandlerV2( + agent_executor=mock_agent_executor, task_store=mock_task_store + ) + message_config = SendMessageConfiguration( + history_length=-1, accepted_output_modes=['text/plain'] + ) + params = SendMessageRequest( + message=Message( + role=Role.ROLE_USER, message_id='msg1', parts=[Part(text='hello')] + ), + configuration=message_config, + ) + context = create_server_call_context() + with pytest.raises(InvalidParamsError) as exc_info: + await request_handler.on_message_send(params, context) + assert 'history length must be non-negative' in exc_info.value.message + + +@pytest.mark.asyncio +async def test_on_message_send_limit_history(): + task_store = InMemoryTaskStore() + push_store = InMemoryPushNotificationConfigStore() + + request_handler = DefaultRequestHandlerV2( + agent_executor=HelloAgentExecutor(), + task_store=task_store, + push_config_store=push_store, + ) + params = SendMessageRequest( + message=Message( + role=Role.ROLE_USER, + message_id='msg_push', + parts=[Part(text='Hi')], + ), + configuration=SendMessageConfiguration( + accepted_output_modes=['text/plain'], + history_length=1, + ), + ) + + context = create_server_call_context() + result = await request_handler.on_message_send(params, context) + + # verify that history_length is honored + assert result is not None + assert isinstance(result, Task) + assert result.history is not None and len(result.history) == 1 + assert result.status.state == TaskState.TASK_STATE_COMPLETED + + # verify that history is still persisted to the store + task = await task_store.get(result.id, context) + assert task is not None + assert task.history is not None and len(task.history) > 1 + + +@pytest.mark.asyncio +async def test_on_message_send_task_id_mismatch(): + mock_task_store = AsyncMock(spec=TaskStore) + mock_agent_executor = AsyncMock(spec=AgentExecutor) + mock_request_context_builder = AsyncMock(spec=RequestContextBuilder) + + context_task_id = 'context_task_id_1' + result_task_id = 'DIFFERENT_task_id_1' + + mock_request_context = MagicMock() + mock_request_context.task_id = context_task_id + mock_request_context_builder.build.return_value = mock_request_context + + request_handler = DefaultRequestHandlerV2( + agent_executor=mock_agent_executor, + task_store=mock_task_store, + request_context_builder=mock_request_context_builder, + ) + params = SendMessageRequest( + message=Message( + role=Role.ROLE_USER, + message_id='msg_id_mismatch', + parts=[Part(text='hello')], + ) + ) + + mock_active_task = MagicMock() + mismatched_task = create_sample_task(task_id=result_task_id) + mock_active_task.wait = AsyncMock(return_value=mismatched_task) + mock_active_task.start = AsyncMock() + mock_active_task.enqueue_request = AsyncMock() + mock_active_task.get_task = AsyncMock(return_value=mismatched_task) + with ( + patch.object( + request_handler._active_task_registry, + 'get_or_create', + return_value=mock_active_task, + ), + patch( + 'a2a.server.request_handlers.default_request_handler.TaskManager.get_task', + return_value=None, + ), + ): + with pytest.raises(InternalError) as exc_info: + await request_handler.on_message_send(params, context=MagicMock()) + assert 'Task ID mismatch' in exc_info.value.message + + +@pytest.mark.asyncio +async def test_on_message_send_stream_task_id_mismatch(): + mock_task_store = AsyncMock(spec=TaskStore) + mock_agent_executor = AsyncMock(spec=AgentExecutor) + mock_request_context_builder = AsyncMock(spec=RequestContextBuilder) + + context_task_id = 'context_task_id_stream_1' + result_task_id = 'DIFFERENT_task_id_stream_1' + + mock_request_context = MagicMock() + mock_request_context.task_id = context_task_id + mock_request_context_builder.build.return_value = mock_request_context + + request_handler = DefaultRequestHandlerV2( + agent_executor=mock_agent_executor, + task_store=mock_task_store, + request_context_builder=mock_request_context_builder, + ) + params = SendMessageRequest( + message=Message( + role=Role.ROLE_USER, + message_id='msg_id_mismatch_stream', + parts=[Part(text='hello')], + ) + ) + + mismatched_task = create_sample_task(task_id=result_task_id) + + async def mock_subscribe(request=None, include_initial_task=False): + yield mismatched_task + + mock_active_task = MagicMock() + mock_active_task.subscribe.side_effect = mock_subscribe + mock_active_task.start = AsyncMock() + mock_active_task.enqueue_request = AsyncMock() + + with ( + patch.object( + request_handler._active_task_registry, + 'get_or_create', + return_value=mock_active_task, + ), + patch( + 'a2a.server.request_handlers.default_request_handler.TaskManager.get_task', + return_value=None, + ), + ): + stream = request_handler.on_message_send_stream( + params, context=MagicMock() + ) + with pytest.raises(InternalError) as exc_info: + async for _ in stream: + pass + assert 'Task ID mismatch' in exc_info.value.message + + +@pytest.mark.asyncio +async def test_on_message_send_non_blocking(): + task_store = InMemoryTaskStore() + push_store = InMemoryPushNotificationConfigStore() + + request_handler = DefaultRequestHandlerV2( + agent_executor=HelloAgentExecutor(), + task_store=task_store, + push_config_store=push_store, + ) + params = SendMessageRequest( + message=Message( + role=Role.ROLE_USER, + message_id='msg_push_non_blocking', + parts=[Part(text='Hi')], + ), + configuration=SendMessageConfiguration( + return_immediately=True, + ), + ) + + context = create_server_call_context() + result = await request_handler.on_message_send(params, context) + + # non-blocking should return the task immediately + assert result is not None + assert isinstance(result, Task) + assert result.status.state == TaskState.TASK_STATE_SUBMITTED + + +@pytest.mark.asyncio +async def test_on_message_send_with_push_notification(): + task_store = InMemoryTaskStore() + push_store = AsyncMock(spec=PushNotificationConfigStore) + + request_handler = DefaultRequestHandlerV2( + agent_executor=HelloAgentExecutor(), + task_store=task_store, + push_config_store=push_store, + ) + push_config = TaskPushNotificationConfig(url='http://example.com/webhook') + params = SendMessageRequest( + message=Message( + role=Role.ROLE_USER, + message_id='msg_push_1', + parts=[Part(text='Hi')], + ), + configuration=SendMessageConfiguration( + task_push_notification_config=push_config + ), + ) + + context = create_server_call_context() + result = await request_handler.on_message_send(params, context) + + assert result is not None + assert isinstance(result, Task) + push_store.set_info.assert_awaited_once_with( + result.id, push_config, context + ) From 605fa4913ad23539a51a3ee1f5b9ca07f24e1d2d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=A1bor=20Feh=C3=A9r?= Date: Tue, 7 Apr 2026 15:40:26 +0200 Subject: [PATCH 04/67] feat: Add support for more Task Message and Artifact fields in the Vertex Task Store (#936) Add support for the following fields: * Part metadata * Artifact extensions, display_name, description * Message extensions, reference_task_ids * Parts of DataPart are now restored to their original type when read back * Add support for status detail messages in task updates For #802 (for the 1.0 branch) --- .github/actions/spelling/allow.txt | 1 + .../contrib/tasks/vertex_task_converter.py | 171 +++++++++++++++++- src/a2a/contrib/tasks/vertex_task_store.py | 33 ++++ tests/contrib/tasks/fake_vertex_client.py | 6 + .../tasks/test_vertex_task_converter.py | 130 ++++++++++--- tests/contrib/tasks/test_vertex_task_store.py | 70 +++++++ 6 files changed, 377 insertions(+), 34 deletions(-) diff --git a/.github/actions/spelling/allow.txt b/.github/actions/spelling/allow.txt index df74a242d..b3657f2b8 100644 --- a/.github/actions/spelling/allow.txt +++ b/.github/actions/spelling/allow.txt @@ -37,6 +37,7 @@ codegen coro culsans datamodel +datapart deepwiki drivername DSNs diff --git a/src/a2a/contrib/tasks/vertex_task_converter.py b/src/a2a/contrib/tasks/vertex_task_converter.py index 6f23dad2e..9441d2153 100644 --- a/src/a2a/contrib/tasks/vertex_task_converter.py +++ b/src/a2a/contrib/tasks/vertex_task_converter.py @@ -11,13 +11,18 @@ import base64 import json +from dataclasses import dataclass +from typing import Any + from a2a.compat.v0_3.types import ( Artifact, DataPart, FilePart, FileWithBytes, FileWithUri, + Message, Part, + Role, Task, TaskState, TaskStatus, @@ -25,6 +30,16 @@ ) +_ORIGINAL_METADATA_KEY = 'originalMetadata' +_EXTENSIONS_KEY = 'extensions' +_REFERENCE_TASK_IDS_KEY = 'referenceTaskIds' +_PART_METADATA_KEY = 'partMetadata' +_METADATA_VERSION_KEY = '__vertex_compat_v' +_METADATA_VERSION_NUMBER = 1.0 + +_DATA_PART_MIME_TYPE = 'application/x-a2a-datapart' + + _TO_SDK_TASK_STATE = { vertexai_types.A2aTaskState.STATE_UNSPECIFIED: TaskState.unknown, vertexai_types.A2aTaskState.SUBMITTED: TaskState.submitted, @@ -52,6 +67,55 @@ def to_stored_task_state(task_state: TaskState) -> vertexai_types.A2aTaskState: ) +def to_stored_metadata( + original_metadata: dict[str, Any] | None, + extensions: list[str] | None, + reference_task_ids: list[str] | None, + parts: list[Part], +) -> dict[str, Any]: + """Packs original metadata, extensions, and part types/metadata into a storage dictionary.""" + metadata: dict[str, Any] = {_METADATA_VERSION_KEY: _METADATA_VERSION_NUMBER} + if original_metadata: + metadata[_ORIGINAL_METADATA_KEY] = original_metadata + if extensions: + metadata[_EXTENSIONS_KEY] = extensions + if reference_task_ids: + metadata[_REFERENCE_TASK_IDS_KEY] = reference_task_ids + + metadata[_PART_METADATA_KEY] = [part.root.metadata for part in parts] + + return metadata + + +@dataclass +class _UnpackedMetadata: + original_metadata: dict[str, Any] | None = None + extensions: list[str] | None = None + reference_task_ids: list[str] | None = None + part_metadata: list[dict[str, Any] | None] | None = None + + +def to_sdk_metadata( + stored_metadata: dict[str, Any] | None, +) -> _UnpackedMetadata: + """Unpacks metadata, extensions, and part types/metadata from a storage dictionary.""" + if not stored_metadata: + return _UnpackedMetadata() + + version = stored_metadata.get(_METADATA_VERSION_KEY) + if version is None: + return _UnpackedMetadata(original_metadata=stored_metadata) + if version > _METADATA_VERSION_NUMBER: + raise ValueError(f'Unsupported metadata version: {version}') + + return _UnpackedMetadata( + original_metadata=stored_metadata.get(_ORIGINAL_METADATA_KEY), + extensions=stored_metadata.get(_EXTENSIONS_KEY), + reference_task_ids=stored_metadata.get(_REFERENCE_TASK_IDS_KEY), + part_metadata=stored_metadata.get(_PART_METADATA_KEY), + ) + + def to_stored_part(part: Part) -> genai_types.Part: """Converts a SDK Part to a proto Part.""" if isinstance(part.root, TextPart): @@ -60,7 +124,7 @@ def to_stored_part(part: Part) -> genai_types.Part: data_bytes = json.dumps(part.root.data).encode('utf-8') return genai_types.Part( inline_data=genai_types.Blob( - mime_type='application/json', data=data_bytes + mime_type=_DATA_PART_MIME_TYPE, data=data_bytes ) ) if isinstance(part.root, FilePart): @@ -82,20 +146,31 @@ def to_stored_part(part: Part) -> genai_types.Part: raise ValueError(f'Unsupported part type: {type(part.root)}') -def to_sdk_part(stored_part: genai_types.Part) -> Part: +def to_sdk_part( + stored_part: genai_types.Part, + part_metadata: dict[str, Any] | None = None, +) -> Part: """Converts a proto Part to a SDK Part.""" if stored_part.text: - return Part(root=TextPart(text=stored_part.text)) + return Part( + root=TextPart(text=stored_part.text, metadata=part_metadata) + ) if stored_part.inline_data: + mime_type = stored_part.inline_data.mime_type + if mime_type == _DATA_PART_MIME_TYPE: + data_dict = json.loads(stored_part.inline_data.data or b'{}') + return Part(root=DataPart(data=data_dict, metadata=part_metadata)) + encoded_bytes = base64.b64encode( stored_part.inline_data.data or b'' ).decode('utf-8') return Part( root=FilePart( file=FileWithBytes( - mime_type=stored_part.inline_data.mime_type, + mime_type=mime_type, bytes=encoded_bytes, - ) + ), + metadata=part_metadata, ) ) if stored_part.file_data and stored_part.file_data.file_uri: @@ -103,8 +178,9 @@ def to_sdk_part(stored_part: genai_types.Part) -> Part: root=FilePart( file=FileWithUri( mime_type=stored_part.file_data.mime_type, - uri=stored_part.file_data.file_uri, - ) + uri=stored_part.file_data.file_uri or '', + ), + metadata=part_metadata, ) ) @@ -115,15 +191,83 @@ def to_stored_artifact(artifact: Artifact) -> vertexai_types.TaskArtifact: """Converts a SDK Artifact to a proto TaskArtifact.""" return vertexai_types.TaskArtifact( artifact_id=artifact.artifact_id, + display_name=artifact.name, + description=artifact.description, parts=[to_stored_part(part) for part in artifact.parts], + metadata=to_stored_metadata( + original_metadata=artifact.metadata, + extensions=artifact.extensions, + reference_task_ids=None, + parts=artifact.parts, + ), ) def to_sdk_artifact(stored_artifact: vertexai_types.TaskArtifact) -> Artifact: """Converts a proto TaskArtifact to a SDK Artifact.""" + unpacked_meta = to_sdk_metadata(stored_artifact.metadata) + part_metadata_list = unpacked_meta.part_metadata or [] + + parts = [] + for i, part in enumerate(stored_artifact.parts or []): + meta: dict[str, Any] | None = None + if i < len(part_metadata_list): + meta = part_metadata_list[i] + parts.append(to_sdk_part(part, part_metadata=meta)) + return Artifact( artifact_id=stored_artifact.artifact_id, - parts=[to_sdk_part(part) for part in stored_artifact.parts], + name=stored_artifact.display_name, + description=stored_artifact.description, + extensions=unpacked_meta.extensions, + metadata=unpacked_meta.original_metadata, + parts=parts, + ) + + +def to_stored_message( + message: Message | None, +) -> vertexai_types.TaskMessage | None: + """Converts a SDK Message to a proto Message.""" + if not message: + return None + role = message.role.value if message.role else '' + return vertexai_types.TaskMessage( + message_id=message.message_id, + role=role, + parts=[to_stored_part(part) for part in message.parts], + metadata=to_stored_metadata( + original_metadata=message.metadata, + extensions=message.extensions, + reference_task_ids=message.reference_task_ids, + parts=message.parts, + ), + ) + + +def to_sdk_message( + stored_msg: vertexai_types.TaskMessage | None, +) -> Message | None: + """Converts a proto Message to a SDK Message.""" + if not stored_msg: + return None + unpacked_meta = to_sdk_metadata(stored_msg.metadata) + part_metadata_list = unpacked_meta.part_metadata or [] + + parts = [] + for i, part in enumerate(stored_msg.parts or []): + part_metadata: dict[str, Any] | None = None + if i < len(part_metadata_list): + part_metadata = part_metadata_list[i] + parts.append(to_sdk_part(part, part_metadata=part_metadata)) + + return Message( + message_id=stored_msg.message_id, + role=Role(stored_msg.role), + extensions=unpacked_meta.extensions, + reference_task_ids=unpacked_meta.reference_task_ids, + metadata=unpacked_meta.original_metadata, + parts=parts, ) @@ -133,6 +277,11 @@ def to_stored_task(task: Task) -> vertexai_types.A2aTask: context_id=task.context_id, metadata=task.metadata, state=to_stored_task_state(task.status.state), + status_details=vertexai_types.TaskStatusDetails( + task_message=to_stored_message(task.status.message) + ) + if task.status.message + else None, output=vertexai_types.TaskOutput( artifacts=[ to_stored_artifact(artifact) @@ -144,10 +293,14 @@ def to_stored_task(task: Task) -> vertexai_types.A2aTask: def to_sdk_task(a2a_task: vertexai_types.A2aTask) -> Task: """Converts a proto A2aTask to a SDK Task.""" + msg: Message | None = None + if a2a_task.status_details and a2a_task.status_details.task_message: + msg = to_sdk_message(a2a_task.status_details.task_message) + return Task( id=a2a_task.name.split('/')[-1], context_id=a2a_task.context_id, - status=TaskStatus(state=to_sdk_task_state(a2a_task.state)), + status=TaskStatus(state=to_sdk_task_state(a2a_task.state), message=msg), metadata=a2a_task.metadata or {}, artifacts=[ to_sdk_artifact(artifact) diff --git a/src/a2a/contrib/tasks/vertex_task_store.py b/src/a2a/contrib/tasks/vertex_task_store.py index ccd9fffba..0457694e4 100644 --- a/src/a2a/contrib/tasks/vertex_task_store.py +++ b/src/a2a/contrib/tasks/vertex_task_store.py @@ -84,6 +84,32 @@ def _get_status_change_event( ) return None + def _get_status_details_change_event( + self, + previous_task: CompatTask, + task: CompatTask, + event_sequence_number: int, + ) -> vertexai_types.TaskEvent | None: + if task.status.message != previous_task.status.message: + status_details = ( + vertexai_types.TaskStatusDetails( + task_message=vertex_task_converter.to_stored_message( + task.status.message + ) + ) + if task.status.message + else vertexai_types.TaskStatusDetails() + ) + return vertexai_types.TaskEvent( + event_data=vertexai_types.TaskEventData( + status_details_change=vertexai_types.TaskStatusDetailsChange( + new_task_status=status_details, + ), + ), + event_sequence_number=event_sequence_number, + ) + return None + def _get_metadata_change_event( self, previous_task: CompatTask, @@ -168,6 +194,13 @@ async def _update( events.append(status_event) event_sequence_number += 1 + status_details_event = self._get_status_details_change_event( + previous_task, task, event_sequence_number + ) + if status_details_event: + events.append(status_details_event) + event_sequence_number += 1 + metadata_event = self._get_metadata_change_event( previous_task, task, event_sequence_number ) diff --git a/tests/contrib/tasks/fake_vertex_client.py b/tests/contrib/tasks/fake_vertex_client.py index 86d14ede0..8a4a86903 100644 --- a/tests/contrib/tasks/fake_vertex_client.py +++ b/tests/contrib/tasks/fake_vertex_client.py @@ -36,6 +36,12 @@ async def append( data = event.event_data if getattr(data, 'state_change', None): task.state = getattr(data.state_change, 'new_state', task.state) + if getattr(data, 'status_details_change', None): + task.status_details = getattr( + data.status_details_change, + 'new_task_status', + getattr(task, 'status_details', None), + ) if getattr(data, 'metadata_change', None): task.metadata = getattr( data.metadata_change, 'new_metadata', task.metadata diff --git a/tests/contrib/tasks/test_vertex_task_converter.py b/tests/contrib/tasks/test_vertex_task_converter.py index a060bc451..3d260c599 100644 --- a/tests/contrib/tasks/test_vertex_task_converter.py +++ b/tests/contrib/tasks/test_vertex_task_converter.py @@ -9,11 +9,14 @@ from vertexai import types as vertexai_types from google.genai import types as genai_types from a2a.contrib.tasks.vertex_task_converter import ( + _DATA_PART_MIME_TYPE, to_sdk_artifact, + to_sdk_message, to_sdk_part, to_sdk_task, to_sdk_task_state, to_stored_artifact, + to_stored_message, to_stored_part, to_stored_task, to_stored_task_state, @@ -24,7 +27,9 @@ FilePart, FileWithBytes, FileWithUri, + Message, Part, + Role, Task, TaskState, TaskStatus, @@ -123,7 +128,7 @@ def test_to_stored_part_data() -> None: sdk_part = Part(root=DataPart(data={'key': 'value'})) stored_part = to_stored_part(sdk_part) assert stored_part.inline_data is not None - assert stored_part.inline_data.mime_type == 'application/json' + assert stored_part.inline_data.mime_type == _DATA_PART_MIME_TYPE assert stored_part.inline_data.data == b'{"key": "value"}' @@ -190,6 +195,18 @@ def test_to_sdk_part_inline_data() -> None: assert sdk_part.root.file.bytes == expected_b64 +def test_to_sdk_part_inline_data_datapart() -> None: + stored_part = genai_types.Part( + inline_data=genai_types.Blob( + mime_type=_DATA_PART_MIME_TYPE, + data=b'{"key": "val"}', + ) + ) + sdk_part = to_sdk_part(stored_part) + assert isinstance(sdk_part.root, DataPart) + assert sdk_part.root.data == {'key': 'val'} + + def test_to_sdk_part_file_data() -> None: stored_part = genai_types.Part( file_data=genai_types.FileData( @@ -313,23 +330,11 @@ def test_sdk_part_text_conversion_round_trip() -> None: def test_sdk_part_data_conversion_round_trip() -> None: - # A DataPart is converted to `inline_data` in Vertex AI, which lacks the original - # `DataPart` vs `FilePart` distinction. When reading it back from the stored - # protocol format, it becomes a `FilePart` with base64-encoded `FileWithBytes` - # and `mime_type="application/json"`. sdk_part = Part(root=DataPart(data={'key': 'value'})) stored_part = to_stored_part(sdk_part) - round_trip_sdk_part = to_sdk_part(stored_part) + round_trip_sdk_part = to_sdk_part(stored_part, part_metadata=None) - expected_b64 = base64.b64encode(b'{"key": "value"}').decode('utf-8') - assert round_trip_sdk_part == Part( - root=FilePart( - file=FileWithBytes( - bytes=expected_b64, - mime_type='application/json', - ) - ) - ) + assert round_trip_sdk_part == sdk_part def test_sdk_part_file_bytes_conversion_round_trip() -> None: @@ -361,16 +366,6 @@ def test_sdk_part_file_uri_conversion_round_trip() -> None: assert round_trip_sdk_part == sdk_part -def test_sdk_artifact_conversion_round_trip() -> None: - sdk_artifact = Artifact( - artifact_id='art-123', - parts=[Part(root=TextPart(text='part_1'))], - ) - stored_artifact = to_stored_artifact(sdk_artifact) - round_trip_sdk_artifact = to_sdk_artifact(stored_artifact) - assert round_trip_sdk_artifact == sdk_artifact - - def test_sdk_task_conversion_round_trip() -> None: sdk_task = Task( id='task-1', @@ -403,3 +398,88 @@ def test_sdk_task_conversion_round_trip() -> None: assert round_trip_sdk_task.metadata == sdk_task.metadata assert round_trip_sdk_task.artifacts == sdk_task.artifacts assert round_trip_sdk_task.history == [] + + +def test_stored_artifact_conversion_round_trip() -> None: + """Test converting an Artifact to TaskArtifact and back restores everything.""" + original_artifact = Artifact( + artifact_id='art123', + name='My cool artifact', + description='A very interesting description', + extensions=['ext1', 'ext2'], + metadata={'custom': 'value'}, + parts=[ + Part( + root=TextPart( + text='hello', metadata={'part_meta': 'hello_meta'} + ) + ), + Part(root=DataPart(data={'foo': 'bar'})), # no metadata + ], + ) + + stored = to_stored_artifact(original_artifact) + assert isinstance(stored, vertexai_types.TaskArtifact) + + # ensure it was populated correctly + assert stored.display_name == 'My cool artifact' + assert stored.description == 'A very interesting description' + assert stored.metadata['__vertex_compat_v'] == 1.0 + + restored_artifact = to_sdk_artifact(stored) + + assert restored_artifact.artifact_id == original_artifact.artifact_id + assert restored_artifact.name == original_artifact.name + assert restored_artifact.description == original_artifact.description + assert restored_artifact.extensions == original_artifact.extensions + assert restored_artifact.metadata == original_artifact.metadata + + assert len(restored_artifact.parts) == 2 + assert isinstance(restored_artifact.parts[0].root, TextPart) + assert restored_artifact.parts[0].root.text == 'hello' + assert restored_artifact.parts[0].root.metadata == { + 'part_meta': 'hello_meta' + } + + assert isinstance(restored_artifact.parts[1].root, DataPart) + assert restored_artifact.parts[1].root.data == {'foo': 'bar'} + assert restored_artifact.parts[1].root.metadata is None + + +def test_stored_message_conversion_round_trip() -> None: + """Test converting a Message to TaskMessage and back restores everything.""" + original_message = Message( + message_id='msg456', + role=Role.agent, + reference_task_ids=['tsk2', 'tsk3'], + extensions=['ext_msg'], + metadata={'msg_meta': 42}, + parts=[ + Part(root=TextPart(text='message text')), + ], + ) + + stored = to_stored_message(original_message) + assert stored is not None + assert isinstance(stored, vertexai_types.TaskMessage) + + assert stored.message_id == 'msg456' + assert stored.role == 'agent' + assert stored.metadata['__vertex_compat_v'] == 1.0 + + restored_message = to_sdk_message(stored) + assert restored_message is not None + + assert restored_message.message_id == original_message.message_id + assert restored_message.role == original_message.role + assert ( + restored_message.reference_task_ids + == original_message.reference_task_ids + ) + assert restored_message.extensions == original_message.extensions + assert restored_message.metadata == original_message.metadata + + assert len(restored_message.parts) == 1 + assert isinstance(restored_message.parts[0].root, TextPart) + assert restored_message.parts[0].root.text == 'message text' + assert restored_message.parts[0].root.metadata is None diff --git a/tests/contrib/tasks/test_vertex_task_store.py b/tests/contrib/tasks/test_vertex_task_store.py index 75e3bdf08..4be8cd4e6 100644 --- a/tests/contrib/tasks/test_vertex_task_store.py +++ b/tests/contrib/tasks/test_vertex_task_store.py @@ -65,7 +65,9 @@ def backend_type(request) -> str: from a2a.server.context import ServerCallContext from a2a.types.a2a_pb2 import ( Artifact, + Message, Part, + Role, Task, TaskState, TaskStatus, @@ -530,3 +532,71 @@ async def test_metadata_field_mapping( ) assert retrieved_none is not None assert retrieved_none.metadata == {} + + +@pytest.mark.asyncio +async def test_update_task_status_details( + vertex_store: VertexTaskStore, +) -> None: + """Test updating an existing task by changing the status details (message) with part metadata.""" + task_id = 'update-test-task-status-details' + original_task = Task( + id=task_id, + context_id='session-update', + status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED), + metadata=None, + artifacts=[], + history=[], + ) + await vertex_store.save(original_task, ServerCallContext()) + + retrieved_before_update = await vertex_store.get( + task_id, ServerCallContext() + ) + assert retrieved_before_update is not None + assert ( + retrieved_before_update.status.state == TaskState.TASK_STATE_SUBMITTED + ) + + updated_task = Task() + updated_task.CopyFrom(original_task) + updated_task.status.state = TaskState.TASK_STATE_FAILED + updated_task.status.timestamp.FromJsonString('2023-01-02T11:00:00Z') + updated_task.status.message.CopyFrom( + Message( + message_id='msg-error-1', + role=Role.ROLE_AGENT, + parts=[ + Part( + text='Task failed due to an unknown error', + metadata={'error_code': 'UNKNOWN', 'retryable': False}, + ) + ], + ) + ) + + await vertex_store.save(updated_task, ServerCallContext()) + + retrieved_after_update = await vertex_store.get( + task_id, ServerCallContext() + ) + assert retrieved_after_update is not None + assert retrieved_after_update.status.state == TaskState.TASK_STATE_FAILED + assert retrieved_after_update.status.message is not None + assert retrieved_after_update.status.message.message_id == 'msg-error-1' + assert retrieved_after_update.status.message.role == Role.ROLE_AGENT + assert len(retrieved_after_update.status.message.parts) == 1 + + part = retrieved_after_update.status.message.parts[0] + assert part.text == 'Task failed due to an unknown error' + assert part.metadata == {'error_code': 'UNKNOWN', 'retryable': False} + + # Also test clearing the message + cleared_task = Task() + cleared_task.CopyFrom(updated_task) + cleared_task.status.ClearField('message') + + await vertex_store.save(cleared_task, ServerCallContext()) + retrieved_cleared = await vertex_store.get(task_id, ServerCallContext()) + assert retrieved_cleared is not None + assert not retrieved_cleared.status.HasField('message') From a61f6d4e2e7ce1616a35c3a2ede64a4c9067048a Mon Sep 17 00:00:00 2001 From: Ivan Shymko Date: Tue, 7 Apr 2026 14:33:23 +0000 Subject: [PATCH 05/67] chore: release 1.0.0-alpha.1 Release-As: 1.0.0-alpha.1 From 4fc6b54fd26cc83d810d81f923579a1cd4853b39 Mon Sep 17 00:00:00 2001 From: Bartek Wolowiec Date: Wed, 8 Apr 2026 09:23:44 +0200 Subject: [PATCH 06/67] feat: Unhandled exception in AgentExecutor marks task as failed (#943) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes #869 🦕 --- src/a2a/server/agent_execution/active_task.py | 152 ++++++++++-------- tests/integration/test_scenarios.py | 24 ++- 2 files changed, 97 insertions(+), 79 deletions(-) diff --git a/src/a2a/server/agent_execution/active_task.py b/src/a2a/server/agent_execution/active_task.py index f313ca11e..bf9e129a6 100644 --- a/src/a2a/server/agent_execution/active_task.py +++ b/src/a2a/server/agent_execution/active_task.py @@ -32,6 +32,8 @@ Message, Task, TaskState, + TaskStatus, + TaskStatusUpdateEvent, ) from a2a.utils.errors import ( InvalidParamsError, @@ -252,80 +254,75 @@ async def _run_producer(self) -> None: """ logger.debug('Producer[%s]: Started', self._task_id) try: - try: - try: - while True: - ( - request_context, - request_id, - ) = await self._request_queue.get() - await self._request_lock.acquire() - # TODO: Should we create task manager every time? - self._task_manager._call_context = ( - request_context.call_context - ) - request_context.current_task = ( - await self._task_manager.get_task() - ) + active = True + while active: + ( + request_context, + request_id, + ) = await self._request_queue.get() + await self._request_lock.acquire() + # TODO: Should we create task manager every time? + self._task_manager._call_context = request_context.call_context + request_context.current_task = ( + await self._task_manager.get_task() + ) - message = request_context.message - if message: - request_context.current_task = ( - self._task_manager.update_with_message( - message, - cast('Task', request_context.current_task), - ) - ) - await self._task_manager.save_task_event( - request_context.current_task - ) - self._task_created.set() - logger.debug( - 'Producer[%s]: Executing agent task %s', - self._task_id, - request_context.current_task, + message = request_context.message + if message: + request_context.current_task = ( + self._task_manager.update_with_message( + message, + cast('Task', request_context.current_task), ) + ) + await self._task_manager.save_task_event( + request_context.current_task + ) + self._task_created.set() + logger.debug( + 'Producer[%s]: Executing agent task %s', + self._task_id, + request_context.current_task, + ) - try: - await self._agent_executor.execute( - request_context, self._event_queue_agent - ) - logger.debug( - 'Producer[%s]: Execution finished successfully', - self._task_id, - ) - except Exception as e: - async with self._lock: - if self._exception is None: - self._exception = e - raise - finally: - logger.debug( - 'Producer[%s]: Enqueuing request completed event', - self._task_id, - ) - # TODO: Hide from external consumers - await self._event_queue_agent.enqueue_event( - cast('Event', _RequestCompleted(request_id)) - ) - self._request_queue.task_done() + try: + await self._agent_executor.execute( + request_context, self._event_queue_agent + ) + logger.debug( + 'Producer[%s]: Execution finished successfully', + self._task_id, + ) except QueueShutDown: logger.debug( 'Producer[%s]: Request queue shut down', self._task_id ) - except asyncio.CancelledError: - logger.debug('Producer[%s]: Cancelled', self._task_id) - raise - except Exception as e: - logger.exception('Producer[%s]: Failed', self._task_id) - async with self._lock: - if self._exception is None: - self._exception = e - finally: - self._request_queue.shutdown(immediate=True) - await self._event_queue_agent.close(immediate=False) - await self._event_queue_subscribers.close(immediate=False) + raise + except asyncio.CancelledError: + logger.debug('Producer[%s]: Cancelled', self._task_id) + raise + except Exception as e: + logger.exception( + 'Producer[%s]: Execution failed', + self._task_id, + ) + async with self._lock: + await self._mark_task_as_failed(e) + active = False + finally: + logger.debug( + 'Producer[%s]: Enqueuing request completed event', + self._task_id, + ) + # TODO: Hide from external consumers + await self._event_queue_agent.enqueue_event( + cast('Event', _RequestCompleted(request_id)) + ) + self._request_queue.task_done() finally: + self._request_queue.shutdown(immediate=True) + await self._event_queue_agent.close(immediate=False) + await self._event_queue_subscribers.close(immediate=False) logger.debug('Producer[%s]: Completed', self._task_id) async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912 @@ -443,8 +440,7 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912 except Exception as e: logger.exception('Consumer[%s]: Failed', self._task_id) async with self._lock: - if self._exception is None: - self._exception = e + await self._mark_task_as_failed(e) finally: # The consumer is dead. The ActiveTask is permanently finished. self._is_finished.set() @@ -581,9 +577,7 @@ async def cancel(self, call_context: ServerCallContext) -> Task | Message: logger.exception( 'Cancel[%s]: Agent cancel failed', self._task_id ) - if not self._exception: - self._exception = e - + await self._mark_task_as_failed(e) raise else: logger.debug( @@ -619,6 +613,22 @@ async def _maybe_cleanup(self) -> None: logger.debug('Cleanup[%s]: Triggering cleanup', self._task_id) self._on_cleanup(self) + async def _mark_task_as_failed(self, exception: Exception) -> None: + if self._exception is None: + self._exception = exception + if self._task_created.is_set(): + task = await self._task_manager.get_task() + if task is not None: + await self._event_queue_agent.enqueue_event( + TaskStatusUpdateEvent( + task_id=task.id, + context_id=task.context_id, + status=TaskStatus( + state=TaskState.TASK_STATE_FAILED, + ), + ) + ) + async def get_task(self) -> Task: """Get task from db.""" # TODO: THERE IS ZERO CONCURRENCY SAFETY HERE (Except inital task creation). diff --git a/tests/integration/test_scenarios.py b/tests/integration/test_scenarios.py index 94774e29a..a7d85a28c 100644 --- a/tests/integration/test_scenarios.py +++ b/tests/integration/test_scenarios.py @@ -437,9 +437,8 @@ async def cancel( # Legacy is not creating tasks for agent failures. assert len((await client.list_tasks(ListTasksRequest())).tasks) == 0 else: - # TODO: should it be TASK_STATE_FAILED ? (task,) = (await client.list_tasks(ListTasksRequest())).tasks - assert task.status.state == TaskState.TASK_STATE_SUBMITTED + assert task.status.state == TaskState.TASK_STATE_FAILED # Scenario 12/13: Exception after initial event @@ -503,9 +502,12 @@ async def release_agent(): await asyncio.gather(*tasks) - # TODO: should it be TASK_STATE_FAILED ? (task,) = (await client.list_tasks(ListTasksRequest())).tasks - assert task.status.state == TaskState.TASK_STATE_WORKING + if use_legacy: + # Legacy does not update task state on exception. + assert task.status.state == TaskState.TASK_STATE_WORKING + else: + assert task.status.state == TaskState.TASK_STATE_FAILED # Scenario 14: Exception in Cancel @@ -563,9 +565,12 @@ async def cancel( with pytest.raises(A2AClientError, match='TEST_ERROR_IN_CANCEL'): await client.cancel_task(CancelTaskRequest(id=task_id)) - # TODO: should it be TASK_STATE_CANCELED or TASK_STATE_FAILED? (task,) = (await client.list_tasks(ListTasksRequest())).tasks - assert task.status.state == TaskState.TASK_STATE_WORKING + if use_legacy: + # Legacy does not update task state on exception. + assert task.status.state == TaskState.TASK_STATE_WORKING + else: + assert task.status.state == TaskState.TASK_STATE_FAILED # Scenario 15: Subscribe to task that errors out @@ -632,9 +637,12 @@ async def consume_events(): with pytest.raises(A2AClientError, match='TEST_ERROR_IN_EXECUTE'): await consume_task - # TODO: should it be TASK_STATE_FAILED? (task,) = (await client.list_tasks(ListTasksRequest())).tasks - assert task.status.state == TaskState.TASK_STATE_WORKING + if use_legacy: + # Legacy does not update task state on exception. + assert task.status.state == TaskState.TASK_STATE_WORKING + else: + assert task.status.state == TaskState.TASK_STATE_FAILED # Scenario 16: Slow execution and return_immediately=True From 2159140b1c24fe556a41accf97a6af7f54ec6701 Mon Sep 17 00:00:00 2001 From: Iva Sokolaj <102302011+sokoliva@users.noreply.github.com> Date: Wed, 8 Apr 2026 10:49:19 +0200 Subject: [PATCH 07/67] feat: Add GetExtendedAgentCard Support to RequestHandlers (#919) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Description The `GetExtendedAgentCard` capability was defined in the spec but not implemented in the `request_handler.py`. # Changes - Added `on_get_extended_agent_card` to the base `RequestHandler` and its child class `DefaultRequestHandler`. - Removed `GetExtendedAgentCard` method implementations from the Transport layer and consequently moved `AgentCard` informations from the Transport layer to the `RequestHandlers`. - moved `validate` logic from the transport layer to the default request handler Fixes #866 🦕 --- itk/main.py | 24 +- samples/hello_world_agent.py | 10 +- src/a2a/compat/v0_3/grpc_handler.py | 42 +- src/a2a/compat/v0_3/jsonrpc_adapter.py | 44 +- src/a2a/compat/v0_3/request_handler.py | 16 +- src/a2a/compat/v0_3/rest_adapter.py | 60 +-- src/a2a/compat/v0_3/rest_handler.py | 30 +- .../default_request_handler.py | 74 +++- .../default_request_handler_v2.py | 81 +++- .../server/request_handlers/grpc_handler.py | 37 +- .../request_handlers/request_handler.py | 158 ++++++- src/a2a/server/routes/jsonrpc_dispatcher.py | 61 +-- src/a2a/server/routes/jsonrpc_routes.py | 33 +- src/a2a/server/routes/rest_dispatcher.py | 65 +-- src/a2a/server/routes/rest_routes.py | 31 +- src/a2a/utils/helpers.py | 100 +---- tck/sut_agent.py | 7 +- tests/compat/v0_3/test_grpc_handler.py | 10 +- tests/compat/v0_3/test_jsonrpc_app_compat.py | 15 +- tests/compat/v0_3/test_request_handler.py | 58 ++- tests/compat/v0_3/test_rest_handler.py | 40 +- tests/compat/v0_3/test_rest_routes_compat.py | 3 +- tests/e2e/push_notifications/agent_app.py | 6 +- .../test_default_push_notification_support.py | 4 +- tests/e2e/push_notifications/utils.py | 6 +- .../cross_version/client_server/server_1_0.py | 13 +- tests/integration/test_agent_card.py | 7 +- .../test_client_server_integration.py | 102 +++-- .../integration/test_copying_observability.py | 4 +- tests/integration/test_end_to_end.py | 13 +- tests/integration/test_scenarios.py | 16 +- .../test_stream_generator_cleanup.py | 3 +- tests/integration/test_tenant.py | 2 - tests/integration/test_version_header.py | 13 +- .../test_default_request_handler.py | 394 ++++++++++++++---- .../test_default_request_handler_v2.py | 112 ++++- .../request_handlers/test_grpc_handler.py | 42 +- .../server/routes/test_jsonrpc_dispatcher.py | 22 +- tests/server/routes/test_jsonrpc_routes.py | 6 +- tests/server/routes/test_rest_dispatcher.py | 95 ++--- tests/server/routes/test_rest_routes.py | 27 +- tests/server/test_integration.py | 4 +- tests/utils/test_helpers.py | 22 - 43 files changed, 1109 insertions(+), 803 deletions(-) diff --git a/itk/main.py b/itk/main.py index 97d5cb29e..22cfef2a4 100644 --- a/itk/main.py +++ b/itk/main.py @@ -292,7 +292,11 @@ async def main_async(http_port: int, grpc_port: int) -> None: name='ITK v10 Agent', description='Python agent using SDK 1.0.', version='1.0.0', - capabilities=AgentCapabilities(streaming=True), + capabilities=AgentCapabilities( + streaming=True, + push_notifications=True, + extended_agent_card=True, + ), default_input_modes=['text/plain'], default_output_modes=['text/plain'], supported_interfaces=interfaces, @@ -302,18 +306,25 @@ async def main_async(http_port: int, grpc_port: int) -> None: handler = DefaultRequestHandler( agent_executor=V10AgentExecutor(), task_store=task_store, + agent_card=agent_card, queue_manager=InMemoryQueueManager(), ) + handler_extended = DefaultRequestHandler( + agent_executor=V10AgentExecutor(), + task_store=task_store, + agent_card=agent_card, + queue_manager=InMemoryQueueManager(), + extended_agent_card=agent_card, + ) + app = FastAPI() agent_card_routes = create_agent_card_routes( agent_card=agent_card, card_url='/.well-known/agent-card.json' ) jsonrpc_routes = create_jsonrpc_routes( - agent_card=agent_card, - request_handler=handler, - extended_agent_card=agent_card, + request_handler=handler_extended, rpc_url='/', enable_v0_3_compat=True, ) @@ -323,7 +334,6 @@ async def main_async(http_port: int, grpc_port: int) -> None: ) rest_routes = create_rest_routes( - agent_card=agent_card, request_handler=handler, enable_v0_3_compat=True, ) @@ -331,9 +341,9 @@ async def main_async(http_port: int, grpc_port: int) -> None: server = grpc.aio.server() - compat_servicer = CompatGrpcHandler(agent_card, handler) + compat_servicer = CompatGrpcHandler(handler) a2a_v0_3_pb2_grpc.add_A2AServiceServicer_to_server(compat_servicer, server) - servicer = GrpcHandler(agent_card, handler) + servicer = GrpcHandler(handler) a2a_pb2_grpc.add_A2AServiceServicer_to_server(servicer, server) server.add_insecure_port(f'127.0.0.1:{grpc_port}') diff --git a/samples/hello_world_agent.py b/samples/hello_world_agent.py index e286fa130..909e6550d 100644 --- a/samples/hello_world_agent.py +++ b/samples/hello_world_agent.py @@ -191,17 +191,17 @@ async def serve( task_store = InMemoryTaskStore() request_handler = DefaultRequestHandler( - agent_executor=SampleAgentExecutor(), task_store=task_store + agent_executor=SampleAgentExecutor(), + task_store=task_store, + agent_card=agent_card, ) rest_routes = create_rest_routes( - agent_card=agent_card, request_handler=request_handler, path_prefix='/a2a/rest', enable_v0_3_compat=True, ) jsonrpc_routes = create_jsonrpc_routes( - agent_card=agent_card, request_handler=request_handler, rpc_url='/a2a/jsonrpc', enable_v0_3_compat=True, @@ -216,12 +216,12 @@ async def serve( grpc_server = grpc.aio.server() grpc_server.add_insecure_port(f'{host}:{grpc_port}') - servicer = GrpcHandler(agent_card, request_handler) + servicer = GrpcHandler(request_handler) a2a_pb2_grpc.add_A2AServiceServicer_to_server(servicer, grpc_server) compat_grpc_server = grpc.aio.server() compat_grpc_server.add_insecure_port(f'{host}:{compat_grpc_port}') - compat_servicer = CompatGrpcHandler(agent_card, request_handler) + compat_servicer = CompatGrpcHandler(request_handler) a2a_v0_3_pb2_grpc.add_A2AServiceServicer_to_server( compat_servicer, compat_grpc_server ) diff --git a/src/a2a/compat/v0_3/grpc_handler.py b/src/a2a/compat/v0_3/grpc_handler.py index c9db99557..23d1f831d 100644 --- a/src/a2a/compat/v0_3/grpc_handler.py +++ b/src/a2a/compat/v0_3/grpc_handler.py @@ -12,7 +12,6 @@ from a2a.compat.v0_3 import ( a2a_v0_3_pb2, a2a_v0_3_pb2_grpc, - conversions, proto_utils, ) from a2a.compat.v0_3 import ( @@ -27,9 +26,7 @@ GrpcServerCallContextBuilder, ) from a2a.server.request_handlers.request_handler import RequestHandler -from a2a.types.a2a_pb2 import AgentCard from a2a.utils.errors import A2AError, InvalidParamsError -from a2a.utils.helpers import maybe_await, validate logger = logging.getLogger(__name__) @@ -42,29 +39,21 @@ class CompatGrpcHandler(a2a_v0_3_pb2_grpc.A2AServiceServicer): def __init__( self, - agent_card: AgentCard, request_handler: RequestHandler, context_builder: GrpcServerCallContextBuilder | None = None, - card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] - | None = None, ): """Initializes the CompatGrpcHandler. Args: - agent_card: The AgentCard describing the agent's capabilities (v1.0). request_handler: The underlying `RequestHandler` instance to delegate requests to. context_builder: The CallContextBuilder object. If none the DefaultCallContextBuilder is used. - card_modifier: An optional callback to dynamically modify the public - agent card before it is served. """ - self.agent_card = agent_card self.handler03 = RequestHandler03(request_handler=request_handler) self._context_builder = ( context_builder or DefaultGrpcServerCallContextBuilder() ) - self.card_modifier = card_modifier async def _handle_unary( self, @@ -179,10 +168,6 @@ async def SendStreamingMessage( ) -> AsyncIterable[a2a_v0_3_pb2.StreamResponse]: """Handles the 'SendStreamingMessage' gRPC method (v0.3).""" - @validate( - lambda _: self.agent_card.capabilities.streaming, - 'Streaming is not supported by the agent', - ) async def _handler( server_context: ServerCallContext, ) -> AsyncIterable[a2a_v0_3_pb2.StreamResponse]: @@ -242,10 +227,6 @@ async def TaskSubscription( ) -> AsyncIterable[a2a_v0_3_pb2.StreamResponse]: """Handles the 'TaskSubscription' gRPC method (v0.3).""" - @validate( - lambda _: self.agent_card.capabilities.streaming, - 'Streaming is not supported by the agent', - ) async def _handler( server_context: ServerCallContext, ) -> AsyncIterable[a2a_v0_3_pb2.StreamResponse]: @@ -269,10 +250,6 @@ async def CreateTaskPushNotificationConfig( ) -> a2a_v0_3_pb2.TaskPushNotificationConfig: """Handles the 'CreateTaskPushNotificationConfig' gRPC method (v0.3).""" - @validate( - lambda _: self.agent_card.capabilities.push_notifications, - 'Push notifications are not supported by the agent', - ) async def _handler( server_context: ServerCallContext, ) -> a2a_v0_3_pb2.TaskPushNotificationConfig: @@ -360,12 +337,19 @@ async def GetAgentCard( request: a2a_v0_3_pb2.GetAgentCardRequest, context: grpc.aio.ServicerContext, ) -> a2a_v0_3_pb2.AgentCard: - """Get the agent card for the agent served (v0.3).""" - card_to_serve = self.agent_card - if self.card_modifier: - card_to_serve = await maybe_await(self.card_modifier(card_to_serve)) - return proto_utils.ToProto.agent_card( - conversions.to_compat_agent_card(card_to_serve) + """Get the extended agent card for the agent served (v0.3).""" + + async def _handler( + server_context: ServerCallContext, + ) -> a2a_v0_3_pb2.AgentCard: + req_v03 = types_v03.GetAuthenticatedExtendedCardRequest(id=0) + res_v03 = await self.handler03.on_get_extended_agent_card( + req_v03, server_context + ) + return proto_utils.ToProto.agent_card(res_v03) + + return await self._handle_unary( + context, _handler, a2a_v0_3_pb2.AgentCard() ) async def DeleteTaskPushNotificationConfig( diff --git a/src/a2a/compat/v0_3/jsonrpc_adapter.py b/src/a2a/compat/v0_3/jsonrpc_adapter.py index d01a7e11c..baa2bcda8 100644 --- a/src/a2a/compat/v0_3/jsonrpc_adapter.py +++ b/src/a2a/compat/v0_3/jsonrpc_adapter.py @@ -1,6 +1,6 @@ import logging -from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable +from collections.abc import AsyncIterable, AsyncIterator from typing import TYPE_CHECKING, Any from sse_starlette.sse import EventSourceResponse @@ -11,7 +11,6 @@ from starlette.requests import Request from a2a.server.request_handlers.request_handler import RequestHandler - from a2a.types.a2a_pb2 import AgentCard _package_starlette_installed = True else: @@ -24,7 +23,6 @@ _package_starlette_installed = False -from a2a.compat.v0_3 import conversions from a2a.compat.v0_3 import types as types_v03 from a2a.compat.v0_3.request_handler import RequestHandler03 from a2a.server.context import ServerCallContext @@ -42,8 +40,7 @@ ServerCallContextBuilder, ) from a2a.utils import constants -from a2a.utils.errors import ExtendedAgentCardNotConfiguredError -from a2a.utils.helpers import maybe_await, validate_version +from a2a.utils.helpers import validate_version logger = logging.getLogger(__name__) @@ -65,19 +62,11 @@ class JSONRPC03Adapter: 'agent/getAuthenticatedExtendedCard': types_v03.GetAuthenticatedExtendedCardRequest, } - def __init__( # noqa: PLR0913 + def __init__( self, - agent_card: 'AgentCard', http_handler: 'RequestHandler', - extended_agent_card: 'AgentCard | None' = None, context_builder: 'ServerCallContextBuilder | None' = None, - card_modifier: 'Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] | None' = None, - extended_card_modifier: 'Callable[[AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard] | None' = None, ): - self.agent_card = agent_card - self.extended_agent_card = extended_agent_card - self.card_modifier = card_modifier - self.extended_card_modifier = extended_card_modifier self.handler = RequestHandler03( request_handler=http_handler, ) @@ -227,7 +216,7 @@ async def _process_non_streaming_request( ) ) elif method == 'agent/getAuthenticatedExtendedCard': - res_card = await self.get_authenticated_extended_card( + res_card = await self.handler.on_get_extended_agent_card( request_obj, context ) result = types_v03.GetAuthenticatedExtendedCardResponse( @@ -244,31 +233,6 @@ async def _process_non_streaming_request( ) ) - async def get_authenticated_extended_card( - self, - request: types_v03.GetAuthenticatedExtendedCardRequest, - context: ServerCallContext, - ) -> types_v03.AgentCard: - """Handles the 'agent/getAuthenticatedExtendedCard' JSON-RPC method.""" - if not self.agent_card.capabilities.extended_agent_card: - raise ExtendedAgentCardNotConfiguredError( - message='Authenticated card not supported' - ) - - base_card = self.extended_agent_card - if base_card is None: - base_card = self.agent_card - - card_to_serve = base_card - if self.extended_card_modifier and context: - card_to_serve = await maybe_await( - self.extended_card_modifier(base_card, context) - ) - elif self.card_modifier: - card_to_serve = await maybe_await(self.card_modifier(base_card)) - - return conversions.to_compat_agent_card(card_to_serve) - @validate_version(constants.PROTOCOL_VERSION_0_3) async def _process_streaming_request( self, diff --git a/src/a2a/compat/v0_3/request_handler.py b/src/a2a/compat/v0_3/request_handler.py index 6ec675312..d79a5cc5d 100644 --- a/src/a2a/compat/v0_3/request_handler.py +++ b/src/a2a/compat/v0_3/request_handler.py @@ -9,9 +9,7 @@ from a2a.server.request_handlers.request_handler import RequestHandler from a2a.types.a2a_pb2 import Task from a2a.utils import proto_utils as core_proto_utils -from a2a.utils.errors import ( - TaskNotFoundError, -) +from a2a.utils.errors import TaskNotFoundError logger = logging.getLogger(__name__) @@ -170,3 +168,15 @@ async def on_delete_task_push_notification_config( await self.request_handler.on_delete_task_push_notification_config( v10_req, context ) + + async def on_get_extended_agent_card( + self, + request: types_v03.GetAuthenticatedExtendedCardRequest, + context: ServerCallContext, + ) -> types_v03.AgentCard: + """Gets the authenticated extended agent card using v0.3 protocol types.""" + v10_req = conversions.to_core_get_extended_agent_card_request(request) + v10_card = await self.request_handler.on_get_extended_agent_card( + v10_req, context + ) + return conversions.to_compat_agent_card(v10_card) diff --git a/src/a2a/compat/v0_3/rest_adapter.py b/src/a2a/compat/v0_3/rest_adapter.py index 27aba2aad..a2a9b56ee 100644 --- a/src/a2a/compat/v0_3/rest_adapter.py +++ b/src/a2a/compat/v0_3/rest_adapter.py @@ -11,8 +11,8 @@ from starlette.requests import Request from starlette.responses import JSONResponse, Response + from a2a.server.context import ServerCallContext from a2a.server.request_handlers.request_handler import RequestHandler - from a2a.types.a2a_pb2 import AgentCard _package_starlette_installed = True else: @@ -31,9 +31,7 @@ _package_starlette_installed = False -from a2a.compat.v0_3 import conversions from a2a.compat.v0_3.rest_handler import REST03Handler -from a2a.server.context import ServerCallContext from a2a.server.routes.common import ( DefaultServerCallContextBuilder, ServerCallContextBuilder, @@ -43,10 +41,8 @@ rest_stream_error_handler, ) from a2a.utils.errors import ( - ExtendedAgentCardNotConfiguredError, InvalidRequestError, ) -from a2a.utils.helpers import maybe_await logger = logging.getLogger(__name__) @@ -58,22 +54,12 @@ class REST03Adapter: Defines v0.3 REST request processors and their routes, as well as managing response generation including Server-Sent Events (SSE). """ - def __init__( # noqa: PLR0913 + def __init__( self, - agent_card: 'AgentCard', http_handler: 'RequestHandler', - extended_agent_card: 'AgentCard | None' = None, context_builder: 'ServerCallContextBuilder | None' = None, - card_modifier: 'Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] | None' = None, - extended_card_modifier: 'Callable[[AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard] | None' = None, ): - self.agent_card = agent_card - self.extended_agent_card = extended_agent_card - self.card_modifier = card_modifier - self.extended_card_modifier = extended_card_modifier - self.handler = REST03Handler( - agent_card=agent_card, request_handler=http_handler - ) + self.handler = REST03Handler(request_handler=http_handler) self._context_builder = ( context_builder or DefaultServerCallContextBuilder() ) @@ -113,39 +99,6 @@ async def event_generator( event_generator(method(request, call_context)) ) - async def handle_get_agent_card( - self, request: Request, call_context: ServerCallContext - ) -> dict[str, Any]: - """Handles GET requests for the agent card endpoint.""" - card_to_serve = self.agent_card - if self.card_modifier: - card_to_serve = await maybe_await(self.card_modifier(card_to_serve)) - v03_card = conversions.to_compat_agent_card(card_to_serve) - return v03_card.model_dump(mode='json', exclude_none=True) - - async def handle_authenticated_agent_card( - self, request: Request, call_context: ServerCallContext - ) -> dict[str, Any]: - """Hook for per credential agent card response.""" - if not self.agent_card.capabilities.extended_agent_card: - raise ExtendedAgentCardNotConfiguredError( - message='Authenticated card not supported' - ) - card_to_serve = self.extended_agent_card - - if not card_to_serve: - card_to_serve = self.agent_card - - if self.extended_card_modifier: - card_to_serve = await maybe_await( - self.extended_card_modifier(card_to_serve, call_context) - ) - elif self.card_modifier: - card_to_serve = await maybe_await(self.card_modifier(card_to_serve)) - - v03_card = conversions.to_compat_agent_card(card_to_serve) - return v03_card.model_dump(mode='json', exclude_none=True) - def routes(self) -> dict[tuple[str, str], Callable[[Request], Any]]: """Constructs a dictionary of API routes and their corresponding handlers.""" routes: dict[tuple[str, str], Callable[[Request], Any]] = { @@ -191,10 +144,9 @@ def routes(self) -> dict[tuple[str, str], Callable[[Request], Any]]: ('/v1/tasks', 'GET'): functools.partial( self._handle_request, self.handler.list_tasks ), + ('/v1/card', 'GET'): functools.partial( + self._handle_request, self.handler.on_get_extended_agent_card + ), } - if self.agent_card.capabilities.extended_agent_card: - routes[('/v1/card', 'GET')] = functools.partial( - self._handle_request, self.handle_authenticated_agent_card - ) return routes diff --git a/src/a2a/compat/v0_3/rest_handler.py b/src/a2a/compat/v0_3/rest_handler.py index 470f94b3e..0c64506cb 100644 --- a/src/a2a/compat/v0_3/rest_handler.py +++ b/src/a2a/compat/v0_3/rest_handler.py @@ -10,7 +10,6 @@ from starlette.requests import Request from a2a.server.request_handlers.request_handler import RequestHandler - from a2a.types.a2a_pb2 import AgentCard _package_starlette_installed = True else: @@ -30,7 +29,6 @@ from a2a.server.context import ServerCallContext from a2a.utils import constants from a2a.utils.helpers import ( - validate, validate_version, ) from a2a.utils.telemetry import SpanKind, trace_class @@ -45,16 +43,13 @@ class REST03Handler: def __init__( self, - agent_card: 'AgentCard', request_handler: 'RequestHandler', ): """Initializes the REST03Handler. Args: - agent_card: The AgentCard describing the agent's capabilities (v1.0). request_handler: The underlying `RequestHandler` instance to delegate requests to (v1.0). """ - self.agent_card = agent_card self.handler03 = RequestHandler03(request_handler=request_handler) @validate_version(constants.PROTOCOL_VERSION_0_3) @@ -84,10 +79,6 @@ async def on_message_send( return MessageToDict(pb2_v03_resp) @validate_version(constants.PROTOCOL_VERSION_0_3) - @validate( - lambda self: self.agent_card.capabilities.streaming, - 'Streaming is not supported by the agent', - ) async def on_message_send_stream( self, request: Request, @@ -142,10 +133,6 @@ async def on_cancel_task( return MessageToDict(pb2_v03_task) @validate_version(constants.PROTOCOL_VERSION_0_3) - @validate( - lambda self: self.agent_card.capabilities.streaming, - 'Streaming is not supported by the agent', - ) async def on_subscribe_to_task( self, request: Request, @@ -208,10 +195,6 @@ async def get_push_notification( return MessageToDict(pb2_v03_config) @validate_version(constants.PROTOCOL_VERSION_0_3) - @validate( - lambda self: self.agent_card.capabilities.push_notifications, - 'Push notifications are not supported by the agent', - ) async def set_push_notification( self, request: Request, @@ -317,3 +300,16 @@ async def list_tasks( ) -> dict[str, Any]: """Handles the 'tasks/list' REST method.""" raise NotImplementedError('list tasks not implemented') + + @validate_version(constants.PROTOCOL_VERSION_0_3) + async def on_get_extended_agent_card( + self, + request: Request, + context: ServerCallContext, + ) -> dict[str, Any]: + """Handles the 'v1/agent/authenticatedExtendedAgentCard' REST method.""" + rpc_req = types_v03.GetAuthenticatedExtendedCardRequest(id=0) + v03_resp = await self.handler03.on_get_extended_agent_card( + rpc_req, context + ) + return v03_resp.model_dump(mode='json', exclude_none=True) diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index ba1f08caa..e6b992250 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -1,7 +1,7 @@ import asyncio import logging -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, Awaitable, Callable from typing import cast from a2a.server.agent_execution import ( @@ -21,6 +21,7 @@ ) from a2a.server.request_handlers.request_handler import ( RequestHandler, + validate, validate_request_params, ) from a2a.server.tasks import ( @@ -32,8 +33,10 @@ TaskStore, ) from a2a.types.a2a_pb2 import ( + AgentCard, CancelTaskRequest, DeleteTaskPushNotificationConfigRequest, + GetExtendedAgentCardRequest, GetTaskPushNotificationConfigRequest, GetTaskRequest, ListTaskPushNotificationConfigsRequest, @@ -48,6 +51,7 @@ TaskState, ) from a2a.utils.errors import ( + ExtendedAgentCardNotConfiguredError, InternalError, InvalidParamsError, PushNotificationNotSupportedError, @@ -55,6 +59,7 @@ TaskNotFoundError, UnsupportedOperationError, ) +from a2a.utils.helpers import maybe_await from a2a.utils.task import ( apply_history_length, validate_history_length, @@ -89,27 +94,39 @@ def __init__( # noqa: PLR0913 self, agent_executor: AgentExecutor, task_store: TaskStore, + agent_card: AgentCard, queue_manager: QueueManager | None = None, push_config_store: PushNotificationConfigStore | None = None, push_sender: PushNotificationSender | None = None, request_context_builder: RequestContextBuilder | None = None, + extended_agent_card: AgentCard | None = None, + extended_card_modifier: Callable[ + [AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard + ] + | None = None, ) -> None: """Initializes the DefaultRequestHandler. Args: agent_executor: The `AgentExecutor` instance to run agent logic. task_store: The `TaskStore` instance to manage task persistence. + agent_card: The `AgentCard` describing the agent's capabilities. queue_manager: The `QueueManager` instance to manage event queues. Defaults to `InMemoryQueueManager`. push_config_store: The `PushNotificationConfigStore` instance for managing push notification configurations. Defaults to None. push_sender: The `PushNotificationSender` instance for sending push notifications. Defaults to None. request_context_builder: The `RequestContextBuilder` instance used to build request contexts. Defaults to `SimpleRequestContextBuilder`. + extended_agent_card: An optional, distinct `AgentCard` to be served at the extended card endpoint. + extended_card_modifier: An optional callback to dynamically modify the extended `AgentCard` before it is served. """ self.agent_executor = agent_executor self.task_store = task_store + self._agent_card = agent_card self._queue_manager = queue_manager or InMemoryQueueManager() self._push_config_store = push_config_store self._push_sender = push_sender + self.extended_agent_card = extended_agent_card + self.extended_card_modifier = extended_card_modifier self._request_context_builder = ( request_context_builder or SimpleRequestContextBuilder( @@ -397,6 +414,10 @@ async def push_notification_callback(event: Event) -> None: return result @validate_request_params + @validate( + lambda self: self._agent_card.capabilities.streaming, + 'Streaming is not supported by the agent', + ) async def on_message_send_stream( self, params: SendMessageRequest, @@ -486,6 +507,11 @@ async def _cleanup_producer( self._running_agents.pop(task_id, None) @validate_request_params + @validate( + lambda self: self._agent_card.capabilities.push_notifications, + error_message='Push notifications are not supported by the agent', + error_type=PushNotificationNotSupportedError, + ) async def on_create_task_push_notification_config( self, params: TaskPushNotificationConfig, @@ -512,6 +538,11 @@ async def on_create_task_push_notification_config( return params @validate_request_params + @validate( + lambda self: self._agent_card.capabilities.push_notifications, + error_message='Push notifications are not supported by the agent', + error_type=PushNotificationNotSupportedError, + ) async def on_get_task_push_notification_config( self, params: GetTaskPushNotificationConfigRequest, @@ -538,9 +569,13 @@ async def on_get_task_push_notification_config( if config.id == config_id: return config - raise InternalError(message='Push notification config not found') + raise TaskNotFoundError @validate_request_params + @validate( + lambda self: self._agent_card.capabilities.streaming, + 'Streaming is not supported by the agent', + ) async def on_subscribe_to_task( self, params: SubscribeToTaskRequest, @@ -584,6 +619,11 @@ async def on_subscribe_to_task( yield event @validate_request_params + @validate( + lambda self: self._agent_card.capabilities.push_notifications, + error_message='Push notifications are not supported by the agent', + error_type=PushNotificationNotSupportedError, + ) async def on_list_task_push_notification_configs( self, params: ListTaskPushNotificationConfigsRequest, @@ -610,6 +650,11 @@ async def on_list_task_push_notification_configs( ) @validate_request_params + @validate( + lambda self: self._agent_card.capabilities.push_notifications, + error_message='Push notifications are not supported by the agent', + error_type=PushNotificationNotSupportedError, + ) async def on_delete_task_push_notification_config( self, params: DeleteTaskPushNotificationConfigRequest, @@ -629,3 +674,28 @@ async def on_delete_task_push_notification_config( raise TaskNotFoundError await self._push_config_store.delete_info(task_id, context, config_id) + + @validate_request_params + @validate( + lambda self: self._agent_card.capabilities.extended_agent_card, + error_message='The agent does not support authenticated extended cards', + ) + async def on_get_extended_agent_card( + self, + params: GetExtendedAgentCardRequest, + context: ServerCallContext, + ) -> AgentCard: + """Default handler for 'GetExtendedAgentCard'. + + Requires `capabilities.extended_agent_card` to be true. + """ + extended_card = self.extended_agent_card + if not extended_card: + raise ExtendedAgentCardNotConfiguredError + + if self.extended_card_modifier: + return await maybe_await( + self.extended_card_modifier(extended_card, context) + ) + + return extended_card diff --git a/src/a2a/server/request_handlers/default_request_handler_v2.py b/src/a2a/server/request_handlers/default_request_handler_v2.py index e05593bec..ccc9cdd0e 100644 --- a/src/a2a/server/request_handlers/default_request_handler_v2.py +++ b/src/a2a/server/request_handlers/default_request_handler_v2.py @@ -18,11 +18,14 @@ from a2a.server.agent_execution.active_task_registry import ActiveTaskRegistry from a2a.server.request_handlers.request_handler import ( RequestHandler, + validate, validate_request_params, ) from a2a.types.a2a_pb2 import ( + AgentCard, CancelTaskRequest, DeleteTaskPushNotificationConfigRequest, + GetExtendedAgentCardRequest, GetTaskPushNotificationConfigRequest, GetTaskRequest, ListTaskPushNotificationConfigsRequest, @@ -37,12 +40,14 @@ TaskStatusUpdateEvent, ) from a2a.utils.errors import ( + ExtendedAgentCardNotConfiguredError, InternalError, InvalidParamsError, + PushNotificationNotSupportedError, TaskNotCancelableError, TaskNotFoundError, - UnsupportedOperationError, ) +from a2a.utils.helpers import maybe_await from a2a.utils.task import ( apply_history_length, validate_history_length, @@ -52,7 +57,7 @@ if TYPE_CHECKING: - from collections.abc import AsyncGenerator + from collections.abc import AsyncGenerator, Awaitable, Callable from a2a.server.agent_execution.active_task import ActiveTask from a2a.server.context import ServerCallContext @@ -80,16 +85,25 @@ def __init__( # noqa: PLR0913 self, agent_executor: AgentExecutor, task_store: TaskStore, + agent_card: AgentCard, queue_manager: Any | None = None, # Kept for backward compat in signature push_config_store: PushNotificationConfigStore | None = None, push_sender: PushNotificationSender | None = None, request_context_builder: RequestContextBuilder | None = None, + extended_agent_card: AgentCard | None = None, + extended_card_modifier: Callable[ + [AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard + ] + | None = None, ) -> None: self.agent_executor = agent_executor self.task_store = task_store + self._agent_card = agent_card self._push_config_store = push_config_store self._push_sender = push_sender + self.extended_agent_card = extended_agent_card + self.extended_card_modifier = extended_card_modifier self._request_context_builder = ( request_context_builder or SimpleRequestContextBuilder( @@ -286,6 +300,10 @@ async def on_message_send( # noqa: D102 # TODO: Unify with on_message_send @validate_request_params + @validate( + lambda self: self._agent_card.capabilities.streaming, + 'Streaming is not supported by the agent', + ) async def on_message_send_stream( # noqa: D102 self, params: SendMessageRequest, @@ -310,13 +328,18 @@ async def on_message_send_stream( # noqa: D102 yield event @validate_request_params + @validate( + lambda self: self._agent_card.capabilities.push_notifications, + error_message='Push notifications are not supported by the agent', + error_type=PushNotificationNotSupportedError, + ) async def on_create_task_push_notification_config( # noqa: D102 self, params: TaskPushNotificationConfig, context: ServerCallContext, ) -> TaskPushNotificationConfig: if not self._push_config_store: - raise UnsupportedOperationError + raise PushNotificationNotSupportedError task_id = params.task_id task: Task | None = await self.task_store.get(task_id, context) @@ -332,13 +355,18 @@ async def on_create_task_push_notification_config( # noqa: D102 return params @validate_request_params + @validate( + lambda self: self._agent_card.capabilities.push_notifications, + error_message='Push notifications are not supported by the agent', + error_type=PushNotificationNotSupportedError, + ) async def on_get_task_push_notification_config( # noqa: D102 self, params: GetTaskPushNotificationConfigRequest, context: ServerCallContext, ) -> TaskPushNotificationConfig: if not self._push_config_store: - raise UnsupportedOperationError + raise PushNotificationNotSupportedError task_id = params.task_id config_id = params.id @@ -354,9 +382,13 @@ async def on_get_task_push_notification_config( # noqa: D102 if config.id == config_id: return config - raise InternalError(message='Push notification config not found') + raise TaskNotFoundError @validate_request_params + @validate( + lambda self: self._agent_card.capabilities.streaming, + 'Streaming is not supported by the agent', + ) async def on_subscribe_to_task( # noqa: D102 self, params: SubscribeToTaskRequest, @@ -374,13 +406,18 @@ async def on_subscribe_to_task( # noqa: D102 yield event @validate_request_params + @validate( + lambda self: self._agent_card.capabilities.push_notifications, + error_message='Push notifications are not supported by the agent', + error_type=PushNotificationNotSupportedError, + ) async def on_list_task_push_notification_configs( # noqa: D102 self, params: ListTaskPushNotificationConfigsRequest, context: ServerCallContext, ) -> ListTaskPushNotificationConfigsResponse: if not self._push_config_store: - raise UnsupportedOperationError + raise PushNotificationNotSupportedError task_id = params.task_id task: Task | None = await self.task_store.get(task_id, context) @@ -396,13 +433,18 @@ async def on_list_task_push_notification_configs( # noqa: D102 ) @validate_request_params + @validate( + lambda self: self._agent_card.capabilities.push_notifications, + error_message='Push notifications are not supported by the agent', + error_type=PushNotificationNotSupportedError, + ) async def on_delete_task_push_notification_config( # noqa: D102 self, params: DeleteTaskPushNotificationConfigRequest, context: ServerCallContext, ) -> None: if not self._push_config_store: - raise UnsupportedOperationError + raise PushNotificationNotSupportedError task_id = params.task_id config_id = params.id @@ -411,3 +453,28 @@ async def on_delete_task_push_notification_config( # noqa: D102 raise TaskNotFoundError await self._push_config_store.delete_info(task_id, context, config_id) + + @validate_request_params + @validate( + lambda self: self._agent_card.capabilities.extended_agent_card, + error_message='The agent does not support authenticated extended cards', + ) + async def on_get_extended_agent_card( + self, + params: GetExtendedAgentCardRequest, + context: ServerCallContext, + ) -> AgentCard: + """Default handler for 'GetExtendedAgentCard'. + + Requires `capabilities.extended_agent_card` to be true. + """ + extended_card = self.extended_agent_card + if not extended_card: + raise ExtendedAgentCardNotConfiguredError + + if self.extended_card_modifier: + return await maybe_await( + self.extended_card_modifier(extended_card, context) + ) + + return extended_card diff --git a/src/a2a/server/request_handlers/grpc_handler.py b/src/a2a/server/request_handlers/grpc_handler.py index 60aa41d22..2ccfa9bdd 100644 --- a/src/a2a/server/request_handlers/grpc_handler.py +++ b/src/a2a/server/request_handlers/grpc_handler.py @@ -32,10 +32,8 @@ from a2a.server.context import ServerCallContext from a2a.server.request_handlers.request_handler import RequestHandler from a2a.types import a2a_pb2 -from a2a.types.a2a_pb2 import AgentCard from a2a.utils import proto_utils from a2a.utils.errors import A2A_ERROR_REASONS, A2AError, TaskNotFoundError -from a2a.utils.helpers import maybe_await, validate from a2a.utils.proto_utils import validation_errors_to_bad_request @@ -109,30 +107,22 @@ class GrpcHandler(a2a_grpc.A2AServiceServicer): def __init__( self, - agent_card: AgentCard, request_handler: RequestHandler, context_builder: GrpcServerCallContextBuilder | None = None, - card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] - | None = None, ): """Initializes the GrpcHandler. Args: - agent_card: The AgentCard describing the agent's capabilities. request_handler: The underlying `RequestHandler` instance to delegate requests to. context_builder: The GrpcContextBuilder used to construct the ServerCallContext passed to the request_handler. If None the DefaultGrpcContextBuilder is used. - card_modifier: An optional callback to dynamically modify the public - agent card before it is served. """ - self.agent_card = agent_card self.request_handler = request_handler self._context_builder = ( context_builder or DefaultGrpcServerCallContextBuilder() ) - self.card_modifier = card_modifier async def _handle_unary( self, @@ -195,10 +185,6 @@ async def SendStreamingMessage( ) -> AsyncIterable[a2a_pb2.StreamResponse]: """Handles the 'StreamMessage' gRPC method.""" - @validate( - lambda _: self.agent_card.capabilities.streaming, - 'Streaming is not supported by the agent', - ) async def _handler( server_context: ServerCallContext, ) -> AsyncIterable[a2a_pb2.StreamResponse]: @@ -236,10 +222,6 @@ async def SubscribeToTask( ) -> AsyncIterable[a2a_pb2.StreamResponse]: """Handles the 'SubscribeToTask' gRPC method.""" - @validate( - lambda _: self.agent_card.capabilities.streaming, - 'Streaming is not supported by the agent', - ) async def _handler( server_context: ServerCallContext, ) -> AsyncIterable[a2a_pb2.StreamResponse]: @@ -278,10 +260,6 @@ async def CreateTaskPushNotificationConfig( ) -> a2a_pb2.TaskPushNotificationConfig: """Handles the 'CreateTaskPushNotificationConfig' gRPC method.""" - @validate( - lambda _: self.agent_card.capabilities.push_notifications, - 'Push notifications are not supported by the agent', - ) async def _handler( server_context: ServerCallContext, ) -> a2a_pb2.TaskPushNotificationConfig: @@ -376,10 +354,17 @@ async def GetExtendedAgentCard( context: grpc.aio.ServicerContext, ) -> a2a_pb2.AgentCard: """Get the extended agent card for the agent served.""" - card_to_serve = self.agent_card - if self.card_modifier: - card_to_serve = await maybe_await(self.card_modifier(card_to_serve)) - return card_to_serve + + async def _handler( + server_context: ServerCallContext, + ) -> a2a_pb2.AgentCard: + return await self.request_handler.on_get_extended_agent_card( + request, server_context + ) + + return await self._handle_unary( + request, context, _handler, a2a_pb2.AgentCard() + ) async def abort_context( self, error: A2AError, context: grpc.aio.ServicerContext diff --git a/src/a2a/server/request_handlers/request_handler.py b/src/a2a/server/request_handlers/request_handler.py index 23b0f2b95..6fb42098f 100644 --- a/src/a2a/server/request_handlers/request_handler.py +++ b/src/a2a/server/request_handlers/request_handler.py @@ -1,5 +1,6 @@ import functools import inspect +import logging from abc import ABC, abstractmethod from collections.abc import AsyncGenerator, Callable @@ -10,8 +11,10 @@ from a2a.server.context import ServerCallContext from a2a.server.events.event_queue import Event from a2a.types.a2a_pb2 import ( + AgentCard, CancelTaskRequest, DeleteTaskPushNotificationConfigRequest, + GetExtendedAgentCardRequest, GetTaskPushNotificationConfigRequest, GetTaskRequest, ListTaskPushNotificationConfigsRequest, @@ -32,7 +35,7 @@ class RequestHandler(ABC): """A2A request handler interface. This interface defines the methods that an A2A server implementation must - provide to handle incoming JSON-RPC requests. + provide to handle incoming A2A requests from any transport (gRPC, REST, JSON-RPC). """ @abstractmethod @@ -59,7 +62,7 @@ async def on_list_tasks( ) -> ListTasksResponse: """Handles the tasks/list method. - Retrieves all task for an agent. Supports filtering, pagination, + Retrieves all tasks for an agent. Supports filtering, pagination, ordering, limiting the history length, excluding artifacts, etc. Args: @@ -124,10 +127,8 @@ async def on_message_send_stream( Yields: `Event` objects from the agent's execution. - - Raises: - UnsupportedOperationError: By default, if not implemented. """ + # This is needed for typechecker to recognise this method as an async generator. raise UnsupportedOperationError yield @@ -183,9 +184,6 @@ async def on_subscribe_to_task( Yields: `Event` objects from the agent's ongoing execution for the specified task. - - Raises: - UnsupportedOperationError: By default, if not implemented. """ raise UnsupportedOperationError yield @@ -226,6 +224,25 @@ async def on_delete_task_push_notification_config( None """ + @abstractmethod + async def on_get_extended_agent_card( + self, + params: GetExtendedAgentCardRequest, + context: ServerCallContext, + ) -> AgentCard: + """Handles the 'GetExtendedAgentCard' method. + + Retrieves the extended agent card for the agent. + + Args: + params: Parameters for the request. + context: Context provided by the server. + + Returns: + The `AgentCard` object representing the extended properties of the agent. + + """ + def validate_request_params(method: Callable) -> Callable: """Decorator for RequestHandler methods to validate required fields on incoming requests.""" @@ -268,3 +285,128 @@ async def async_wrapper( return await method(self, params, context, *args, **kwargs) return async_wrapper + + +def validate( + expression: Callable[[Any], bool], + error_message: str | None = None, + error_type: type[Exception] = UnsupportedOperationError, +) -> Callable: + """Decorator that validates if a given expression evaluates to True. + + Typically used on class methods to check capabilities or configuration + before executing the method's logic. If the expression is False, + the specified `error_type` (defaults to `UnsupportedOperationError`) is raised. + + Args: + expression: A callable that takes the instance (`self`) as its argument + and returns a boolean. + error_message: An optional custom error message for the error raised. + If None, the string representation of the expression will be used. + error_type: The exception class to raise on validation failure. + Must take a `message` keyword argument (inherited from A2AError). + + Examples: + Demonstrating with an async method: + >>> import asyncio + >>> from a2a.utils.errors import UnsupportedOperationError + >>> + >>> class MyAgent: + ... def __init__(self, streaming_enabled: bool): + ... self.streaming_enabled = streaming_enabled + ... + ... @validate( + ... lambda self: self.streaming_enabled, + ... 'Streaming is not enabled for this agent', + ... ) + ... async def stream_response(self, message: str): + ... return f'Streaming: {message}' + >>> + >>> async def run_async_test(): + ... # Successful call + ... agent_ok = MyAgent(streaming_enabled=True) + ... result = await agent_ok.stream_response('hello') + ... print(result) + ... + ... # Call that fails validation + ... agent_fail = MyAgent(streaming_enabled=False) + ... try: + ... await agent_fail.stream_response('world') + ... except UnsupportedOperationError as e: + ... print(e.message) + >>> + >>> asyncio.run(run_async_test()) + Streaming: hello + Streaming is not enabled for this agent + + Demonstrating with a sync method: + >>> class SecureAgent: + ... def __init__(self): + ... self.auth_enabled = False + ... + ... @validate( + ... lambda self: self.auth_enabled, + ... 'Authentication must be enabled for this operation', + ... ) + ... def secure_operation(self, data: str): + ... return f'Processing secure data: {data}' + >>> + >>> # Error case example + >>> agent = SecureAgent() + >>> try: + ... agent.secure_operation('secret') + ... except UnsupportedOperationError as e: + ... print(e.message) + Authentication must be enabled for this operation + + Note: + This decorator works with both sync and async methods automatically. + """ + + def decorator(function: Callable) -> Callable: + if inspect.isasyncgenfunction(function): + + @functools.wraps(function) + async def async_gen_wrapper(self: Any, *args, **kwargs) -> Any: + if not expression(self): + final_message = error_message or str(expression) + logging.getLogger(__name__).error( + 'Validation failure: %s', final_message + ) + raise error_type(final_message) + inner = function(self, *args, **kwargs) + try: + async for item in inner: + yield item + finally: + await inner.aclose() + + return async_gen_wrapper + + if inspect.iscoroutinefunction(function): + + @functools.wraps(function) + async def async_wrapper(self: Any, *args, **kwargs) -> Any: + if not expression(self): + final_message = error_message or str(expression) + logging.getLogger(__name__).error( + 'Validation failure: %s', final_message + ) + raise error_type(final_message) + return await function(self, *args, **kwargs) + + return async_wrapper + + @functools.wraps(function) + def sync_wrapper(self: Any, *args, **kwargs) -> Any: + if not expression(self): + final_message = error_message or str(expression) + logging.getLogger(__name__).error( + 'Validation failure: %s', final_message + ) + raise error_type(final_message) + return function(self, *args, **kwargs) + + return sync_wrapper + + return decorator diff --git a/src/a2a/server/routes/jsonrpc_dispatcher.py b/src/a2a/server/routes/jsonrpc_dispatcher.py index e0f0042b0..de20610f6 100644 --- a/src/a2a/server/routes/jsonrpc_dispatcher.py +++ b/src/a2a/server/routes/jsonrpc_dispatcher.py @@ -4,7 +4,7 @@ import logging import traceback -from collections.abc import AsyncGenerator, Awaitable, Callable +from collections.abc import AsyncGenerator from typing import TYPE_CHECKING, Any from google.protobuf.json_format import MessageToDict, ParseDict @@ -32,7 +32,6 @@ ServerCallContextBuilder, ) from a2a.types.a2a_pb2 import ( - AgentCard, CancelTaskRequest, DeleteTaskPushNotificationConfigRequest, GetExtendedAgentCardRequest, @@ -49,11 +48,10 @@ from a2a.utils import constants, proto_utils from a2a.utils.errors import ( A2AError, - ExtendedAgentCardNotConfiguredError, TaskNotFoundError, UnsupportedOperationError, ) -from a2a.utils.helpers import maybe_await, validate, validate_version +from a2a.utils.helpers import validate_version from a2a.utils.telemetry import SpanKind, trace_class @@ -130,36 +128,20 @@ class JsonRpcDispatcher: 'GetExtendedAgentCard': GetExtendedAgentCardRequest, } - def __init__( # noqa: PLR0913 + def __init__( self, - agent_card: AgentCard, request_handler: RequestHandler, - extended_agent_card: AgentCard | None = None, context_builder: ServerCallContextBuilder | None = None, - card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] - | None = None, - extended_card_modifier: Callable[ - [AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard - ] - | None = None, enable_v0_3_compat: bool = False, ) -> None: """Initializes the JsonRpcDispatcher. Args: - agent_card: The AgentCard describing the agent's capabilities. request_handler: The handler instance responsible for processing A2A requests via http. - extended_agent_card: An optional, distinct AgentCard to be served - at the authenticated extended card endpoint. context_builder: The ServerCallContextBuilder used to construct the ServerCallContext passed to the request_handler. If None the DefaultServerCallContextBuilder is used. - card_modifier: An optional callback to dynamically modify the public - agent card before it is served. - extended_card_modifier: An optional callback to dynamically modify - the extended agent card before it is served. It receives the - call context. enable_v0_3_compat: Whether to enable v0.3 backward compatibility on the same endpoint. """ if not _package_starlette_installed: @@ -169,11 +151,7 @@ def __init__( # noqa: PLR0913 ' optional dependencies, `a2a-sdk[http-server]`.' ) - self.agent_card = agent_card self.request_handler = request_handler - self.extended_agent_card = extended_agent_card - self.card_modifier = card_modifier - self.extended_card_modifier = extended_card_modifier self._context_builder = ( context_builder or DefaultServerCallContextBuilder() ) @@ -182,12 +160,8 @@ def __init__( # noqa: PLR0913 if self.enable_v0_3_compat: self._v03_adapter = JSONRPC03Adapter( - agent_card=agent_card, http_handler=request_handler, - extended_agent_card=extended_agent_card, context_builder=self._context_builder, - card_modifier=card_modifier, - extended_card_modifier=extended_card_modifier, ) def _generate_error_response( @@ -333,6 +307,9 @@ async def handle_requests(self, request: Request) -> Response: # noqa: PLR0911, call_context.state['request_id'] = request_id # Route streaming requests by method name + handler_result: ( + AsyncGenerator[dict[str, Any], None] | dict[str, Any] + ) if method in ('SendStreamingMessage', 'SubscribeToTask'): handler_result = await self._process_streaming_request( request_id, specific_request, call_context @@ -369,10 +346,6 @@ async def handle_requests(self, request: Request) -> Response: # noqa: PLR0911, ) @validate_version(constants.PROTOCOL_VERSION_1_0) - @validate( - lambda self: self.agent_card.capabilities.streaming, - 'Streaming is not supported by the agent', - ) async def _process_streaming_request( self, request_id: str | int | None, @@ -456,10 +429,6 @@ async def _handle_list_tasks( always_print_fields_with_no_presence=True, ) - @validate( - lambda self: self.agent_card.capabilities.push_notifications, - 'Push notifications are not supported by the agent', - ) async def _handle_create_task_push_notification_config( self, request_obj: TaskPushNotificationConfig, @@ -512,20 +481,10 @@ async def _handle_get_extended_agent_card( request_obj: GetExtendedAgentCardRequest, context: ServerCallContext, ) -> dict[str, Any]: - if not self.agent_card.capabilities.extended_agent_card: - raise ExtendedAgentCardNotConfiguredError( - message='The agent does not have an extended agent card configured' - ) - base_card = self.extended_agent_card or self.agent_card - card_to_serve = base_card - if self.extended_card_modifier and context: - card_to_serve = await maybe_await( - self.extended_card_modifier(base_card, context) - ) - elif self.card_modifier: - card_to_serve = await maybe_await(self.card_modifier(base_card)) - - return MessageToDict(card_to_serve, preserving_proto_field_name=False) + card = await self.request_handler.on_get_extended_agent_card( + request_obj, context + ) + return MessageToDict(card, preserving_proto_field_name=False) @validate_version(constants.PROTOCOL_VERSION_1_0) async def _process_non_streaming_request( # noqa: PLR0911 diff --git a/src/a2a/server/routes/jsonrpc_routes.py b/src/a2a/server/routes/jsonrpc_routes.py index f19625379..a94d513ae 100644 --- a/src/a2a/server/routes/jsonrpc_routes.py +++ b/src/a2a/server/routes/jsonrpc_routes.py @@ -1,4 +1,5 @@ -from collections.abc import Awaitable, Callable +import logging + from typing import TYPE_CHECKING, Any @@ -16,26 +17,18 @@ _package_starlette_installed = False - -from a2a.server.context import ServerCallContext from a2a.server.request_handlers.request_handler import RequestHandler from a2a.server.routes.common import ServerCallContextBuilder from a2a.server.routes.jsonrpc_dispatcher import JsonRpcDispatcher -from a2a.types.a2a_pb2 import AgentCard -def create_jsonrpc_routes( # noqa: PLR0913 - agent_card: AgentCard, +logger = logging.getLogger(__name__) + + +def create_jsonrpc_routes( request_handler: RequestHandler, rpc_url: str, - extended_agent_card: AgentCard | None = None, context_builder: ServerCallContextBuilder | None = None, - card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] - | None = None, - extended_card_modifier: Callable[ - [AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard - ] - | None = None, enable_v0_3_compat: bool = False, ) -> list['Route']: """Creates the Starlette Route for the A2A protocol JSON-RPC endpoint. @@ -45,20 +38,12 @@ def create_jsonrpc_routes( # noqa: PLR0913 (SSE). Args: - agent_card: The AgentCard describing the agent's capabilities. request_handler: The handler instance responsible for processing A2A requests via http. - rpc_url: The URL prefix for the RPC endpoints. - extended_agent_card: An optional, distinct AgentCard to be served - at the authenticated extended card endpoint. + rpc_url: The URL prefix for the RPC endpoints. Should start with a leading slash '/'. context_builder: The ServerCallContextBuilder used to construct the ServerCallContext passed to the request_handler. If None the DefaultServerCallContextBuilder is used. - card_modifier: An optional callback to dynamically modify the public - agent card before it is served. - extended_card_modifier: An optional callback to dynamically modify - the extended agent card before it is served. It receives the - call context. enable_v0_3_compat: Whether to enable v0.3 backward compatibility on the same endpoint. """ if not _package_starlette_installed: @@ -69,12 +54,8 @@ def create_jsonrpc_routes( # noqa: PLR0913 ) dispatcher = JsonRpcDispatcher( - agent_card=agent_card, request_handler=request_handler, - extended_agent_card=extended_agent_card, context_builder=context_builder, - card_modifier=card_modifier, - extended_card_modifier=extended_card_modifier, enable_v0_3_compat=enable_v0_3_compat, ) diff --git a/src/a2a/server/routes/rest_dispatcher.py b/src/a2a/server/routes/rest_dispatcher.py index 1f91dd573..fa9a12af8 100644 --- a/src/a2a/server/routes/rest_dispatcher.py +++ b/src/a2a/server/routes/rest_dispatcher.py @@ -14,7 +14,6 @@ ) from a2a.types import a2a_pb2 from a2a.types.a2a_pb2 import ( - AgentCard, CancelTaskRequest, GetTaskPushNotificationConfigRequest, SubscribeToTaskRequest, @@ -25,11 +24,10 @@ rest_stream_error_handler, ) from a2a.utils.errors import ( - ExtendedAgentCardNotConfiguredError, InvalidRequestError, TaskNotFoundError, ) -from a2a.utils.helpers import maybe_await, validate, validate_version +from a2a.utils.helpers import validate_version from a2a.utils.telemetry import SpanKind, trace_class @@ -66,34 +64,18 @@ class RestDispatcher: Handles context building, routing to RequestHandler directly, and response formatting (JSON/SSE). """ - def __init__( # noqa: PLR0913 + def __init__( self, - agent_card: AgentCard, request_handler: RequestHandler, - extended_agent_card: AgentCard | None = None, context_builder: ServerCallContextBuilder | None = None, - card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] - | None = None, - extended_card_modifier: Callable[ - [AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard - ] - | None = None, ) -> None: """Initializes the RestDispatcher. Args: - agent_card: The AgentCard describing the agent's capabilities. request_handler: The underlying `RequestHandler` instance to delegate requests to. - extended_agent_card: An optional, distinct AgentCard to be served - at the authenticated extended card endpoint. context_builder: The ServerCallContextBuilder used to construct the ServerCallContext passed to the request_handler. If None the DefaultServerCallContextBuilder is used. - card_modifier: An optional callback to dynamically modify the public - agent card before it is served. - extended_card_modifier: An optional callback to dynamically modify - the extended agent card before it is served. It receives the - call context. """ if not _package_starlette_installed: raise ImportError( @@ -102,10 +84,6 @@ def __init__( # noqa: PLR0913 'optional dependencies, `a2a-sdk[http-server]`.' ) - self.agent_card = agent_card - self.extended_agent_card = extended_agent_card - self.card_modifier = card_modifier - self.extended_card_modifier = extended_card_modifier self._context_builder = ( context_builder or DefaultServerCallContextBuilder() ) @@ -192,10 +170,6 @@ async def on_message_send_stream( """Handles the 'message/stream' REST method.""" @validate_version(constants.PROTOCOL_VERSION_1_0) - @validate( - lambda _: self.agent_card.capabilities.streaming, - 'Streaming is not supported by the agent', - ) async def _handler( context: ServerCallContext, ) -> AsyncIterator[dict[str, Any]]: @@ -235,10 +209,6 @@ async def on_subscribe_to_task( task_id = request.path_params['id'] @validate_version(constants.PROTOCOL_VERSION_1_0) - @validate( - lambda _: self.agent_card.capabilities.streaming, - 'Streaming is not supported by the agent', - ) async def _handler( context: ServerCallContext, ) -> AsyncIterator[dict[str, Any]]: @@ -312,10 +282,6 @@ async def set_push_notification(self, request: Request) -> Response: """Handles the 'tasks/pushNotificationConfig/set' REST method.""" @validate_version(constants.PROTOCOL_VERSION_1_0) - @validate( - lambda _: self.agent_card.capabilities.push_notifications, - 'Push notifications are not supported by the agent', - ) async def _handler( context: ServerCallContext, ) -> a2a_pb2.TaskPushNotificationConfig: @@ -371,23 +337,16 @@ async def _handler( async def handle_authenticated_agent_card( self, request: Request ) -> Response: - """Handles the 'extendedAgentCard' REST method.""" - if not self.agent_card.capabilities.extended_agent_card: - raise ExtendedAgentCardNotConfiguredError( - message='Authenticated card not supported' - ) - card_to_serve = self.extended_agent_card or self.agent_card + """Handles the 'agentCard' REST method.""" - if self.extended_card_modifier: - context = self._build_call_context(request) - card_to_serve = await maybe_await( - self.extended_card_modifier(card_to_serve, context) + @validate_version(constants.PROTOCOL_VERSION_1_0) + async def _handler( + context: ServerCallContext, + ) -> a2a_pb2.AgentCard: + params = a2a_pb2.GetExtendedAgentCardRequest() + return await self.request_handler.on_get_extended_agent_card( + params, context ) - elif self.card_modifier: - card_to_serve = await maybe_await(self.card_modifier(card_to_serve)) - return JSONResponse( - content=MessageToDict( - card_to_serve, preserving_proto_field_name=True - ) - ) + response = await self._handle_non_streaming(request, _handler) + return JSONResponse(content=MessageToDict(response)) diff --git a/src/a2a/server/routes/rest_routes.py b/src/a2a/server/routes/rest_routes.py index 20a899ca4..2ba8cecfc 100644 --- a/src/a2a/server/routes/rest_routes.py +++ b/src/a2a/server/routes/rest_routes.py @@ -1,16 +1,11 @@ import logging -from collections.abc import Awaitable, Callable from typing import TYPE_CHECKING, Any from a2a.compat.v0_3.rest_adapter import REST03Adapter -from a2a.server.context import ServerCallContext from a2a.server.request_handlers.request_handler import RequestHandler from a2a.server.routes.common import ServerCallContextBuilder from a2a.server.routes.rest_dispatcher import RestDispatcher -from a2a.types.a2a_pb2 import ( - AgentCard, -) if TYPE_CHECKING: @@ -32,36 +27,20 @@ logger = logging.getLogger(__name__) -def create_rest_routes( # noqa: PLR0913 - agent_card: AgentCard, +def create_rest_routes( request_handler: RequestHandler, - extended_agent_card: AgentCard | None = None, context_builder: ServerCallContextBuilder | None = None, - card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] - | None = None, - extended_card_modifier: Callable[ - [AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard - ] - | None = None, enable_v0_3_compat: bool = False, path_prefix: str = '', ) -> list['BaseRoute']: """Creates the Starlette Routes for the A2A protocol REST endpoint. Args: - agent_card: The AgentCard describing the agent's capabilities. request_handler: The handler instance responsible for processing A2A requests via http. - extended_agent_card: An optional, distinct AgentCard to be served - at the authenticated extended card endpoint. context_builder: The ServerCallContextBuilder used to construct the ServerCallContext passed to the request_handler. If None the DefaultServerCallContextBuilder is used. - card_modifier: An optional callback to dynamically modify the public - agent card before it is served. - extended_card_modifier: An optional callback to dynamically modify - the extended agent card before it is served. It receives the - call context. enable_v0_3_compat: If True, mounts backward-compatible v0.3 protocol endpoints using REST03Adapter. path_prefix: The URL prefix for the REST endpoints. @@ -74,23 +53,15 @@ def create_rest_routes( # noqa: PLR0913 ) dispatcher = RestDispatcher( - agent_card=agent_card, request_handler=request_handler, - extended_agent_card=extended_agent_card, context_builder=context_builder, - card_modifier=card_modifier, - extended_card_modifier=extended_card_modifier, ) routes: list[BaseRoute] = [] if enable_v0_3_compat: v03_adapter = REST03Adapter( - agent_card=agent_card, http_handler=request_handler, - extended_agent_card=extended_agent_card, context_builder=context_builder, - card_modifier=card_modifier, - extended_card_modifier=extended_card_modifier, ) v03_routes = v03_adapter.routes() for (path, method), endpoint in v03_routes.items(): diff --git a/src/a2a/utils/helpers.py b/src/a2a/utils/helpers.py index badfde180..ba55da86e 100644 --- a/src/a2a/utils/helpers.py +++ b/src/a2a/utils/helpers.py @@ -24,7 +24,7 @@ TaskStatus, ) from a2a.utils import constants -from a2a.utils.errors import UnsupportedOperationError, VersionNotSupportedError +from a2a.utils.errors import VersionNotSupportedError from a2a.utils.telemetry import trace_function @@ -134,104 +134,6 @@ def build_text_artifact(text: str, artifact_id: str) -> Artifact: return Artifact(parts=[part], artifact_id=artifact_id) -def validate( - expression: Callable[[Any], bool], error_message: str | None = None -) -> Callable: - """Decorator that validates if a given expression evaluates to True. - - Typically used on class methods to check capabilities or configuration - before executing the method's logic. If the expression is False, - an `UnsupportedOperationError` is raised. - - Args: - expression: A callable that takes the instance (`self`) as its argument - and returns a boolean. - error_message: An optional custom error message for the `UnsupportedOperationError`. - If None, the string representation of the expression will be used. - - Examples: - Demonstrating with an async method: - >>> import asyncio - >>> from a2a.utils.errors import UnsupportedOperationError - >>> - >>> class MyAgent: - ... def __init__(self, streaming_enabled: bool): - ... self.streaming_enabled = streaming_enabled - ... - ... @validate( - ... lambda self: self.streaming_enabled, - ... 'Streaming is not enabled for this agent', - ... ) - ... async def stream_response(self, message: str): - ... return f'Streaming: {message}' - >>> - >>> async def run_async_test(): - ... # Successful call - ... agent_ok = MyAgent(streaming_enabled=True) - ... result = await agent_ok.stream_response('hello') - ... print(result) - ... - ... # Call that fails validation - ... agent_fail = MyAgent(streaming_enabled=False) - ... try: - ... await agent_fail.stream_response('world') - ... except UnsupportedOperationError as e: - ... print(e.message) - >>> - >>> asyncio.run(run_async_test()) - Streaming: hello - Streaming is not enabled for this agent - - Demonstrating with a sync method: - >>> class SecureAgent: - ... def __init__(self): - ... self.auth_enabled = False - ... - ... @validate( - ... lambda self: self.auth_enabled, - ... 'Authentication must be enabled for this operation', - ... ) - ... def secure_operation(self, data: str): - ... return f'Processing secure data: {data}' - >>> - >>> # Error case example - >>> agent = SecureAgent() - >>> try: - ... agent.secure_operation('secret') - ... except UnsupportedOperationError as e: - ... print(e.message) - Authentication must be enabled for this operation - - Note: - This decorator works with both sync and async methods automatically. - """ - - def decorator(function: Callable) -> Callable: - if inspect.iscoroutinefunction(function): - - @functools.wraps(function) - async def async_wrapper(self: Any, *args, **kwargs) -> Any: - if not expression(self): - final_message = error_message or str(expression) - logger.error('Unsupported Operation: %s', final_message) - raise UnsupportedOperationError(message=final_message) - return await function(self, *args, **kwargs) - - return async_wrapper - - @functools.wraps(function) - def sync_wrapper(self: Any, *args, **kwargs) -> Any: - if not expression(self): - final_message = error_message or str(expression) - logger.error('Unsupported Operation: %s', final_message) - raise UnsupportedOperationError(message=final_message) - return function(self, *args, **kwargs) - - return sync_wrapper - - return decorator - - def are_modalities_compatible( server_output_modes: list[str] | None, client_output_modes: list[str] | None ) -> bool: diff --git a/tck/sut_agent.py b/tck/sut_agent.py index 259b16a5d..96eca850f 100644 --- a/tck/sut_agent.py +++ b/tck/sut_agent.py @@ -193,13 +193,13 @@ def serve(task_store: TaskStore) -> None: ) request_handler = DefaultRequestHandler( + agent_card=agent_card, agent_executor=SUTAgentExecutor(), task_store=task_store, ) # JSONRPC jsonrpc_routes = create_jsonrpc_routes( - agent_card=agent_card, request_handler=request_handler, rpc_url=JSONRPC_URL, ) @@ -209,7 +209,6 @@ def serve(task_store: TaskStore) -> None: ) # REST rest_routes = create_rest_routes( - agent_card=agent_card, request_handler=request_handler, path_prefix=REST_URL, ) @@ -229,8 +228,8 @@ def serve(task_store: TaskStore) -> None: # GRPC grpc_server = grpc.aio.server() grpc_server.add_insecure_port(f'[::]:{grpc_port}') - servicer = GrpcHandler(agent_card, request_handler) - compat_servicer = CompatGrpcHandler(agent_card, request_handler) + servicer = GrpcHandler(request_handler) + compat_servicer = CompatGrpcHandler(request_handler) a2a_grpc.add_A2AServiceServicer_to_server(servicer, grpc_server) a2a_v0_3_grpc.add_A2AServiceServicer_to_server(compat_servicer, grpc_server) diff --git a/tests/compat/v0_3/test_grpc_handler.py b/tests/compat/v0_3/test_grpc_handler.py index 9040388e2..75c6421e8 100644 --- a/tests/compat/v0_3/test_grpc_handler.py +++ b/tests/compat/v0_3/test_grpc_handler.py @@ -37,6 +37,7 @@ def sample_agent_card() -> a2a_pb2.AgentCard: capabilities=a2a_pb2.AgentCapabilities( streaming=True, push_notifications=True, + extended_agent_card=True, ), supported_interfaces=[ a2a_pb2.AgentInterface( @@ -53,7 +54,7 @@ def handler( mock_request_handler: AsyncMock, sample_agent_card: a2a_pb2.AgentCard ) -> compat_grpc_handler.CompatGrpcHandler: return compat_grpc_handler.CompatGrpcHandler( - agent_card=sample_agent_card, request_handler=mock_request_handler + request_handler=mock_request_handler, ) @@ -437,9 +438,15 @@ async def test_list_push_config_success( @pytest.mark.asyncio async def test_get_agent_card_success( handler: compat_grpc_handler.CompatGrpcHandler, + mock_request_handler: AsyncMock, mock_grpc_context: AsyncMock, + sample_agent_card: a2a_pb2.AgentCard, ) -> None: request = a2a_v0_3_pb2.GetAgentCardRequest() + mock_request_handler.on_get_extended_agent_card.return_value = ( + sample_agent_card + ) + response = await handler.GetAgentCard(request, mock_grpc_context) expected_res = a2a_v0_3_pb2.AgentCard( @@ -448,6 +455,7 @@ async def test_get_agent_card_success( url='http://jsonrpc.v03.com', version='1.0.0', protocol_version='0.3', + supports_authenticated_extended_card=True, preferred_transport='JSONRPC', capabilities=a2a_v0_3_pb2.AgentCapabilities( streaming=True, diff --git a/tests/compat/v0_3/test_jsonrpc_app_compat.py b/tests/compat/v0_3/test_jsonrpc_app_compat.py index 1417b5dac..6658097dc 100644 --- a/tests/compat/v0_3/test_jsonrpc_app_compat.py +++ b/tests/compat/v0_3/test_jsonrpc_app_compat.py @@ -46,8 +46,8 @@ def mock_handler(): @pytest.fixture -def test_app(mock_handler): - agent_card = AgentCard( +def agent_card(): + card = AgentCard( name='TestAgent', description='Test Description', version='1.0.0', @@ -55,13 +55,17 @@ def test_app(mock_handler): streaming=False, push_notifications=True, extended_agent_card=True ), ) - interface = agent_card.supported_interfaces.add() + interface = card.supported_interfaces.add() interface.url = 'http://mockurl.com' interface.protocol_binding = 'jsonrpc' interface.protocol_version = '0.3' + return card + +@pytest.fixture +def test_app(mock_handler, agent_card): + mock_handler._agent_card = agent_card jsonrpc_routes = create_jsonrpc_routes( - agent_card=agent_card, request_handler=mock_handler, enable_v0_3_compat=True, rpc_url='/', @@ -123,9 +127,10 @@ def test_get_task_v03_compat( def test_get_extended_agent_card_v03_compat( - client: TestClient, + client: TestClient, mock_handler: AsyncMock, agent_card: AgentCard ) -> None: """Test that the v0.3 method name 'agent/getAuthenticatedExtendedCard' is correctly routed.""" + mock_handler.on_get_extended_agent_card.return_value = agent_card request_payload = { 'jsonrpc': '2.0', 'id': '3', diff --git a/tests/compat/v0_3/test_request_handler.py b/tests/compat/v0_3/test_request_handler.py index 55b0d2cab..26ad74264 100644 --- a/tests/compat/v0_3/test_request_handler.py +++ b/tests/compat/v0_3/test_request_handler.py @@ -7,24 +7,15 @@ from a2a.server.context import ServerCallContext from a2a.server.request_handlers.request_handler import RequestHandler from a2a.types.a2a_pb2 import ( + AgentCapabilities, + AgentCard, + AgentInterface, ListTaskPushNotificationConfigsResponse as V10ListPushConfigsResp, -) -from a2a.types.a2a_pb2 import ( Message as V10Message, -) -from a2a.types.a2a_pb2 import ( Part as V10Part, -) -from a2a.types.a2a_pb2 import ( Task as V10Task, -) -from a2a.types.a2a_pb2 import ( TaskPushNotificationConfig as V10PushConfig, -) -from a2a.types.a2a_pb2 import ( TaskState as V10TaskState, -) -from a2a.types.a2a_pb2 import ( TaskStatus as V10TaskStatus, ) from a2a.utils.errors import TaskNotFoundError @@ -32,7 +23,16 @@ @pytest.fixture def mock_core_handler(): - return AsyncMock(spec=RequestHandler) + handler = AsyncMock(spec=RequestHandler) + + handler.agent_card = AgentCard( + capabilities=AgentCapabilities( + streaming=True, + push_notifications=True, + extended_agent_card=True, + ) + ) + return handler @pytest.fixture @@ -355,3 +355,35 @@ async def test_on_delete_task_push_notification_config( assert result is None mock_core_handler.on_delete_task_push_notification_config.assert_called_once() + + +@pytest.mark.anyio +async def test_on_get_extended_agent_card_success( + v03_handler, mock_core_handler, mock_context +): + v03_req = types_v03.GetAuthenticatedExtendedCardRequest(id=0) + + mock_core_handler.on_get_extended_agent_card.return_value = AgentCard( + name='Extended Agent', + description='An extended test agent', + version='1.0.0', + supported_interfaces=[ + AgentInterface( + url='http://jsonrpc.v03.com', + protocol_version='0.3', + ) + ], + capabilities=AgentCapabilities( + streaming=True, + push_notifications=True, + extended_agent_card=True, + ), + ) + + result = await v03_handler.on_get_extended_agent_card(v03_req, mock_context) + + assert isinstance(result, types_v03.AgentCard) + assert result.name == 'Extended Agent' + assert result.capabilities.streaming is True + assert result.capabilities.push_notifications is True + mock_core_handler.on_get_extended_agent_card.assert_called_once() diff --git a/tests/compat/v0_3/test_rest_handler.py b/tests/compat/v0_3/test_rest_handler.py index f864b7037..6ff44abb1 100644 --- a/tests/compat/v0_3/test_rest_handler.py +++ b/tests/compat/v0_3/test_rest_handler.py @@ -27,9 +27,7 @@ def agent_card(): @pytest.fixture def rest_handler(agent_card, mock_core_handler): - handler = REST03Handler( - agent_card=agent_card, request_handler=mock_core_handler - ) + handler = REST03Handler(request_handler=mock_core_handler) # Mock the internal handler03 for easier testing of translations handler.handler03 = AsyncMock() return handler @@ -363,3 +361,39 @@ async def test_list_push_notifications( async def test_list_tasks(rest_handler, mock_request, mock_context): with pytest.raises(NotImplementedError): await rest_handler.list_tasks(mock_request, mock_context) + + +# Add our new translation method test +@pytest.mark.anyio +async def test_on_get_extended_agent_card_success( + rest_handler, mock_request, mock_context +): + rest_handler.handler03.on_get_extended_agent_card.return_value = ( + types_v03.AgentCard( + name='Extended Agent', + description='An extended test agent', + version='1.0.0', + url='http://jsonrpc.v03.com', + preferred_transport='JSONRPC', + protocol_version='0.3', + default_input_modes=[], + default_output_modes=[], + skills=[], + capabilities=types_v03.AgentCapabilities( + streaming=True, + push_notifications=True, + ), + ) + ) + + result = await rest_handler.on_get_extended_agent_card( + mock_request, mock_context + ) + + # on_get_extended_agent_card returns a JSON-friendly dict via model_dump + assert isinstance(result, dict) + assert result['name'] == 'Extended Agent' + assert result['capabilities']['streaming'] is True + assert result['capabilities']['pushNotifications'] is True + + rest_handler.handler03.on_get_extended_agent_card.assert_called_once() diff --git a/tests/compat/v0_3/test_rest_routes_compat.py b/tests/compat/v0_3/test_rest_routes_compat.py index 5ee0f60ca..b3b9e70b3 100644 --- a/tests/compat/v0_3/test_rest_routes_compat.py +++ b/tests/compat/v0_3/test_rest_routes_compat.py @@ -53,8 +53,9 @@ async def app( request_handler: RequestHandler, ) -> Starlette: """Builds the Starlette application for testing.""" + request_handler._agent_card = agent_card rest_routes = create_rest_routes( - agent_card, request_handler, enable_v0_3_compat=True + request_handler=request_handler, enable_v0_3_compat=True ) agent_card_routes = create_agent_card_routes( agent_card=agent_card, card_url='/well-known/agent.json' diff --git a/tests/e2e/push_notifications/agent_app.py b/tests/e2e/push_notifications/agent_app.py index 94ccae03a..106a97cea 100644 --- a/tests/e2e/push_notifications/agent_app.py +++ b/tests/e2e/push_notifications/agent_app.py @@ -142,9 +142,13 @@ def create_agent_app( """Creates a new HTTP+REST Starlette application for the test agent.""" push_config_store = InMemoryPushNotificationConfigStore() card = test_agent_card(url) + extended_card = test_agent_card(url) + extended_card.name = 'Test Agent Extended' handler = DefaultRequestHandler( agent_executor=TestAgentExecutor(), task_store=InMemoryTaskStore(), + agent_card=card, + extended_agent_card=extended_card, push_config_store=push_config_store, push_sender=BasePushNotificationSender( httpx_client=notification_client, @@ -152,7 +156,7 @@ def create_agent_app( context=ServerCallContext(), ), ) - rest_routes = create_rest_routes(agent_card=card, request_handler=handler) + rest_routes = create_rest_routes(request_handler=handler) agent_card_routes = create_agent_card_routes( agent_card=card, card_url='/.well-known/agent-card.json' ) diff --git a/tests/e2e/push_notifications/test_default_push_notification_support.py b/tests/e2e/push_notifications/test_default_push_notification_support.py index 053707d62..3d8d92481 100644 --- a/tests/e2e/push_notifications/test_default_push_notification_support.py +++ b/tests/e2e/push_notifications/test_default_push_notification_support.py @@ -75,7 +75,9 @@ def agent_server(notifications_client: httpx.AsyncClient): ) process.start() try: - wait_for_server_ready(f'{url}/extendedAgentCard') + wait_for_server_ready( + f'{url}/extendedAgentCard', headers={'A2A-Version': '1.0'} + ) except TimeoutError as e: process.terminate() raise e diff --git a/tests/e2e/push_notifications/utils.py b/tests/e2e/push_notifications/utils.py index 2934ecc58..a7317f1b2 100644 --- a/tests/e2e/push_notifications/utils.py +++ b/tests/e2e/push_notifications/utils.py @@ -20,12 +20,14 @@ def run_server(app, host, port) -> None: uvicorn.run(app, host=host, port=port, log_level='warning') -def wait_for_server_ready(url: str, timeout: int = 10) -> None: +def wait_for_server_ready( + url: str, timeout: int = 10, headers: dict | None = None +) -> None: """Polls the provided URL endpoint until the server is up.""" start_time = time.time() while True: with contextlib.suppress(httpx.ConnectError): - with httpx.Client() as client: + with httpx.Client(headers=headers) as client: response = client.get(url) if response.status_code == 200: return diff --git a/tests/integration/cross_version/client_server/server_1_0.py b/tests/integration/cross_version/client_server/server_1_0.py index 74e0bc23b..e11b1d69d 100644 --- a/tests/integration/cross_version/client_server/server_1_0.py +++ b/tests/integration/cross_version/client_server/server_1_0.py @@ -158,10 +158,12 @@ async def main_async(http_port: int, grpc_port: int): task_store = InMemoryTaskStore() handler = DefaultRequestHandler( - agent_executor=MockAgentExecutor(), - task_store=task_store, + MockAgentExecutor(), + task_store, + agent_card, queue_manager=InMemoryQueueManager(), push_config_store=InMemoryPushNotificationConfigStore(), + extended_agent_card=agent_card, ) app = FastAPI() @@ -171,9 +173,7 @@ async def main_async(http_port: int, grpc_port: int): agent_card=agent_card, card_url='/.well-known/agent-card.json' ) jsonrpc_routes = create_jsonrpc_routes( - agent_card=agent_card, request_handler=handler, - extended_agent_card=agent_card, rpc_url='/', enable_v0_3_compat=True, ) @@ -183,7 +183,6 @@ async def main_async(http_port: int, grpc_port: int): ) rest_routes = create_rest_routes( - agent_card=agent_card, request_handler=handler, enable_v0_3_compat=True, ) @@ -194,10 +193,10 @@ async def main_async(http_port: int, grpc_port: int): # Start gRPC Server server = grpc.aio.server() - servicer = GrpcHandler(agent_card, handler) + servicer = GrpcHandler(handler) a2a_pb2_grpc.add_A2AServiceServicer_to_server(servicer, server) - compat_servicer = CompatGrpcHandler(agent_card, handler) + compat_servicer = CompatGrpcHandler(handler) a2a_v0_3_pb2_grpc.add_A2AServiceServicer_to_server(compat_servicer, server) server.add_insecure_port(f'127.0.0.1:{grpc_port}') diff --git a/tests/integration/test_agent_card.py b/tests/integration/test_agent_card.py index 494fd151c..afa1078f0 100644 --- a/tests/integration/test_agent_card.py +++ b/tests/integration/test_agent_card.py @@ -66,6 +66,7 @@ async def test_agent_card_integration(header_val: str | None) -> None: handler = DefaultRequestHandler( agent_executor=DummyAgentExecutor(), task_store=task_store, + agent_card=agent_card, queue_manager=InMemoryQueueManager(), push_config_store=InMemoryPushNotificationConfigStore(), ) @@ -76,9 +77,7 @@ async def test_agent_card_integration(header_val: str | None) -> None: *create_agent_card_routes( agent_card=agent_card, card_url='/.well-known/agent-card.json' ), - *create_jsonrpc_routes( - agent_card=agent_card, request_handler=handler, rpc_url='/' - ), + *create_jsonrpc_routes(request_handler=handler, rpc_url='/'), ] jsonrpc_app = Starlette(routes=jsonrpc_routes) app.mount('/jsonrpc', jsonrpc_app) @@ -87,7 +86,7 @@ async def test_agent_card_integration(header_val: str | None) -> None: *create_agent_card_routes( agent_card=agent_card, card_url='/.well-known/agent-card.json' ), - *create_rest_routes(agent_card=agent_card, request_handler=handler), + *create_rest_routes(request_handler=handler), ] rest_app = Starlette(routes=rest_routes) app.mount('/rest', rest_app) diff --git a/tests/integration/test_client_server_integration.py b/tests/integration/test_client_server_integration.py index 59d9995c2..36565205a 100644 --- a/tests/integration/test_client_server_integration.py +++ b/tests/integration/test_client_server_integration.py @@ -34,6 +34,9 @@ create_jsonrpc_routes, create_rest_routes, ) +from a2a.server.request_handlers.default_request_handler import ( + LegacyRequestHandler, +) from a2a.types import a2a_pb2_grpc from a2a.types.a2a_pb2 import ( AgentCapabilities, @@ -141,11 +144,12 @@ def key_provider(kid: str | None, jku: str | None): @pytest.fixture -def mock_request_handler() -> AsyncMock: +def mock_request_handler(agent_card) -> AsyncMock: """Provides a mock RequestHandler for the server-side handlers.""" handler = AsyncMock(spec=RequestHandler) # Configure on_message_send for non-streaming calls + handler._agent_card = agent_card handler.on_message_send.return_value = TASK_FROM_BLOCKING # Configure on_message_send_stream for streaming calls @@ -167,6 +171,14 @@ async def stream_side_effect(*args, **kwargs): ) handler.on_delete_task_push_notification_config.return_value = None + # Use async def to ensure it returns an awaitable + async def get_extended_agent_card_mock(*args, **kwargs): + return agent_card + + handler.on_get_extended_agent_card.side_effect = ( + get_extended_agent_card_mock # type: ignore[union-attr] + ) + async def resubscribe_side_effect(*args, **kwargs): yield RESUBSCRIBE_EVENT @@ -219,7 +231,7 @@ def http_base_setup(mock_request_handler: AsyncMock, agent_card: AgentCard): """A base fixture to patch the sse-starlette event loop issue.""" from sse_starlette import sse - sse.AppStatus.should_exit_event = asyncio.Event() # type: ignore[attr-defined] + sse.AppStatus.should_exit_event = asyncio.Event() yield mock_request_handler, agent_card @@ -231,10 +243,7 @@ def jsonrpc_setup(http_base_setup) -> TransportSetup: agent_card=agent_card, card_url='/' ) jsonrpc_routes = create_jsonrpc_routes( - agent_card=agent_card, - request_handler=mock_request_handler, - extended_agent_card=agent_card, - rpc_url='/', + request_handler=mock_request_handler, rpc_url='/' ) app = Starlette(routes=[*agent_card_routes, *jsonrpc_routes]) httpx_client = httpx.AsyncClient(transport=httpx.ASGITransport(app=app)) @@ -252,9 +261,7 @@ def jsonrpc_setup(http_base_setup) -> TransportSetup: def rest_setup(http_base_setup) -> TransportSetup: """Sets up the RestTransport and in-memory server.""" mock_request_handler, agent_card = http_base_setup - rest_routes = create_rest_routes( - agent_card, mock_request_handler, extended_agent_card=agent_card - ) + rest_routes = create_rest_routes(mock_request_handler) agent_card_routes = create_agent_card_routes( agent_card=agent_card, card_url='/' ) @@ -343,7 +350,7 @@ async def grpc_server_and_handler( server = grpc.aio.server() port = server.add_insecure_port('[::]:0') server_address = f'localhost:{port}' - servicer = GrpcHandler(agent_card, mock_request_handler) + servicer = GrpcHandler(request_handler=mock_request_handler) a2a_pb2_grpc.add_A2AServiceServicer_to_server(servicer, server) await server.start() try: @@ -360,7 +367,9 @@ async def grpc_03_server_and_handler( server = grpc.aio.server() port = server.add_insecure_port('[::]:0') server_address = f'localhost:{port}' - servicer = CompatGrpcHandler(agent_card, mock_request_handler) + servicer = CompatGrpcHandler( + request_handler=mock_request_handler, + ) a2a_v0_3_pb2_grpc.add_A2AServiceServicer_to_server(servicer, server) await server.start() try: @@ -704,10 +713,7 @@ async def test_json_transport_get_signed_base_card( agent_card=agent_card, card_url='/', card_modifier=signer ) jsonrpc_routes = create_jsonrpc_routes( - agent_card=agent_card, - request_handler=mock_request_handler, - extended_agent_card=agent_card, - rpc_url='/', + request_handler=mock_request_handler, rpc_url='/' ) app = Starlette(routes=[*agent_card_routes, *jsonrpc_routes]) httpx_client = httpx.AsyncClient( @@ -764,7 +770,7 @@ async def test_client_get_signed_extended_card( private_key = ec.generate_private_key(ec.SECP256R1()) public_key = private_key.public_key() signer = create_agent_card_signer( - signing_key=private_key, # type: ignore[arg-type] + signing_key=private_key, protected_header={ 'alg': 'ES256', 'kid': 'testkey', @@ -773,15 +779,18 @@ async def test_client_get_signed_extended_card( }, ) + async def get_extended_agent_card_mock_2(*args, **kwargs) -> AgentCard: + return signer(extended_agent_card) + + mock_request_handler.on_get_extended_agent_card.side_effect = ( + get_extended_agent_card_mock_2 # type: ignore[union-attr] + ) + agent_card_routes = create_agent_card_routes( agent_card=agent_card, card_url='/' ) jsonrpc_routes = create_jsonrpc_routes( - agent_card=agent_card, - request_handler=mock_request_handler, - extended_agent_card=extended_agent_card, - extended_card_modifier=lambda card, ctx: signer(card), - rpc_url='/', + request_handler=mock_request_handler, rpc_url='/' ) app = Starlette(routes=[*agent_card_routes, *jsonrpc_routes]) httpx_client = httpx.AsyncClient( @@ -837,7 +846,7 @@ async def test_client_get_signed_base_and_extended_cards( private_key = ec.generate_private_key(ec.SECP256R1()) public_key = private_key.public_key() signer = create_agent_card_signer( - signing_key=private_key, # type: ignore[arg-type] + signing_key=private_key, protected_header={ 'alg': 'ES256', 'kid': 'testkey', @@ -845,16 +854,20 @@ async def test_client_get_signed_base_and_extended_cards( 'typ': 'JOSE', }, ) + signer(extended_agent_card) + # Use async def to ensure it returns an awaitable + async def get_extended_agent_card_mock_3(*args, **kwargs): + return extended_agent_card + + mock_request_handler.on_get_extended_agent_card.side_effect = ( + get_extended_agent_card_mock_3 # type: ignore[union-attr] + ) agent_card_routes = create_agent_card_routes( agent_card=agent_card, card_url='/', card_modifier=signer ) jsonrpc_routes = create_jsonrpc_routes( - agent_card=agent_card, - request_handler=mock_request_handler, - extended_agent_card=extended_agent_card, - extended_card_modifier=lambda card, ctx: signer(card), - rpc_url='/', + request_handler=mock_request_handler, rpc_url='/' ) app = Starlette(routes=[*agent_card_routes, *jsonrpc_routes]) httpx_client = httpx.AsyncClient( @@ -1116,11 +1129,21 @@ async def test_validate_decorator_push_notifications_disabled( """Integration test for @validate decorator with push notifications disabled.""" client = error_handling_setups.client - agent_card.capabilities.push_notifications = False + real_handler = LegacyRequestHandler( + agent_executor=AsyncMock(), + task_store=AsyncMock(), + agent_card=agent_card, + ) - params = TaskPushNotificationConfig(task_id='123') + error_handling_setups.handler.on_create_task_push_notification_config.side_effect = real_handler.on_create_task_push_notification_config - with pytest.raises(UnsupportedOperationError): + params = TaskPushNotificationConfig( + task_id='123', + id='pnc-123', + url='http://example.com', + ) + + with pytest.raises(PushNotificationNotSupportedError): await client.create_task_push_notification_config(request=params) await client.close() @@ -1136,8 +1159,25 @@ async def test_validate_streaming_disabled( agent_card.capabilities.streaming = False + real_handler = LegacyRequestHandler( + agent_executor=AsyncMock(), + task_store=AsyncMock(), + agent_card=agent_card, + ) + + error_handling_setups.handler.on_message_send_stream.side_effect = ( + real_handler.on_message_send_stream + ) + error_handling_setups.handler.on_subscribe_to_task.side_effect = ( + real_handler.on_subscribe_to_task + ) + params = SendMessageRequest( - message=Message(role=Role.ROLE_USER, parts=[Part(text='hi')]) + message=Message( + role=Role.ROLE_USER, + parts=[Part(text='hi')], + message_id='msg-123', + ) ) stream = transport.send_message_streaming(request=params) diff --git a/tests/integration/test_copying_observability.py b/tests/integration/test_copying_observability.py index a207c9b24..d5171097a 100644 --- a/tests/integration/test_copying_observability.py +++ b/tests/integration/test_copying_observability.py @@ -94,15 +94,15 @@ def setup_client(agent_card: AgentCard, use_copying: bool) -> ClientSetup: handler = DefaultRequestHandler( agent_executor=MockMutatingAgentExecutor(), task_store=task_store, + agent_card=agent_card, queue_manager=InMemoryQueueManager(), + extended_agent_card=agent_card, ) agent_card_routes = create_agent_card_routes( agent_card=agent_card, card_url='/' ) jsonrpc_routes = create_jsonrpc_routes( - agent_card=agent_card, request_handler=handler, - extended_agent_card=agent_card, rpc_url='/', ) app = Starlette(routes=[*agent_card_routes, *jsonrpc_routes]) diff --git a/tests/integration/test_end_to_end.py b/tests/integration/test_end_to_end.py index 4987acdb5..1043a7d72 100644 --- a/tests/integration/test_end_to_end.py +++ b/tests/integration/test_end_to_end.py @@ -166,11 +166,12 @@ class ClientSetup(NamedTuple): @pytest.fixture -def base_e2e_setup(): +def base_e2e_setup(agent_card): task_store = InMemoryTaskStore() handler = DefaultRequestHandler( agent_executor=MockAgentExecutor(), task_store=task_store, + agent_card=agent_card, queue_manager=InMemoryQueueManager(), ) return task_store, handler @@ -179,9 +180,7 @@ def base_e2e_setup(): @pytest.fixture def rest_setup(agent_card, base_e2e_setup) -> ClientSetup: task_store, handler = base_e2e_setup - rest_routes = create_rest_routes( - agent_card=agent_card, request_handler=handler - ) + rest_routes = create_rest_routes(request_handler=handler) agent_card_routes = create_agent_card_routes( agent_card=agent_card, card_url='/' ) @@ -209,9 +208,7 @@ def jsonrpc_setup(agent_card, base_e2e_setup) -> ClientSetup: agent_card=agent_card, card_url='/' ) jsonrpc_routes = create_jsonrpc_routes( - agent_card=agent_card, request_handler=handler, - extended_agent_card=agent_card, rpc_url='/', ) app = Starlette(routes=[*agent_card_routes, *jsonrpc_routes]) @@ -250,8 +247,8 @@ async def grpc_setup( break else: raise ValueError('No gRPC interface found in agent card') - - servicer = GrpcHandler(grpc_agent_card, handler) + handler._agent_card = grpc_agent_card + servicer = GrpcHandler(handler) a2a_pb2_grpc.add_A2AServiceServicer_to_server(servicer, server) await server.start() diff --git a/tests/integration/test_scenarios.py b/tests/integration/test_scenarios.py index a7d85a28c..1e2253430 100644 --- a/tests/integration/test_scenarios.py +++ b/tests/integration/test_scenarios.py @@ -141,7 +141,7 @@ async def create_client(handler, agent_card, streaming=False): agent_card.supported_interfaces[0].protocol_binding = TransportProtocol.GRPC servicer = GrpcHandler( - agent_card, handler, context_builder=MockCallContextBuilder() + request_handler=handler, context_builder=MockCallContextBuilder() ) a2a_pb2_grpc.add_A2AServiceServicer_to_server(servicer, server) await server.start() @@ -165,9 +165,19 @@ def create_handler( task_store = task_store or InMemoryTaskStore() queue_manager = queue_manager or InMemoryQueueManager() return ( - LegacyRequestHandler(agent_executor, task_store, queue_manager) + LegacyRequestHandler( + agent_executor, + task_store, + agent_card(), + queue_manager, + ) if use_legacy - else DefaultRequestHandlerV2(agent_executor, task_store, queue_manager) + else DefaultRequestHandlerV2( + agent_executor, + task_store, + agent_card(), + queue_manager, + ) ) diff --git a/tests/integration/test_stream_generator_cleanup.py b/tests/integration/test_stream_generator_cleanup.py index 47ab5212f..f26f62c6f 100644 --- a/tests/integration/test_stream_generator_cleanup.py +++ b/tests/integration/test_stream_generator_cleanup.py @@ -75,15 +75,14 @@ def client(): handler = DefaultRequestHandler( agent_executor=_MessageExecutor(), task_store=InMemoryTaskStore(), + agent_card=card, queue_manager=InMemoryQueueManager(), ) app = Starlette( routes=[ *create_agent_card_routes(agent_card=card, card_url='/card'), *create_jsonrpc_routes( - agent_card=card, request_handler=handler, - extended_agent_card=card, rpc_url='/', ), ] diff --git a/tests/integration/test_tenant.py b/tests/integration/test_tenant.py index 6ceb1e070..6b489270b 100644 --- a/tests/integration/test_tenant.py +++ b/tests/integration/test_tenant.py @@ -202,9 +202,7 @@ def server_app(self, jsonrpc_agent_card, mock_handler): agent_card=jsonrpc_agent_card, card_url='/' ) jsonrpc_routes = create_jsonrpc_routes( - agent_card=jsonrpc_agent_card, request_handler=mock_handler, - extended_agent_card=jsonrpc_agent_card, rpc_url='/jsonrpc', ) app = Starlette(routes=[*agent_card_routes, *jsonrpc_routes]) diff --git a/tests/integration/test_version_header.py b/tests/integration/test_version_header.py index 683c56833..046f4d4cc 100644 --- a/tests/integration/test_version_header.py +++ b/tests/integration/test_version_header.py @@ -39,6 +39,7 @@ def test_app(): handler = DefaultRequestHandler( agent_executor=DummyAgentExecutor(), task_store=InMemoryTaskStore(), + agent_card=agent_card, queue_manager=InMemoryQueueManager(), push_config_store=InMemoryPushNotificationConfigStore(), ) @@ -61,19 +62,13 @@ async def mock_on_message_send_stream(*args, **kwargs): agent_card=agent_card, card_url='/' ) jsonrpc_routes = create_jsonrpc_routes( - agent_card=agent_card, - request_handler=handler, - rpc_url='/jsonrpc', - enable_v0_3_compat=True, + request_handler=handler, rpc_url='/jsonrpc', enable_v0_3_compat=True ) app.routes.extend(agent_card_routes) app.routes.extend(jsonrpc_routes) rest_routes = create_rest_routes( - agent_card=agent_card, - request_handler=handler, - path_prefix='/rest', - enable_v0_3_compat=True, + request_handler=handler, path_prefix='/rest', enable_v0_3_compat=True ) app.routes.extend(rest_routes) return app @@ -98,7 +93,7 @@ def client(test_app): ('INVALID', 'none'), ], ) -def test_version_header_integration( # noqa: PLR0912, PLR0913, PLR0915 +def test_version_header_integration( client, transport, endpoint_ver, is_streaming, header_val, should_succeed ): headers = {} diff --git a/tests/server/request_handlers/test_default_request_handler.py b/tests/server/request_handlers/test_default_request_handler.py index 68945d06d..59e965116 100644 --- a/tests/server/request_handlers/test_default_request_handler.py +++ b/tests/server/request_handlers/test_default_request_handler.py @@ -36,6 +36,7 @@ TaskUpdater, ) from a2a.types import ( + ExtendedAgentCardNotConfiguredError, InternalError, InvalidParamsError, PushNotificationNotSupportedError, @@ -44,10 +45,13 @@ UnsupportedOperationError, ) from a2a.types.a2a_pb2 import ( + AgentCapabilities, + AgentCard, Artifact, CancelTaskRequest, DeleteTaskPushNotificationConfigRequest, GetTaskPushNotificationConfigRequest, + GetExtendedAgentCardRequest, GetTaskRequest, ListTaskPushNotificationConfigsRequest, ListTasksRequest, @@ -113,13 +117,25 @@ def create_server_call_context() -> ServerCallContext: return ServerCallContext(user=UnauthenticatedUser()) -def test_init_default_dependencies(): +@pytest.fixture +def agent_card(): + """Provides a standard AgentCard with streaming and push notifications enabled for tests.""" + return AgentCard( + name='test_agent', + version='1.0', + capabilities=AgentCapabilities(streaming=True, push_notifications=True), + ) + + +def test_init_default_dependencies(agent_card): """Test that default dependencies are created if not provided.""" agent_executor = MockAgentExecutor() task_store = InMemoryTaskStore() handler = DefaultRequestHandler( - agent_executor=agent_executor, task_store=task_store + agent_executor=agent_executor, + task_store=task_store, + agent_card=agent_card, ) assert isinstance(handler._queue_manager, InMemoryQueueManager) @@ -136,13 +152,15 @@ def test_init_default_dependencies(): @pytest.mark.asyncio -async def test_on_get_task_not_found(): +async def test_on_get_task_not_found(agent_card): """Test on_get_task when task_store.get returns None.""" mock_task_store = AsyncMock(spec=TaskStore) mock_task_store.get.return_value = None request_handler = DefaultRequestHandler( - agent_executor=MockAgentExecutor(), task_store=mock_task_store + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + agent_card=agent_card, ) params = GetTaskRequest(id='non_existent_task') @@ -155,7 +173,7 @@ async def test_on_get_task_not_found(): @pytest.mark.asyncio -async def test_on_list_tasks_success(): +async def test_on_list_tasks_success(agent_card): """Test on_list_tasks successfully returns a page of tasks .""" mock_task_store = AsyncMock(spec=TaskStore) task2 = create_sample_task(task_id='task2') @@ -177,7 +195,9 @@ async def test_on_list_tasks_success(): ) mock_task_store.list.return_value = mock_page request_handler = DefaultRequestHandler( - agent_executor=AsyncMock(spec=AgentExecutor), task_store=mock_task_store + agent_executor=AsyncMock(spec=AgentExecutor), + task_store=mock_task_store, + agent_card=agent_card, ) params = ListTasksRequest(include_artifacts=True, page_size=10) context = create_server_call_context() @@ -190,7 +210,7 @@ async def test_on_list_tasks_success(): @pytest.mark.asyncio -async def test_on_list_tasks_excludes_artifacts(): +async def test_on_list_tasks_excludes_artifacts(agent_card): """Test on_list_tasks excludes artifacts from returned tasks.""" mock_task_store = AsyncMock(spec=TaskStore) task2 = create_sample_task(task_id='task2') @@ -212,7 +232,9 @@ async def test_on_list_tasks_excludes_artifacts(): ) mock_task_store.list.return_value = mock_page request_handler = DefaultRequestHandler( - agent_executor=AsyncMock(spec=AgentExecutor), task_store=mock_task_store + agent_executor=AsyncMock(spec=AgentExecutor), + task_store=mock_task_store, + agent_card=agent_card, ) params = ListTasksRequest(include_artifacts=False, page_size=10) context = create_server_call_context() @@ -223,7 +245,7 @@ async def test_on_list_tasks_excludes_artifacts(): @pytest.mark.asyncio -async def test_on_list_tasks_applies_history_length(): +async def test_on_list_tasks_applies_history_length(agent_card): """Test on_list_tasks applies history length filter.""" mock_task_store = AsyncMock(spec=TaskStore) history = [ @@ -241,7 +263,9 @@ async def test_on_list_tasks_applies_history_length(): ) mock_task_store.list.return_value = mock_page request_handler = DefaultRequestHandler( - agent_executor=AsyncMock(spec=AgentExecutor), task_store=mock_task_store + agent_executor=AsyncMock(spec=AgentExecutor), + task_store=mock_task_store, + agent_card=agent_card, ) params = ListTasksRequest(history_length=1, page_size=10) context = create_server_call_context() @@ -252,11 +276,13 @@ async def test_on_list_tasks_applies_history_length(): @pytest.mark.asyncio -async def test_on_list_tasks_negative_history_length_error(): +async def test_on_list_tasks_negative_history_length_error(agent_card): """Test on_list_tasks raises error for negative history length.""" mock_task_store = AsyncMock(spec=TaskStore) request_handler = DefaultRequestHandler( - agent_executor=AsyncMock(spec=AgentExecutor), task_store=mock_task_store + agent_executor=AsyncMock(spec=AgentExecutor), + task_store=mock_task_store, + agent_card=agent_card, ) params = ListTasksRequest(history_length=-1, page_size=10) context = create_server_call_context() @@ -274,7 +300,9 @@ async def test_on_cancel_task_task_not_found(): mock_task_store.get.return_value = None request_handler = DefaultRequestHandler( - agent_executor=MockAgentExecutor(), task_store=mock_task_store + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + agent_card=agent_card, ) params = CancelTaskRequest(id='task_not_found_for_cancel') @@ -288,7 +316,7 @@ async def test_on_cancel_task_task_not_found(): @pytest.mark.asyncio -async def test_on_cancel_task_queue_tap_returns_none(): +async def test_on_cancel_task_queue_tap_returns_none(agent_card): """Test on_cancel_task when queue_manager.tap returns None.""" mock_task_store = AsyncMock(spec=TaskStore) sample_task = create_sample_task(task_id='tap_none_task') @@ -316,6 +344,7 @@ async def test_on_cancel_task_queue_tap_returns_none(): agent_executor=mock_agent_executor, task_store=mock_task_store, queue_manager=mock_queue_manager, + agent_card=agent_card, ) context = create_server_call_context() @@ -343,7 +372,7 @@ async def test_on_cancel_task_queue_tap_returns_none(): @pytest.mark.asyncio -async def test_on_cancel_task_cancels_running_agent(): +async def test_on_cancel_task_cancels_running_agent(agent_card): """Test on_cancel_task cancels a running agent task.""" task_id = 'running_agent_task_to_cancel' sample_task = create_sample_task(task_id=task_id) @@ -368,6 +397,7 @@ async def test_on_cancel_task_cancels_running_agent(): agent_executor=mock_agent_executor, task_store=mock_task_store, queue_manager=mock_queue_manager, + agent_card=agent_card, ) # Simulate a running agent task @@ -387,7 +417,7 @@ async def test_on_cancel_task_cancels_running_agent(): @pytest.mark.asyncio -async def test_on_cancel_task_completes_during_cancellation(): +async def test_on_cancel_task_completes_during_cancellation(agent_card): """Test on_cancel_task fails to cancel a task due to concurrent task completion.""" task_id = 'running_agent_task_to_cancel' sample_task = create_sample_task(task_id=task_id) @@ -412,6 +442,7 @@ async def test_on_cancel_task_completes_during_cancellation(): agent_executor=mock_agent_executor, task_store=mock_task_store, queue_manager=mock_queue_manager, + agent_card=agent_card, ) # Simulate a running agent task @@ -433,7 +464,7 @@ async def test_on_cancel_task_completes_during_cancellation(): @pytest.mark.asyncio -async def test_on_cancel_task_invalid_result_type(): +async def test_on_cancel_task_invalid_result_type(agent_card): """Test on_cancel_task when result_aggregator returns a Message instead of a Task.""" task_id = 'cancel_invalid_result_task' sample_task = create_sample_task(task_id=task_id) @@ -458,6 +489,7 @@ async def test_on_cancel_task_invalid_result_type(): agent_executor=mock_agent_executor, task_store=mock_task_store, queue_manager=mock_queue_manager, + agent_card=agent_card, ) with patch( @@ -477,7 +509,7 @@ async def test_on_cancel_task_invalid_result_type(): @pytest.mark.asyncio -async def test_on_message_send_with_push_notification(): +async def test_on_message_send_with_push_notification(agent_card): """Test on_message_send sets push notification info if provided.""" mock_task_store = AsyncMock(spec=TaskStore) mock_push_notification_store = AsyncMock(spec=PushNotificationConfigStore) @@ -513,6 +545,7 @@ async def test_on_message_send_with_push_notification(): task_store=mock_task_store, push_config_store=mock_push_notification_store, request_context_builder=mock_request_context_builder, + agent_card=agent_card, ) push_config = TaskPushNotificationConfig(url='http://callback.com/push') @@ -578,7 +611,9 @@ async def mock_current_result(): @pytest.mark.asyncio -async def test_on_message_send_with_push_notification_in_non_blocking_request(): +async def test_on_message_send_with_push_notification_in_non_blocking_request( + agent_card, +): """Test that push notification callback is called during background event processing for non-blocking requests.""" mock_task_store = AsyncMock(spec=TaskStore) mock_push_notification_store = AsyncMock(spec=PushNotificationConfigStore) @@ -617,6 +652,7 @@ async def test_on_message_send_with_push_notification_in_non_blocking_request(): push_config_store=mock_push_notification_store, request_context_builder=mock_request_context_builder, push_sender=mock_push_sender, + agent_card=agent_card, ) # Configure push notification @@ -717,7 +753,9 @@ async def mock_consume_and_break_on_interrupt( @pytest.mark.asyncio -async def test_on_message_send_with_push_notification_no_existing_Task(): +async def test_on_message_send_with_push_notification_no_existing_Task( + agent_card, +): """Test on_message_send for new task sets push notification info if provided.""" mock_task_store = AsyncMock(spec=TaskStore) mock_push_notification_store = AsyncMock(spec=PushNotificationConfigStore) @@ -742,6 +780,7 @@ async def test_on_message_send_with_push_notification_no_existing_Task(): task_store=mock_task_store, push_config_store=mock_push_notification_store, request_context_builder=mock_request_context_builder, + agent_card=agent_card, ) push_config = TaskPushNotificationConfig(url='http://callback.com/push') @@ -801,8 +840,8 @@ async def mock_current_result(): @pytest.mark.asyncio -async def test_on_message_send_no_result_from_aggregator(): - """Test on_message_send when aggregator returns (None, False).""" +async def test_on_message_send_no_result_from_aggregator(agent_card): + """Test on_message_send when aggregator returns (None, False). Completes unsuccessfully and raises InternalError.""" mock_task_store = AsyncMock(spec=TaskStore) mock_agent_executor = AsyncMock(spec=AgentExecutor) mock_request_context_builder = AsyncMock(spec=RequestContextBuilder) @@ -817,6 +856,7 @@ async def test_on_message_send_no_result_from_aggregator(): agent_executor=mock_agent_executor, task_store=mock_task_store, request_context_builder=mock_request_context_builder, + agent_card=agent_card, ) params = SendMessageRequest( message=Message( @@ -850,7 +890,8 @@ async def test_on_message_send_no_result_from_aggregator(): @pytest.mark.asyncio -async def test_on_message_send_task_id_mismatch(): +async def test_on_message_send_task_id_mismatch(agent_card): + """Test on_message_send returns InternalError if aggregator returns mismatched Task ID.""" """Test on_message_send when result task ID doesn't match request context task ID.""" mock_task_store = AsyncMock(spec=TaskStore) mock_agent_executor = AsyncMock(spec=AgentExecutor) @@ -868,6 +909,7 @@ async def test_on_message_send_task_id_mismatch(): agent_executor=mock_agent_executor, task_store=mock_task_store, request_context_builder=mock_request_context_builder, + agent_card=agent_card, ) params = SendMessageRequest( message=Message( @@ -935,7 +977,7 @@ async def cancel(self, context: RequestContext, event_queue: EventQueue): @pytest.mark.asyncio -async def test_on_message_send_non_blocking(): +async def test_on_message_send_non_blocking(agent_card): task_store = InMemoryTaskStore() push_store = InMemoryPushNotificationConfigStore() @@ -943,6 +985,7 @@ async def test_on_message_send_non_blocking(): agent_executor=HelloAgentExecutor(), task_store=task_store, push_config_store=push_store, + agent_card=agent_card, ) params = SendMessageRequest( message=Message( @@ -981,7 +1024,7 @@ async def test_on_message_send_non_blocking(): @pytest.mark.asyncio -async def test_on_message_send_limit_history(): +async def test_on_message_send_limit_history(agent_card): task_store = InMemoryTaskStore() push_store = InMemoryPushNotificationConfigStore() @@ -989,6 +1032,7 @@ async def test_on_message_send_limit_history(): agent_executor=HelloAgentExecutor(), task_store=task_store, push_config_store=push_store, + agent_card=agent_card, ) params = SendMessageRequest( message=Message( @@ -1018,7 +1062,7 @@ async def test_on_message_send_limit_history(): @pytest.mark.asyncio -async def test_on_get_task_limit_history(): +async def test_on_get_task_limit_history(agent_card): task_store = InMemoryTaskStore() push_store = InMemoryPushNotificationConfigStore() @@ -1026,6 +1070,7 @@ async def test_on_get_task_limit_history(): agent_executor=HelloAgentExecutor(), task_store=task_store, push_config_store=push_store, + agent_card=agent_card, ) params = SendMessageRequest( message=Message( @@ -1058,7 +1103,7 @@ async def test_on_get_task_limit_history(): @pytest.mark.asyncio -async def test_on_message_send_interrupted_flow(): +async def test_on_message_send_interrupted_flow(agent_card): """Test on_message_send when flow is interrupted (e.g., auth_required).""" mock_task_store = AsyncMock(spec=TaskStore) mock_agent_executor = AsyncMock(spec=AgentExecutor) @@ -1074,6 +1119,7 @@ async def test_on_message_send_interrupted_flow(): agent_executor=mock_agent_executor, task_store=mock_task_store, request_context_builder=mock_request_context_builder, + agent_card=agent_card, ) params = SendMessageRequest( message=Message( @@ -1139,7 +1185,7 @@ def capture_create_task(coro): @pytest.mark.asyncio -async def test_on_message_send_stream_with_push_notification(): +async def test_on_message_send_stream_with_push_notification(agent_card): """Test on_message_send_stream sets and uses push notification info.""" mock_task_store = AsyncMock(spec=TaskStore) mock_push_config_store = AsyncMock(spec=PushNotificationConfigStore) @@ -1177,6 +1223,7 @@ async def test_on_message_send_stream_with_push_notification(): push_config_store=mock_push_config_store, push_sender=mock_push_sender, request_context_builder=mock_request_context_builder, + agent_card=agent_card, ) push_config = TaskPushNotificationConfig( @@ -1286,7 +1333,9 @@ async def to_coro(val): @pytest.mark.asyncio -async def test_stream_disconnect_then_resubscribe_receives_future_events(): +async def test_stream_disconnect_then_resubscribe_receives_future_events( + agent_card, +): """Start streaming, disconnect, then resubscribe and ensure subsequent events are streamed.""" # Arrange mock_task_store = AsyncMock(spec=TaskStore) @@ -1310,6 +1359,7 @@ async def test_stream_disconnect_then_resubscribe_receives_future_events(): agent_executor=mock_agent_executor, task_store=mock_task_store, queue_manager=queue_manager, + agent_card=agent_card, ) params = SendMessageRequest( @@ -1377,7 +1427,9 @@ async def exec_side_effect(_request, queue: EventQueue): @pytest.mark.asyncio -async def test_on_message_send_stream_client_disconnect_triggers_background_cleanup_and_producer_continues(): +async def test_on_message_send_stream_client_disconnect_triggers_background_cleanup_and_producer_continues( + agent_card, +): """Simulate client disconnect: stream stops early, cleanup is scheduled in background, producer keeps running, and cleanup completes after producer finishes.""" # Arrange @@ -1408,6 +1460,7 @@ async def test_on_message_send_stream_client_disconnect_triggers_background_clea task_store=mock_task_store, queue_manager=mock_queue_manager, request_context_builder=mock_request_context_builder, + agent_card=agent_card, ) params = SendMessageRequest( @@ -1516,7 +1569,7 @@ def create_task_spy(coro): @pytest.mark.asyncio -async def test_disconnect_persists_final_task_to_store(): +async def test_disconnect_persists_final_task_to_store(agent_card): """After client disconnect, ensure background consumer persists final Task to store.""" task_store = InMemoryTaskStore() queue_manager = InMemoryQueueManager() @@ -1547,7 +1600,10 @@ async def cancel( agent = FinishingAgent() handler = DefaultRequestHandler( - agent_executor=agent, task_store=task_store, queue_manager=queue_manager + agent_executor=agent, + task_store=task_store, + queue_manager=queue_manager, + agent_card=agent_card, ) params = SendMessageRequest( @@ -1606,7 +1662,7 @@ async def wait_until(predicate, timeout: float = 0.2, interval: float = 0.0): @pytest.mark.asyncio -async def test_background_cleanup_task_is_tracked_and_cleared(): +async def test_background_cleanup_task_is_tracked_and_cleared(agent_card): """Ensure background cleanup task is tracked while pending and removed when done.""" # Arrange mock_task_store = AsyncMock(spec=TaskStore) @@ -1635,6 +1691,7 @@ async def test_background_cleanup_task_is_tracked_and_cleared(): task_store=mock_task_store, queue_manager=mock_queue_manager, request_context_builder=mock_request_context_builder, + agent_card=agent_card, ) params = SendMessageRequest( @@ -1724,7 +1781,7 @@ def create_task_spy(coro): @pytest.mark.asyncio -async def test_on_message_send_stream_task_id_mismatch(): +async def test_on_message_send_stream_task_id_mismatch(agent_card): """Test on_message_send_stream raises error if yielded task ID mismatches.""" mock_task_store = AsyncMock(spec=TaskStore) mock_agent_executor = AsyncMock( @@ -1743,6 +1800,7 @@ async def test_on_message_send_stream_task_id_mismatch(): agent_executor=mock_agent_executor, task_store=mock_task_store, request_context_builder=mock_request_context_builder, + agent_card=agent_card, ) params = SendMessageRequest( message=Message( @@ -1784,7 +1842,7 @@ async def event_stream_gen_mismatch(): @pytest.mark.asyncio -async def test_cleanup_producer_task_id_not_in_running_agents(): +async def test_cleanup_producer_task_id_not_in_running_agents(agent_card): """Test _cleanup_producer when task_id is not in _running_agents (e.g., already cleaned up).""" mock_task_store = AsyncMock(spec=TaskStore) mock_queue_manager = AsyncMock(spec=QueueManager) @@ -1792,6 +1850,7 @@ async def test_cleanup_producer_task_id_not_in_running_agents(): agent_executor=MockAgentExecutor(), task_store=mock_task_store, queue_manager=mock_queue_manager, + agent_card=agent_card, ) task_id = 'task_already_cleaned' @@ -1821,12 +1880,13 @@ async def noop_coro_for_task(): @pytest.mark.asyncio -async def test_set_task_push_notification_config_no_notifier(): +async def test_set_task_push_notification_config_no_notifier(agent_card): """Test on_create_task_push_notification_config when _push_config_store is None.""" request_handler = DefaultRequestHandler( agent_executor=MockAgentExecutor(), task_store=AsyncMock(spec=TaskStore), - push_config_store=None, # Explicitly None + push_config_store=None, # Explicitly None, + agent_card=agent_card, ) params = TaskPushNotificationConfig( task_id='task1', @@ -1840,7 +1900,7 @@ async def test_set_task_push_notification_config_no_notifier(): @pytest.mark.asyncio -async def test_set_task_push_notification_config_task_not_found(): +async def test_set_task_push_notification_config_task_not_found(agent_card): """Test on_create_task_push_notification_config when task is not found.""" mock_task_store = AsyncMock(spec=TaskStore) mock_task_store.get.return_value = None # Task not found @@ -1852,6 +1912,7 @@ async def test_set_task_push_notification_config_task_not_found(): task_store=mock_task_store, push_config_store=mock_push_store, push_sender=mock_push_sender, + agent_card=agent_card, ) params = TaskPushNotificationConfig( task_id='non_existent_task', @@ -1868,12 +1929,13 @@ async def test_set_task_push_notification_config_task_not_found(): @pytest.mark.asyncio -async def test_get_task_push_notification_config_no_store(): +async def test_get_task_push_notification_config_no_store(agent_card): """Test on_get_task_push_notification_config when _push_config_store is None.""" request_handler = DefaultRequestHandler( agent_executor=MockAgentExecutor(), task_store=AsyncMock(spec=TaskStore), - push_config_store=None, # Explicitly None + push_config_store=None, # Explicitly None, + agent_card=agent_card, ) params = GetTaskPushNotificationConfigRequest( task_id='task1', @@ -1887,7 +1949,7 @@ async def test_get_task_push_notification_config_no_store(): @pytest.mark.asyncio -async def test_get_task_push_notification_config_task_not_found(): +async def test_get_task_push_notification_config_task_not_found(agent_card): """Test on_get_task_push_notification_config when task is not found.""" mock_task_store = AsyncMock(spec=TaskStore) mock_task_store.get.return_value = None # Task not found @@ -1897,6 +1959,7 @@ async def test_get_task_push_notification_config_task_not_found(): agent_executor=MockAgentExecutor(), task_store=mock_task_store, push_config_store=mock_push_store, + agent_card=agent_card, ) params = GetTaskPushNotificationConfigRequest( task_id='non_existent_task', id='task_push_notification_config' @@ -1912,7 +1975,7 @@ async def test_get_task_push_notification_config_task_not_found(): @pytest.mark.asyncio -async def test_get_task_push_notification_config_info_not_found(): +async def test_get_task_push_notification_config_info_not_found(agent_card): """Test on_get_task_push_notification_config when push_config_store.get_info returns None.""" mock_task_store = AsyncMock(spec=TaskStore) @@ -1926,13 +1989,14 @@ async def test_get_task_push_notification_config_info_not_found(): agent_executor=MockAgentExecutor(), task_store=mock_task_store, push_config_store=mock_push_store, + agent_card=agent_card, ) params = GetTaskPushNotificationConfigRequest( task_id='non_existent_task', id='task_push_notification_config' ) context = create_server_call_context() - with pytest.raises(InternalError): + with pytest.raises(TaskNotFoundError): await request_handler.on_get_task_push_notification_config( params, context ) @@ -1943,7 +2007,7 @@ async def test_get_task_push_notification_config_info_not_found(): @pytest.mark.asyncio -async def test_get_task_push_notification_config_info_with_config(): +async def test_get_task_push_notification_config_info_with_config(agent_card): """Test on_get_task_push_notification_config with valid push config id""" mock_task_store = AsyncMock(spec=TaskStore) mock_task_store.get.return_value = Task(id='task_1', context_id='ctx_1') @@ -1954,6 +2018,7 @@ async def test_get_task_push_notification_config_info_with_config(): agent_executor=MockAgentExecutor(), task_store=mock_task_store, push_config_store=push_store, + agent_card=agent_card, ) set_config_params = TaskPushNotificationConfig( @@ -1981,7 +2046,9 @@ async def test_get_task_push_notification_config_info_with_config(): @pytest.mark.asyncio -async def test_get_task_push_notification_config_info_with_config_no_id(): +async def test_get_task_push_notification_config_info_with_config_no_id( + agent_card, +): """Test on_get_task_push_notification_config with no push config id""" mock_task_store = AsyncMock(spec=TaskStore) mock_task_store.get.return_value = Task(id='task_1', context_id='ctx_1') @@ -1992,6 +2059,7 @@ async def test_get_task_push_notification_config_info_with_config_no_id(): agent_executor=MockAgentExecutor(), task_store=mock_task_store, push_config_store=push_store, + agent_card=agent_card, ) set_config_params = TaskPushNotificationConfig( @@ -2017,13 +2085,15 @@ async def test_get_task_push_notification_config_info_with_config_no_id(): @pytest.mark.asyncio -async def test_on_subscribe_to_task_task_not_found(): +async def test_on_subscribe_to_task_task_not_found(agent_card): """Test on_subscribe_to_task when the task is not found.""" mock_task_store = AsyncMock(spec=TaskStore) mock_task_store.get.return_value = None # Task not found request_handler = DefaultRequestHandler( - agent_executor=MockAgentExecutor(), task_store=mock_task_store + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + agent_card=agent_card, ) params = SubscribeToTaskRequest(id='resub_task_not_found') @@ -2038,7 +2108,7 @@ async def test_on_subscribe_to_task_task_not_found(): @pytest.mark.asyncio -async def test_on_subscribe_to_task_queue_not_found(): +async def test_on_subscribe_to_task_queue_not_found(agent_card): """Test on_subscribe_to_task when the queue is not found by queue_manager.tap.""" mock_task_store = AsyncMock(spec=TaskStore) sample_task = create_sample_task(task_id='resub_queue_not_found') @@ -2051,6 +2121,7 @@ async def test_on_subscribe_to_task_queue_not_found(): agent_executor=MockAgentExecutor(), task_store=mock_task_store, queue_manager=mock_queue_manager, + agent_card=agent_card, ) params = SubscribeToTaskRequest(id='resub_queue_not_found') @@ -2065,9 +2136,11 @@ async def test_on_subscribe_to_task_queue_not_found(): @pytest.mark.asyncio -async def test_on_message_send_stream(): +async def test_on_message_send_stream(agent_card): request_handler = DefaultRequestHandler( - MockAgentExecutor(), InMemoryTaskStore() + MockAgentExecutor(), + InMemoryTaskStore(), + agent_card=agent_card, ) message_params = SendMessageRequest( message=Message( @@ -2102,12 +2175,13 @@ async def consume_stream(): @pytest.mark.asyncio -async def test_list_task_push_notification_config_no_store(): +async def test_list_task_push_notification_config_no_store(agent_card): """Test on_list_task_push_notification_configs when _push_config_store is None.""" request_handler = DefaultRequestHandler( agent_executor=MockAgentExecutor(), task_store=AsyncMock(spec=TaskStore), - push_config_store=None, # Explicitly None + push_config_store=None, # Explicitly None, + agent_card=agent_card, ) params = ListTaskPushNotificationConfigsRequest(task_id='task1') @@ -2118,7 +2192,7 @@ async def test_list_task_push_notification_config_no_store(): @pytest.mark.asyncio -async def test_list_task_push_notification_config_task_not_found(): +async def test_list_task_push_notification_config_task_not_found(agent_card): """Test on_list_task_push_notification_configs when task is not found.""" mock_task_store = AsyncMock(spec=TaskStore) mock_task_store.get.return_value = None # Task not found @@ -2128,6 +2202,7 @@ async def test_list_task_push_notification_config_task_not_found(): agent_executor=MockAgentExecutor(), task_store=mock_task_store, push_config_store=mock_push_store, + agent_card=agent_card, ) params = ListTaskPushNotificationConfigsRequest(task_id='non_existent_task') @@ -2141,7 +2216,7 @@ async def test_list_task_push_notification_config_task_not_found(): @pytest.mark.asyncio -async def test_list_no_task_push_notification_config_info(): +async def test_list_no_task_push_notification_config_info(agent_card): """Test on_get_task_push_notification_config when push_config_store.get_info returns []""" mock_task_store = AsyncMock(spec=TaskStore) @@ -2154,6 +2229,7 @@ async def test_list_no_task_push_notification_config_info(): agent_executor=MockAgentExecutor(), task_store=mock_task_store, push_config_store=push_store, + agent_card=agent_card, ) params = ListTaskPushNotificationConfigsRequest(task_id='non_existent_task') @@ -2164,7 +2240,7 @@ async def test_list_no_task_push_notification_config_info(): @pytest.mark.asyncio -async def test_list_task_push_notification_config_info_with_config(): +async def test_list_task_push_notification_config_info_with_config(agent_card): """Test on_list_task_push_notification_configs with push config+id""" mock_task_store = AsyncMock(spec=TaskStore) @@ -2187,6 +2263,7 @@ async def test_list_task_push_notification_config_info_with_config(): agent_executor=MockAgentExecutor(), task_store=mock_task_store, push_config_store=push_store, + agent_card=agent_card, ) params = ListTaskPushNotificationConfigsRequest(task_id='task_1') @@ -2202,7 +2279,9 @@ async def test_list_task_push_notification_config_info_with_config(): @pytest.mark.asyncio -async def test_list_task_push_notification_config_info_with_config_and_no_id(): +async def test_list_task_push_notification_config_info_with_config_and_no_id( + agent_card, +): """Test on_list_task_push_notification_configs with no push config id""" mock_task_store = AsyncMock(spec=TaskStore) mock_task_store.get.return_value = Task(id='task_1', context_id='ctx_1') @@ -2213,6 +2292,7 @@ async def test_list_task_push_notification_config_info_with_config_and_no_id(): agent_executor=MockAgentExecutor(), task_store=mock_task_store, push_config_store=push_store, + agent_card=agent_card, ) # multiple calls without config id should replace the existing @@ -2245,12 +2325,13 @@ async def test_list_task_push_notification_config_info_with_config_and_no_id(): @pytest.mark.asyncio -async def test_delete_task_push_notification_config_no_store(): +async def test_delete_task_push_notification_config_no_store(agent_card): """Test on_delete_task_push_notification_config when _push_config_store is None.""" request_handler = DefaultRequestHandler( agent_executor=MockAgentExecutor(), task_store=AsyncMock(spec=TaskStore), - push_config_store=None, # Explicitly None + push_config_store=None, # Explicitly None, + agent_card=agent_card, ) params = DeleteTaskPushNotificationConfigRequest( task_id='task1', id='config1' @@ -2263,7 +2344,7 @@ async def test_delete_task_push_notification_config_no_store(): @pytest.mark.asyncio -async def test_delete_task_push_notification_config_task_not_found(): +async def test_delete_task_push_notification_config_task_not_found(agent_card): """Test on_delete_task_push_notification_config when task is not found.""" mock_task_store = AsyncMock(spec=TaskStore) mock_task_store.get.return_value = None # Task not found @@ -2273,6 +2354,7 @@ async def test_delete_task_push_notification_config_task_not_found(): agent_executor=MockAgentExecutor(), task_store=mock_task_store, push_config_store=mock_push_store, + agent_card=agent_card, ) params = DeleteTaskPushNotificationConfigRequest( task_id='non_existent_task', id='config1' @@ -2289,7 +2371,7 @@ async def test_delete_task_push_notification_config_task_not_found(): @pytest.mark.asyncio -async def test_delete_no_task_push_notification_config_info(): +async def test_delete_no_task_push_notification_config_info(agent_card): """Test on_delete_task_push_notification_config without config info""" mock_task_store = AsyncMock(spec=TaskStore) @@ -2307,6 +2389,7 @@ async def test_delete_no_task_push_notification_config_info(): agent_executor=MockAgentExecutor(), task_store=mock_task_store, push_config_store=push_store, + agent_card=agent_card, ) params = DeleteTaskPushNotificationConfigRequest( task_id='task1', id='config_non_existant' @@ -2328,7 +2411,9 @@ async def test_delete_no_task_push_notification_config_info(): @pytest.mark.asyncio -async def test_delete_task_push_notification_config_info_with_config(): +async def test_delete_task_push_notification_config_info_with_config( + agent_card, +): """Test on_list_task_push_notification_configs with push config+id""" mock_task_store = AsyncMock(spec=TaskStore) @@ -2352,6 +2437,7 @@ async def test_delete_task_push_notification_config_info_with_config(): agent_executor=MockAgentExecutor(), task_store=mock_task_store, push_config_store=push_store, + agent_card=agent_card, ) params = DeleteTaskPushNotificationConfigRequest( task_id='task_1', id='config_1' @@ -2374,7 +2460,9 @@ async def test_delete_task_push_notification_config_info_with_config(): @pytest.mark.asyncio -async def test_delete_task_push_notification_config_info_with_config_and_no_id(): +async def test_delete_task_push_notification_config_info_with_config_and_no_id( + agent_card, +): """Test on_list_task_push_notification_configs with no push config id""" mock_task_store = AsyncMock(spec=TaskStore) @@ -2393,6 +2481,7 @@ async def test_delete_task_push_notification_config_info_with_config_and_no_id() agent_executor=MockAgentExecutor(), task_store=mock_task_store, push_config_store=push_store, + agent_card=agent_card, ) params = DeleteTaskPushNotificationConfigRequest( task_id='task_1', id='task_1' @@ -2422,7 +2511,9 @@ async def test_delete_task_push_notification_config_info_with_config_and_no_id() @pytest.mark.asyncio @pytest.mark.parametrize('terminal_state', TERMINAL_TASK_STATES) -async def test_on_message_send_task_in_terminal_state(terminal_state): +async def test_on_message_send_task_in_terminal_state( + terminal_state, agent_card +): """Test on_message_send when task is already in a terminal state.""" state_name = TaskState.Name(terminal_state) task_id = f'terminal_task_{state_name}' @@ -2436,7 +2527,9 @@ async def test_on_message_send_task_in_terminal_state(terminal_state): # So we should patch that instead. request_handler = DefaultRequestHandler( - agent_executor=MockAgentExecutor(), task_store=mock_task_store + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + agent_card=agent_card, ) params = SendMessageRequest( @@ -2466,7 +2559,9 @@ async def test_on_message_send_task_in_terminal_state(terminal_state): @pytest.mark.asyncio @pytest.mark.parametrize('terminal_state', TERMINAL_TASK_STATES) -async def test_on_message_send_stream_task_in_terminal_state(terminal_state): +async def test_on_message_send_stream_task_in_terminal_state( + terminal_state, agent_card +): """Test on_message_send_stream when task is already in a terminal state.""" state_name = TaskState.Name(terminal_state) task_id = f'terminal_stream_task_{state_name}' @@ -2477,7 +2572,9 @@ async def test_on_message_send_stream_task_in_terminal_state(terminal_state): mock_task_store = AsyncMock(spec=TaskStore) request_handler = DefaultRequestHandler( - agent_executor=MockAgentExecutor(), task_store=mock_task_store + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + agent_card=agent_card, ) params = SendMessageRequest( @@ -2507,7 +2604,9 @@ async def test_on_message_send_stream_task_in_terminal_state(terminal_state): @pytest.mark.asyncio @pytest.mark.parametrize('terminal_state', TERMINAL_TASK_STATES) -async def test_on_subscribe_to_task_in_terminal_state(terminal_state): +async def test_on_subscribe_to_task_in_terminal_state( + terminal_state, agent_card +): """Test on_subscribe_to_task when task is in a terminal state.""" state_name = TaskState.Name(terminal_state) task_id = f'resub_terminal_task_{state_name}' @@ -2522,6 +2621,7 @@ async def test_on_subscribe_to_task_in_terminal_state(terminal_state): agent_executor=MockAgentExecutor(), task_store=mock_task_store, queue_manager=AsyncMock(spec=QueueManager), + agent_card=agent_card, ) params = SubscribeToTaskRequest(id=f'{task_id}') @@ -2539,13 +2639,15 @@ async def test_on_subscribe_to_task_in_terminal_state(terminal_state): @pytest.mark.asyncio -async def test_on_message_send_task_id_provided_but_task_not_found(): +async def test_on_message_send_task_id_provided_but_task_not_found(agent_card): """Test on_message_send when task_id is provided but task doesn't exist.""" task_id = 'nonexistent_task' mock_task_store = AsyncMock(spec=TaskStore) request_handler = DefaultRequestHandler( - agent_executor=MockAgentExecutor(), task_store=mock_task_store + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + agent_card=agent_card, ) params = SendMessageRequest( @@ -2575,13 +2677,17 @@ async def test_on_message_send_task_id_provided_but_task_not_found(): @pytest.mark.asyncio -async def test_on_message_send_stream_task_id_provided_but_task_not_found(): +async def test_on_message_send_stream_task_id_provided_but_task_not_found( + agent_card, +): """Test on_message_send_stream when task_id is provided but task doesn't exist.""" task_id = 'nonexistent_stream_task' mock_task_store = AsyncMock(spec=TaskStore) request_handler = DefaultRequestHandler( - agent_executor=MockAgentExecutor(), task_store=mock_task_store + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + agent_card=agent_card, ) params = SendMessageRequest( @@ -2639,14 +2745,16 @@ async def cancel( # we should reconsider the approach. @pytest.mark.asyncio @pytest.mark.timeout(1) -async def test_on_message_send_error_does_not_hang(): +async def test_on_message_send_error_does_not_hang(agent_card): """Test that if the consumer raises an exception during blocking wait, the producer is cancelled and no deadlock occurs.""" agent = HelloWorldAgentExecutor() task_store = AsyncMock(spec=TaskStore) task_store.save.side_effect = RuntimeError('This is an Error!') request_handler = DefaultRequestHandler( - agent_executor=agent, task_store=task_store + agent_executor=agent, + task_store=task_store, + agent_card=agent_card, ) params = SendMessageRequest( @@ -2664,11 +2772,13 @@ async def test_on_message_send_error_does_not_hang(): @pytest.mark.asyncio -async def test_on_get_task_negative_history_length_error(): +async def test_on_get_task_negative_history_length_error(agent_card): """Test on_get_task raises error for negative history length.""" mock_task_store = AsyncMock(spec=TaskStore) request_handler = DefaultRequestHandler( - agent_executor=AsyncMock(spec=AgentExecutor), task_store=mock_task_store + agent_executor=AsyncMock(spec=AgentExecutor), + task_store=mock_task_store, + agent_card=agent_card, ) # GetTaskRequest also has history_length params = GetTaskRequest(id='task1', history_length=-1) @@ -2681,11 +2791,13 @@ async def test_on_get_task_negative_history_length_error(): @pytest.mark.asyncio -async def test_on_list_tasks_page_size_too_small(): +async def test_on_list_tasks_page_size_too_small(agent_card): """Test on_list_tasks raises error for page_size < 1.""" mock_task_store = AsyncMock(spec=TaskStore) request_handler = DefaultRequestHandler( - agent_executor=AsyncMock(spec=AgentExecutor), task_store=mock_task_store + agent_executor=AsyncMock(spec=AgentExecutor), + task_store=mock_task_store, + agent_card=agent_card, ) params = ListTasksRequest(page_size=0) context = create_server_call_context() @@ -2697,11 +2809,13 @@ async def test_on_list_tasks_page_size_too_small(): @pytest.mark.asyncio -async def test_on_list_tasks_page_size_too_large(): +async def test_on_list_tasks_page_size_too_large(agent_card): """Test on_list_tasks raises error for page_size > 100.""" mock_task_store = AsyncMock(spec=TaskStore) request_handler = DefaultRequestHandler( - agent_executor=AsyncMock(spec=AgentExecutor), task_store=mock_task_store + agent_executor=AsyncMock(spec=AgentExecutor), + task_store=mock_task_store, + agent_card=agent_card, ) params = ListTasksRequest(page_size=101) context = create_server_call_context() @@ -2713,12 +2827,14 @@ async def test_on_list_tasks_page_size_too_large(): @pytest.mark.asyncio -async def test_on_message_send_negative_history_length_error(): +async def test_on_message_send_negative_history_length_error(agent_card): """Test on_message_send raises error for negative history length in configuration.""" mock_task_store = AsyncMock(spec=TaskStore) mock_agent_executor = AsyncMock(spec=AgentExecutor) request_handler = DefaultRequestHandler( - agent_executor=mock_agent_executor, task_store=mock_task_store + agent_executor=mock_agent_executor, + task_store=mock_task_store, + agent_card=agent_card, ) message_config = SendMessageConfiguration( @@ -2737,3 +2853,119 @@ async def test_on_message_send_negative_history_length_error(): await request_handler.on_message_send(params, context) assert 'history length must be non-negative' in exc_info.value.message + + +@pytest.mark.asyncio +async def test_on_get_extended_agent_card_success(agent_card): + """Test on_get_extended_agent_card when extended_agent_card is supported.""" + agent_card.capabilities.extended_agent_card = True + + extended_agent_card = AgentCard( + name='Extended Agent', + description='An extended agent', + version='1.0.0', + capabilities=AgentCapabilities( + streaming=True, + push_notifications=True, + extended_agent_card=True, + ), + ) + + request_handler = DefaultRequestHandler( + agent_executor=AsyncMock(spec=AgentExecutor), + task_store=AsyncMock(spec=TaskStore), + agent_card=agent_card, + extended_agent_card=extended_agent_card, + ) + + params = GetExtendedAgentCardRequest() + context = create_server_call_context() + + result = await request_handler.on_get_extended_agent_card(params, context) + + assert result == extended_agent_card + + +@pytest.mark.asyncio +async def test_on_message_send_stream_unsupported(agent_card): + """Test on_message_send_stream when streaming is unsupported.""" + agent_card.capabilities.streaming = False + + request_handler = DefaultRequestHandler( + agent_executor=AsyncMock(spec=AgentExecutor), + task_store=AsyncMock(spec=TaskStore), + agent_card=agent_card, + ) + + params = SendMessageRequest( + message=Message( + role=Role.ROLE_USER, + message_id='msg-unsupported', + parts=[Part(text='hi')], + ) + ) + + context = create_server_call_context() + + with pytest.raises(UnsupportedOperationError): + async for _ in request_handler.on_message_send_stream(params, context): + pass + + +@pytest.mark.asyncio +async def test_on_get_extended_agent_card_unsupported(agent_card): + """Test on_get_extended_agent_card when extended_agent_card is unsupported.""" + agent_card.capabilities.extended_agent_card = False + + request_handler = DefaultRequestHandler( + agent_executor=AsyncMock(spec=AgentExecutor), + task_store=AsyncMock(spec=TaskStore), + agent_card=agent_card, + ) + + params = GetExtendedAgentCardRequest() + context = create_server_call_context() + + with pytest.raises(UnsupportedOperationError): + await request_handler.on_get_extended_agent_card(params, context) + + +@pytest.mark.asyncio +async def test_on_create_task_push_notification_config_unsupported(agent_card): + """Test on_create_task_push_notification_config when push_notifications is unsupported.""" + agent_card.capabilities.push_notifications = False + + request_handler = DefaultRequestHandler( + agent_executor=AsyncMock(spec=AgentExecutor), + task_store=AsyncMock(spec=TaskStore), + agent_card=agent_card, + ) + + params = TaskPushNotificationConfig(url='http://callback.com/push') + + context = create_server_call_context() + + with pytest.raises(PushNotificationNotSupportedError): + await request_handler.on_create_task_push_notification_config( + params, context + ) + + +@pytest.mark.asyncio +async def test_on_subscribe_to_task_unsupported(agent_card): + """Test on_subscribe_to_task when streaming is unsupported.""" + agent_card.capabilities.streaming = False + + request_handler = DefaultRequestHandler( + agent_executor=AsyncMock(spec=AgentExecutor), + task_store=AsyncMock(spec=TaskStore), + agent_card=agent_card, + ) + + params = SubscribeToTaskRequest(id='some_task') + context = create_server_call_context() + + with pytest.raises(UnsupportedOperationError): + # We need to exhaust the generator to trigger the decorator evaluation + async for _ in request_handler.on_subscribe_to_task(params, context): + pass diff --git a/tests/server/request_handlers/test_default_request_handler_v2.py b/tests/server/request_handlers/test_default_request_handler_v2.py index abe35bf64..605078201 100644 --- a/tests/server/request_handlers/test_default_request_handler_v2.py +++ b/tests/server/request_handlers/test_default_request_handler_v2.py @@ -30,9 +30,11 @@ InternalError, InvalidParamsError, TaskNotFoundError, - UnsupportedOperationError, + PushNotificationNotSupportedError, ) from a2a.types.a2a_pb2 import ( + AgentCapabilities, + AgentCard, Artifact, CancelTaskRequest, DeleteTaskPushNotificationConfigRequest, @@ -55,6 +57,15 @@ from a2a.utils import new_agent_text_message, new_task +def create_default_agent_card(): + """Provides a standard AgentCard with streaming and push notifications enabled for tests.""" + return AgentCard( + name='test_agent', + version='1.0', + capabilities=AgentCapabilities(streaming=True, push_notifications=True), + ) + + class MockAgentExecutor(AgentExecutor): async def execute(self, context: RequestContext, event_queue: EventQueue): task_updater = TaskUpdater( @@ -99,7 +110,9 @@ def test_init_default_dependencies(): agent_executor = MockAgentExecutor() task_store = InMemoryTaskStore() handler = DefaultRequestHandlerV2( - agent_executor=agent_executor, task_store=task_store + agent_executor=agent_executor, + task_store=task_store, + agent_card=create_default_agent_card(), ) assert isinstance(handler._active_task_registry, ActiveTaskRegistry) assert isinstance( @@ -120,7 +133,9 @@ async def test_on_get_task_not_found(): mock_task_store = AsyncMock(spec=TaskStore) mock_task_store.get.return_value = None request_handler = DefaultRequestHandlerV2( - agent_executor=MockAgentExecutor(), task_store=mock_task_store + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + agent_card=create_default_agent_card(), ) params = GetTaskRequest(id='non_existent_task') context = create_server_call_context() @@ -149,7 +164,9 @@ async def test_on_list_tasks_success(): ) mock_task_store.list.return_value = mock_page request_handler = DefaultRequestHandlerV2( - agent_executor=AsyncMock(spec=AgentExecutor), task_store=mock_task_store + agent_executor=AsyncMock(spec=AgentExecutor), + task_store=mock_task_store, + agent_card=create_default_agent_card(), ) params = ListTasksRequest(include_artifacts=True, page_size=10) context = create_server_call_context() @@ -179,7 +196,9 @@ async def test_on_list_tasks_excludes_artifacts(): ) mock_task_store.list.return_value = mock_page request_handler = DefaultRequestHandlerV2( - agent_executor=AsyncMock(spec=AgentExecutor), task_store=mock_task_store + agent_executor=AsyncMock(spec=AgentExecutor), + task_store=mock_task_store, + agent_card=create_default_agent_card(), ) params = ListTasksRequest(include_artifacts=False, page_size=10) context = create_server_call_context() @@ -203,7 +222,9 @@ async def test_on_list_tasks_applies_history_length(): ) mock_task_store.list.return_value = mock_page request_handler = DefaultRequestHandlerV2( - agent_executor=AsyncMock(spec=AgentExecutor), task_store=mock_task_store + agent_executor=AsyncMock(spec=AgentExecutor), + task_store=mock_task_store, + agent_card=create_default_agent_card(), ) params = ListTasksRequest(history_length=1, page_size=10) context = create_server_call_context() @@ -216,7 +237,9 @@ async def test_on_list_tasks_negative_history_length_error(): """Test on_list_tasks raises error for negative history length.""" mock_task_store = AsyncMock(spec=TaskStore) request_handler = DefaultRequestHandlerV2( - agent_executor=AsyncMock(spec=AgentExecutor), task_store=mock_task_store + agent_executor=AsyncMock(spec=AgentExecutor), + task_store=mock_task_store, + agent_card=create_default_agent_card(), ) params = ListTasksRequest(history_length=-1, page_size=10) context = create_server_call_context() @@ -231,7 +254,9 @@ async def test_on_cancel_task_task_not_found(): mock_task_store = AsyncMock(spec=TaskStore) mock_task_store.get.return_value = None request_handler = DefaultRequestHandlerV2( - agent_executor=MockAgentExecutor(), task_store=mock_task_store + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + agent_card=create_default_agent_card(), ) params = CancelTaskRequest(id='task_not_found_for_cancel') context = create_server_call_context() @@ -278,6 +303,7 @@ async def test_on_get_task_limit_history(): agent_executor=HelloAgentExecutor(), task_store=task_store, push_config_store=push_store, + agent_card=create_default_agent_card(), ) params = SendMessageRequest( message=Message( @@ -323,11 +349,12 @@ async def test_set_task_push_notification_config_no_notifier(): agent_executor=MockAgentExecutor(), task_store=AsyncMock(spec=TaskStore), push_config_store=None, + agent_card=create_default_agent_card(), ) params = TaskPushNotificationConfig( task_id='task1', url='http://example.com' ) - with pytest.raises(UnsupportedOperationError): + with pytest.raises(PushNotificationNotSupportedError): await request_handler.on_create_task_push_notification_config( params, create_server_call_context() ) @@ -345,6 +372,7 @@ async def test_set_task_push_notification_config_task_not_found(): task_store=mock_task_store, push_config_store=mock_push_store, push_sender=mock_push_sender, + agent_card=create_default_agent_card(), ) params = TaskPushNotificationConfig( task_id='non_existent_task', url='http://example.com' @@ -365,11 +393,12 @@ async def test_get_task_push_notification_config_no_store(): agent_executor=MockAgentExecutor(), task_store=AsyncMock(spec=TaskStore), push_config_store=None, + agent_card=create_default_agent_card(), ) params = GetTaskPushNotificationConfigRequest( task_id='task1', id='task_push_notification_config' ) - with pytest.raises(UnsupportedOperationError): + with pytest.raises(PushNotificationNotSupportedError): await request_handler.on_get_task_push_notification_config( params, create_server_call_context() ) @@ -385,6 +414,7 @@ async def test_get_task_push_notification_config_task_not_found(): agent_executor=MockAgentExecutor(), task_store=mock_task_store, push_config_store=mock_push_store, + agent_card=create_default_agent_card(), ) params = GetTaskPushNotificationConfigRequest( task_id='non_existent_task', id='task_push_notification_config' @@ -410,12 +440,13 @@ async def test_get_task_push_notification_config_info_not_found(): agent_executor=MockAgentExecutor(), task_store=mock_task_store, push_config_store=mock_push_store, + agent_card=create_default_agent_card(), ) params = GetTaskPushNotificationConfigRequest( task_id='non_existent_task', id='task_push_notification_config' ) context = create_server_call_context() - with pytest.raises(InternalError): + with pytest.raises(TaskNotFoundError): await request_handler.on_get_task_push_notification_config( params, context ) @@ -435,6 +466,7 @@ async def test_get_task_push_notification_config_info_with_config(): agent_executor=MockAgentExecutor(), task_store=mock_task_store, push_config_store=push_store, + agent_card=create_default_agent_card(), ) set_config_params = TaskPushNotificationConfig( task_id='task_1', id='config_id', url='http://1.example.com' @@ -467,6 +499,7 @@ async def test_get_task_push_notification_config_info_with_config_no_id(): agent_executor=MockAgentExecutor(), task_store=mock_task_store, push_config_store=push_store, + agent_card=create_default_agent_card(), ) set_config_params = TaskPushNotificationConfig( task_id='task_1', url='http://1.example.com' @@ -492,7 +525,9 @@ async def test_on_subscribe_to_task_task_not_found(): mock_task_store = AsyncMock(spec=TaskStore) mock_task_store.get.return_value = None request_handler = DefaultRequestHandlerV2( - agent_executor=MockAgentExecutor(), task_store=mock_task_store + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + agent_card=create_default_agent_card(), ) params = SubscribeToTaskRequest(id='resub_task_not_found') context = create_server_call_context() @@ -507,7 +542,9 @@ async def test_on_subscribe_to_task_task_not_found(): @pytest.mark.asyncio async def test_on_message_send_stream(): request_handler = DefaultRequestHandlerV2( - MockAgentExecutor(), InMemoryTaskStore() + MockAgentExecutor(), + InMemoryTaskStore(), + create_default_agent_card(), ) message_params = SendMessageRequest( message=Message( @@ -543,9 +580,10 @@ async def test_list_task_push_notification_config_no_store(): agent_executor=MockAgentExecutor(), task_store=AsyncMock(spec=TaskStore), push_config_store=None, + agent_card=create_default_agent_card(), ) params = ListTaskPushNotificationConfigsRequest(task_id='task1') - with pytest.raises(UnsupportedOperationError): + with pytest.raises(PushNotificationNotSupportedError): await request_handler.on_list_task_push_notification_configs( params, create_server_call_context() ) @@ -561,6 +599,7 @@ async def test_list_task_push_notification_config_task_not_found(): agent_executor=MockAgentExecutor(), task_store=mock_task_store, push_config_store=mock_push_store, + agent_card=create_default_agent_card(), ) params = ListTaskPushNotificationConfigsRequest(task_id='non_existent_task') context = create_server_call_context() @@ -583,6 +622,7 @@ async def test_list_no_task_push_notification_config_info(): agent_executor=MockAgentExecutor(), task_store=mock_task_store, push_config_store=push_store, + agent_card=create_default_agent_card(), ) params = ListTaskPushNotificationConfigsRequest(task_id='non_existent_task') result = await request_handler.on_list_task_push_notification_configs( @@ -612,6 +652,7 @@ async def test_list_task_push_notification_config_info_with_config(): agent_executor=MockAgentExecutor(), task_store=mock_task_store, push_config_store=push_store, + agent_card=create_default_agent_card(), ) params = ListTaskPushNotificationConfigsRequest(task_id='task_1') result = await request_handler.on_list_task_push_notification_configs( @@ -634,6 +675,7 @@ async def test_list_task_push_notification_config_info_with_config_and_no_id(): agent_executor=MockAgentExecutor(), task_store=mock_task_store, push_config_store=push_store, + agent_card=create_default_agent_card(), ) set_config_params1 = TaskPushNotificationConfig( task_id='task_1', url='http://1.example.com' @@ -664,15 +706,16 @@ async def test_delete_task_push_notification_config_no_store(): agent_executor=MockAgentExecutor(), task_store=AsyncMock(spec=TaskStore), push_config_store=None, + agent_card=create_default_agent_card(), ) params = DeleteTaskPushNotificationConfigRequest( task_id='task1', id='config1' ) - with pytest.raises(UnsupportedOperationError) as exc_info: + with pytest.raises(PushNotificationNotSupportedError) as exc_info: await request_handler.on_delete_task_push_notification_config( params, create_server_call_context() ) - assert isinstance(exc_info.value, UnsupportedOperationError) + assert isinstance(exc_info.value, PushNotificationNotSupportedError) @pytest.mark.asyncio @@ -685,6 +728,7 @@ async def test_delete_task_push_notification_config_task_not_found(): agent_executor=MockAgentExecutor(), task_store=mock_task_store, push_config_store=mock_push_store, + agent_card=create_default_agent_card(), ) params = DeleteTaskPushNotificationConfigRequest( task_id='non_existent_task', id='config1' @@ -714,6 +758,7 @@ async def test_delete_no_task_push_notification_config_info(): agent_executor=MockAgentExecutor(), task_store=mock_task_store, push_config_store=push_store, + agent_card=create_default_agent_card(), ) params = DeleteTaskPushNotificationConfigRequest( task_id='task1', id='config_non_existant' @@ -752,6 +797,7 @@ async def test_delete_task_push_notification_config_info_with_config(): agent_executor=MockAgentExecutor(), task_store=mock_task_store, push_config_store=push_store, + agent_card=create_default_agent_card(), ) params = DeleteTaskPushNotificationConfigRequest( task_id='task_1', id='config_1' @@ -784,6 +830,7 @@ async def test_delete_task_push_notification_config_info_with_config_and_no_id() agent_executor=MockAgentExecutor(), task_store=mock_task_store, push_config_store=push_store, + agent_card=create_default_agent_card(), ) params = DeleteTaskPushNotificationConfigRequest( task_id='task_1', id='task_1' @@ -818,7 +865,9 @@ async def test_on_message_send_task_in_terminal_state(terminal_state): ) mock_task_store = AsyncMock(spec=TaskStore) request_handler = DefaultRequestHandlerV2( - agent_executor=MockAgentExecutor(), task_store=mock_task_store + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + agent_card=create_default_agent_card(), ) params = SendMessageRequest( message=Message( @@ -855,7 +904,9 @@ async def test_on_message_send_stream_task_in_terminal_state(terminal_state): ) mock_task_store = AsyncMock(spec=TaskStore) request_handler = DefaultRequestHandlerV2( - agent_executor=MockAgentExecutor(), task_store=mock_task_store + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + agent_card=create_default_agent_card(), ) params = SendMessageRequest( message=Message( @@ -924,7 +975,9 @@ async def test_on_message_send_error_does_not_hang(): task_store.save.side_effect = RuntimeError('This is an Error!') request_handler = DefaultRequestHandlerV2( - agent_executor=agent, task_store=task_store + agent_executor=agent, + task_store=task_store, + agent_card=create_default_agent_card(), ) params = SendMessageRequest( @@ -945,7 +998,9 @@ async def test_on_get_task_negative_history_length_error(): """Test on_get_task raises error for negative history length.""" mock_task_store = AsyncMock(spec=TaskStore) request_handler = DefaultRequestHandlerV2( - agent_executor=AsyncMock(spec=AgentExecutor), task_store=mock_task_store + agent_executor=AsyncMock(spec=AgentExecutor), + task_store=mock_task_store, + agent_card=create_default_agent_card(), ) params = GetTaskRequest(id='task1', history_length=-1) context = create_server_call_context() @@ -959,7 +1014,9 @@ async def test_on_list_tasks_page_size_too_small(): """Test on_list_tasks raises error for page_size < 1.""" mock_task_store = AsyncMock(spec=TaskStore) request_handler = DefaultRequestHandlerV2( - agent_executor=AsyncMock(spec=AgentExecutor), task_store=mock_task_store + agent_executor=AsyncMock(spec=AgentExecutor), + task_store=mock_task_store, + agent_card=create_default_agent_card(), ) params = ListTasksRequest(page_size=0) context = create_server_call_context() @@ -973,7 +1030,9 @@ async def test_on_list_tasks_page_size_too_large(): """Test on_list_tasks raises error for page_size > 100.""" mock_task_store = AsyncMock(spec=TaskStore) request_handler = DefaultRequestHandlerV2( - agent_executor=AsyncMock(spec=AgentExecutor), task_store=mock_task_store + agent_executor=AsyncMock(spec=AgentExecutor), + task_store=mock_task_store, + agent_card=create_default_agent_card(), ) params = ListTasksRequest(page_size=101) context = create_server_call_context() @@ -988,7 +1047,9 @@ async def test_on_message_send_negative_history_length_error(): mock_task_store = AsyncMock(spec=TaskStore) mock_agent_executor = AsyncMock(spec=AgentExecutor) request_handler = DefaultRequestHandlerV2( - agent_executor=mock_agent_executor, task_store=mock_task_store + agent_executor=mock_agent_executor, + task_store=mock_task_store, + agent_card=create_default_agent_card(), ) message_config = SendMessageConfiguration( history_length=-1, accepted_output_modes=['text/plain'] @@ -1014,6 +1075,7 @@ async def test_on_message_send_limit_history(): agent_executor=HelloAgentExecutor(), task_store=task_store, push_config_store=push_store, + agent_card=create_default_agent_card(), ) params = SendMessageRequest( message=Message( @@ -1059,6 +1121,7 @@ async def test_on_message_send_task_id_mismatch(): agent_executor=mock_agent_executor, task_store=mock_task_store, request_context_builder=mock_request_context_builder, + agent_card=create_default_agent_card(), ) params = SendMessageRequest( message=Message( @@ -1107,6 +1170,7 @@ async def test_on_message_send_stream_task_id_mismatch(): agent_executor=mock_agent_executor, task_store=mock_task_store, request_context_builder=mock_request_context_builder, + agent_card=create_default_agent_card(), ) params = SendMessageRequest( message=Message( @@ -1155,6 +1219,7 @@ async def test_on_message_send_non_blocking(): agent_executor=HelloAgentExecutor(), task_store=task_store, push_config_store=push_store, + agent_card=create_default_agent_card(), ) params = SendMessageRequest( message=Message( @@ -1185,6 +1250,7 @@ async def test_on_message_send_with_push_notification(): agent_executor=HelloAgentExecutor(), task_store=task_store, push_config_store=push_store, + agent_card=create_default_agent_card(), ) push_config = TaskPushNotificationConfig(url='http://example.com/webhook') params = SendMessageRequest( diff --git a/tests/server/request_handlers/test_grpc_handler.py b/tests/server/request_handlers/test_grpc_handler.py index 11ceaf7bb..2b1a37385 100644 --- a/tests/server/request_handlers/test_grpc_handler.py +++ b/tests/server/request_handlers/test_grpc_handler.py @@ -53,9 +53,8 @@ def sample_agent_card() -> types.AgentCard: def grpc_handler( mock_request_handler: AsyncMock, sample_agent_card: types.AgentCard ) -> GrpcHandler: - return GrpcHandler( - agent_card=sample_agent_card, request_handler=mock_request_handler - ) + mock_request_handler._agent_card = sample_agent_card + return GrpcHandler(request_handler=mock_request_handler) # --- Test Cases --- @@ -182,13 +181,19 @@ async def test_get_extended_agent_card( grpc_handler: GrpcHandler, sample_agent_card: types.AgentCard, mock_grpc_context: AsyncMock, + mock_request_handler: AsyncMock, ) -> None: """Test GetExtendedAgentCard call.""" + + async def to_coro(*args, **kwargs): + return sample_agent_card + + mock_request_handler.on_get_extended_agent_card.side_effect = to_coro request_proto = a2a_pb2.GetExtendedAgentCardRequest() response = await grpc_handler.GetExtendedAgentCard( request_proto, mock_grpc_context ) - + mock_request_handler.on_get_extended_agent_card.assert_awaited_once() assert response.name == sample_agent_card.name assert response.version == sample_agent_card.version @@ -207,17 +212,20 @@ async def modifier(card: types.AgentCard) -> types.AgentCard: modified_card.name = 'Modified gRPC Agent' return modified_card - grpc_handler_modified = GrpcHandler( - agent_card=sample_agent_card, - request_handler=mock_request_handler, - card_modifier=modifier, - ) + # Use side_effect to ensure it returns an awaitable + async def side_effect_func(*_args, **_kwargs): + return await modifier(sample_agent_card) + mock_request_handler.on_get_extended_agent_card.side_effect = ( + side_effect_func + ) + mock_request_handler._agent_card = sample_agent_card + grpc_handler_modified = GrpcHandler(request_handler=mock_request_handler) request_proto = a2a_pb2.GetExtendedAgentCardRequest() response = await grpc_handler_modified.GetExtendedAgentCard( request_proto, mock_grpc_context ) - + mock_request_handler.on_get_extended_agent_card.assert_awaited_once() assert response.name == 'Modified gRPC Agent' assert response.version == sample_agent_card.version @@ -237,17 +245,17 @@ def modifier(card: types.AgentCard) -> types.AgentCard: modified_card.name = 'Modified gRPC Agent' return modified_card - grpc_handler_modified = GrpcHandler( - agent_card=sample_agent_card, - request_handler=mock_request_handler, - card_modifier=modifier, - ) + async def async_modifier(*args, **kwargs): + return modifier(sample_agent_card) + mock_request_handler.on_get_extended_agent_card.side_effect = async_modifier + mock_request_handler._agent_card = sample_agent_card + grpc_handler_modified = GrpcHandler(request_handler=mock_request_handler) request_proto = a2a_pb2.GetExtendedAgentCardRequest() response = await grpc_handler_modified.GetExtendedAgentCard( request_proto, mock_grpc_context ) - + mock_request_handler.on_get_extended_agent_card.assert_awaited_once() assert response.name == 'Modified gRPC Agent' assert response.version == sample_agent_card.version @@ -346,7 +354,7 @@ async def test_list_tasks_success( ), ], ) -async def test_abort_context_error_mapping( # noqa: PLR0913 +async def test_abort_context_error_mapping( grpc_handler: GrpcHandler, mock_request_handler: AsyncMock, mock_grpc_context: AsyncMock, diff --git a/tests/server/routes/test_jsonrpc_dispatcher.py b/tests/server/routes/test_jsonrpc_dispatcher.py index f884bb38e..15d3349cd 100644 --- a/tests/server/routes/test_jsonrpc_dispatcher.py +++ b/tests/server/routes/test_jsonrpc_dispatcher.py @@ -61,7 +61,7 @@ def test_app(mock_handler): mock_agent_card.capabilities.streaming = False jsonrpc_routes = create_jsonrpc_routes( - agent_card=mock_agent_card, request_handler=mock_handler, rpc_url='/' + request_handler=mock_handler, rpc_url='/' ) from starlette.applications import Starlette @@ -101,7 +101,8 @@ def mock_app_params(self) -> dict: mock_handler = MagicMock(spec=RequestHandler) mock_agent_card = MagicMock(spec=AgentCard) mock_agent_card.url = 'http://example.com' - return {'agent_card': mock_agent_card, 'request_handler': mock_handler} + mock_handler._agent_card = mock_agent_card + return {'request_handler': mock_handler} @pytest.fixture(scope='class') def mark_pkg_starlette_not_installed(self): @@ -228,13 +229,12 @@ def test_v0_3_compat_flag_routes_to_adapter(self, mock_handler): mock_agent_card.capabilities = MagicMock() mock_agent_card.capabilities.streaming = False + mock_handler._agent_card = mock_agent_card + from starlette.applications import Starlette jsonrpc_routes = create_jsonrpc_routes( - agent_card=mock_agent_card, - request_handler=mock_handler, - enable_v0_3_compat=True, - rpc_url='/', + request_handler=mock_handler, enable_v0_3_compat=True, rpc_url='/' ) app = Starlette(routes=jsonrpc_routes) client = TestClient(app) @@ -328,9 +328,7 @@ def agent_card(self): @pytest.fixture def client(self, handler, agent_card): jsonrpc_routes = create_jsonrpc_routes( - agent_card=agent_card, request_handler=handler, - extended_agent_card=agent_card, rpc_url='/', ) from starlette.applications import Starlette @@ -480,11 +478,9 @@ async def capture_modifier(card, context): captured['method'] = context.state.get('method') return card + handler.on_get_extended_agent_card.return_value = agent_card jsonrpc_routes = create_jsonrpc_routes( - agent_card=agent_card, request_handler=handler, - extended_agent_card=agent_card, - extended_card_modifier=capture_modifier, rpc_url='/', ) from starlette.applications import Starlette @@ -500,7 +496,7 @@ async def capture_modifier(card, context): data = response.json() assert 'result' in data assert data['result']['name'] == 'TestAgent' - assert captured['method'] == 'GetExtendedAgentCard' + handler.on_get_extended_agent_card.assert_called_once() # --- Streaming method routing tests --- @@ -526,7 +522,6 @@ async def stream_generator(): ) jsonrpc_routes = create_jsonrpc_routes( - agent_card=agent_card, request_handler=handler, rpc_url='/', ) @@ -588,7 +583,6 @@ async def stream_generator(): ) jsonrpc_routes = create_jsonrpc_routes( - agent_card=agent_card, request_handler=handler, rpc_url='/', ) diff --git a/tests/server/routes/test_jsonrpc_routes.py b/tests/server/routes/test_jsonrpc_routes.py index 3330d14c8..ff1b81f3f 100644 --- a/tests/server/routes/test_jsonrpc_routes.py +++ b/tests/server/routes/test_jsonrpc_routes.py @@ -23,9 +23,7 @@ def mock_handler(): def test_routes_creation(agent_card, mock_handler): """Tests that create_jsonrpc_routes creates Route objects list.""" routes = create_jsonrpc_routes( - agent_card=agent_card, - request_handler=mock_handler, - rpc_url='/a2a/jsonrpc', + request_handler=mock_handler, rpc_url='/a2a/jsonrpc' ) assert isinstance(routes, list) @@ -41,7 +39,7 @@ def test_jsonrpc_custom_url(agent_card, mock_handler): """Tests that custom rpc_url is respected for routing.""" custom_url = '/custom/api/jsonrpc' routes = create_jsonrpc_routes( - agent_card=agent_card, request_handler=mock_handler, rpc_url=custom_url + request_handler=mock_handler, rpc_url=custom_url ) app = Starlette(routes=routes) diff --git a/tests/server/routes/test_rest_dispatcher.py b/tests/server/routes/test_rest_dispatcher.py index be5870cc4..5284db617 100644 --- a/tests/server/routes/test_rest_dispatcher.py +++ b/tests/server/routes/test_rest_dispatcher.py @@ -31,12 +31,25 @@ @pytest.fixture -def mock_handler(): +def agent_card(): + card = MagicMock(spec=AgentCard) + card.capabilities = AgentCapabilities( + streaming=True, + push_notifications=True, + extended_agent_card=True, + ) + return card + + +@pytest.fixture +def mock_handler(agent_card): handler = AsyncMock(spec=RequestHandler) # Default success cases + handler._agent_card = agent_card handler.on_message_send.return_value = Message(message_id='test_msg') handler.on_cancel_task.return_value = Task(id='test_task') handler.on_get_task.return_value = Task(id='test_task') + handler.on_get_extended_agent_card.return_value = agent_card() handler.on_list_tasks.return_value = ListTasksResponse() handler.on_get_task_push_notification_config.return_value = ( TaskPushNotificationConfig(url='http://test') @@ -59,19 +72,8 @@ async def mock_stream(*args, **kwargs) -> AsyncIterator[Task]: @pytest.fixture -def agent_card(): - card = MagicMock(spec=AgentCard) - card.capabilities = AgentCapabilities( - streaming=True, - push_notifications=True, - extended_agent_card=True, - ) - return card - - -@pytest.fixture -def rest_dispatcher_instance(agent_card, mock_handler): - return RestDispatcher(agent_card=agent_card, request_handler=mock_handler) +def rest_dispatcher_instance(mock_handler): + return RestDispatcher(request_handler=mock_handler) from starlette.datastructures import Headers @@ -117,13 +119,13 @@ def mark_pkg_starlette_not_installed(self): ) def test_missing_starlette_raises_importerror( - self, mark_pkg_starlette_not_installed, agent_card, mock_handler + self, mark_pkg_starlette_not_installed, mock_handler ): with pytest.raises( ImportError, match='Packages `starlette` and `sse-starlette` are required', ): - RestDispatcher(agent_card=agent_card, request_handler=mock_handler) + RestDispatcher(request_handler=mock_handler) @pytest.mark.asyncio @@ -237,18 +239,6 @@ async def test_delete_push_notification( response = await rest_dispatcher_instance.delete_push_notification(req) assert response.status_code == 200 - async def test_set_push_notification_disabled_raises( - self, agent_card, mock_handler - ): - agent_card.capabilities.push_notifications = False - dispatcher = RestDispatcher( - agent_card=agent_card, request_handler=mock_handler - ) - req = make_mock_request(method='POST', path_params={'id': 'task1'}) - - response = await dispatcher.set_push_notification(req) - assert response.status_code == 400 # UnsupportedOperation maps to 400 - async def test_handle_authenticated_agent_card( self, rest_dispatcher_instance ): @@ -258,45 +248,9 @@ async def test_handle_authenticated_agent_card( ) assert response.status_code == 200 - async def test_handle_authenticated_agent_card_unsupported( - self, agent_card, mock_handler - ): - agent_card.capabilities.extended_agent_card = False - dispatcher = RestDispatcher( - agent_card=agent_card, request_handler=mock_handler - ) - req = make_mock_request() - - response = await dispatcher.handle_authenticated_agent_card(req) - assert response.status_code == 400 - @pytest.mark.asyncio class TestRestDispatcherStreaming: - async def test_on_message_send_stream_unsupported( - self, agent_card, mock_handler - ): - agent_card.capabilities.streaming = False - dispatcher = RestDispatcher( - agent_card=agent_card, request_handler=mock_handler - ) - req = make_mock_request(method='POST') - - response = await dispatcher.on_message_send_stream(req) - assert response.status_code == 400 - - async def test_on_subscribe_to_task_unsupported( - self, agent_card, mock_handler - ): - agent_card.capabilities.streaming = False - dispatcher = RestDispatcher( - agent_card=agent_card, request_handler=mock_handler - ) - req = make_mock_request(method='GET', path_params={'id': 't1'}) - - response = await dispatcher.on_subscribe_to_task(req) - assert response.status_code == 400 - async def test_on_message_send_stream_success( self, rest_dispatcher_instance ): @@ -327,3 +281,16 @@ async def test_on_subscribe_to_task_success(self, rest_dispatcher_instance): assert len(chunks) == 2 assert 'chunk1' in str(chunks[0]) assert 'chunk2' in str(chunks[1]) + + async def test_on_message_send_stream_handler_error(self, mock_handler): + from a2a.utils.errors import UnsupportedOperationError + + mock_handler.on_message_send_stream.side_effect = ( + UnsupportedOperationError('Mocked error') + ) + + dispatcher = RestDispatcher(request_handler=mock_handler) + req = make_mock_request(method='POST') + + response = await dispatcher.on_message_send_stream(req) + assert response.status_code == 400 diff --git a/tests/server/routes/test_rest_routes.py b/tests/server/routes/test_rest_routes.py index 98bf4130d..2b3477c6b 100644 --- a/tests/server/routes/test_rest_routes.py +++ b/tests/server/routes/test_rest_routes.py @@ -22,26 +22,21 @@ def mock_handler(): def test_routes_creation(agent_card, mock_handler): """Tests that create_rest_routes creates Route objects list.""" - routes = create_rest_routes( - agent_card=agent_card, request_handler=mock_handler - ) + routes = create_rest_routes(request_handler=mock_handler) assert isinstance(routes, list) assert len(routes) > 0 - assert all(isinstance(r, BaseRoute) for r in routes) + assert all((isinstance(r, BaseRoute) for r in routes)) def test_routes_creation_v03_compat(agent_card, mock_handler): """Tests that create_rest_routes creates more routes with enable_v0_3_compat.""" + mock_handler._agent_card = agent_card routes_without_compat = create_rest_routes( - agent_card=agent_card, - request_handler=mock_handler, - enable_v0_3_compat=False, + request_handler=mock_handler, enable_v0_3_compat=False ) routes_with_compat = create_rest_routes( - agent_card=agent_card, - request_handler=mock_handler, - enable_v0_3_compat=True, + request_handler=mock_handler, enable_v0_3_compat=True ) assert len(routes_with_compat) > len(routes_without_compat) @@ -51,9 +46,7 @@ def test_rest_endpoints_routing(agent_card, mock_handler): """Tests that mounted routes route to the handler endpoints.""" mock_handler.on_message_send.return_value = Task(id='123') - routes = create_rest_routes( - agent_card=agent_card, request_handler=mock_handler - ) + routes = create_rest_routes(request_handler=mock_handler) app = Starlette(routes=routes) client = TestClient(app) @@ -70,9 +63,7 @@ def test_rest_endpoints_routing_tenant(agent_card, mock_handler): """Tests that mounted routes with {tenant} route to the handler endpoints.""" mock_handler.on_message_send.return_value = Task(id='123') - routes = create_rest_routes( - agent_card=agent_card, request_handler=mock_handler - ) + routes = create_rest_routes(request_handler=mock_handler) app = Starlette(routes=routes) client = TestClient(app) @@ -94,9 +85,7 @@ def test_rest_list_tasks(agent_card, mock_handler): """Tests that list tasks endpoint is routed to the handler.""" mock_handler.on_list_tasks.return_value = ListTasksResponse() - routes = create_rest_routes( - agent_card=agent_card, request_handler=mock_handler - ) + routes = create_rest_routes(request_handler=mock_handler) app = Starlette(routes=routes) client = TestClient(app) diff --git a/tests/server/test_integration.py b/tests/server/test_integration.py index f879e8078..ddab2661a 100644 --- a/tests/server/test_integration.py +++ b/tests/server/test_integration.py @@ -165,9 +165,7 @@ def build( app_instance.routes.extend(card_routes) # JSON-RPC router - rpc_routes = create_jsonrpc_routes( - self.agent_card, self.handler, rpc_url=rpc_url - ) + rpc_routes = create_jsonrpc_routes(self.handler, rpc_url=rpc_url) app_instance.routes.extend(rpc_routes) return app_instance diff --git a/tests/utils/test_helpers.py b/tests/utils/test_helpers.py index c157bb986..427e33aff 100644 --- a/tests/utils/test_helpers.py +++ b/tests/utils/test_helpers.py @@ -29,7 +29,6 @@ build_text_artifact, canonicalize_agent_card, create_task_obj, - validate, ) @@ -249,27 +248,6 @@ def test_build_text_artifact(): assert artifact.parts[0].text == text -# Test validate decorator -def test_validate_decorator(): - class TestClass: - condition = True - - @validate(lambda self: self.condition, 'Condition not met') - def test_method(self) -> str: - return 'Success' - - obj = TestClass() - - # Test passing condition - assert obj.test_method() == 'Success' - - # Test failing condition - obj.condition = False - with pytest.raises(UnsupportedOperationError) as exc_info: - obj.test_method() - assert 'Condition not met' in str(exc_info.value) - - # Tests for are_modalities_compatible def test_are_modalities_compatible_client_none(): assert ( From cc094aa51caba8107b63982e9b79256f7c2d331a Mon Sep 17 00:00:00 2001 From: Guglielmo Colombo Date: Wed, 8 Apr 2026 11:05:43 +0200 Subject: [PATCH 08/67] feat: merge metadata of new and old artifact when append=True (#945) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Description When a new TaskArtifactUpdateEvent is emitted with append=True, if an artifact with the same id exists on the Task saved on the TaskStore, the metadata from the new artifact are merged with the ones of the existing one. Fixes #735 🦕 --- src/a2a/utils/helpers.py | 3 +++ tests/utils/test_helpers.py | 8 +++++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/src/a2a/utils/helpers.py b/src/a2a/utils/helpers.py index ba55da86e..fe69bf26d 100644 --- a/src/a2a/utils/helpers.py +++ b/src/a2a/utils/helpers.py @@ -110,6 +110,9 @@ def append_artifact_to_task(task: Task, event: TaskArtifactUpdateEvent) -> None: task.id, ) existing_artifact.parts.extend(new_artifact_data.parts) + existing_artifact.metadata.update( + dict(new_artifact_data.metadata.items()) + ) else: # We received a chunk to append, but we don't have an existing artifact. # we will ignore this chunk diff --git a/tests/utils/test_helpers.py b/tests/utils/test_helpers.py index 427e33aff..d8a85fcd9 100644 --- a/tests/utils/test_helpers.py +++ b/tests/utils/test_helpers.py @@ -177,6 +177,7 @@ def test_append_artifact_to_task(): artifact_id='artifact-123', name='updated name', parts=[Part(text='Updated')], + metadata={'existing_key': 'existing_value'}, ) append_event_2 = TaskArtifactUpdateEvent( artifact=artifact_2, append=False, task_id='123', context_id='123' @@ -187,10 +188,13 @@ def test_append_artifact_to_task(): assert task.artifacts[0].name == 'updated name' assert len(task.artifacts[0].parts) == 1 assert task.artifacts[0].parts[0].text == 'Updated' + assert task.artifacts[0].metadata['existing_key'] == 'existing_value' # Test appending parts to an existing artifact artifact_with_parts = Artifact( - artifact_id='artifact-123', parts=[Part(text='Part 2')] + artifact_id='artifact-123', + parts=[Part(text='Part 2')], + metadata={'new_key': 'new_value'}, ) append_event_3 = TaskArtifactUpdateEvent( artifact=artifact_with_parts, @@ -202,6 +206,8 @@ def test_append_artifact_to_task(): assert len(task.artifacts[0].parts) == 2 assert task.artifacts[0].parts[0].text == 'Updated' assert task.artifacts[0].parts[1].text == 'Part 2' + assert task.artifacts[0].metadata['existing_key'] == 'existing_value' + assert task.artifacts[0].metadata['new_key'] == 'new_value' # Test adding another new artifact another_artifact_with_parts = Artifact( From 617fdf3f06f88ddfd187fbc628f3c50458c1b75a Mon Sep 17 00:00:00 2001 From: Guglielmo Colombo Date: Wed, 8 Apr 2026 12:29:43 +0200 Subject: [PATCH 09/67] refactor: adapt wrong imports in tck and sample (#948) # Description This PR fixes the wrong imports after the introduction of new default_request_handler_V2 --- samples/hello_world_agent.py | 5 +---- tck/sut_agent.py | 4 +--- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/samples/hello_world_agent.py b/samples/hello_world_agent.py index 909e6550d..8db34dc03 100644 --- a/samples/hello_world_agent.py +++ b/samples/hello_world_agent.py @@ -12,10 +12,7 @@ from a2a.server.agent_execution.agent_executor import AgentExecutor from a2a.server.agent_execution.context import RequestContext from a2a.server.events.event_queue import EventQueue -from a2a.server.request_handlers import GrpcHandler -from a2a.server.request_handlers.default_request_handler import ( - DefaultRequestHandler, -) +from a2a.server.request_handlers import DefaultRequestHandler, GrpcHandler from a2a.server.routes import ( create_agent_card_routes, create_jsonrpc_routes, diff --git a/tck/sut_agent.py b/tck/sut_agent.py index 96eca850f..0ca3a1450 100644 --- a/tck/sut_agent.py +++ b/tck/sut_agent.py @@ -17,9 +17,7 @@ from a2a.server.agent_execution.agent_executor import AgentExecutor from a2a.server.agent_execution.context import RequestContext from a2a.server.events.event_queue import EventQueue -from a2a.server.request_handlers.default_request_handler import ( - DefaultRequestHandler, -) +from a2a.server.request_handlers import DefaultRequestHandler from a2a.server.request_handlers.grpc_handler import GrpcHandler from a2a.server.routes import ( create_agent_card_routes, From 94537c382be4160332279a44d83254feeb0b8037 Mon Sep 17 00:00:00 2001 From: Ivan Shymko Date: Wed, 8 Apr 2026 12:39:56 +0200 Subject: [PATCH 10/67] fix(client): do not mutate SendMessageRequest in BaseClient.send_message (#949) Updating passed parameter by reference is not great. --- src/a2a/client/base_client.py | 27 +++++++++++++++--------- tests/client/test_base_client.py | 35 ++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 10 deletions(-) diff --git a/src/a2a/client/base_client.py b/src/a2a/client/base_client.py index 53fd38cdb..342e01f06 100644 --- a/src/a2a/client/base_client.py +++ b/src/a2a/client/base_client.py @@ -66,7 +66,7 @@ async def send_message( Yields: An async iterator of `StreamResponse` """ - self._apply_client_config(request) + request = self._apply_client_config(request) if not self._config.streaming or not self._card.capabilities.streaming: response = await self._execute_with_interceptors( input_data=request, @@ -100,22 +100,29 @@ async def send_message( ): yield event - def _apply_client_config(self, request: SendMessageRequest) -> None: - request.configuration.return_immediately |= self._config.polling - if ( - not request.configuration.HasField('task_push_notification_config') - and self._config.push_notification_configs + def _apply_client_config( + self, request: SendMessageRequest + ) -> SendMessageRequest: + modified_request = SendMessageRequest() + modified_request.CopyFrom(request) + if self._config.polling: + modified_request.configuration.return_immediately = True + if self._config.push_notification_configs and ( + not modified_request.configuration.HasField( + 'task_push_notification_config' + ) ): - request.configuration.task_push_notification_config.CopyFrom( + modified_request.configuration.task_push_notification_config.CopyFrom( self._config.push_notification_configs[0] ) if ( - not request.configuration.accepted_output_modes - and self._config.accepted_output_modes + self._config.accepted_output_modes + and not modified_request.configuration.accepted_output_modes ): - request.configuration.accepted_output_modes.extend( + modified_request.configuration.accepted_output_modes.extend( self._config.accepted_output_modes ) + return modified_request async def _process_stream( self, diff --git a/tests/client/test_base_client.py b/tests/client/test_base_client.py index ed49469a7..d37e3deb4 100644 --- a/tests/client/test_base_client.py +++ b/tests/client/test_base_client.py @@ -208,6 +208,41 @@ async def test_send_message_non_streaming_agent_capability_false( response = events[0] assert response.task.id == 'task-789' + @pytest.mark.asyncio + async def test_send_message_does_not_mutate_request( + self, + base_client: BaseClient, + mock_transport: MagicMock, + sample_message: Message, + ): + base_client._config.streaming = False + base_client._config.polling = True + base_client._config.accepted_output_modes = ['application/json'] + base_client._config.push_notification_configs = [ + TaskPushNotificationConfig( + task_id='task-1', + ) + ] + + task = Task( + id='task-no-mutate', + context_id='ctx-no-mutate', + status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), + ) + response = SendMessageResponse() + response.task.CopyFrom(task) + mock_transport.send_message.return_value = response + + request = SendMessageRequest(message=sample_message) + + original = SendMessageRequest() + original.CopyFrom(request) + + events = [event async for event in base_client.send_message(request)] + assert len(events) == 1 + + assert request == original + @pytest.mark.asyncio async def test_send_message_callsite_config_overrides_non_streaming( self, From 546fb868cf18696bef818a2e355d3544745f1ddb Mon Sep 17 00:00:00 2001 From: Ivan Shymko Date: Wed, 8 Apr 2026 16:27:09 +0200 Subject: [PATCH 11/67] chore: revert #949 (#950) This reverts commit 94537c382be4160332279a44d83254feeb0b8037 (#949). Seems like it breaks ITK tests. --- src/a2a/client/base_client.py | 27 +++++++++--------------- tests/client/test_base_client.py | 35 -------------------------------- 2 files changed, 10 insertions(+), 52 deletions(-) diff --git a/src/a2a/client/base_client.py b/src/a2a/client/base_client.py index 342e01f06..53fd38cdb 100644 --- a/src/a2a/client/base_client.py +++ b/src/a2a/client/base_client.py @@ -66,7 +66,7 @@ async def send_message( Yields: An async iterator of `StreamResponse` """ - request = self._apply_client_config(request) + self._apply_client_config(request) if not self._config.streaming or not self._card.capabilities.streaming: response = await self._execute_with_interceptors( input_data=request, @@ -100,29 +100,22 @@ async def send_message( ): yield event - def _apply_client_config( - self, request: SendMessageRequest - ) -> SendMessageRequest: - modified_request = SendMessageRequest() - modified_request.CopyFrom(request) - if self._config.polling: - modified_request.configuration.return_immediately = True - if self._config.push_notification_configs and ( - not modified_request.configuration.HasField( - 'task_push_notification_config' - ) + def _apply_client_config(self, request: SendMessageRequest) -> None: + request.configuration.return_immediately |= self._config.polling + if ( + not request.configuration.HasField('task_push_notification_config') + and self._config.push_notification_configs ): - modified_request.configuration.task_push_notification_config.CopyFrom( + request.configuration.task_push_notification_config.CopyFrom( self._config.push_notification_configs[0] ) if ( - self._config.accepted_output_modes - and not modified_request.configuration.accepted_output_modes + not request.configuration.accepted_output_modes + and self._config.accepted_output_modes ): - modified_request.configuration.accepted_output_modes.extend( + request.configuration.accepted_output_modes.extend( self._config.accepted_output_modes ) - return modified_request async def _process_stream( self, diff --git a/tests/client/test_base_client.py b/tests/client/test_base_client.py index d37e3deb4..ed49469a7 100644 --- a/tests/client/test_base_client.py +++ b/tests/client/test_base_client.py @@ -208,41 +208,6 @@ async def test_send_message_non_streaming_agent_capability_false( response = events[0] assert response.task.id == 'task-789' - @pytest.mark.asyncio - async def test_send_message_does_not_mutate_request( - self, - base_client: BaseClient, - mock_transport: MagicMock, - sample_message: Message, - ): - base_client._config.streaming = False - base_client._config.polling = True - base_client._config.accepted_output_modes = ['application/json'] - base_client._config.push_notification_configs = [ - TaskPushNotificationConfig( - task_id='task-1', - ) - ] - - task = Task( - id='task-no-mutate', - context_id='ctx-no-mutate', - status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), - ) - response = SendMessageResponse() - response.task.CopyFrom(task) - mock_transport.send_message.return_value = response - - request = SendMessageRequest(message=sample_message) - - original = SendMessageRequest() - original.CopyFrom(request) - - events = [event async for event in base_client.send_message(request)] - assert len(events) == 1 - - assert request == original - @pytest.mark.asyncio async def test_send_message_callsite_config_overrides_non_streaming( self, From 538715a5dec3ca317a2cff39f3e9dcd2ce34c92e Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 9 Apr 2026 09:45:37 +0200 Subject: [PATCH 12/67] chore(deps): bump cryptography from 46.0.6 to 46.0.7 (#953) Bumps [cryptography](https://github.com/pyca/cryptography) from 46.0.6 to 46.0.7.
Changelog

Sourced from cryptography's changelog.

46.0.7 - 2026-04-07


* **SECURITY ISSUE**: Fixed an issue where non-contiguous buffers could
be
  passed to APIs that accept Python buffers, which could lead to buffer
  overflow. **CVE-2026-39892**
* Updated Windows, macOS, and Linux wheels to be compiled with OpenSSL
3.5.6.

.. _v46-0-6:

Commits

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=cryptography&package-manager=uv&previous-version=46.0.6&new-version=46.0.7)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself) You can disable automated security fix PRs for this repo from the [Security Alerts page](https://github.com/a2aproject/a2a-python/network/alerts).
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- uv.lock | 102 ++++++++++++++++++++++++++++---------------------------- 1 file changed, 51 insertions(+), 51 deletions(-) diff --git a/uv.lock b/uv.lock index 85d655891..778082c04 100644 --- a/uv.lock +++ b/uv.lock @@ -621,62 +621,62 @@ toml = [ [[package]] name = "cryptography" -version = "46.0.6" +version = "46.0.7" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cffi", marker = "platform_python_implementation != 'PyPy'" }, { name = "typing-extensions", marker = "python_full_version < '3.11'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/a4/ba/04b1bd4218cbc58dc90ce967106d51582371b898690f3ae0402876cc4f34/cryptography-46.0.6.tar.gz", hash = "sha256:27550628a518c5c6c903d84f637fbecf287f6cb9ced3804838a1295dc1fd0759", size = 750542, upload-time = "2026-03-25T23:34:53.396Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/47/23/9285e15e3bc57325b0a72e592921983a701efc1ee8f91c06c5f0235d86d9/cryptography-46.0.6-cp311-abi3-macosx_10_9_universal2.whl", hash = "sha256:64235194bad039a10bb6d2d930ab3323baaec67e2ce36215fd0952fad0930ca8", size = 7176401, upload-time = "2026-03-25T23:33:22.096Z" }, - { url = "https://files.pythonhosted.org/packages/60/f8/e61f8f13950ab6195b31913b42d39f0f9afc7d93f76710f299b5ec286ae6/cryptography-46.0.6-cp311-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:26031f1e5ca62fcb9d1fcb34b2b60b390d1aacaa15dc8b895a9ed00968b97b30", size = 4275275, upload-time = "2026-03-25T23:33:23.844Z" }, - { url = "https://files.pythonhosted.org/packages/19/69/732a736d12c2631e140be2348b4ad3d226302df63ef64d30dfdb8db7ad1c/cryptography-46.0.6-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:9a693028b9cbe51b5a1136232ee8f2bc242e4e19d456ded3fa7c86e43c713b4a", size = 4425320, upload-time = "2026-03-25T23:33:25.703Z" }, - { url = "https://files.pythonhosted.org/packages/d4/12/123be7292674abf76b21ac1fc0e1af50661f0e5b8f0ec8285faac18eb99e/cryptography-46.0.6-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:67177e8a9f421aa2d3a170c3e56eca4e0128883cf52a071a7cbf53297f18b175", size = 4278082, upload-time = "2026-03-25T23:33:27.423Z" }, - { url = "https://files.pythonhosted.org/packages/5b/ba/d5e27f8d68c24951b0a484924a84c7cdaed7502bac9f18601cd357f8b1d2/cryptography-46.0.6-cp311-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:d9528b535a6c4f8ff37847144b8986a9a143585f0540fbcb1a98115b543aa463", size = 4926514, upload-time = "2026-03-25T23:33:29.206Z" }, - { url = "https://files.pythonhosted.org/packages/34/71/1ea5a7352ae516d5512d17babe7e1b87d9db5150b21f794b1377eac1edc0/cryptography-46.0.6-cp311-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:22259338084d6ae497a19bae5d4c66b7ca1387d3264d1c2c0e72d9e9b6a77b97", size = 4457766, upload-time = "2026-03-25T23:33:30.834Z" }, - { url = "https://files.pythonhosted.org/packages/01/59/562be1e653accee4fdad92c7a2e88fced26b3fdfce144047519bbebc299e/cryptography-46.0.6-cp311-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:760997a4b950ff00d418398ad73fbc91aa2894b5c1db7ccb45b4f68b42a63b3c", size = 3986535, upload-time = "2026-03-25T23:33:33.02Z" }, - { url = "https://files.pythonhosted.org/packages/d6/8b/b1ebfeb788bf4624d36e45ed2662b8bd43a05ff62157093c1539c1288a18/cryptography-46.0.6-cp311-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:3dfa6567f2e9e4c5dceb8ccb5a708158a2a871052fa75c8b78cb0977063f1507", size = 4277618, upload-time = "2026-03-25T23:33:34.567Z" }, - { url = "https://files.pythonhosted.org/packages/dd/52/a005f8eabdb28df57c20f84c44d397a755782d6ff6d455f05baa2785bd91/cryptography-46.0.6-cp311-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:cdcd3edcbc5d55757e5f5f3d330dd00007ae463a7e7aa5bf132d1f22a4b62b19", size = 4890802, upload-time = "2026-03-25T23:33:37.034Z" }, - { url = "https://files.pythonhosted.org/packages/ec/4d/8e7d7245c79c617d08724e2efa397737715ca0ec830ecb3c91e547302555/cryptography-46.0.6-cp311-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:d4e4aadb7fc1f88687f47ca20bb7227981b03afaae69287029da08096853b738", size = 4457425, upload-time = "2026-03-25T23:33:38.904Z" }, - { url = "https://files.pythonhosted.org/packages/1d/5c/f6c3596a1430cec6f949085f0e1a970638d76f81c3ea56d93d564d04c340/cryptography-46.0.6-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:2b417edbe8877cda9022dde3a008e2deb50be9c407eef034aeeb3a8b11d9db3c", size = 4405530, upload-time = "2026-03-25T23:33:40.842Z" }, - { url = "https://files.pythonhosted.org/packages/7e/c9/9f9cea13ee2dbde070424e0c4f621c091a91ffcc504ffea5e74f0e1daeff/cryptography-46.0.6-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:380343e0653b1c9d7e1f55b52aaa2dbb2fdf2730088d48c43ca1c7c0abb7cc2f", size = 4667896, upload-time = "2026-03-25T23:33:42.781Z" }, - { url = "https://files.pythonhosted.org/packages/ad/b5/1895bc0821226f129bc74d00eccfc6a5969e2028f8617c09790bf89c185e/cryptography-46.0.6-cp311-abi3-win32.whl", hash = "sha256:bcb87663e1f7b075e48c3be3ecb5f0b46c8fc50b50a97cf264e7f60242dca3f2", size = 3026348, upload-time = "2026-03-25T23:33:45.021Z" }, - { url = "https://files.pythonhosted.org/packages/c3/f8/c9bcbf0d3e6ad288b9d9aa0b1dee04b063d19e8c4f871855a03ab3a297ab/cryptography-46.0.6-cp311-abi3-win_amd64.whl", hash = "sha256:6739d56300662c468fddb0e5e291f9b4d084bead381667b9e654c7dd81705124", size = 3483896, upload-time = "2026-03-25T23:33:46.649Z" }, - { url = "https://files.pythonhosted.org/packages/01/41/3a578f7fd5c70611c0aacba52cd13cb364a5dee895a5c1d467208a9380b0/cryptography-46.0.6-cp314-cp314t-macosx_10_9_universal2.whl", hash = "sha256:2ef9e69886cbb137c2aef9772c2e7138dc581fad4fcbcf13cc181eb5a3ab6275", size = 7117147, upload-time = "2026-03-25T23:33:48.249Z" }, - { url = "https://files.pythonhosted.org/packages/fa/87/887f35a6fca9dde90cad08e0de0c89263a8e59b2d2ff904fd9fcd8025b6f/cryptography-46.0.6-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:7f417f034f91dcec1cb6c5c35b07cdbb2ef262557f701b4ecd803ee8cefed4f4", size = 4266221, upload-time = "2026-03-25T23:33:49.874Z" }, - { url = "https://files.pythonhosted.org/packages/aa/a8/0a90c4f0b0871e0e3d1ed126aed101328a8a57fd9fd17f00fb67e82a51ca/cryptography-46.0.6-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:d24c13369e856b94892a89ddf70b332e0b70ad4a5c43cf3e9cb71d6d7ffa1f7b", size = 4408952, upload-time = "2026-03-25T23:33:52.128Z" }, - { url = "https://files.pythonhosted.org/packages/16/0b/b239701eb946523e4e9f329336e4ff32b1247e109cbab32d1a7b61da8ed7/cryptography-46.0.6-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:aad75154a7ac9039936d50cf431719a2f8d4ed3d3c277ac03f3339ded1a5e707", size = 4270141, upload-time = "2026-03-25T23:33:54.11Z" }, - { url = "https://files.pythonhosted.org/packages/0f/a8/976acdd4f0f30df7b25605f4b9d3d89295351665c2091d18224f7ad5cdbf/cryptography-46.0.6-cp314-cp314t-manylinux_2_28_ppc64le.whl", hash = "sha256:3c21d92ed15e9cfc6eb64c1f5a0326db22ca9c2566ca46d845119b45b4400361", size = 4904178, upload-time = "2026-03-25T23:33:55.725Z" }, - { url = "https://files.pythonhosted.org/packages/b1/1b/bf0e01a88efd0e59679b69f42d4afd5bced8700bb5e80617b2d63a3741af/cryptography-46.0.6-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:4668298aef7cddeaf5c6ecc244c2302a2b8e40f384255505c22875eebb47888b", size = 4441812, upload-time = "2026-03-25T23:33:57.364Z" }, - { url = "https://files.pythonhosted.org/packages/bb/8b/11df86de2ea389c65aa1806f331cae145f2ed18011f30234cc10ca253de8/cryptography-46.0.6-cp314-cp314t-manylinux_2_31_armv7l.whl", hash = "sha256:8ce35b77aaf02f3b59c90b2c8a05c73bac12cea5b4e8f3fbece1f5fddea5f0ca", size = 3963923, upload-time = "2026-03-25T23:33:59.361Z" }, - { url = "https://files.pythonhosted.org/packages/91/e0/207fb177c3a9ef6a8108f234208c3e9e76a6aa8cf20d51932916bd43bda0/cryptography-46.0.6-cp314-cp314t-manylinux_2_34_aarch64.whl", hash = "sha256:c89eb37fae9216985d8734c1afd172ba4927f5a05cfd9bf0e4863c6d5465b013", size = 4269695, upload-time = "2026-03-25T23:34:00.909Z" }, - { url = "https://files.pythonhosted.org/packages/21/5e/19f3260ed1e95bced52ace7501fabcd266df67077eeb382b79c81729d2d3/cryptography-46.0.6-cp314-cp314t-manylinux_2_34_ppc64le.whl", hash = "sha256:ed418c37d095aeddf5336898a132fba01091f0ac5844e3e8018506f014b6d2c4", size = 4869785, upload-time = "2026-03-25T23:34:02.796Z" }, - { url = "https://files.pythonhosted.org/packages/10/38/cd7864d79aa1d92ef6f1a584281433419b955ad5a5ba8d1eb6c872165bcb/cryptography-46.0.6-cp314-cp314t-manylinux_2_34_x86_64.whl", hash = "sha256:69cf0056d6947edc6e6760e5f17afe4bea06b56a9ac8a06de9d2bd6b532d4f3a", size = 4441404, upload-time = "2026-03-25T23:34:04.35Z" }, - { url = "https://files.pythonhosted.org/packages/09/0a/4fe7a8d25fed74419f91835cf5829ade6408fd1963c9eae9c4bce390ecbb/cryptography-46.0.6-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:8e7304c4f4e9490e11efe56af6713983460ee0780f16c63f219984dab3af9d2d", size = 4397549, upload-time = "2026-03-25T23:34:06.342Z" }, - { url = "https://files.pythonhosted.org/packages/5f/a0/7d738944eac6513cd60a8da98b65951f4a3b279b93479a7e8926d9cd730b/cryptography-46.0.6-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:b928a3ca837c77a10e81a814a693f2295200adb3352395fad024559b7be7a736", size = 4651874, upload-time = "2026-03-25T23:34:07.916Z" }, - { url = "https://files.pythonhosted.org/packages/cb/f1/c2326781ca05208845efca38bf714f76939ae446cd492d7613808badedf1/cryptography-46.0.6-cp314-cp314t-win32.whl", hash = "sha256:97c8115b27e19e592a05c45d0dd89c57f81f841cc9880e353e0d3bf25b2139ed", size = 3001511, upload-time = "2026-03-25T23:34:09.892Z" }, - { url = "https://files.pythonhosted.org/packages/c9/57/fe4a23eb549ac9d903bd4698ffda13383808ef0876cc912bcb2838799ece/cryptography-46.0.6-cp314-cp314t-win_amd64.whl", hash = "sha256:c797e2517cb7880f8297e2c0f43bb910e91381339336f75d2c1c2cbf811b70b4", size = 3471692, upload-time = "2026-03-25T23:34:11.613Z" }, - { url = "https://files.pythonhosted.org/packages/c4/cc/f330e982852403da79008552de9906804568ae9230da8432f7496ce02b71/cryptography-46.0.6-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:12cae594e9473bca1a7aceb90536060643128bb274fcea0fc459ab90f7d1ae7a", size = 7162776, upload-time = "2026-03-25T23:34:13.308Z" }, - { url = "https://files.pythonhosted.org/packages/49/b3/dc27efd8dcc4bff583b3f01d4a3943cd8b5821777a58b3a6a5f054d61b79/cryptography-46.0.6-cp38-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:639301950939d844a9e1c4464d7e07f902fe9a7f6b215bb0d4f28584729935d8", size = 4270529, upload-time = "2026-03-25T23:34:15.019Z" }, - { url = "https://files.pythonhosted.org/packages/e6/05/e8d0e6eb4f0d83365b3cb0e00eb3c484f7348db0266652ccd84632a3d58d/cryptography-46.0.6-cp38-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ed3775295fb91f70b4027aeba878d79b3e55c0b3e97eaa4de71f8f23a9f2eb77", size = 4414827, upload-time = "2026-03-25T23:34:16.604Z" }, - { url = "https://files.pythonhosted.org/packages/2f/97/daba0f5d2dc6d855e2dcb70733c812558a7977a55dd4a6722756628c44d1/cryptography-46.0.6-cp38-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:8927ccfbe967c7df312ade694f987e7e9e22b2425976ddbf28271d7e58845290", size = 4271265, upload-time = "2026-03-25T23:34:18.586Z" }, - { url = "https://files.pythonhosted.org/packages/89/06/fe1fce39a37ac452e58d04b43b0855261dac320a2ebf8f5260dd55b201a9/cryptography-46.0.6-cp38-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:b12c6b1e1651e42ab5de8b1e00dc3b6354fdfd778e7fa60541ddacc27cd21410", size = 4916800, upload-time = "2026-03-25T23:34:20.561Z" }, - { url = "https://files.pythonhosted.org/packages/ff/8a/b14f3101fe9c3592603339eb5d94046c3ce5f7fc76d6512a2d40efd9724e/cryptography-46.0.6-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:063b67749f338ca9c5a0b7fe438a52c25f9526b851e24e6c9310e7195aad3b4d", size = 4448771, upload-time = "2026-03-25T23:34:22.406Z" }, - { url = "https://files.pythonhosted.org/packages/01/b3/0796998056a66d1973fd52ee89dc1bb3b6581960a91ad4ac705f182d398f/cryptography-46.0.6-cp38-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:02fad249cb0e090b574e30b276a3da6a149e04ee2f049725b1f69e7b8351ec70", size = 3978333, upload-time = "2026-03-25T23:34:24.281Z" }, - { url = "https://files.pythonhosted.org/packages/c5/3d/db200af5a4ffd08918cd55c08399dc6c9c50b0bc72c00a3246e099d3a849/cryptography-46.0.6-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:7e6142674f2a9291463e5e150090b95a8519b2fb6e6aaec8917dd8d094ce750d", size = 4271069, upload-time = "2026-03-25T23:34:25.895Z" }, - { url = "https://files.pythonhosted.org/packages/d7/18/61acfd5b414309d74ee838be321c636fe71815436f53c9f0334bf19064fa/cryptography-46.0.6-cp38-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:456b3215172aeefb9284550b162801d62f5f264a081049a3e94307fe20792cfa", size = 4878358, upload-time = "2026-03-25T23:34:27.67Z" }, - { url = "https://files.pythonhosted.org/packages/8b/65/5bf43286d566f8171917cae23ac6add941654ccf085d739195a4eacf1674/cryptography-46.0.6-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:341359d6c9e68834e204ceaf25936dffeafea3829ab80e9503860dcc4f4dac58", size = 4448061, upload-time = "2026-03-25T23:34:29.375Z" }, - { url = "https://files.pythonhosted.org/packages/e0/25/7e49c0fa7205cf3597e525d156a6bce5b5c9de1fd7e8cb01120e459f205a/cryptography-46.0.6-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:9a9c42a2723999a710445bc0d974e345c32adfd8d2fac6d8a251fa829ad31cfb", size = 4399103, upload-time = "2026-03-25T23:34:32.036Z" }, - { url = "https://files.pythonhosted.org/packages/44/46/466269e833f1c4718d6cd496ffe20c56c9c8d013486ff66b4f69c302a68d/cryptography-46.0.6-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:6617f67b1606dfd9fe4dbfa354a9508d4a6d37afe30306fe6c101b7ce3274b72", size = 4659255, upload-time = "2026-03-25T23:34:33.679Z" }, - { url = "https://files.pythonhosted.org/packages/0a/09/ddc5f630cc32287d2c953fc5d32705e63ec73e37308e5120955316f53827/cryptography-46.0.6-cp38-abi3-win32.whl", hash = "sha256:7f6690b6c55e9c5332c0b59b9c8a3fb232ebf059094c17f9019a51e9827df91c", size = 3010660, upload-time = "2026-03-25T23:34:35.418Z" }, - { url = "https://files.pythonhosted.org/packages/1b/82/ca4893968aeb2709aacfb57a30dec6fa2ab25b10fa9f064b8882ce33f599/cryptography-46.0.6-cp38-abi3-win_amd64.whl", hash = "sha256:79e865c642cfc5c0b3eb12af83c35c5aeff4fa5c672dc28c43721c2c9fdd2f0f", size = 3471160, upload-time = "2026-03-25T23:34:37.191Z" }, - { url = "https://files.pythonhosted.org/packages/2e/84/7ccff00ced5bac74b775ce0beb7d1be4e8637536b522b5df9b73ada42da2/cryptography-46.0.6-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:2ea0f37e9a9cf0df2952893ad145fd9627d326a59daec9b0802480fa3bcd2ead", size = 3475444, upload-time = "2026-03-25T23:34:38.944Z" }, - { url = "https://files.pythonhosted.org/packages/bc/1f/4c926f50df7749f000f20eede0c896769509895e2648db5da0ed55db711d/cryptography-46.0.6-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:a3e84d5ec9ba01f8fd03802b2147ba77f0c8f2617b2aff254cedd551844209c8", size = 4218227, upload-time = "2026-03-25T23:34:40.871Z" }, - { url = "https://files.pythonhosted.org/packages/c6/65/707be3ffbd5f786028665c3223e86e11c4cda86023adbc56bd72b1b6bab5/cryptography-46.0.6-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:12f0fa16cc247b13c43d56d7b35287ff1569b5b1f4c5e87e92cc4fcc00cd10c0", size = 4381399, upload-time = "2026-03-25T23:34:42.609Z" }, - { url = "https://files.pythonhosted.org/packages/f3/6d/73557ed0ef7d73d04d9aba745d2c8e95218213687ee5e76b7d236a5030fc/cryptography-46.0.6-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:50575a76e2951fe7dbd1f56d181f8c5ceeeb075e9ff88e7ad997d2f42af06e7b", size = 4217595, upload-time = "2026-03-25T23:34:44.205Z" }, - { url = "https://files.pythonhosted.org/packages/9e/c5/e1594c4eec66a567c3ac4400008108a415808be2ce13dcb9a9045c92f1a0/cryptography-46.0.6-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:90e5f0a7b3be5f40c3a0a0eafb32c681d8d2c181fc2a1bdabe9b3f611d9f6b1a", size = 4380912, upload-time = "2026-03-25T23:34:46.328Z" }, - { url = "https://files.pythonhosted.org/packages/1a/89/843b53614b47f97fe1abc13f9a86efa5ec9e275292c457af1d4a60dc80e0/cryptography-46.0.6-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:6728c49e3b2c180ef26f8e9f0a883a2c585638db64cf265b49c9ba10652d430e", size = 3409955, upload-time = "2026-03-25T23:34:48.465Z" }, +sdist = { url = "https://files.pythonhosted.org/packages/47/93/ac8f3d5ff04d54bc814e961a43ae5b0b146154c89c61b47bb07557679b18/cryptography-46.0.7.tar.gz", hash = "sha256:e4cfd68c5f3e0bfdad0d38e023239b96a2fe84146481852dffbcca442c245aa5", size = 750652, upload-time = "2026-04-08T01:57:54.692Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0b/5d/4a8f770695d73be252331e60e526291e3df0c9b27556a90a6b47bccca4c2/cryptography-46.0.7-cp311-abi3-macosx_10_9_universal2.whl", hash = "sha256:ea42cbe97209df307fdc3b155f1b6fa2577c0defa8f1f7d3be7d31d189108ad4", size = 7179869, upload-time = "2026-04-08T01:56:17.157Z" }, + { url = "https://files.pythonhosted.org/packages/5f/45/6d80dc379b0bbc1f9d1e429f42e4cb9e1d319c7a8201beffd967c516ea01/cryptography-46.0.7-cp311-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:b36a4695e29fe69215d75960b22577197aca3f7a25b9cf9d165dcfe9d80bc325", size = 4275492, upload-time = "2026-04-08T01:56:19.36Z" }, + { url = "https://files.pythonhosted.org/packages/4a/9a/1765afe9f572e239c3469f2cb429f3ba7b31878c893b246b4b2994ffe2fe/cryptography-46.0.7-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5ad9ef796328c5e3c4ceed237a183f5d41d21150f972455a9d926593a1dcb308", size = 4426670, upload-time = "2026-04-08T01:56:21.415Z" }, + { url = "https://files.pythonhosted.org/packages/8f/3e/af9246aaf23cd4ee060699adab1e47ced3f5f7e7a8ffdd339f817b446462/cryptography-46.0.7-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:73510b83623e080a2c35c62c15298096e2a5dc8d51c3b4e1740211839d0dea77", size = 4280275, upload-time = "2026-04-08T01:56:23.539Z" }, + { url = "https://files.pythonhosted.org/packages/0f/54/6bbbfc5efe86f9d71041827b793c24811a017c6ac0fd12883e4caa86b8ed/cryptography-46.0.7-cp311-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:cbd5fb06b62bd0721e1170273d3f4d5a277044c47ca27ee257025146c34cbdd1", size = 4928402, upload-time = "2026-04-08T01:56:25.624Z" }, + { url = "https://files.pythonhosted.org/packages/2d/cf/054b9d8220f81509939599c8bdbc0c408dbd2bdd41688616a20731371fe0/cryptography-46.0.7-cp311-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:420b1e4109cc95f0e5700eed79908cef9268265c773d3a66f7af1eef53d409ef", size = 4459985, upload-time = "2026-04-08T01:56:27.309Z" }, + { url = "https://files.pythonhosted.org/packages/f9/46/4e4e9c6040fb01c7467d47217d2f882daddeb8828f7df800cb806d8a2288/cryptography-46.0.7-cp311-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:24402210aa54baae71d99441d15bb5a1919c195398a87b563df84468160a65de", size = 3990652, upload-time = "2026-04-08T01:56:29.095Z" }, + { url = "https://files.pythonhosted.org/packages/36/5f/313586c3be5a2fbe87e4c9a254207b860155a8e1f3cca99f9910008e7d08/cryptography-46.0.7-cp311-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:8a469028a86f12eb7d2fe97162d0634026d92a21f3ae0ac87ed1c4a447886c83", size = 4279805, upload-time = "2026-04-08T01:56:30.928Z" }, + { url = "https://files.pythonhosted.org/packages/69/33/60dfc4595f334a2082749673386a4d05e4f0cf4df8248e63b2c3437585f2/cryptography-46.0.7-cp311-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:9694078c5d44c157ef3162e3bf3946510b857df5a3955458381d1c7cfc143ddb", size = 4892883, upload-time = "2026-04-08T01:56:32.614Z" }, + { url = "https://files.pythonhosted.org/packages/c7/0b/333ddab4270c4f5b972f980adef4faa66951a4aaf646ca067af597f15563/cryptography-46.0.7-cp311-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:42a1e5f98abb6391717978baf9f90dc28a743b7d9be7f0751a6f56a75d14065b", size = 4459756, upload-time = "2026-04-08T01:56:34.306Z" }, + { url = "https://files.pythonhosted.org/packages/d2/14/633913398b43b75f1234834170947957c6b623d1701ffc7a9600da907e89/cryptography-46.0.7-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:91bbcb08347344f810cbe49065914fe048949648f6bd5c2519f34619142bbe85", size = 4410244, upload-time = "2026-04-08T01:56:35.977Z" }, + { url = "https://files.pythonhosted.org/packages/10/f2/19ceb3b3dc14009373432af0c13f46aa08e3ce334ec6eff13492e1812ccd/cryptography-46.0.7-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:5d1c02a14ceb9148cc7816249f64f623fbfee39e8c03b3650d842ad3f34d637e", size = 4674868, upload-time = "2026-04-08T01:56:38.034Z" }, + { url = "https://files.pythonhosted.org/packages/1a/bb/a5c213c19ee94b15dfccc48f363738633a493812687f5567addbcbba9f6f/cryptography-46.0.7-cp311-abi3-win32.whl", hash = "sha256:d23c8ca48e44ee015cd0a54aeccdf9f09004eba9fc96f38c911011d9ff1bd457", size = 3026504, upload-time = "2026-04-08T01:56:39.666Z" }, + { url = "https://files.pythonhosted.org/packages/2b/02/7788f9fefa1d060ca68717c3901ae7fffa21ee087a90b7f23c7a603c32ae/cryptography-46.0.7-cp311-abi3-win_amd64.whl", hash = "sha256:397655da831414d165029da9bc483bed2fe0e75dde6a1523ec2fe63f3c46046b", size = 3488363, upload-time = "2026-04-08T01:56:41.893Z" }, + { url = "https://files.pythonhosted.org/packages/7b/56/15619b210e689c5403bb0540e4cb7dbf11a6bf42e483b7644e471a2812b3/cryptography-46.0.7-cp314-cp314t-macosx_10_9_universal2.whl", hash = "sha256:d151173275e1728cf7839aaa80c34fe550c04ddb27b34f48c232193df8db5842", size = 7119671, upload-time = "2026-04-08T01:56:44Z" }, + { url = "https://files.pythonhosted.org/packages/74/66/e3ce040721b0b5599e175ba91ab08884c75928fbeb74597dd10ef13505d2/cryptography-46.0.7-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:db0f493b9181c7820c8134437eb8b0b4792085d37dbb24da050476ccb664e59c", size = 4268551, upload-time = "2026-04-08T01:56:46.071Z" }, + { url = "https://files.pythonhosted.org/packages/03/11/5e395f961d6868269835dee1bafec6a1ac176505a167f68b7d8818431068/cryptography-46.0.7-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ebd6daf519b9f189f85c479427bbd6e9c9037862cf8fe89ee35503bd209ed902", size = 4408887, upload-time = "2026-04-08T01:56:47.718Z" }, + { url = "https://files.pythonhosted.org/packages/40/53/8ed1cf4c3b9c8e611e7122fb56f1c32d09e1fff0f1d77e78d9ff7c82653e/cryptography-46.0.7-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:b7b412817be92117ec5ed95f880defe9cf18a832e8cafacf0a22337dc1981b4d", size = 4271354, upload-time = "2026-04-08T01:56:49.312Z" }, + { url = "https://files.pythonhosted.org/packages/50/46/cf71e26025c2e767c5609162c866a78e8a2915bbcfa408b7ca495c6140c4/cryptography-46.0.7-cp314-cp314t-manylinux_2_28_ppc64le.whl", hash = "sha256:fbfd0e5f273877695cb93baf14b185f4878128b250cc9f8e617ea0c025dfb022", size = 4905845, upload-time = "2026-04-08T01:56:50.916Z" }, + { url = "https://files.pythonhosted.org/packages/c0/ea/01276740375bac6249d0a971ebdf6b4dc9ead0ee0a34ef3b5a88c1a9b0d4/cryptography-46.0.7-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:ffca7aa1d00cf7d6469b988c581598f2259e46215e0140af408966a24cf086ce", size = 4444641, upload-time = "2026-04-08T01:56:52.882Z" }, + { url = "https://files.pythonhosted.org/packages/3d/4c/7d258f169ae71230f25d9f3d06caabcff8c3baf0978e2b7d65e0acac3827/cryptography-46.0.7-cp314-cp314t-manylinux_2_31_armv7l.whl", hash = "sha256:60627cf07e0d9274338521205899337c5d18249db56865f943cbe753aa96f40f", size = 3967749, upload-time = "2026-04-08T01:56:54.597Z" }, + { url = "https://files.pythonhosted.org/packages/b5/2a/2ea0767cad19e71b3530e4cad9605d0b5e338b6a1e72c37c9c1ceb86c333/cryptography-46.0.7-cp314-cp314t-manylinux_2_34_aarch64.whl", hash = "sha256:80406c3065e2c55d7f49a9550fe0c49b3f12e5bfff5dedb727e319e1afb9bf99", size = 4270942, upload-time = "2026-04-08T01:56:56.416Z" }, + { url = "https://files.pythonhosted.org/packages/41/3d/fe14df95a83319af25717677e956567a105bb6ab25641acaa093db79975d/cryptography-46.0.7-cp314-cp314t-manylinux_2_34_ppc64le.whl", hash = "sha256:c5b1ccd1239f48b7151a65bc6dd54bcfcc15e028c8ac126d3fada09db0e07ef1", size = 4871079, upload-time = "2026-04-08T01:56:58.31Z" }, + { url = "https://files.pythonhosted.org/packages/9c/59/4a479e0f36f8f378d397f4eab4c850b4ffb79a2f0d58704b8fa0703ddc11/cryptography-46.0.7-cp314-cp314t-manylinux_2_34_x86_64.whl", hash = "sha256:d5f7520159cd9c2154eb61eb67548ca05c5774d39e9c2c4339fd793fe7d097b2", size = 4443999, upload-time = "2026-04-08T01:57:00.508Z" }, + { url = "https://files.pythonhosted.org/packages/28/17/b59a741645822ec6d04732b43c5d35e4ef58be7bfa84a81e5ae6f05a1d33/cryptography-46.0.7-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:fcd8eac50d9138c1d7fc53a653ba60a2bee81a505f9f8850b6b2888555a45d0e", size = 4399191, upload-time = "2026-04-08T01:57:02.654Z" }, + { url = "https://files.pythonhosted.org/packages/59/6a/bb2e166d6d0e0955f1e9ff70f10ec4b2824c9cfcdb4da772c7dd69cc7d80/cryptography-46.0.7-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:65814c60f8cc400c63131584e3e1fad01235edba2614b61fbfbfa954082db0ee", size = 4655782, upload-time = "2026-04-08T01:57:04.592Z" }, + { url = "https://files.pythonhosted.org/packages/95/b6/3da51d48415bcb63b00dc17c2eff3a651b7c4fed484308d0f19b30e8cb2c/cryptography-46.0.7-cp314-cp314t-win32.whl", hash = "sha256:fdd1736fed309b4300346f88f74cd120c27c56852c3838cab416e7a166f67298", size = 3002227, upload-time = "2026-04-08T01:57:06.91Z" }, + { url = "https://files.pythonhosted.org/packages/32/a8/9f0e4ed57ec9cebe506e58db11ae472972ecb0c659e4d52bbaee80ca340a/cryptography-46.0.7-cp314-cp314t-win_amd64.whl", hash = "sha256:e06acf3c99be55aa3b516397fe42f5855597f430add9c17fa46bf2e0fb34c9bb", size = 3475332, upload-time = "2026-04-08T01:57:08.807Z" }, + { url = "https://files.pythonhosted.org/packages/a7/7f/cd42fc3614386bc0c12f0cb3c4ae1fc2bbca5c9662dfed031514911d513d/cryptography-46.0.7-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:462ad5cb1c148a22b2e3bcc5ad52504dff325d17daf5df8d88c17dda1f75f2a4", size = 7165618, upload-time = "2026-04-08T01:57:10.645Z" }, + { url = "https://files.pythonhosted.org/packages/a5/d0/36a49f0262d2319139d2829f773f1b97ef8aef7f97e6e5bd21455e5a8fb5/cryptography-46.0.7-cp38-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:84d4cced91f0f159a7ddacad249cc077e63195c36aac40b4150e7a57e84fffe7", size = 4270628, upload-time = "2026-04-08T01:57:12.885Z" }, + { url = "https://files.pythonhosted.org/packages/8a/6c/1a42450f464dda6ffbe578a911f773e54dd48c10f9895a23a7e88b3e7db5/cryptography-46.0.7-cp38-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:128c5edfe5e5938b86b03941e94fac9ee793a94452ad1365c9fc3f4f62216832", size = 4415405, upload-time = "2026-04-08T01:57:14.923Z" }, + { url = "https://files.pythonhosted.org/packages/9a/92/4ed714dbe93a066dc1f4b4581a464d2d7dbec9046f7c8b7016f5286329e2/cryptography-46.0.7-cp38-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:5e51be372b26ef4ba3de3c167cd3d1022934bc838ae9eaad7e644986d2a3d163", size = 4272715, upload-time = "2026-04-08T01:57:16.638Z" }, + { url = "https://files.pythonhosted.org/packages/b7/e6/a26b84096eddd51494bba19111f8fffe976f6a09f132706f8f1bf03f51f7/cryptography-46.0.7-cp38-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:cdf1a610ef82abb396451862739e3fc93b071c844399e15b90726ef7470eeaf2", size = 4918400, upload-time = "2026-04-08T01:57:19.021Z" }, + { url = "https://files.pythonhosted.org/packages/c7/08/ffd537b605568a148543ac3c2b239708ae0bd635064bab41359252ef88ed/cryptography-46.0.7-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:1d25aee46d0c6f1a501adcddb2d2fee4b979381346a78558ed13e50aa8a59067", size = 4450634, upload-time = "2026-04-08T01:57:21.185Z" }, + { url = "https://files.pythonhosted.org/packages/16/01/0cd51dd86ab5b9befe0d031e276510491976c3a80e9f6e31810cce46c4ad/cryptography-46.0.7-cp38-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:cdfbe22376065ffcf8be74dc9a909f032df19bc58a699456a21712d6e5eabfd0", size = 3985233, upload-time = "2026-04-08T01:57:22.862Z" }, + { url = "https://files.pythonhosted.org/packages/92/49/819d6ed3a7d9349c2939f81b500a738cb733ab62fbecdbc1e38e83d45e12/cryptography-46.0.7-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:abad9dac36cbf55de6eb49badd4016806b3165d396f64925bf2999bcb67837ba", size = 4271955, upload-time = "2026-04-08T01:57:24.814Z" }, + { url = "https://files.pythonhosted.org/packages/80/07/ad9b3c56ebb95ed2473d46df0847357e01583f4c52a85754d1a55e29e4d0/cryptography-46.0.7-cp38-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:935ce7e3cfdb53e3536119a542b839bb94ec1ad081013e9ab9b7cfd478b05006", size = 4879888, upload-time = "2026-04-08T01:57:26.88Z" }, + { url = "https://files.pythonhosted.org/packages/b8/c7/201d3d58f30c4c2bdbe9b03844c291feb77c20511cc3586daf7edc12a47b/cryptography-46.0.7-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:35719dc79d4730d30f1c2b6474bd6acda36ae2dfae1e3c16f2051f215df33ce0", size = 4449961, upload-time = "2026-04-08T01:57:29.068Z" }, + { url = "https://files.pythonhosted.org/packages/a5/ef/649750cbf96f3033c3c976e112265c33906f8e462291a33d77f90356548c/cryptography-46.0.7-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:7bbc6ccf49d05ac8f7d7b5e2e2c33830d4fe2061def88210a126d130d7f71a85", size = 4401696, upload-time = "2026-04-08T01:57:31.029Z" }, + { url = "https://files.pythonhosted.org/packages/41/52/a8908dcb1a389a459a29008c29966c1d552588d4ae6d43f3a1a4512e0ebe/cryptography-46.0.7-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a1529d614f44b863a7b480c6d000fe93b59acee9c82ffa027cfadc77521a9f5e", size = 4664256, upload-time = "2026-04-08T01:57:33.144Z" }, + { url = "https://files.pythonhosted.org/packages/4b/fa/f0ab06238e899cc3fb332623f337a7364f36f4bb3f2534c2bb95a35b132c/cryptography-46.0.7-cp38-abi3-win32.whl", hash = "sha256:f247c8c1a1fb45e12586afbb436ef21ff1e80670b2861a90353d9b025583d246", size = 3013001, upload-time = "2026-04-08T01:57:34.933Z" }, + { url = "https://files.pythonhosted.org/packages/d2/f1/00ce3bde3ca542d1acd8f8cfa38e446840945aa6363f9b74746394b14127/cryptography-46.0.7-cp38-abi3-win_amd64.whl", hash = "sha256:506c4ff91eff4f82bdac7633318a526b1d1309fc07ca76a3ad182cb5b686d6d3", size = 3472985, upload-time = "2026-04-08T01:57:36.714Z" }, + { url = "https://files.pythonhosted.org/packages/63/0c/dca8abb64e7ca4f6b2978769f6fea5ad06686a190cec381f0a796fdcaaba/cryptography-46.0.7-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:fc9ab8856ae6cf7c9358430e49b368f3108f050031442eaeb6b9d87e4dcf4e4f", size = 3476879, upload-time = "2026-04-08T01:57:38.664Z" }, + { url = "https://files.pythonhosted.org/packages/3a/ea/075aac6a84b7c271578d81a2f9968acb6e273002408729f2ddff517fed4a/cryptography-46.0.7-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:d3b99c535a9de0adced13d159c5a9cf65c325601aa30f4be08afd680643e9c15", size = 4219700, upload-time = "2026-04-08T01:57:40.625Z" }, + { url = "https://files.pythonhosted.org/packages/6c/7b/1c55db7242b5e5612b29fc7a630e91ee7a6e3c8e7bf5406d22e206875fbd/cryptography-46.0.7-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:d02c738dacda7dc2a74d1b2b3177042009d5cab7c7079db74afc19e56ca1b455", size = 4385982, upload-time = "2026-04-08T01:57:42.725Z" }, + { url = "https://files.pythonhosted.org/packages/cb/da/9870eec4b69c63ef5925bf7d8342b7e13bc2ee3d47791461c4e49ca212f4/cryptography-46.0.7-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:04959522f938493042d595a736e7dbdff6eb6cc2339c11465b3ff89343b65f65", size = 4219115, upload-time = "2026-04-08T01:57:44.939Z" }, + { url = "https://files.pythonhosted.org/packages/f4/72/05aa5832b82dd341969e9a734d1812a6aadb088d9eb6f0430fc337cc5a8f/cryptography-46.0.7-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:3986ac1dee6def53797289999eabe84798ad7817f3e97779b5061a95b0ee4968", size = 4385479, upload-time = "2026-04-08T01:57:46.86Z" }, + { url = "https://files.pythonhosted.org/packages/20/2a/1b016902351a523aa2bd446b50a5bc1175d7a7d1cf90fe2ef904f9b84ebc/cryptography-46.0.7-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:258514877e15963bd43b558917bc9f54cf7cf866c38aa576ebf47a77ddbc43a4", size = 3412829, upload-time = "2026-04-08T01:57:48.874Z" }, ] [[package]] From 01b3b2c0e196b0aab4f1f0dc22a95c09c7ee914d Mon Sep 17 00:00:00 2001 From: Ivan Shymko Date: Thu, 9 Apr 2026 09:57:46 +0200 Subject: [PATCH 13/67] refactor(client)!: reorganize ClientFactory API (#947) #### Replace `connect` class method with `create_from_url` instance method 1. `connect` implies some persistent connection, in fact the only difference with `create` is I/O during agent card resolution (also optional, as it accepted both URL or agent card itself). 2. Contained logic which was useful for a pre-configured factory instance (like agent card resolution). 3. It's a separate `async` method and `create` is kept without I/O. #### Added a utility `create_client` module function One-line entry point similar to the former `connect` to simplify migration, but doesn't contain any domain logic and just does dispatching between URL and agent card. --- itk/main.py | 7 +- samples/cli.py | 4 +- src/a2a/client/__init__.py | 7 +- src/a2a/client/client_factory.py | 161 ++++++++++-------- tests/client/test_client_factory.py | 146 ++++++++++++---- .../cross_version/client_server/client_1_0.py | 6 +- 6 files changed, 214 insertions(+), 117 deletions(-) diff --git a/itk/main.py b/itk/main.py index 22cfef2a4..7be7a5a20 100644 --- a/itk/main.py +++ b/itk/main.py @@ -12,7 +12,7 @@ from pyproto import instruction_pb2 -from a2a.client import ClientConfig, ClientFactory +from a2a.client import ClientConfig, create_client from a2a.compat.v0_3 import a2a_v0_3_pb2_grpc from a2a.compat.v0_3.grpc_handler import CompatGrpcHandler from a2a.server.agent_execution import AgentExecutor, RequestContext @@ -128,10 +128,7 @@ async def handle_call_agent(call: instruction_pb2.CallAgent) -> list[str]: ) try: - client = await ClientFactory.connect( - call.agent_card_uri, - client_config=config, - ) + client = await create_client(call.agent_card_uri, client_config=config) # Wrap nested instruction async with client: diff --git a/samples/cli.py b/samples/cli.py index 6a4597fa9..8515fd5a9 100644 --- a/samples/cli.py +++ b/samples/cli.py @@ -9,7 +9,7 @@ import grpc import httpx -from a2a.client import A2ACardResolver, ClientConfig, ClientFactory +from a2a.client import A2ACardResolver, ClientConfig, create_client from a2a.types import Message, Part, Role, SendMessageRequest, TaskState @@ -79,7 +79,7 @@ async def main() -> None: print('\n✓ Agent Card Found:') print(f' Name: {card.name}') - client = await ClientFactory.connect(card, client_config=config) + client = await create_client(card, client_config=config) actual_transport = getattr(client, '_transport', client) print(f' Picked Transport: {actual_transport.__class__.__name__}') diff --git a/src/a2a/client/__init__.py b/src/a2a/client/__init__.py index 188ab4c80..c23041f32 100644 --- a/src/a2a/client/__init__.py +++ b/src/a2a/client/__init__.py @@ -12,7 +12,11 @@ ClientCallContext, ClientConfig, ) -from a2a.client.client_factory import ClientFactory, minimal_agent_card +from a2a.client.client_factory import ( + ClientFactory, + create_client, + minimal_agent_card, +) from a2a.client.errors import ( A2AClientError, A2AClientTimeoutError, @@ -36,6 +40,7 @@ 'ClientFactory', 'CredentialService', 'InMemoryContextCredentialStore', + 'create_client', 'create_text_message_object', 'minimal_agent_card', ] diff --git a/src/a2a/client/client_factory.py b/src/a2a/client/client_factory.py index c5d5e8aa4..a59189ade 100644 --- a/src/a2a/client/client_factory.py +++ b/src/a2a/client/client_factory.py @@ -3,7 +3,7 @@ import logging from collections.abc import Callable -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any import httpx @@ -56,32 +56,35 @@ class ClientFactory: - """ClientFactory is used to generate the appropriate client for the agent. + """Factory for creating clients that communicate with A2A agents. - The factory is configured with a `ClientConfig` and optionally a list of - `Consumer`s to use for all generated `Client`s. The expected use is: - - .. code-block:: python + The factory is configured with a `ClientConfig` and optionally custom + transport producers registered via `register`. Example usage: factory = ClientFactory(config) - # Optionally register custom client implementations - factory.register('my_customer_transport', NewCustomTransportClient) - # Then with an agent card make a client with additional interceptors + # Optionally register custom transport implementations + factory.register('my_custom_transport', custom_transport_producer) + # Create a client from an AgentCard client = factory.create(card, interceptors) + # Or resolve an AgentCard from a URL and create a client + client = await factory.create_from_url('https://example.com') - Now the client can be used consistently regardless of the transport. This + The client can be used consistently regardless of the transport. This aligns the client configuration with the server's capabilities. """ def __init__( self, - config: ClientConfig, + config: ClientConfig | None = None, ): - client = config.httpx_client or httpx.AsyncClient() - client.headers.setdefault(VERSION_HEADER, PROTOCOL_VERSION_CURRENT) - config.httpx_client = client + config = config or ClientConfig() + httpx_client = config.httpx_client or httpx.AsyncClient() + httpx_client.headers.setdefault( + VERSION_HEADER, PROTOCOL_VERSION_CURRENT + ) self._config = config + self._httpx_client = httpx_client self._registry: dict[str, TransportProducer] = {} self._register_defaults(config.supported_protocol_bindings) @@ -112,13 +115,13 @@ def jsonrpc_transport_producer( ) return CompatJsonRpcTransport( - cast('httpx.AsyncClient', config.httpx_client), + self._httpx_client, card, url, ) return JsonRpcTransport( - cast('httpx.AsyncClient', config.httpx_client), + self._httpx_client, card, url, ) @@ -151,13 +154,13 @@ def rest_transport_producer( ) return CompatRestTransport( - cast('httpx.AsyncClient', config.httpx_client), + self._httpx_client, card, url, ) return RestTransport( - cast('httpx.AsyncClient', config.httpx_client), + self._httpx_client, card, url, ) @@ -252,73 +255,45 @@ def _find_best_interface( return best_gt_1_0 or best_ge_0_3 or best_no_version - @classmethod - async def connect( # noqa: PLR0913 - cls, - agent: str | AgentCard, - client_config: ClientConfig | None = None, + async def create_from_url( + self, + url: str, interceptors: list[ClientCallInterceptor] | None = None, relative_card_path: str | None = None, resolver_http_kwargs: dict[str, Any] | None = None, - extra_transports: dict[str, TransportProducer] | None = None, signature_verifier: Callable[[AgentCard], None] | None = None, ) -> Client: - """Convenience method for constructing a client. - - Constructs a client that connects to the specified agent. Note that - creating multiple clients via this method is less efficient than - constructing an instance of ClientFactory and reusing that. - - .. code-block:: python + """Create a `Client` by resolving an `AgentCard` from a URL. - # This will search for an AgentCard at /.well-known/agent-card.json - my_agent_url = 'https://travel.agents.example.com' - client = await ClientFactory.connect(my_agent_url) + Resolves the agent card from the given URL using the factory's + configured httpx client, then creates a client via `create`. + If the agent card is already available, use `create` directly + instead. Args: - agent: The base URL of the agent, or the AgentCard to connect to. - client_config: The ClientConfig to use when connecting to the agent. - - interceptors: A list of interceptors to use for each request. These - are used for things like attaching credentials or http headers - to all outbound requests. - relative_card_path: If the agent field is a URL, this value is used as - the relative path when resolving the agent card. See - A2AAgentCardResolver.get_agent_card for more details. - resolver_http_kwargs: Dictionary of arguments to provide to the httpx - client when resolving the agent card. This value is provided to - A2AAgentCardResolver.get_agent_card as the http_kwargs parameter. - extra_transports: Additional transport protocols to enable when - constructing the client. - signature_verifier: A callable used to verify the agent card's signatures. + url: The base URL of the agent. The agent card will be fetched + from `/.well-known/agent-card.json` by default. + interceptors: A list of interceptors to use for each request. + These are used for things like attaching credentials or http + headers to all outbound requests. + relative_card_path: The relative path when resolving the agent + card. See `A2ACardResolver.get_agent_card` for details. + resolver_http_kwargs: Dictionary of arguments to provide to the + httpx client when resolving the agent card. + signature_verifier: A callable used to verify the agent card's + signatures. Returns: A `Client` object. """ - client_config = client_config or ClientConfig() - if isinstance(agent, str): - if not client_config.httpx_client: - async with httpx.AsyncClient() as client: - resolver = A2ACardResolver(client, agent) - card = await resolver.get_agent_card( - relative_card_path=relative_card_path, - http_kwargs=resolver_http_kwargs, - signature_verifier=signature_verifier, - ) - else: - resolver = A2ACardResolver(client_config.httpx_client, agent) - card = await resolver.get_agent_card( - relative_card_path=relative_card_path, - http_kwargs=resolver_http_kwargs, - signature_verifier=signature_verifier, - ) - else: - card = agent - factory = cls(client_config) - for label, generator in (extra_transports or {}).items(): - factory.register(label, generator) - return factory.create(card, interceptors) + resolver = A2ACardResolver(self._httpx_client, url) + card = await resolver.get_agent_card( + relative_card_path=relative_card_path, + http_kwargs=resolver_http_kwargs, + signature_verifier=signature_verifier, + ) + return self.create(card, interceptors) def register(self, label: str, generator: TransportProducer) -> None: """Register a new transport producer for a given transport label.""" @@ -389,6 +364,48 @@ def create( ) +async def create_client( # noqa: PLR0913 + agent: str | AgentCard, + client_config: ClientConfig | None = None, + interceptors: list[ClientCallInterceptor] | None = None, + relative_card_path: str | None = None, + resolver_http_kwargs: dict[str, Any] | None = None, + signature_verifier: Callable[[AgentCard], None] | None = None, +) -> Client: + """Create a `Client` for an agent from a URL or `AgentCard`. + + Convenience function that constructs a `ClientFactory` internally. + For reusing a factory across multiple agents or registering custom + transports, use `ClientFactory` directly instead. + + Args: + agent: The base URL of the agent, or an `AgentCard` to use + directly. + client_config: Optional `ClientConfig`. A default config is + created if not provided. + interceptors: A list of interceptors to use for each request. + relative_card_path: The relative path when resolving the agent + card. Only used when `agent` is a URL. + resolver_http_kwargs: Dictionary of arguments to provide to the + httpx client when resolving the agent card. + signature_verifier: A callable used to verify the agent card's + signatures. + + Returns: + A `Client` object. + """ + factory = ClientFactory(client_config) + if isinstance(agent, str): + return await factory.create_from_url( + agent, + interceptors=interceptors, + relative_card_path=relative_card_path, + resolver_http_kwargs=resolver_http_kwargs, + signature_verifier=signature_verifier, + ) + return factory.create(agent, interceptors) + + def minimal_agent_card( url: str, transports: list[str] | None = None ) -> AgentCard: diff --git a/tests/client/test_client_factory.py b/tests/client/test_client_factory.py index a5366e0d3..b30d57d12 100644 --- a/tests/client/test_client_factory.py +++ b/tests/client/test_client_factory.py @@ -1,18 +1,16 @@ """Tests for the ClientFactory.""" -from collections.abc import AsyncGenerator from unittest.mock import AsyncMock, MagicMock, patch import typing import httpx import pytest -from a2a.client import ClientConfig, ClientFactory +from a2a.client import ClientConfig, ClientFactory, create_client from a2a.client.client_factory import TransportProducer from a2a.client.transports import ( JsonRpcTransport, RestTransport, - ClientTransport, ) from a2a.client.transports.tenant_decorator import TenantTransportDecorator from a2a.types.a2a_pb2 import ( @@ -127,26 +125,27 @@ def test_client_factory_no_compatible_transport(base_agent_card: AgentCard): factory.create(base_agent_card) -@pytest.mark.asyncio -async def test_client_factory_connect_with_agent_card( +def test_client_factory_create_with_default_config( base_agent_card: AgentCard, ): - """Verify that connect works correctly when provided with an AgentCard.""" - client = await ClientFactory.connect(base_agent_card) + """Verify that create works correctly with a default ClientConfig.""" + factory = ClientFactory() + client = factory.create(base_agent_card) assert isinstance(client._transport, JsonRpcTransport) # type: ignore[attr-defined] assert client._transport.url == 'http://primary-url.com' # type: ignore[attr-defined] @pytest.mark.asyncio -async def test_client_factory_connect_with_url(base_agent_card: AgentCard): - """Verify that connect works correctly when provided with a URL.""" +async def test_client_factory_create_from_url(base_agent_card: AgentCard): + """Verify that create_from_url resolves the card and creates a client.""" with patch('a2a.client.client_factory.A2ACardResolver') as mock_resolver: mock_resolver.return_value.get_agent_card = AsyncMock( return_value=base_agent_card ) agent_url = 'http://example.com' - client = await ClientFactory.connect(agent_url) + factory = ClientFactory() + client = await factory.create_from_url(agent_url) mock_resolver.assert_called_once() assert mock_resolver.call_args[0][1] == agent_url @@ -157,10 +156,10 @@ async def test_client_factory_connect_with_url(base_agent_card: AgentCard): @pytest.mark.asyncio -async def test_client_factory_connect_with_url_and_client_config( +async def test_client_factory_create_from_url_uses_factory_httpx_client( base_agent_card: AgentCard, ): - """Verify connect with a URL and a pre-configured httpx client.""" + """Verify create_from_url uses the factory's configured httpx client.""" with patch('a2a.client.client_factory.A2ACardResolver') as mock_resolver: mock_resolver.return_value.get_agent_card = AsyncMock( return_value=base_agent_card @@ -170,7 +169,8 @@ async def test_client_factory_connect_with_url_and_client_config( mock_httpx_client = httpx.AsyncClient() config = ClientConfig(httpx_client=mock_httpx_client) - client = await ClientFactory.connect(agent_url, client_config=config) + factory = ClientFactory(config) + client = await factory.create_from_url(agent_url) mock_resolver.assert_called_once_with(mock_httpx_client, agent_url) mock_resolver.return_value.get_agent_card.assert_awaited_once() @@ -180,10 +180,10 @@ async def test_client_factory_connect_with_url_and_client_config( @pytest.mark.asyncio -async def test_client_factory_connect_with_resolver_args( +async def test_client_factory_create_from_url_passes_resolver_args( base_agent_card: AgentCard, ): - """Verify connect passes resolver arguments correctly.""" + """Verify create_from_url passes resolver arguments correctly.""" with patch('a2a.client.client_factory.A2ACardResolver') as mock_resolver: mock_resolver.return_value.get_agent_card = AsyncMock( return_value=base_agent_card @@ -193,12 +193,11 @@ async def test_client_factory_connect_with_resolver_args( relative_path = '/extendedAgentCard' http_kwargs = {'headers': {'X-Test': 'true'}} - # The resolver args are only passed if an httpx_client is provided in config config = ClientConfig(httpx_client=httpx.AsyncClient()) + factory = ClientFactory(config) - await ClientFactory.connect( + await factory.create_from_url( agent_url, - client_config=config, relative_card_path=relative_path, resolver_http_kwargs=http_kwargs, ) @@ -211,10 +210,10 @@ async def test_client_factory_connect_with_resolver_args( @pytest.mark.asyncio -async def test_client_factory_connect_resolver_args_without_client( +async def test_client_factory_create_from_url_with_default_config( base_agent_card: AgentCard, ): - """Verify resolver args are ignored if no httpx_client is provided.""" + """Verify create_from_url works with a default ClientConfig.""" with patch('a2a.client.client_factory.A2ACardResolver') as mock_resolver: mock_resolver.return_value.get_agent_card = AsyncMock( return_value=base_agent_card @@ -224,12 +223,16 @@ async def test_client_factory_connect_resolver_args_without_client( relative_path = '/extendedAgentCard' http_kwargs = {'headers': {'X-Test': 'true'}} - await ClientFactory.connect( + factory = ClientFactory() + + await factory.create_from_url( agent_url, relative_card_path=relative_path, resolver_http_kwargs=http_kwargs, ) + # Factory always creates an httpx client, so resolver gets it + mock_resolver.assert_called_once() mock_resolver.return_value.get_agent_card.assert_awaited_once_with( relative_card_path=relative_path, http_kwargs=http_kwargs, @@ -237,16 +240,17 @@ async def test_client_factory_connect_resolver_args_without_client( ) -@pytest.mark.asyncio -async def test_client_factory_connect_with_extra_transports( +def test_client_factory_register_and_create_custom_transport( base_agent_card: AgentCard, ): - """Verify that connect can register and use extra transports.""" + """Verify that register() + create() uses custom transports.""" class CustomTransport: pass - def custom_transport_producer(*args, **kwargs): + def custom_transport_producer( + *args: typing.Any, **kwargs: typing.Any + ) -> CustomTransport: return CustomTransport() base_agent_card.supported_interfaces.insert( @@ -255,27 +259,60 @@ def custom_transport_producer(*args, **kwargs): ) config = ClientConfig(supported_protocol_bindings=['custom']) - - client = await ClientFactory.connect( - base_agent_card, - client_config=config, - extra_transports=typing.cast( - dict[str, TransportProducer], {'custom': custom_transport_producer} - ), + factory = ClientFactory(config) + factory.register( + 'custom', + typing.cast(TransportProducer, custom_transport_producer), ) + client = factory.create(base_agent_card) assert isinstance(client._transport, CustomTransport) # type: ignore[attr-defined] @pytest.mark.asyncio -async def test_client_factory_connect_with_interceptors( +async def test_client_factory_create_from_url_uses_registered_transports( + base_agent_card: AgentCard, +): + """Verify that create_from_url() respects custom transports from register().""" + + class CustomTransport: + pass + + def custom_transport_producer( + *args: typing.Any, **kwargs: typing.Any + ) -> CustomTransport: + return CustomTransport() + + base_agent_card.supported_interfaces.insert( + 0, + AgentInterface(protocol_binding='custom', url='custom://foo'), + ) + + with patch('a2a.client.client_factory.A2ACardResolver') as mock_resolver: + mock_resolver.return_value.get_agent_card = AsyncMock( + return_value=base_agent_card + ) + + config = ClientConfig(supported_protocol_bindings=['custom']) + factory = ClientFactory(config) + factory.register( + 'custom', + typing.cast(TransportProducer, custom_transport_producer), + ) + + client = await factory.create_from_url('http://example.com') + assert isinstance(client._transport, CustomTransport) # type: ignore[attr-defined] + + +def test_client_factory_create_with_interceptors( base_agent_card: AgentCard, ): """Verify interceptors are passed through correctly.""" interceptor1 = MagicMock() with patch('a2a.client.client_factory.BaseClient') as mock_base_client: - await ClientFactory.connect( + factory = ClientFactory() + factory.create( base_agent_card, interceptors=[interceptor1], ) @@ -298,3 +335,44 @@ def test_client_factory_applies_tenant_decorator(base_agent_card: AgentCard): assert isinstance(client._transport, TenantTransportDecorator) # type: ignore[attr-defined] assert client._transport._tenant == 'my-tenant' # type: ignore[attr-defined] assert isinstance(client._transport._base, JsonRpcTransport) # type: ignore[attr-defined] + + +@pytest.mark.asyncio +async def test_create_client_with_agent_card(base_agent_card: AgentCard): + """Verify create_client works when given an AgentCard directly.""" + client = await create_client(base_agent_card) + assert isinstance(client._transport, JsonRpcTransport) # type: ignore[attr-defined] + assert client._transport.url == 'http://primary-url.com' # type: ignore[attr-defined] + + +@pytest.mark.asyncio +async def test_create_client_with_url(base_agent_card: AgentCard): + """Verify create_client resolves a URL and creates a client.""" + with patch('a2a.client.client_factory.A2ACardResolver') as mock_resolver: + mock_resolver.return_value.get_agent_card = AsyncMock( + return_value=base_agent_card + ) + + client = await create_client('http://example.com') + + mock_resolver.assert_called_once() + assert mock_resolver.call_args[0][1] == 'http://example.com' + assert isinstance(client._transport, JsonRpcTransport) # type: ignore[attr-defined] + + +@pytest.mark.asyncio +async def test_create_client_with_url_and_config(base_agent_card: AgentCard): + """Verify create_client passes client_config to the factory.""" + with patch('a2a.client.client_factory.A2ACardResolver') as mock_resolver: + mock_resolver.return_value.get_agent_card = AsyncMock( + return_value=base_agent_card + ) + + mock_httpx_client = httpx.AsyncClient() + config = ClientConfig(httpx_client=mock_httpx_client) + + await create_client('http://example.com', client_config=config) + + mock_resolver.assert_called_once_with( + mock_httpx_client, 'http://example.com' + ) diff --git a/tests/integration/cross_version/client_server/client_1_0.py b/tests/integration/cross_version/client_server/client_1_0.py index 5a5e192cf..6630bddad 100644 --- a/tests/integration/cross_version/client_server/client_1_0.py +++ b/tests/integration/cross_version/client_server/client_1_0.py @@ -5,7 +5,7 @@ import sys from uuid import uuid4 -from a2a.client import ClientFactory, ClientConfig +from a2a.client import ClientConfig, create_client from a2a.utils import TransportProtocol from a2a.types import ( Message, @@ -80,7 +80,7 @@ async def test_send_message_sync(url, protocol_enum): config.supported_protocol_bindings = [protocol_enum] config.streaming = False - client = await ClientFactory.connect(url, client_config=config) + client = await create_client(url, client_config=config) msg = Message( role=Role.ROLE_USER, message_id=f'sync-{uuid4()}', @@ -296,7 +296,7 @@ async def run_client(url: str, protocol: str): config.supported_protocol_bindings = [protocol_enum] config.streaming = True - client = await ClientFactory.connect(url, client_config=config) + client = await create_client(url, client_config=config) # 1. Get Extended Agent Card server_name = await test_get_extended_agent_card(client) From 3a68d8f916d96ae135748ee2b9b907f8dace4fa7 Mon Sep 17 00:00:00 2001 From: Ivan Shymko Date: Thu, 9 Apr 2026 12:07:52 +0200 Subject: [PATCH 14/67] fix: handle SSE errors occurred after stream started (#894) The spec doesn't defined this behavior: https://github.com/a2aproject/A2A/issues/1262, but currently it'd close the connection. --- src/a2a/client/transports/http_helpers.py | 21 ++++- src/a2a/client/transports/jsonrpc.py | 10 ++- src/a2a/client/transports/rest.py | 83 ++++++++++++------- src/a2a/server/routes/jsonrpc_dispatcher.py | 26 +++++- src/a2a/server/routes/rest_dispatcher.py | 19 ++++- src/a2a/utils/error_handlers.py | 80 +++++++++--------- .../test_client_server_integration.py | 76 +++++++++++++++++ tests/server/routes/test_rest_dispatcher.py | 9 +- 8 files changed, 242 insertions(+), 82 deletions(-) diff --git a/src/a2a/client/transports/http_helpers.py b/src/a2a/client/transports/http_helpers.py index eca386bd4..0a73ed83c 100644 --- a/src/a2a/client/transports/http_helpers.py +++ b/src/a2a/client/transports/http_helpers.py @@ -12,6 +12,10 @@ from a2a.client.errors import A2AClientError, A2AClientTimeoutError +def _default_sse_error_handler(sse_data: str) -> NoReturn: + raise A2AClientError(f'SSE stream error event received: {sse_data}') + + @contextmanager def handle_http_exceptions( status_error_handler: Callable[[httpx.HTTPStatusError], NoReturn] @@ -71,9 +75,22 @@ async def send_http_stream_request( url: str, status_error_handler: Callable[[httpx.HTTPStatusError], NoReturn] | None = None, + sse_error_handler: Callable[[str], NoReturn] = _default_sse_error_handler, **kwargs: Any, ) -> AsyncGenerator[str]: - """Sends a streaming HTTP request, yielding SSE data strings and handling exceptions.""" + """Sends a streaming HTTP request, yielding SSE data strings and handling exceptions. + + Args: + httpx_client: The async HTTP client. + method: The HTTP method (e.g. 'POST', 'GET'). + url: The URL to send the request to. + status_error_handler: Handler for HTTP status errors. Should raise an + appropriate domain-specific exception. + sse_error_handler: Handler for SSE error events. Called with the + raw SSE data string when an ``event: error`` SSE event is received. + Should raise an appropriate domain-specific exception. + **kwargs: Additional keyword arguments forwarded to ``aconnect_sse``. + """ with handle_http_exceptions(status_error_handler): async with _SSEEventSource( httpx_client, method, url, **kwargs @@ -97,6 +114,8 @@ async def send_http_stream_request( async for sse in event_source.aiter_sse(): if not sse.data: continue + if sse.event == 'error': + sse_error_handler(sse.data) yield sse.data diff --git a/src/a2a/client/transports/jsonrpc.py b/src/a2a/client/transports/jsonrpc.py index eca6c4897..252ea439d 100644 --- a/src/a2a/client/transports/jsonrpc.py +++ b/src/a2a/client/transports/jsonrpc.py @@ -1,7 +1,7 @@ import logging from collections.abc import AsyncGenerator -from typing import Any +from typing import Any, NoReturn from uuid import uuid4 import httpx @@ -350,6 +350,7 @@ async def _send_stream_request( 'POST', self.url, None, + self._handle_sse_error, json=rpc_request_payload, **http_kwargs, ): @@ -360,3 +361,10 @@ async def _send_stream_request( json_rpc_response.result, StreamResponse() ) yield response + + def _handle_sse_error(self, sse_data: str) -> NoReturn: + """Handles SSE error events by parsing JSON-RPC error payload and raising the appropriate domain error.""" + json_rpc_response = JSONRPC20Response.from_json(sse_data) + if json_rpc_response.error: + raise self._create_jsonrpc_error(json_rpc_response.error) + raise A2AClientError(f'SSE stream error: {sse_data}') diff --git a/src/a2a/client/transports/rest.py b/src/a2a/client/transports/rest.py index ed40d31c7..3dfe95927 100644 --- a/src/a2a/client/transports/rest.py +++ b/src/a2a/client/transports/rest.py @@ -41,6 +41,47 @@ logger = logging.getLogger(__name__) +def _parse_rest_error( + error_payload: dict[str, Any], + fallback_message: str, +) -> Exception | None: + """Parses a REST error payload and returns the appropriate A2AError. + + Args: + error_payload: The parsed JSON error payload. + fallback_message: Message to use if the payload has no ``message``. + + Returns: + The mapped A2AError if a known reason was found, otherwise ``None``. + """ + error_data = error_payload.get('error', {}) + message = error_data.get('message', fallback_message) + details = error_data.get('details', []) + if not isinstance(details, list): + return None + + # The `details` array can contain multiple different error objects. + # We extract the first `ErrorInfo` object because it contains the + # specific `reason` code needed to map this back to a Python A2AError. + for d in details: + if ( + isinstance(d, dict) + and d.get('@type') == 'type.googleapis.com/google.rpc.ErrorInfo' + ): + reason = d.get('reason') + metadata = d.get('metadata') or {} + if isinstance(reason, str): + exception_cls = A2A_REASON_TO_ERROR.get(reason) + if exception_cls: + exc = exception_cls(message) + if metadata: + exc.data = metadata + return exc + break + + return None + + @trace_class(kind=SpanKind.CLIENT) class RestTransport(ClientTransport): """A REST transport for the A2A client.""" @@ -294,39 +335,12 @@ def _handle_http_error(self, e: httpx.HTTPStatusError) -> NoReturn: """Handles HTTP status errors and raises the appropriate A2AError.""" try: error_payload = e.response.json() - error_data = error_payload.get('error', {}) - - message = error_data.get('message', str(e)) - details = error_data.get('details', []) - if not isinstance(details, list): - details = [] - - # The `details` array can contain multiple different error objects. - # We extract the first `ErrorInfo` object because it contains the - # specific `reason` code needed to map this back to a Python A2AError. - error_info = {} - for d in details: - if ( - isinstance(d, dict) - and d.get('@type') - == 'type.googleapis.com/google.rpc.ErrorInfo' - ): - error_info = d - break - reason = error_info.get('reason') - metadata = error_info.get('metadata') or {} - - if isinstance(reason, str): - exception_cls = A2A_REASON_TO_ERROR.get(reason) - if exception_cls: - exc = exception_cls(message) - if metadata: - exc.data = metadata - raise exc from e + mapped = _parse_rest_error(error_payload, str(e)) + if mapped: + raise mapped from e except (json.JSONDecodeError, ValueError): pass - # Fallback mappings for status codes if 'type' is missing or unknown status_code = e.response.status_code if status_code == httpx.codes.NOT_FOUND: raise MethodNotFoundError( @@ -335,6 +349,14 @@ def _handle_http_error(self, e: httpx.HTTPStatusError) -> NoReturn: raise A2AClientError(f'HTTP Error {status_code}: {e}') from e + def _handle_sse_error(self, sse_data: str) -> NoReturn: + """Handles SSE error events by parsing the REST error payload and raising the appropriate A2AError.""" + error_payload = json.loads(sse_data) + mapped = _parse_rest_error(error_payload, sse_data) + if mapped: + raise mapped + raise A2AClientError(sse_data) + async def _send_stream_request( self, method: str, @@ -352,6 +374,7 @@ async def _send_stream_request( method, f'{self.url}{path}', self._handle_http_error, + self._handle_sse_error, json=json, **http_kwargs, ): diff --git a/src/a2a/server/routes/jsonrpc_dispatcher.py b/src/a2a/server/routes/jsonrpc_dispatcher.py index de20610f6..d9ea4ff1a 100644 --- a/src/a2a/server/routes/jsonrpc_dispatcher.py +++ b/src/a2a/server/routes/jsonrpc_dispatcher.py @@ -565,8 +565,30 @@ def _create_response( async def event_generator( stream: AsyncGenerator[dict[str, Any]], ) -> AsyncGenerator[dict[str, str]]: - async for item in stream: - yield {'data': json.dumps(item)} + try: + async for item in stream: + event: dict[str, str] = { + 'data': json.dumps(item), + } + if 'error' in item: + event['event'] = 'error' + yield event + except Exception as e: + logger.exception( + 'Unhandled error during JSON-RPC SSE stream' + ) + rpc_error: A2AError | JSONRPCError = ( + e + if isinstance(e, A2AError | JSONRPCError) + else InternalError(message=str(e)) + ) + error_response = build_error_response( + context.state.get('request_id'), rpc_error + ) + yield { + 'event': 'error', + 'data': json.dumps(error_response), + } return EventSourceResponse( event_generator(handler_result), headers=headers diff --git a/src/a2a/server/routes/rest_dispatcher.py b/src/a2a/server/routes/rest_dispatcher.py index fa9a12af8..8af384893 100644 --- a/src/a2a/server/routes/rest_dispatcher.py +++ b/src/a2a/server/routes/rest_dispatcher.py @@ -20,6 +20,7 @@ ) from a2a.utils import constants, proto_utils from a2a.utils.error_handlers import ( + build_rest_error_payload, rest_error_handler, rest_stream_error_handler, ) @@ -32,6 +33,7 @@ if TYPE_CHECKING: + from sse_starlette.event import ServerSentEvent from sse_starlette.sse import EventSourceResponse from starlette.requests import Request from starlette.responses import JSONResponse, Response @@ -39,6 +41,7 @@ _package_starlette_installed = True else: try: + from sse_starlette.event import ServerSentEvent from sse_starlette.sse import EventSourceResponse from starlette.requests import Request from starlette.responses import JSONResponse, Response @@ -46,6 +49,7 @@ _package_starlette_installed = True except ImportError: EventSourceResponse = Any + ServerSentEvent = Any Request = Any JSONResponse = Any Response = Any @@ -135,10 +139,17 @@ async def _handle_streaming( except StopAsyncIteration: return EventSourceResponse(iter([])) - async def event_generator() -> AsyncIterator[str]: - yield json.dumps(first_item) - async for item in stream: - yield json.dumps(item) + async def event_generator() -> AsyncIterator[ServerSentEvent]: + yield ServerSentEvent(data=json.dumps(first_item)) + try: + async for item in stream: + yield ServerSentEvent(data=json.dumps(item)) + except Exception as e: + logger.exception('Error during REST SSE stream') + yield ServerSentEvent( + data=json.dumps(build_rest_error_payload(e)), + event='error', + ) return EventSourceResponse(event_generator()) diff --git a/src/a2a/utils/error_handlers.py b/src/a2a/utils/error_handlers.py index d21a9e24c..ea544d79d 100644 --- a/src/a2a/utils/error_handlers.py +++ b/src/a2a/utils/error_handlers.py @@ -54,16 +54,43 @@ def _build_error_payload( return {'error': payload} -def _create_error_response(error: Exception) -> Response: - """Helper function to create a JSONResponse for an error.""" +def build_rest_error_payload(error: Exception) -> dict[str, Any]: + """Build a REST error payload dict from an exception. + + Returns: + A dict with the error payload in the standard REST error format. + """ if isinstance(error, A2AError): mapping = A2A_REST_ERROR_MAPPING.get( type(error), RestErrorMap(500, 'INTERNAL', 'INTERNAL_ERROR') ) - http_code = mapping.http_code - grpc_status = mapping.grpc_status - reason = mapping.reason + # SECURITY WARNING: Data attached to A2AError.data is serialized unaltered and exposed publicly to the client in the REST API response. + metadata = getattr(error, 'data', None) or {} + return _build_error_payload( + code=mapping.http_code, + status=mapping.grpc_status, + message=getattr(error, 'message', str(error)), + reason=mapping.reason, + metadata=metadata, + ) + if isinstance(error, ParseError): + return _build_error_payload( + code=400, + status='INVALID_ARGUMENT', + message=str(error), + reason='INVALID_REQUEST', + metadata={}, + ) + return _build_error_payload( + code=500, + status='INTERNAL', + message='unknown exception', + ) + +def _create_error_response(error: Exception) -> Response: + """Helper function to create a JSONResponse for an error.""" + if isinstance(error, A2AError): log_level = ( logging.ERROR if isinstance(error, InternalError) @@ -76,42 +103,17 @@ def _create_error_response(error: Exception) -> Response: getattr(error, 'message', str(error)), f', Data={error.data}' if error.data else '', ) - - # SECURITY WARNING: Data attached to A2AError.data is serialized unaltered and exposed publicly to the client in the REST API response. - metadata = getattr(error, 'data', None) or {} - - return JSONResponse( - content=_build_error_payload( - code=http_code, - status=grpc_status, - message=getattr(error, 'message', str(error)), - reason=reason, - metadata=metadata, - ), - status_code=http_code, - media_type='application/json', - ) - if isinstance(error, ParseError): + elif isinstance(error, ParseError): logger.warning('Parse error: %s', str(error)) - return JSONResponse( - content=_build_error_payload( - code=400, - status='INVALID_ARGUMENT', - message=str(error), - reason='INVALID_REQUEST', - metadata={}, - ), - status_code=400, - media_type='application/json', - ) - logger.exception('Unknown error occurred') + else: + logger.exception('Unknown error occurred') + + payload = build_rest_error_payload(error) + # Extract HTTP status code from the payload + http_code = payload.get('error', {}).get('code', 500) return JSONResponse( - content=_build_error_payload( - code=500, - status='INTERNAL', - message='unknown exception', - ), - status_code=500, + content=payload, + status_code=http_code, media_type='application/json', ) diff --git a/tests/integration/test_client_server_integration.py b/tests/integration/test_client_server_integration.py index 36565205a..c7fa29ea5 100644 --- a/tests/integration/test_client_server_integration.py +++ b/tests/integration/test_client_server_integration.py @@ -1187,3 +1187,79 @@ async def test_validate_streaming_disabled( pass await transport.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'error_cls', + [ + TaskNotFoundError, + TaskNotCancelableError, + PushNotificationNotSupportedError, + UnsupportedOperationError, + ContentTypeNotSupportedError, + InvalidAgentResponseError, + ExtendedAgentCardNotConfiguredError, + ExtensionSupportRequiredError, + VersionNotSupportedError, + ], +) +@pytest.mark.parametrize( + 'handler_attr, client_method, request_params', + [ + pytest.param( + 'on_message_send_stream', + 'send_message', + SendMessageRequest( + message=Message( + role=Role.ROLE_USER, + message_id='msg-midstream-test', + parts=[Part(text='Hello, mid-stream test!')], + ) + ), + id='stream', + ), + pytest.param( + 'on_subscribe_to_task', + 'subscribe', + SubscribeToTaskRequest(id='some-id'), + id='subscribe', + ), + ], +) +async def test_client_handles_mid_stream_a2a_errors( + transport_setups, + error_cls, + handler_attr, + client_method, + request_params, +) -> None: + """Integration test for mid-stream errors sent as SSE error events. + + The handler yields one event successfully, then raises an A2AError. + The client must receive the first event and then get the error as the + exact error_cls exception. This mirrors test_client_handles_a2a_errors_streaming + but verifies the error occurs *after* the stream has started producing events. + """ + client = transport_setups.client + handler = transport_setups.handler + + async def mock_generator(*args, **kwargs): + yield TASK_FROM_STREAM + raise error_cls('Mid-stream error') + + getattr(handler, handler_attr).side_effect = mock_generator + + received_events = [] + with pytest.raises(error_cls) as exc_info: + async for event in getattr(client, client_method)( + request=request_params + ): + received_events.append(event) # noqa: PERF401 + + assert 'Mid-stream error' in str(exc_info.value) + assert len(received_events) == 1 + + getattr(handler, handler_attr).side_effect = None + + await client.close() diff --git a/tests/server/routes/test_rest_dispatcher.py b/tests/server/routes/test_rest_dispatcher.py index 5284db617..a1d2c27cd 100644 --- a/tests/server/routes/test_rest_dispatcher.py +++ b/tests/server/routes/test_rest_dispatcher.py @@ -264,9 +264,8 @@ async def test_on_message_send_stream_success( chunks.append(chunk) assert len(chunks) == 2 - # sse-starlette yields strings or bytes formatted as Server-Sent Events - assert 'chunk1' in str(chunks[0]) - assert 'chunk2' in str(chunks[1]) + assert 'chunk1' in chunks[0].data + assert 'chunk2' in chunks[1].data async def test_on_subscribe_to_task_success(self, rest_dispatcher_instance): req = make_mock_request(method='GET', path_params={'id': 'test_task'}) @@ -279,8 +278,8 @@ async def test_on_subscribe_to_task_success(self, rest_dispatcher_instance): chunks.append(chunk) assert len(chunks) == 2 - assert 'chunk1' in str(chunks[0]) - assert 'chunk2' in str(chunks[1]) + assert 'chunk1' in chunks[0].data + assert 'chunk2' in chunks[1].data async def test_on_message_send_stream_handler_error(self, mock_handler): from a2a.utils.errors import UnsupportedOperationError From f0e1d74802e78a4e9f4c22cbc85db104137e0cd2 Mon Sep 17 00:00:00 2001 From: Bartek Wolowiec Date: Thu, 9 Apr 2026 12:18:10 +0200 Subject: [PATCH 15/67] feat: EventQueue is now a simple interface with single enqueue_event method. (#944) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes #869 🦕 --- src/a2a/server/agent_execution/active_task.py | 21 +++--- src/a2a/server/events/event_consumer.py | 4 +- src/a2a/server/events/event_queue.py | 69 +---------------- src/a2a/server/events/event_queue_v2.py | 75 +++++++++++++++---- .../server/events/in_memory_queue_manager.py | 18 ++--- src/a2a/server/events/queue_manager.py | 10 +-- .../default_request_handler.py | 7 +- tests/server/events/test_event_consumer.py | 4 +- .../events/test_inmemory_queue_manager.py | 8 +- .../test_default_request_handler.py | 17 +++-- 10 files changed, 112 insertions(+), 121 deletions(-) diff --git a/src/a2a/server/agent_execution/active_task.py b/src/a2a/server/agent_execution/active_task.py index bf9e129a6..defdd5244 100644 --- a/src/a2a/server/agent_execution/active_task.py +++ b/src/a2a/server/agent_execution/active_task.py @@ -374,30 +374,33 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912 await self._task_manager.process(event) # Check for AUTH_REQUIRED or INPUT_REQUIRED or TERMINAL states - res = await self._task_manager.get_task() + new_task = await self._task_manager.get_task() + if new_task is None: + raise RuntimeError( + f'Task {self.task_id} not found' + ) is_interrupted = ( - res - and res.status.state + new_task.status.state in INTERRUPTED_TASK_STATES ) is_terminal = ( - res - and res.status.state in TERMINAL_TASK_STATES + new_task.status.state + in TERMINAL_TASK_STATES ) # If we hit a breakpoint or terminal state, lock in the result. - if (is_interrupted or is_terminal) and res: + if is_interrupted or is_terminal: logger.debug( 'Consumer[%s]: Setting first result as Task (state=%s)', self._task_id, - res.status.state, + new_task.status.state, ) if is_terminal: logger.debug( 'Consumer[%s]: Reached terminal state %s', self._task_id, - res.status.state if res else 'unknown', + new_task.status.state, ) if not self._is_finished.is_set(): async with self._lock: @@ -413,7 +416,7 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912 logger.debug( 'Consumer[%s]: Interrupted with state %s', self._task_id, - res.status.state if res else 'unknown', + new_task.status.state, ) if ( diff --git a/src/a2a/server/events/event_consumer.py b/src/a2a/server/events/event_consumer.py index a29394795..8414e2d17 100644 --- a/src/a2a/server/events/event_consumer.py +++ b/src/a2a/server/events/event_consumer.py @@ -5,7 +5,7 @@ from pydantic import ValidationError -from a2a.server.events.event_queue import Event, EventQueue, QueueShutDown +from a2a.server.events.event_queue import Event, EventQueueLegacy, QueueShutDown from a2a.types.a2a_pb2 import ( Message, Task, @@ -22,7 +22,7 @@ class EventConsumer: """Consumer to read events from the agent event queue.""" - def __init__(self, queue: EventQueue): + def __init__(self, queue: EventQueueLegacy): """Initializes the EventConsumer. Args: diff --git a/src/a2a/server/events/event_queue.py b/src/a2a/server/events/event_queue.py index 25598d15b..bb4d7b9b4 100644 --- a/src/a2a/server/events/event_queue.py +++ b/src/a2a/server/events/event_queue.py @@ -92,73 +92,6 @@ async def enqueue_event(self, event: Event) -> None: Only main queue can enqueue events. Child queues can only dequeue events. """ - @abstractmethod - async def dequeue_event(self) -> Event: - """Pulls an event from the queue.""" - - @abstractmethod - def task_done(self) -> None: - """Signals that a work on dequeued event is complete.""" - - @abstractmethod - async def tap( - self, max_queue_size: int = DEFAULT_MAX_QUEUE_SIZE - ) -> 'EventQueue': - """Creates a child queue that receives future events. - - Note: The tapped queue may receive some old events if the incoming event - queue is lagging behind and hasn't dispatched them yet. - """ - - @abstractmethod - async def close(self, immediate: bool = False) -> None: - """Closes the queue. - - For parent queue: it closes the main queue and all its child queues. - For child queue: it closes only child queue. - - It is safe to call it multiple times. - If immediate is True, the queue will be closed without waiting for all events to be processed. - If immediate is False, the queue will be closed after all events are processed (and confirmed with task_done() calls). - - WARNING: Closing the parent queue with immediate=False is a deadlock risk if there are unconsumed events - in any of the child sinks and the consumer has crashed without draining its queue. - It is highly recommended to wrap graceful shutdowns with a timeout, e.g., - `asyncio.wait_for(queue.close(immediate=False), timeout=...)`. - """ - - @abstractmethod - def is_closed(self) -> bool: - """[DEPRECATED] Checks if the queue is closed. - - NOTE: Relying on this for enqueue logic introduces race conditions. - It is maintained primarily for backwards compatibility, workarounds for - Python 3.10/3.12 async queues in consumers, and for the test suite. - """ - - @abstractmethod - async def __aenter__(self) -> Self: - """Enters the async context manager, returning the queue itself. - - WARNING: See `__aexit__` for important deadlock risks associated with - exiting this context manager if unconsumed events remain. - """ - - @abstractmethod - async def __aexit__( - self, - exc_type: type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: TracebackType | None, - ) -> None: - """Exits the async context manager, ensuring close() is called. - - WARNING: The context manager calls `close(immediate=False)` by default. - If a consumer exits the `async with` block early (e.g., due to an exception - or an explicit `break`) while unconsumed events remain in the queue, - `__aexit__` will deadlock waiting for `task_done()` to be called on those events. - """ - @trace_class(kind=SpanKind.SERVER) class EventQueueLegacy(EventQueue): @@ -180,7 +113,7 @@ def __init__(self, max_queue_size: int = DEFAULT_MAX_QUEUE_SIZE) -> None: self._queue: AsyncQueue[Event] = _create_async_queue( maxsize=max_queue_size ) - self._children: list[EventQueue] = [] + self._children: list[EventQueueLegacy] = [] self._is_closed = False self._lock = asyncio.Lock() logger.debug('EventQueue initialized.') diff --git a/src/a2a/server/events/event_queue_v2.py b/src/a2a/server/events/event_queue_v2.py index de12c21d1..224cb8e56 100644 --- a/src/a2a/server/events/event_queue_v2.py +++ b/src/a2a/server/events/event_queue_v2.py @@ -193,19 +193,29 @@ async def enqueue_event(self, event: Event) -> None: return async def dequeue_event(self) -> Event: - """Dequeues an event from the default internal sink queue.""" + """Pulls an event from the default internal sink queue.""" if self._default_sink is None: raise ValueError('No default sink available.') return await self._default_sink.dequeue_event() def task_done(self) -> None: - """Signals that a formerly enqueued task is complete via the default internal sink queue.""" + """Signals that a work on dequeued event is complete via the default internal sink queue.""" if self._default_sink is None: raise ValueError('No default sink available.') self._default_sink.task_done() async def close(self, immediate: bool = False) -> None: - """Closes the queue for future push events and also closes all child sinks.""" + """Closes the queue and all its child sinks. + + It is safe to call it multiple times. + If immediate is True, the queue will be closed without waiting for all events to be processed. + If immediate is False, the queue will be closed after all events are processed (and confirmed with task_done() calls). + + WARNING: Closing the parent queue with immediate=False is a deadlock risk if there are unconsumed events + in any of the child sinks and the consumer has crashed without draining its queue. + It is highly recommended to wrap graceful shutdowns with a timeout, e.g., + `asyncio.wait_for(queue.close(immediate=False), timeout=...)`. + """ logger.debug('Closing EventQueueSource: immediate=%s', immediate) async with self._lock: # No more tap() allowed. @@ -230,7 +240,12 @@ async def close(self, immediate: bool = False) -> None: ) def is_closed(self) -> bool: - """Checks if the queue is closed.""" + """[DEPRECATED] Checks if the queue is closed. + + NOTE: Relying on this for enqueue logic introduces race conditions. + It is maintained primarily for backwards compatibility, workarounds for + Python 3.10/3.12 async queues in consumers, and for the test suite. + """ return self._is_closed async def test_only_join_incoming_queue(self) -> None: @@ -238,7 +253,11 @@ async def test_only_join_incoming_queue(self) -> None: await self._join_incoming_queue() async def __aenter__(self) -> Self: - """Enters the async context manager, returning the queue itself.""" + """Enters the async context manager, returning the queue itself. + + WARNING: See `__aexit__` for important deadlock risks associated with + exiting this context manager if unconsumed events remain. + """ return self async def __aexit__( @@ -247,7 +266,13 @@ async def __aexit__( exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> None: - """Exits the async context manager, ensuring close() is called.""" + """Exits the async context manager, ensuring close() is called. + + WARNING: The context manager calls `close(immediate=False)` by default. + If a consumer exits the `async with` block early (e.g., due to an exception + or an explicit `break`) while unconsumed events remain in the queue, + `__aexit__` will deadlock waiting for `task_done()` to be called on those events. + """ await self.close() @@ -290,26 +315,35 @@ async def enqueue_event(self, event: Event) -> None: raise RuntimeError('Cannot enqueue to a sink-only queue') async def dequeue_event(self) -> Event: - """Dequeues an event from the sink queue.""" + """Pulls an event from the sink queue.""" logger.debug('Attempting to dequeue event (waiting).') event = await self._queue.get() logger.debug('Dequeued event: %s', event) return event def task_done(self) -> None: - """Signals that a formerly enqueued task is complete in this sink queue.""" + """Signals that a work on dequeued event is complete in this sink queue.""" logger.debug('Marking task as done in EventQueueSink.') self._queue.task_done() async def tap( self, max_queue_size: int = DEFAULT_MAX_QUEUE_SIZE ) -> 'EventQueueSink': - """Taps the event queue to create a new child queue that receives future events.""" + """Creates a child queue that receives future events. + + Note: The tapped queue may receive some old events if the incoming event + queue is lagging behind and hasn't dispatched them yet. + """ # Delegate tap to the parent source so all sinks are flat under the source return await self._parent.tap(max_queue_size=max_queue_size) async def close(self, immediate: bool = False) -> None: - """Closes the child sink queue.""" + """Closes the child sink queue. + + It is safe to call it multiple times. + If immediate is True, the queue will be closed without waiting for all events to be processed. + If immediate is False, the queue will be closed after all events are processed (and confirmed with task_done() calls). + """ logger.debug('Closing EventQueueSink.') async with self._lock: self._is_closed = True @@ -323,11 +357,20 @@ async def close(self, immediate: bool = False) -> None: await self._queue.join() def is_closed(self) -> bool: - """Checks if the sink queue is closed.""" + """[DEPRECATED] Checks if the queue is closed. + + NOTE: Relying on this for enqueue logic introduces race conditions. + It is maintained primarily for backwards compatibility, workarounds for + Python 3.10/3.12 async queues in consumers, and for the test suite. + """ return self._is_closed async def __aenter__(self) -> Self: - """Enters the async context manager, returning the queue itself.""" + """Enters the async context manager, returning the queue itself. + + WARNING: See `__aexit__` for important deadlock risks associated with + exiting this context manager if unconsumed events remain. + """ return self async def __aexit__( @@ -336,5 +379,11 @@ async def __aexit__( exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> None: - """Exits the async context manager, ensuring close() is called.""" + """Exits the async context manager, ensuring close() is called. + + WARNING: The context manager calls `close(immediate=False)` by default. + If a consumer exits the `async with` block early (e.g., due to an exception + or an explicit `break`) while unconsumed events remain in the queue, + `__aexit__` will deadlock waiting for `task_done()` to be called on those events. + """ await self.close() diff --git a/src/a2a/server/events/in_memory_queue_manager.py b/src/a2a/server/events/in_memory_queue_manager.py index ddff52419..0beb354f9 100644 --- a/src/a2a/server/events/in_memory_queue_manager.py +++ b/src/a2a/server/events/in_memory_queue_manager.py @@ -1,6 +1,6 @@ import asyncio -from a2a.server.events.event_queue import EventQueue, EventQueueLegacy +from a2a.server.events.event_queue import EventQueueLegacy from a2a.server.events.queue_manager import ( NoTaskQueue, QueueManager, @@ -23,10 +23,10 @@ class InMemoryQueueManager(QueueManager): def __init__(self) -> None: """Initializes the InMemoryQueueManager.""" - self._task_queue: dict[str, EventQueue] = {} + self._task_queue: dict[str, EventQueueLegacy] = {} self._lock = asyncio.Lock() - async def add(self, task_id: str, queue: EventQueue) -> None: + async def add(self, task_id: str, queue: EventQueueLegacy) -> None: """Adds a new event queue for a task ID. Raises: @@ -37,22 +37,22 @@ async def add(self, task_id: str, queue: EventQueue) -> None: raise TaskQueueExists self._task_queue[task_id] = queue - async def get(self, task_id: str) -> EventQueue | None: + async def get(self, task_id: str) -> EventQueueLegacy | None: """Retrieves the event queue for a task ID. Returns: - The `EventQueue` instance for the `task_id`, or `None` if not found. + The `EventQueueLegacy` instance for the `task_id`, or `None` if not found. """ async with self._lock: if task_id not in self._task_queue: return None return self._task_queue[task_id] - async def tap(self, task_id: str) -> EventQueue | None: + async def tap(self, task_id: str) -> EventQueueLegacy | None: """Taps the event queue for a task ID to create a child queue. Returns: - A new child `EventQueue` instance, or `None` if the task ID is not found. + A new child `EventQueueLegacy` instance, or `None` if the task ID is not found. """ async with self._lock: if task_id not in self._task_queue: @@ -71,11 +71,11 @@ async def close(self, task_id: str) -> None: queue = self._task_queue.pop(task_id) await queue.close() - async def create_or_tap(self, task_id: str) -> EventQueue: + async def create_or_tap(self, task_id: str) -> EventQueueLegacy: """Creates a new event queue for a task ID if one doesn't exist, otherwise taps the existing one. Returns: - A new or child `EventQueue` instance for the `task_id`. + A new or child `EventQueueLegacy` instance for the `task_id`. """ async with self._lock: if task_id not in self._task_queue: diff --git a/src/a2a/server/events/queue_manager.py b/src/a2a/server/events/queue_manager.py index ed69aae68..b3ec204a5 100644 --- a/src/a2a/server/events/queue_manager.py +++ b/src/a2a/server/events/queue_manager.py @@ -1,21 +1,21 @@ from abc import ABC, abstractmethod -from a2a.server.events.event_queue import EventQueue +from a2a.server.events.event_queue import EventQueueLegacy class QueueManager(ABC): """Interface for managing the event queue lifecycles per task.""" @abstractmethod - async def add(self, task_id: str, queue: EventQueue) -> None: + async def add(self, task_id: str, queue: EventQueueLegacy) -> None: """Adds a new event queue associated with a task ID.""" @abstractmethod - async def get(self, task_id: str) -> EventQueue | None: + async def get(self, task_id: str) -> EventQueueLegacy | None: """Retrieves the event queue for a task ID.""" @abstractmethod - async def tap(self, task_id: str) -> EventQueue | None: + async def tap(self, task_id: str) -> EventQueueLegacy | None: """Creates a child event queue (tap) for an existing task ID.""" @abstractmethod @@ -23,7 +23,7 @@ async def close(self, task_id: str) -> None: """Closes and removes the event queue for a task ID.""" @abstractmethod - async def create_or_tap(self, task_id: str) -> EventQueue: + async def create_or_tap(self, task_id: str) -> EventQueueLegacy: """Creates a queue if one doesn't exist, otherwise taps the existing one.""" diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index e6b992250..fea5184d6 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -14,7 +14,6 @@ from a2a.server.events import ( Event, EventConsumer, - EventQueue, EventQueueLegacy, InMemoryQueueManager, QueueManager, @@ -241,7 +240,7 @@ async def on_cancel_task( return result async def _run_event_stream( - self, request: RequestContext, queue: EventQueue + self, request: RequestContext, queue: EventQueueLegacy ) -> None: """Runs the agent's `execute` method and closes the queue afterwards. @@ -256,7 +255,9 @@ async def _setup_message_execution( self, params: SendMessageRequest, context: ServerCallContext, - ) -> tuple[TaskManager, str, EventQueue, ResultAggregator, asyncio.Task]: + ) -> tuple[ + TaskManager, str, EventQueueLegacy, ResultAggregator, asyncio.Task + ]: """Common setup logic for both streaming and non-streaming message handling. Returns: diff --git a/tests/server/events/test_event_consumer.py b/tests/server/events/test_event_consumer.py index cfd315265..d7d20768b 100644 --- a/tests/server/events/test_event_consumer.py +++ b/tests/server/events/test_event_consumer.py @@ -49,11 +49,11 @@ def create_sample_task( @pytest.fixture def mock_event_queue(): - return AsyncMock(spec=EventQueue) + return AsyncMock(spec=EventQueueLegacy) @pytest.fixture -def event_consumer(mock_event_queue: EventQueue): +def event_consumer(mock_event_queue: EventQueueLegacy): return EventConsumer(queue=mock_event_queue) diff --git a/tests/server/events/test_inmemory_queue_manager.py b/tests/server/events/test_inmemory_queue_manager.py index b51334a95..9716b13bf 100644 --- a/tests/server/events/test_inmemory_queue_manager.py +++ b/tests/server/events/test_inmemory_queue_manager.py @@ -5,7 +5,7 @@ import pytest from a2a.server.events import InMemoryQueueManager -from a2a.server.events.event_queue import EventQueue +from a2a.server.events.event_queue import EventQueueLegacy from a2a.server.events.queue_manager import ( NoTaskQueue, TaskQueueExists, @@ -21,7 +21,7 @@ def queue_manager(self) -> InMemoryQueueManager: @pytest.fixture def event_queue(self) -> MagicMock: """Fixture to create a mock EventQueue.""" - queue = MagicMock(spec=EventQueue) + queue = MagicMock(spec=EventQueueLegacy) # Mock the tap method to return itself queue.tap.return_value = queue @@ -119,7 +119,7 @@ async def test_create_or_tap_new_queue( task_id = 'test_task_id' result = await queue_manager.create_or_tap(task_id) - assert isinstance(result, EventQueue) + assert isinstance(result, EventQueueLegacy) assert queue_manager._task_queue[task_id] == result @pytest.mark.asyncio @@ -142,7 +142,7 @@ async def test_concurrency( """Test concurrent access to the queue manager.""" async def add_task(task_id): - queue = EventQueue() + queue = EventQueueLegacy() await queue_manager.add(task_id, queue) return task_id diff --git a/tests/server/request_handlers/test_default_request_handler.py b/tests/server/request_handlers/test_default_request_handler.py index 59e965116..294f5aefe 100644 --- a/tests/server/request_handlers/test_default_request_handler.py +++ b/tests/server/request_handlers/test_default_request_handler.py @@ -22,7 +22,12 @@ SimpleRequestContextBuilder, ) from a2a.server.context import ServerCallContext -from a2a.server.events import EventQueue, InMemoryQueueManager, QueueManager +from a2a.server.events import ( + EventQueue, + EventQueueLegacy, + InMemoryQueueManager, + QueueManager, +) from a2a.server.request_handlers import ( LegacyRequestHandler as DefaultRequestHandler, ) @@ -380,7 +385,7 @@ async def test_on_cancel_task_cancels_running_agent(agent_card): mock_task_store.get.return_value = sample_task mock_queue_manager = AsyncMock(spec=QueueManager) - mock_event_queue = AsyncMock(spec=EventQueue) + mock_event_queue = AsyncMock(spec=EventQueueLegacy) mock_queue_manager.tap.return_value = mock_event_queue mock_agent_executor = AsyncMock(spec=AgentExecutor) @@ -425,7 +430,7 @@ async def test_on_cancel_task_completes_during_cancellation(agent_card): mock_task_store.get.return_value = sample_task mock_queue_manager = AsyncMock(spec=QueueManager) - mock_event_queue = AsyncMock(spec=EventQueue) + mock_event_queue = AsyncMock(spec=EventQueueLegacy) mock_queue_manager.tap.return_value = mock_event_queue mock_agent_executor = AsyncMock(spec=AgentExecutor) @@ -472,7 +477,7 @@ async def test_on_cancel_task_invalid_result_type(agent_card): mock_task_store.get.return_value = sample_task mock_queue_manager = AsyncMock(spec=QueueManager) - mock_event_queue = AsyncMock(spec=EventQueue) + mock_event_queue = AsyncMock(spec=EventQueueLegacy) mock_queue_manager.tap.return_value = mock_event_queue mock_agent_executor = AsyncMock(spec=AgentExecutor) @@ -1452,7 +1457,7 @@ async def test_on_message_send_stream_client_disconnect_triggers_background_clea mock_request_context_builder.build.return_value = mock_request_context # Queue used by _run_event_stream; must support close() - mock_queue = AsyncMock(spec=EventQueue) + mock_queue = AsyncMock(spec=EventQueueLegacy) mock_queue_manager.create_or_tap.return_value = mock_queue request_handler = DefaultRequestHandler( @@ -1683,7 +1688,7 @@ async def test_background_cleanup_task_is_tracked_and_cleared(agent_card): mock_request_context.context_id = context_id mock_request_context_builder.build.return_value = mock_request_context - mock_queue = AsyncMock(spec=EventQueue) + mock_queue = AsyncMock(spec=EventQueueLegacy) mock_queue_manager.create_or_tap.return_value = mock_queue request_handler = DefaultRequestHandler( From 39e32e915e3229d4cd4eeb596af502df519731ca Mon Sep 17 00:00:00 2001 From: kdziedzic70 Date: Thu, 9 Apr 2026 14:49:54 +0200 Subject: [PATCH 16/67] build: fixes local runnability of itk tests and adds readme on how to setup (#946) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Description PR fixes ability to run itk tests locally and adds readme with proper instructions on how to set up the environment for that. Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [ ] Follow the [`CONTRIBUTING` Guide](https://github.com/a2aproject/a2a-python/blob/main/CONTRIBUTING.md). - [ ] Make your Pull Request title in the specification. - Important Prefixes for [release-please](https://github.com/googleapis/release-please): - `fix:` which represents bug fixes, and correlates to a [SemVer](https://semver.org/) patch. - `feat:` represents a new feature, and correlates to a SemVer minor. - `feat!:`, or `fix!:`, `refactor!:`, etc., which represent a breaking change (indicated by the `!`) and will result in a SemVer major. - [ ] Ensure the tests and linter pass (Run `bash scripts/format.sh` from the repository root to format) - [ ] Appropriate docs were updated (if necessary) Fixes # 🦕 Co-authored-by: Krzysztof Dziedzic Co-authored-by: Ivan Shymko --- .github/actions/spelling/allow.txt | 5 +++ itk/README.md | 54 ++++++++++++++++++++++++++++++ itk/run_itk.sh | 5 +-- 3 files changed, 62 insertions(+), 2 deletions(-) create mode 100644 itk/README.md diff --git a/.github/actions/spelling/allow.txt b/.github/actions/spelling/allow.txt index b3657f2b8..900974409 100644 --- a/.github/actions/spelling/allow.txt +++ b/.github/actions/spelling/allow.txt @@ -138,3 +138,8 @@ TResponse typ typeerror vulnz +Podman +podman +UIDs +subuids +subgids diff --git a/itk/README.md b/itk/README.md new file mode 100644 index 000000000..63ec68fad --- /dev/null +++ b/itk/README.md @@ -0,0 +1,54 @@ +# Running ITK Tests Locally + +This directory contains scripts to run Integration Test Kit (ITK) tests locally using Podman. + +## Prerequisites + +### 1. Install Podman + +Run the following commands to install Podman and its components: + +```bash +sudo apt update && sudo apt install -y podman podman-docker podman-compose +``` + +### 2. Configure SubUIDs/SubGIDs + +For rootless Podman to function correctly, you need to ensure subuids and subgids are configured for your user. + +If they are not already configured, you can add them using (replace `$USER` with your username if needed): + +```bash +sudo usermod --add-subuids 100000-165535 --add-subgids 100000-165535 $USER +``` + +After adding subuids or if you encounter permission issues, run: + +```bash +podman system migrate +``` + +## Running Tests + +### 1. Set Environment Variable + +You must set the `A2A_SAMPLES_REVISION` environment variable to specify which revision of the `a2a-samples` repository to use for testing. This can be a branch name, tag, or commit hash. + +Example: +```bash +export A2A_SAMPLES_REVISION=itk-v.0.11-alpha +``` + +### 2. Execute Tests + +Run the test script from this directory: + +```bash +./run_itk.sh +``` + +The script will: +- Clone `a2a-samples` (if not already present). +- Checkout the specified revision. +- Build the ITK service Docker image. +- Run the tests and output results. diff --git a/itk/run_itk.sh b/itk/run_itk.sh index 908a5fbc5..80e96f9c2 100755 --- a/itk/run_itk.sh +++ b/itk/run_itk.sh @@ -70,8 +70,9 @@ docker run -d --name itk-service \ itk_service # 5.1. Fix dubious ownership for git (needed for uv-dynamic-versioning) -docker exec itk-service git config --global --add safe.directory /app/agents/repo -docker exec itk-service git config --global --add safe.directory /app/agents/repo/itk +docker exec -u root itk-service git config --system --add safe.directory /app/agents/repo +docker exec -u root itk-service git config --system --add safe.directory /app/agents/repo/itk +docker exec -u root itk-service git config --system core.multiPackIndex false # 6. Verify service is up and send post request MAX_RETRIES=30 From be4c5ff17a2f58e20d5d333a5e8e7bfcaa58c6c0 Mon Sep 17 00:00:00 2001 From: Ivan Shymko Date: Thu, 9 Apr 2026 15:39:39 +0200 Subject: [PATCH 17/67] refactor(client)!: make ClientConfig.push_notification_config singular (#955) Send message API only contains one and only the first was used. --- src/a2a/client/base_client.py | 4 ++-- src/a2a/client/client.py | 6 ++---- .../test_default_push_notification_support.py | 12 +++++------- 3 files changed, 9 insertions(+), 13 deletions(-) diff --git a/src/a2a/client/base_client.py b/src/a2a/client/base_client.py index 53fd38cdb..763f23fb5 100644 --- a/src/a2a/client/base_client.py +++ b/src/a2a/client/base_client.py @@ -104,10 +104,10 @@ def _apply_client_config(self, request: SendMessageRequest) -> None: request.configuration.return_immediately |= self._config.polling if ( not request.configuration.HasField('task_push_notification_config') - and self._config.push_notification_configs + and self._config.push_notification_config ): request.configuration.task_push_notification_config.CopyFrom( - self._config.push_notification_configs[0] + self._config.push_notification_config ) if ( not request.configuration.accepted_output_modes diff --git a/src/a2a/client/client.py b/src/a2a/client/client.py index 1f94a4426..3fbf4f287 100644 --- a/src/a2a/client/client.py +++ b/src/a2a/client/client.py @@ -71,10 +71,8 @@ class ClientConfig: accepted_output_modes: list[str] = dataclasses.field(default_factory=list) """The set of accepted output modes for the client.""" - push_notification_configs: list[TaskPushNotificationConfig] = ( - dataclasses.field(default_factory=list) - ) - """Push notification configurations to use for every request.""" + push_notification_config: TaskPushNotificationConfig | None = None + """Push notification configuration to use for every request.""" class ClientCallContext(BaseModel): diff --git a/tests/e2e/push_notifications/test_default_push_notification_support.py b/tests/e2e/push_notifications/test_default_push_notification_support.py index 3d8d92481..35e4bbeb4 100644 --- a/tests/e2e/push_notifications/test_default_push_notification_support.py +++ b/tests/e2e/push_notifications/test_default_push_notification_support.py @@ -109,13 +109,11 @@ async def test_notification_triggering_with_in_message_config_e2e( a2a_client = ClientFactory( ClientConfig( supported_protocol_bindings=[TransportProtocol.HTTP_JSON], - push_notification_configs=[ - TaskPushNotificationConfig( - id='in-message-config', - url=f'{notifications_server}/notifications', - token=token, - ) - ], + push_notification_config=TaskPushNotificationConfig( + id='in-message-config', + url=f'{notifications_server}/notifications', + token=token, + ), ) ).create(minimal_agent_card(agent_server, [TransportProtocol.HTTP_JSON])) From a6695211d92d3dc476e18932c4a778a6ab1b9fbf Mon Sep 17 00:00:00 2001 From: Ivan Shymko Date: Thu, 9 Apr 2026 16:02:08 +0200 Subject: [PATCH 18/67] test: add more scenarios to test_end_to_end (#954) Based on https://a2a-protocol.org/latest/specification/#312-send-streaming-message: 1. `Message` based flow. 2. Emit `Task` as a first event. # TODO: switches to the old request handler as there are known issues in the new one With a new handler failures are caused by 1. `Task` events are not streamed 2. `return_immediately` + direct message - V2 returns a phantom `Task` before the executor produces its `Message` --- tests/integration/test_end_to_end.py | 147 ++++++++++++++++++++------- 1 file changed, 112 insertions(+), 35 deletions(-) diff --git a/tests/integration/test_end_to_end.py b/tests/integration/test_end_to_end.py index 1043a7d72..d5387a047 100644 --- a/tests/integration/test_end_to_end.py +++ b/tests/integration/test_end_to_end.py @@ -5,17 +5,17 @@ import httpx import pytest import pytest_asyncio +from starlette.applications import Starlette from a2a.client.base_client import BaseClient from a2a.client.client import ClientConfig from a2a.client.client_factory import ClientFactory from a2a.server.agent_execution import AgentExecutor, RequestContext -from a2a.server.routes.rest_routes import create_rest_routes -from starlette.applications import Starlette -from a2a.server.routes import create_jsonrpc_routes, create_agent_card_routes from a2a.server.events import EventQueue from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager -from a2a.server.request_handlers import DefaultRequestHandler, GrpcHandler +from a2a.server.request_handlers import GrpcHandler, LegacyRequestHandler +from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes +from a2a.server.routes.rest_routes import create_rest_routes from a2a.server.tasks import TaskUpdater from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore from a2a.types import ( @@ -37,7 +37,7 @@ TaskState, a2a_pb2_grpc, ) -from a2a.utils import TransportProtocol +from a2a.utils import TransportProtocol, new_task from a2a.utils.errors import InvalidParamsError @@ -69,7 +69,9 @@ def assert_events_match(events, expected_events): events, expected_events, strict=True ): assert event.HasField(expected_type) - if expected_type == 'status_update': + if expected_type == 'task': + assert event.task.status.state == expected_val + elif expected_type == 'status_update': assert event.status_update.status.state == expected_val elif expected_type == 'artifact_update': if expected_val is not None: @@ -83,26 +85,30 @@ def assert_events_match(events, expected_events): class MockAgentExecutor(AgentExecutor): async def execute(self, context: RequestContext, event_queue: EventQueue): - task_updater = TaskUpdater( - event_queue, - context.task_id, - context.context_id, - ) user_input = context.get_user_input() - is_input_required_resumption = ( - context.current_task is not None - and context.current_task.status.state - == TaskState.TASK_STATE_INPUT_REQUIRED - ) - - if not is_input_required_resumption: - await task_updater.update_status( - TaskState.TASK_STATE_SUBMITTED, - message=task_updater.new_agent_message( - [Part(text='task submitted')] - ), + # Direct message response (no task created). + if user_input.startswith('Message:'): + await event_queue.enqueue_event( + Message( + role=Role.ROLE_AGENT, + message_id='direct-reply-1', + parts=[Part(text=f'Direct reply to: {user_input}')], + ) ) + return + + # Task-based response. + task = context.current_task + if not task: + task = new_task(context.message) + await event_queue.enqueue_event(task) + + task_updater = TaskUpdater( + event_queue, + task.id, + task.context_id, + ) await task_updater.update_status( TaskState.TASK_STATE_WORKING, @@ -168,7 +174,8 @@ class ClientSetup(NamedTuple): @pytest.fixture def base_e2e_setup(agent_card): task_store = InMemoryTaskStore() - handler = DefaultRequestHandler( + # TODO(https://github.com/a2aproject/a2a-python/issues/869): Use DefaultRequestHandler once it's fixed + handler = LegacyRequestHandler( agent_executor=MockAgentExecutor(), task_store=task_store, agent_card=agent_card, @@ -328,7 +335,6 @@ async def test_end_to_end_send_message_blocking(transport_setups): response.task.history, [ (Role.ROLE_USER, 'Run dummy agent!'), - (Role.ROLE_AGENT, 'task submitted'), (Role.ROLE_AGENT, 'task working'), ], ) @@ -386,20 +392,19 @@ async def test_end_to_end_send_message_streaming(transport_setups): assert_events_match( events, [ - ('status_update', TaskState.TASK_STATE_SUBMITTED), + ('task', TaskState.TASK_STATE_SUBMITTED), ('status_update', TaskState.TASK_STATE_WORKING), ('artifact_update', [('test-artifact', 'artifact content')]), ('status_update', TaskState.TASK_STATE_COMPLETED), ], ) - task_id = events[0].status_update.task_id + task_id = events[0].task.id task = await client.get_task(request=GetTaskRequest(id=task_id)) assert_history_matches( task.history, [ (Role.ROLE_USER, 'Run dummy agent!'), - (Role.ROLE_AGENT, 'task submitted'), (Role.ROLE_AGENT, 'task working'), ], ) @@ -423,7 +428,7 @@ async def test_end_to_end_get_task(transport_setups): ) ] response = events[0] - task_id = response.status_update.task_id + task_id = response.task.id get_request = GetTaskRequest(id=task_id) retrieved_task = await client.get_task(request=get_request) @@ -438,7 +443,6 @@ async def test_end_to_end_get_task(transport_setups): retrieved_task.history, [ (Role.ROLE_USER, 'Test Get Task'), - (Role.ROLE_AGENT, 'task submitted'), (Role.ROLE_AGENT, 'task working'), ], ) @@ -465,7 +469,7 @@ async def test_end_to_end_list_tasks(transport_setups): ) ) ) - expected_task_ids.append(response.status_update.task_id) + expected_task_ids.append(response.task.id) list_request = ListTasksRequest(page_size=page_size) @@ -514,13 +518,13 @@ async def test_end_to_end_input_required(transport_setups): assert_events_match( events, [ - ('status_update', TaskState.TASK_STATE_SUBMITTED), + ('task', TaskState.TASK_STATE_SUBMITTED), ('status_update', TaskState.TASK_STATE_WORKING), ('status_update', TaskState.TASK_STATE_INPUT_REQUIRED), ], ) - task_id = events[0].status_update.task_id + task_id = events[0].task.id task = await client.get_task(request=GetTaskRequest(id=task_id)) assert task.status.state == TaskState.TASK_STATE_INPUT_REQUIRED @@ -528,7 +532,6 @@ async def test_end_to_end_input_required(transport_setups): task.history, [ (Role.ROLE_USER, 'Need input'), - (Role.ROLE_AGENT, 'task submitted'), (Role.ROLE_AGENT, 'task working'), ], ) @@ -572,7 +575,6 @@ async def test_end_to_end_input_required(transport_setups): task.history, [ (Role.ROLE_USER, 'Need input'), - (Role.ROLE_AGENT, 'task submitted'), (Role.ROLE_AGENT, 'task working'), (Role.ROLE_AGENT, 'Please provide input'), (Role.ROLE_USER, 'Here is the input'), @@ -681,3 +683,78 @@ async def test_end_to_end_subscribe_validation_error( assert {e['field'] for e in errors} == {'id'} await client.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'streaming', + [ + pytest.param(False, id='blocking'), + pytest.param(True, id='streaming'), + ], +) +async def test_end_to_end_direct_message(transport_setups, streaming): + """Test that an executor can return a direct Message without creating a Task.""" + client = transport_setups.client + client._config.streaming = streaming + + message_to_send = Message( + role=Role.ROLE_USER, + message_id='msg-direct', + parts=[Part(text='Message: Hello agent')], + ) + + events = [ + event + async for event in client.send_message( + request=SendMessageRequest(message=message_to_send) + ) + ] + + assert len(events) == 1 + response = events[0] + assert response.HasField('message') + assert not response.HasField('task') + assert_message_matches( + response.message, + Role.ROLE_AGENT, + 'Direct reply to: Message: Hello agent', + ) + + +@pytest.mark.asyncio +async def test_end_to_end_direct_message_return_immediately(transport_setups): + """Test that return_immediately still returns the Message for direct replies. + + When the executor responds with a direct Message, the response is + inherently immediate -- there is no async task to defer to. The client + should receive the Message regardless of the return_immediately flag. + """ + client = transport_setups.client + client._config.streaming = False + + message_to_send = Message( + role=Role.ROLE_USER, + message_id='msg-direct-return-immediately', + parts=[Part(text='Message: Quick question')], + ) + configuration = SendMessageConfiguration(return_immediately=True) + + events = [ + event + async for event in client.send_message( + request=SendMessageRequest( + message=message_to_send, configuration=configuration + ) + ) + ] + + assert len(events) == 1 + response = events[0] + assert response.HasField('message') + assert not response.HasField('task') + assert_message_matches( + response.message, + Role.ROLE_AGENT, + 'Direct reply to: Message: Quick question', + ) From ead75f95b852810f1837248b9e9ffc5092c2d463 Mon Sep 17 00:00:00 2001 From: "Agent2Agent (A2A) Bot" Date: Thu, 9 Apr 2026 10:20:16 -0500 Subject: [PATCH 19/67] chore(main): release 0.3.26 (#935) :robot: I have created a release *beep* *boop* --- ## [0.3.26](https://github.com/a2aproject/a2a-python/compare/v0.3.25...v0.3.26) (2026-04-09) ### Features * Add support for more Task Message and Artifact fields in the Vertex Task Store ([#908](https://github.com/a2aproject/a2a-python/issues/908)) ([5e0dcd7](https://github.com/a2aproject/a2a-python/commit/5e0dcd798fcba16a8092b0b4c2d3d8026ca287de)) ### Bug Fixes * remove the use of deprecated types from VertexTaskStore ([#889](https://github.com/a2aproject/a2a-python/issues/889)) ([6d49122](https://github.com/a2aproject/a2a-python/commit/6d49122238a5e7d497c5d002792732446071dcb2)) --- This PR was generated with [Release Please](https://github.com/googleapis/release-please). See [documentation](https://github.com/googleapis/release-please#release-please). --- CHANGELOG.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0be3872ad..01e3469b8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,17 @@ # Changelog +## [0.3.26](https://github.com/a2aproject/a2a-python/compare/v0.3.25...v0.3.26) (2026-04-09) + + +### Features + +* Add support for more Task Message and Artifact fields in the Vertex Task Store ([#908](https://github.com/a2aproject/a2a-python/issues/908)) ([5e0dcd7](https://github.com/a2aproject/a2a-python/commit/5e0dcd798fcba16a8092b0b4c2d3d8026ca287de)) + + +### Bug Fixes + +* remove the use of deprecated types from VertexTaskStore ([#889](https://github.com/a2aproject/a2a-python/issues/889)) ([6d49122](https://github.com/a2aproject/a2a-python/commit/6d49122238a5e7d497c5d002792732446071dcb2)) + ## [0.3.25](https://github.com/a2aproject/a2a-python/compare/v0.3.24...v0.3.25) (2026-03-10) From 6c807d51c49ac294a6e3cbec34be101d4f91870d Mon Sep 17 00:00:00 2001 From: Guglielmo Colombo Date: Thu, 9 Apr 2026 18:01:27 +0200 Subject: [PATCH 20/67] fix: fix JSONRPC error handling (#957) # Description Do one iteration to catch exceptions occurred beforehand to return an error instead of sending headers for SSE. --- .github/actions/spelling/allow.txt | 1 + src/a2a/server/routes/jsonrpc_dispatcher.py | 27 ++++++-- .../test_client_server_integration.py | 65 +++++++++++++++++++ 3 files changed, 86 insertions(+), 7 deletions(-) diff --git a/.github/actions/spelling/allow.txt b/.github/actions/spelling/allow.txt index 900974409..b3b2d56e8 100644 --- a/.github/actions/spelling/allow.txt +++ b/.github/actions/spelling/allow.txt @@ -45,6 +45,7 @@ dunders ES256 euo EUR +evt excinfo FastAPI fernet diff --git a/src/a2a/server/routes/jsonrpc_dispatcher.py b/src/a2a/server/routes/jsonrpc_dispatcher.py index d9ea4ff1a..60620081a 100644 --- a/src/a2a/server/routes/jsonrpc_dispatcher.py +++ b/src/a2a/server/routes/jsonrpc_dispatcher.py @@ -15,6 +15,7 @@ HTTP_EXTENSION_HEADER, ) from a2a.server.context import ServerCallContext +from a2a.server.events import Event from a2a.server.jsonrpc_models import ( InternalError, InvalidParamsError, @@ -376,20 +377,32 @@ async def _process_streaming_request( if stream is None: raise UnsupportedOperationError(message='Stream not supported') + # Eagerly fetch the first event to trigger validation/upfront errors + try: + first_event = await anext(stream) + except StopAsyncIteration: + first_event = None + async def _wrap_stream( - st: AsyncGenerator, + st: AsyncGenerator, first_evt: Event | None ) -> AsyncGenerator[dict[str, Any], None]: + def _map_event(evt: Event) -> dict[str, Any]: + stream_response = proto_utils.to_stream_response(evt) + result = MessageToDict( + stream_response, preserving_proto_field_name=False + ) + return JSONRPC20Response(result=result, _id=request_id).data + try: + if first_evt is not None: + yield _map_event(first_evt) + async for event in st: - stream_response = proto_utils.to_stream_response(event) - result = MessageToDict( - stream_response, preserving_proto_field_name=False - ) - yield JSONRPC20Response(result=result, _id=request_id).data + yield _map_event(event) except A2AError as e: yield build_error_response(request_id, e) - return _wrap_stream(stream) + return _wrap_stream(stream, first_event) async def _handle_send_message( self, request_obj: SendMessageRequest, context: ServerCallContext diff --git a/tests/integration/test_client_server_integration.py b/tests/integration/test_client_server_integration.py index c7fa29ea5..1ac8a7162 100644 --- a/tests/integration/test_client_server_integration.py +++ b/tests/integration/test_client_server_integration.py @@ -1019,6 +1019,71 @@ async def mock_generator(*args, **kwargs): await client.close() +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'error_cls,handler_attr,client_method,request_params', + [ + pytest.param( + UnsupportedOperationError, + 'on_subscribe_to_task', + 'subscribe', + SubscribeToTaskRequest(id='some-id'), + id='subscribe', + ), + ], +) +async def test_server_rejects_stream_on_validation_error( + transport_setups, error_cls, handler_attr, client_method, request_params +) -> None: + """Verify that the server returns an error directly and doesn't open a stream on validation error.""" + client = transport_setups.client + handler = transport_setups.handler + + async def mock_generator(*args, **kwargs): + raise error_cls('Validation failed') + yield + + getattr(handler, handler_attr).side_effect = mock_generator + + transport = client._transport + + if isinstance(transport, (RestTransport, JsonRpcTransport)): + # Spy on httpx client to check response headers + original_send = transport.httpx_client.send + response_headers = {} + + async def mock_send(*args, **kwargs): + resp = await original_send(*args, **kwargs) + response_headers['Content-Type'] = resp.headers.get('Content-Type') + return resp + + transport.httpx_client.send = mock_send + + try: + with pytest.raises(error_cls): + async for _ in getattr(client, client_method)( + request=request_params + ): + pass + finally: + transport.httpx_client.send = original_send + + # Verify that the response content type was NOT text/event-stream + assert not response_headers.get('Content-Type', '').startswith( + 'text/event-stream' + ) + else: + # For gRPC, we just verify it raises the error + with pytest.raises(error_cls): + async for _ in getattr(client, client_method)( + request=request_params + ): + pass + + getattr(handler, handler_attr).side_effect = None + await client.close() + + @pytest.mark.asyncio @pytest.mark.parametrize( 'request_kwargs, expected_error_code', From 354fdfb68dd0c7894daaac885a06dfed0ab839c8 Mon Sep 17 00:00:00 2001 From: Bartek Wolowiec Date: Fri, 10 Apr 2026 10:20:02 +0200 Subject: [PATCH 21/67] feat: Support Message-only simplified execution without creating Task (#956) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes #869 🦕 --- src/a2a/server/agent_execution/active_task.py | 205 +++++++++++++----- .../server/agent_execution/agent_executor.py | 3 + .../default_request_handler_v2.py | 100 ++++----- src/a2a/server/tasks/task_manager.py | 26 ++- tests/integration/test_scenarios.py | 201 ++++++++++++----- .../agent_execution/test_active_task.py | 1 + .../test_default_request_handler_v2.py | 49 ----- 7 files changed, 370 insertions(+), 215 deletions(-) diff --git a/src/a2a/server/agent_execution/active_task.py b/src/a2a/server/agent_execution/active_task.py index defdd5244..a3cd94cbe 100644 --- a/src/a2a/server/agent_execution/active_task.py +++ b/src/a2a/server/agent_execution/active_task.py @@ -5,7 +5,7 @@ import logging import uuid -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, Any, cast from a2a.server.agent_execution.context import RequestContext @@ -56,6 +56,12 @@ } +class _RequestStarted: + def __init__(self, request_id: uuid.UUID, request_context: RequestContext): + self.request_id = request_id + self.request_context = request_context + + class _RequestCompleted: def __init__(self, request_id: uuid.UUID): self.request_id = request_id @@ -199,25 +205,13 @@ async def start( logger.debug('TASK (start): %s', task) if task: + self._task_created.set() if task.status.state in TERMINAL_TASK_STATES: raise InvalidParamsError( message=f'Task {task.id} is in terminal state: {task.status.state}' ) - else: - if not create_task_if_missing: - raise TaskNotFoundError - - # New task. Create and save it so it's not "missing" if queried immediately - # (especially important for return_immediately=True) - if self._task_manager.context_id is None: - raise ValueError('Context ID is required for new tasks') - task = self._task_manager._init_task_obj( - self._task_id, - self._task_manager.context_id, - ) - await self._task_manager.save_task_event(task) - if self._push_sender: - await self._push_sender.send_notification(task.id, task) + elif not create_task_if_missing: + raise TaskNotFoundError except Exception: logger.debug( @@ -253,9 +247,9 @@ async def _run_producer(self) -> None: Runs as a detached asyncio.Task. Safe to cancel. """ logger.debug('Producer[%s]: Started', self._task_id) + request_context = None try: - active = True - while active: + while True: ( request_context, request_id, @@ -263,22 +257,11 @@ async def _run_producer(self) -> None: await self._request_lock.acquire() # TODO: Should we create task manager every time? self._task_manager._call_context = request_context.call_context + request_context.current_task = ( await self._task_manager.get_task() ) - message = request_context.message - if message: - request_context.current_task = ( - self._task_manager.update_with_message( - message, - cast('Task', request_context.current_task), - ) - ) - await self._task_manager.save_task_event( - request_context.current_task - ) - self._task_created.set() logger.debug( 'Producer[%s]: Executing agent task %s', self._task_id, @@ -286,6 +269,13 @@ async def _run_producer(self) -> None: ) try: + await self._event_queue_agent.enqueue_event( + cast( + 'Event', + _RequestStarted(request_id, request_context), + ) + ) + await self._agent_executor.execute( request_context, self._event_queue_agent ) @@ -293,32 +283,36 @@ async def _run_producer(self) -> None: 'Producer[%s]: Execution finished successfully', self._task_id, ) - except QueueShutDown: - logger.debug( - 'Producer[%s]: Request queue shut down', self._task_id - ) - raise - except asyncio.CancelledError: - logger.debug('Producer[%s]: Cancelled', self._task_id) - raise - except Exception as e: - logger.exception( - 'Producer[%s]: Execution failed', - self._task_id, - ) - async with self._lock: - await self._mark_task_as_failed(e) - active = False finally: logger.debug( 'Producer[%s]: Enqueuing request completed event', self._task_id, ) - # TODO: Hide from external consumers await self._event_queue_agent.enqueue_event( cast('Event', _RequestCompleted(request_id)) ) self._request_queue.task_done() + except asyncio.CancelledError: + logger.debug('Producer[%s]: Cancelled', self._task_id) + + except QueueShutDown: + logger.debug('Producer[%s]: Queue shut down', self._task_id) + + except Exception as e: + logger.exception( + 'Producer[%s]: Execution failed', + self._task_id, + ) + # Create task and mark as failed. + if request_context: + await self._task_manager.ensure_task_id( + self._task_id, + request_context.context_id or '', + ) + self._task_created.set() + async with self._lock: + await self._mark_task_as_failed(e) + finally: self._request_queue.shutdown(immediate=True) await self._event_queue_agent.close(immediate=False) @@ -338,6 +332,10 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912 `_is_finished`, unblocking all global subscribers and wait() calls. """ logger.debug('Consumer[%s]: Started', self._task_id) + task_mode = None + message_to_save = None + # TODO: Make helper methods + # TODO: Support Task enqueue try: try: try: @@ -347,6 +345,7 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912 'Consumer[%s]: Waiting for event', self._task_id, ) + new_task = None event = await self._event_queue_agent.dequeue_event() logger.debug( 'Consumer[%s]: Dequeued event %s', @@ -361,17 +360,70 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912 self._task_id, ) self._request_lock.release() + elif isinstance(event, _RequestStarted): + logger.debug( + 'Consumer[%s]: Request started', + self._task_id, + ) + message_to_save = event.request_context.message + elif isinstance(event, Message): + if task_mode is not None: + if task_mode: + logger.error( + 'Received Message() object in task mode.' + ) + else: + logger.error( + 'Multiple Message() objects received.' + ) + task_mode = False logger.debug( 'Consumer[%s]: Setting result to Message: %s', self._task_id, event, ) else: + if task_mode is False: + logger.error( + 'Received %s in message mode.', + type(event).__name__, + ) + + if isinstance(event, Task): + new_task = event + await self._task_manager.save_task_event( + new_task + ) + # TODO: Avoid duplicated messages + else: + new_task = ( + await self._task_manager.ensure_task_id( + self._task_id, + event.context_id, + ) + ) + + if message_to_save is not None: + new_task = self._task_manager.update_with_message( + message_to_save, + new_task, + ) + await ( + self._task_manager.save_task_event( + new_task + ) + ) + message_to_save = None + + task_mode = True # Save structural events (like TaskStatusUpdate) to DB. - # TODO: Create task manager every time ? + self._task_manager.context_id = event.context_id - await self._task_manager.process(event) + if not isinstance(event, Task): + await self._task_manager.process(event) + + self._task_created.set() # Check for AUTH_REQUIRED or INPUT_REQUIRED or TERMINAL states new_task = await self._task_manager.get_task() @@ -379,6 +431,8 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912 raise RuntimeError( f'Task {self.task_id} not found' ) + if isinstance(event, Task): + event = new_task is_interrupted = ( new_task.status.state in INTERRUPTED_TASK_STATES @@ -432,8 +486,23 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912 self._task_id, event ) finally: + if new_task is not None: + new_task_copy = Task() + new_task_copy.CopyFrom(new_task) + new_task = new_task_copy + if isinstance(event, Task): + new_task_copy = Task() + new_task_copy.CopyFrom(event) + event = new_task_copy + + logger.debug( + 'Consumer[%s]: Enqueuing\nEvent: %s\nNew Task: %s\n', + self._task_id, + event, + new_task, + ) await self._event_queue_subscribers.enqueue_event( - event + cast('Any', (event, new_task)) ) self._event_queue_agent.task_done() except QueueShutDown: @@ -459,6 +528,7 @@ async def subscribe( # noqa: PLR0912, PLR0915 *, request: RequestContext | None = None, include_initial_task: bool = False, + replace_status_update_with_task: bool = False, ) -> AsyncGenerator[Event, None]: """Creates a queue tap and yields events as they are produced. @@ -506,9 +576,25 @@ async def subscribe( # noqa: PLR0912, PLR0915 # Wait for next event or task completion try: - event = await asyncio.wait_for( + dequeued = await asyncio.wait_for( tapped_queue.dequeue_event(), timeout=0.1 ) + event, updated_task = cast('Any', dequeued) + logger.debug( + 'Subscriber[%s]\nDequeued event %s\nUpdated task %s\n', + self._task_id, + event, + updated_task, + ) + if replace_status_update_with_task and isinstance( + event, TaskStatusUpdateEvent + ): + logger.debug( + 'Subscriber[%s]: Replacing TaskStatusUpdateEvent with Task: %s', + self._task_id, + updated_task, + ) + event = updated_task if self._exception: raise self._exception from None if isinstance(event, _RequestCompleted): @@ -522,6 +608,12 @@ async def subscribe( # noqa: PLR0912, PLR0915 ) return continue + elif isinstance(event, _RequestStarted): + logger.debug( + 'Subscriber[%s]: Request started', + self._task_id, + ) + continue except (asyncio.TimeoutError, TimeoutError): if self._is_finished.is_set(): if self._exception: @@ -545,7 +637,7 @@ async def subscribe( # noqa: PLR0912, PLR0915 # Evaluate if this was the last subscriber on a finished task. await self._maybe_cleanup() - async def cancel(self, call_context: ServerCallContext) -> Task | Message: + async def cancel(self, call_context: ServerCallContext) -> Task: """Cancels the running active task. Concurrency Guarantee: @@ -558,11 +650,11 @@ async def cancel(self, call_context: ServerCallContext) -> Task | Message: # TODO: Conflicts with call_context on the pending request. self._task_manager._call_context = call_context - task = await self.get_task() + task = await self._task_manager.get_task() request_context = RequestContext( call_context=call_context, task_id=self._task_id, - context_id=task.context_id, + context_id=task.context_id if task else None, task=task, ) @@ -591,7 +683,10 @@ async def cancel(self, call_context: ServerCallContext) -> Task | Message: ) await self._is_finished.wait() - return await self.get_task() + task = await self._task_manager.get_task() + if not task: + raise RuntimeError('Task should have been created') + return task async def _maybe_cleanup(self) -> None: """Triggers cleanup if task is finished and has no subscribers. diff --git a/src/a2a/server/agent_execution/agent_executor.py b/src/a2a/server/agent_execution/agent_executor.py index 764bef4b2..2da8ddfd7 100644 --- a/src/a2a/server/agent_execution/agent_executor.py +++ b/src/a2a/server/agent_execution/agent_executor.py @@ -34,6 +34,9 @@ async def execute( - Explain how cancelation work (executor task will be canceled, cancel() is called, order of calls, etc) - Explain if execute can wait for cancel and if cancel can wait for execute. - Explain behaviour of streaming / not-immediate when execute() returns in active state. + - Possible workflows: + - Enqueue a SINGLE Message object + - Enqueue TaskStatusUpdateEvent (TASK_STATE_SUBMITTED or TASK_STATE_REJECTED) and continue with TaskStatusUpdateEvent / TaskArtifactUpdateEvent. Args: context: The request context containing the message, task ID, etc. diff --git a/src/a2a/server/request_handlers/default_request_handler_v2.py b/src/a2a/server/request_handlers/default_request_handler_v2.py index ccc9cdd0e..1a8464687 100644 --- a/src/a2a/server/request_handlers/default_request_handler_v2.py +++ b/src/a2a/server/request_handlers/default_request_handler_v2.py @@ -242,63 +242,56 @@ async def on_message_send( # noqa: D102 active_task, request_context = await self._setup_active_task( params, context ) + task_id = cast('str', request_context.task_id) - if params.configuration and params.configuration.return_immediately: - await active_task.enqueue_request(request_context) - - task = await active_task.get_task() - if params.configuration: - task = apply_history_length(task, params.configuration) - return task + result: Message | Task | None = None - try: - result_states = TERMINAL_TASK_STATES | INTERRUPTED_TASK_STATES - - result = None - async for event in active_task.subscribe(request=request_context): - logger.debug( - 'Processing[%s] event [%s] %s', - request_context.task_id, - type(event).__name__, - event, - ) - if isinstance(event, Message) or ( - isinstance(event, Task) - and event.status.state in result_states - ): - result = event - break - if ( - isinstance(event, TaskStatusUpdateEvent) - and event.status.state in result_states - ): - result = await self.task_store.get(event.task_id, context) - break - - if result is None: + async for raw_event in active_task.subscribe( + request=request_context, + include_initial_task=False, + replace_status_update_with_task=True, + ): + event = raw_event + logger.debug( + 'Processing[%s] event [%s] %s', + params.message.task_id, + type(event).__name__, + event, + ) + if isinstance(event, TaskStatusUpdateEvent): + self._validate_task_id_match(task_id, event.task_id) + event = await active_task.get_task() logger.debug( - 'Missing result for task %s', request_context.task_id + 'Replaced TaskStatusUpdateEvent with Task: %s', event ) - result = await active_task.get_task() - logger.debug( - 'Processing[%s] result: %s', request_context.task_id, result - ) + if isinstance(event, Task) and ( + params.configuration.return_immediately + or event.status.state + in (TERMINAL_TASK_STATES | INTERRUPTED_TASK_STATES) + ): + self._validate_task_id_match(task_id, event.id) + result = event + break + + if isinstance(event, Message): + result = event + break - except Exception: - logger.exception('Agent execution failed') - raise + if result is None: + logger.debug('Missing result for task %s', request_context.task_id) + result = await active_task.get_task() if isinstance(result, Task): - self._validate_task_id_match( - cast('str', request_context.task_id), result.id - ) - if params.configuration: - result = apply_history_length(result, params.configuration) + result = apply_history_length(result, params.configuration) + logger.debug( + 'Returning result for task %s: %s', + request_context.task_id, + result, + ) return result - # TODO: Unify with on_message_send @validate_request_params @validate( lambda self: self._agent_card.capabilities.streaming, @@ -313,19 +306,20 @@ async def on_message_send_stream( # noqa: D102 params, context ) - include_initial_task = bool( - params.configuration and params.configuration.return_immediately - ) - task_id = cast('str', request_context.task_id) async for event in active_task.subscribe( - request=request_context, include_initial_task=include_initial_task + request=request_context, + include_initial_task=False, ): if isinstance(event, Task): self._validate_task_id_match(task_id, event.id) - logger.debug('Sending event [%s] %s', type(event).__name__, event) - yield event + yield apply_history_length(event, params.configuration) + else: + yield event + + if isinstance(event, Message): + break @validate_request_params @validate( diff --git a/src/a2a/server/tasks/task_manager.py b/src/a2a/server/tasks/task_manager.py index 905b11af3..143413d5b 100644 --- a/src/a2a/server/tasks/task_manager.py +++ b/src/a2a/server/tasks/task_manager.py @@ -147,13 +147,12 @@ async def save_task_event( await self._save_task(task) return task - async def ensure_task( - self, event: TaskStatusUpdateEvent | TaskArtifactUpdateEvent - ) -> Task: + async def ensure_task_id(self, task_id: str, context_id: str) -> Task: """Ensures a Task object exists in memory, loading from store or creating new if needed. Args: - event: The task-related event triggering the need for a Task object. + task_id: The ID for the new task. + context_id: The context ID for the new task. Returns: An existing or newly created `Task` object. @@ -168,16 +167,29 @@ async def ensure_task( if not task: logger.info( 'Task not found or task_id not set. Creating new task for event (task_id: %s, context_id: %s).', - event.task_id, - event.context_id, + task_id, + context_id, ) # streaming agent did not previously stream task object. # Create a task object with the available information and persist the event - task = self._init_task_obj(event.task_id, event.context_id) + task = self._init_task_obj(task_id, context_id) await self._save_task(task) return task + async def ensure_task( + self, event: TaskStatusUpdateEvent | TaskArtifactUpdateEvent + ) -> Task: + """Ensures a Task object exists in memory, loading from store or creating new if needed. + + Args: + event: The task-related event triggering the need for a Task object. + + Returns: + An existing or newly created `Task` object. + """ + return await self.ensure_task_id(event.task_id, event.context_id) + async def process(self, event: Event) -> Event: """Processes an event, updates the task state if applicable, stores it, and returns the event. diff --git a/tests/integration/test_scenarios.py b/tests/integration/test_scenarios.py index 1e2253430..4683dc3e9 100644 --- a/tests/integration/test_scenarios.py +++ b/tests/integration/test_scenarios.py @@ -16,11 +16,14 @@ from a2a.server.context import ServerCallContext from a2a.server.events import EventQueue from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager -from a2a.server.request_handlers import DefaultRequestHandlerV2, GrpcHandler +from a2a.server.request_handlers import ( + DefaultRequestHandlerV2, + GrpcHandler, + GrpcServerCallContextBuilder, +) from a2a.server.request_handlers.default_request_handler import ( LegacyRequestHandler, ) -from a2a.server.request_handlers import GrpcServerCallContextBuilder from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore from a2a.types import a2a_pb2_grpc from a2a.types.a2a_pb2 import ( @@ -701,24 +704,12 @@ async def send_message_and_get_first_response(): ) return await asyncio.wait_for(it.__anext__(), timeout=0.1) - if use_legacy: - # Legacy client hangs forever. - with pytest.raises(asyncio.TimeoutError): - await send_message_and_get_first_response() - else: - event = await send_message_and_get_first_response() - task = event.task - assert task.status.state == TaskState.TASK_STATE_SUBMITTED - (message,) = task.history - assert message.message_id == 'test-msg' + # First response should not be there yet. + with pytest.raises(asyncio.TimeoutError): + await send_message_and_get_first_response() tasks = (await client.list_tasks(ListTasksRequest())).tasks - if use_legacy: - # Legacy didn't create a task - assert len(tasks) == 0 - else: - (task,) = tasks - assert task.status.state == TaskState.TASK_STATE_SUBMITTED + assert len(tasks) == 0 # Scenario 17: Cancellation of a working task. @@ -1090,39 +1081,13 @@ async def cancel( ) states = [get_state(event) async for event in it] - if use_legacy: - if streaming: - assert states == [ - TaskState.TASK_STATE_WORKING, - TaskState.TASK_STATE_COMPLETED, - ] - else: - assert states == [TaskState.TASK_STATE_WORKING] - elif streaming: - assert states == [ - TaskState.TASK_STATE_SUBMITTED, - TaskState.TASK_STATE_WORKING, - TaskState.TASK_STATE_COMPLETED, - ] - else: - assert states == [TaskState.TASK_STATE_SUBMITTED] - - # Test blocking return. - it = client.send_message( - SendMessageRequest( - message=msg, - configuration=SendMessageConfiguration(return_immediately=False), - ) - ) - states = [get_state(event) async for event in it] - if streaming: assert states == [ TaskState.TASK_STATE_WORKING, TaskState.TASK_STATE_COMPLETED, ] else: - assert states == [TaskState.TASK_STATE_COMPLETED] + assert states == [TaskState.TASK_STATE_WORKING] # Scenario: Test TASK_STATE_INPUT_REQUIRED. @@ -1305,7 +1270,7 @@ async def cancel( @pytest.mark.timeout(5.0) @pytest.mark.asyncio @pytest.mark.parametrize('use_legacy', [False, True], ids=['v2', 'legacy']) -async def test_scenario_parallel_subscribe_attach_detach(use_legacy): +async def test_scenario_parallel_subscribe_attach_detach(use_legacy): # noqa: PLR0915 events = collections.defaultdict(asyncio.Event) class EmitAgent(AgentExecutor): @@ -1434,11 +1399,11 @@ async def collect(): await events['emitted_phase_4'].wait() def get_artifact_updates(evs): - txts = [] - for sr in evs: - if sr.HasField('artifact_update'): - txts.append([p.text for p in sr.artifact_update.artifact.parts]) - return txts + return [ + [p.text for p in sr.artifact_update.artifact.parts] + for sr in evs + if sr.HasField('artifact_update') + ] assert get_artifact_updates(await sub1_task) == [ ['artifact_1'], @@ -1459,3 +1424,137 @@ def get_artifact_updates(evs): ] monitor_task.cancel() + + +# Return message directly. +@pytest.mark.timeout(2.0) +@pytest.mark.asyncio +@pytest.mark.parametrize('use_legacy', [False, True], ids=['v2', 'legacy']) +@pytest.mark.parametrize( + 'streaming', [False, True], ids=['blocking', 'streaming'] +) +@pytest.mark.parametrize( + 'return_immediately', + [False, True], + ids=['no_return_immediately', 'return_immediately'], +) +async def test_scenario_publish_message( + use_legacy, streaming, return_immediately +): + class MessageAgent(AgentExecutor): + async def execute( + self, context: RequestContext, event_queue: EventQueue + ): + await event_queue.enqueue_event( + Message( + task_id=context.task_id, + context_id=context.context_id, + message_id='msg-1', + role=Role.ROLE_AGENT, + parts=[Part(text='response text')], + ) + ) + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ): + pass + + handler = create_handler(MessageAgent(), use_legacy) + client = await create_client( + handler, agent_card=agent_card(), streaming=streaming + ) + + msg = Message( + message_id='test-msg', role=Role.ROLE_USER, parts=[Part(text='start')] + ) + + it = client.send_message( + SendMessageRequest( + message=msg, + configuration=SendMessageConfiguration( + return_immediately=return_immediately + ), + ) + ) + events = [event async for event in it] + + (event,) = events + assert event.HasField('message') + assert event.message.parts[0].text == 'response text' + + tasks = (await client.list_tasks(ListTasksRequest())).tasks + assert len(tasks) == 0 + + +# Scenario: Publish ArtifactUpdateEvent +@pytest.mark.timeout(2.0) +@pytest.mark.asyncio +@pytest.mark.parametrize('use_legacy', [False, True], ids=['v2', 'legacy']) +@pytest.mark.parametrize( + 'streaming', [False, True], ids=['blocking', 'streaming'] +) +async def test_scenario_publish_artifact(use_legacy, streaming): + class ArtifactAgent(AgentExecutor): + async def execute( + self, context: RequestContext, event_queue: EventQueue + ): + await event_queue.enqueue_event( + TaskArtifactUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + artifact=Artifact( + artifact_id='art-1', parts=[Part(text='artifact data')] + ), + ) + ) + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), + ) + ) + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ): + pass + + handler = create_handler(ArtifactAgent(), use_legacy) + client = await create_client( + handler, agent_card=agent_card(), streaming=streaming + ) + + msg = Message( + message_id='test-msg', role=Role.ROLE_USER, parts=[Part(text='start')] + ) + + it = client.send_message( + SendMessageRequest( + message=msg, + configuration=SendMessageConfiguration(return_immediately=False), + ) + ) + events = [event async for event in it] + + if streaming: + last_event = events[-1] + assert get_state(last_event) == TaskState.TASK_STATE_COMPLETED + + artifact_events = [e for e in events if e.HasField('artifact_update')] + assert len(artifact_events) > 0, ( + 'Bug: Streaming should return the artifact update event' + ) + assert ( + artifact_events[0].artifact_update.artifact.artifact_id == 'art-1' + ) + else: + last_event = events[-1] + assert last_event.HasField('task') + assert last_event.task.status.state == TaskState.TASK_STATE_COMPLETED + + assert len(last_event.task.artifacts) > 0, ( + 'Bug: Task should include the published artifact' + ) + assert last_event.task.artifacts[0].artifact_id == 'art-1' diff --git a/tests/server/agent_execution/test_active_task.py b/tests/server/agent_execution/test_active_task.py index d3cc95dc3..3a4a24ff6 100644 --- a/tests/server/agent_execution/test_active_task.py +++ b/tests/server/agent_execution/test_active_task.py @@ -1047,6 +1047,7 @@ async def execute_mock(req, q): assert events[0] == initial_task +@pytest.mark.timeout(1) @pytest.mark.asyncio async def test_active_task_subscribe_request_parameter(): agent_executor = Mock() diff --git a/tests/server/request_handlers/test_default_request_handler_v2.py b/tests/server/request_handlers/test_default_request_handler_v2.py index 605078201..d48b82461 100644 --- a/tests/server/request_handlers/test_default_request_handler_v2.py +++ b/tests/server/request_handlers/test_default_request_handler_v2.py @@ -1104,55 +1104,6 @@ async def test_on_message_send_limit_history(): assert task.history is not None and len(task.history) > 1 -@pytest.mark.asyncio -async def test_on_message_send_task_id_mismatch(): - mock_task_store = AsyncMock(spec=TaskStore) - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_request_context_builder = AsyncMock(spec=RequestContextBuilder) - - context_task_id = 'context_task_id_1' - result_task_id = 'DIFFERENT_task_id_1' - - mock_request_context = MagicMock() - mock_request_context.task_id = context_task_id - mock_request_context_builder.build.return_value = mock_request_context - - request_handler = DefaultRequestHandlerV2( - agent_executor=mock_agent_executor, - task_store=mock_task_store, - request_context_builder=mock_request_context_builder, - agent_card=create_default_agent_card(), - ) - params = SendMessageRequest( - message=Message( - role=Role.ROLE_USER, - message_id='msg_id_mismatch', - parts=[Part(text='hello')], - ) - ) - - mock_active_task = MagicMock() - mismatched_task = create_sample_task(task_id=result_task_id) - mock_active_task.wait = AsyncMock(return_value=mismatched_task) - mock_active_task.start = AsyncMock() - mock_active_task.enqueue_request = AsyncMock() - mock_active_task.get_task = AsyncMock(return_value=mismatched_task) - with ( - patch.object( - request_handler._active_task_registry, - 'get_or_create', - return_value=mock_active_task, - ), - patch( - 'a2a.server.request_handlers.default_request_handler.TaskManager.get_task', - return_value=None, - ), - ): - with pytest.raises(InternalError) as exc_info: - await request_handler.on_message_send(params, context=MagicMock()) - assert 'Task ID mismatch' in exc_info.value.message - - @pytest.mark.asyncio async def test_on_message_send_stream_task_id_mismatch(): mock_task_store = AsyncMock(spec=TaskStore) From 62e5e59a30b11b9b493f7bf969aa13173ce51b9c Mon Sep 17 00:00:00 2001 From: Bartek Wolowiec Date: Fri, 10 Apr 2026 13:17:38 +0200 Subject: [PATCH 22/67] feat: Simplify ActiveTask.subscribe() (#958) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Simplify ActiveTask.subscribe() and remove race condition between _is_finished and slow enqueue. Fixes #869 🦕 --- src/a2a/server/agent_execution/active_task.py | 90 +++++++++---------- 1 file changed, 42 insertions(+), 48 deletions(-) diff --git a/src/a2a/server/agent_execution/active_task.py b/src/a2a/server/agent_execution/active_task.py index a3cd94cbe..71e38768f 100644 --- a/src/a2a/server/agent_execution/active_task.py +++ b/src/a2a/server/agent_execution/active_task.py @@ -511,12 +511,14 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912 ) except Exception as e: logger.exception('Consumer[%s]: Failed', self._task_id) + # TODO: Make the task in database as failed. async with self._lock: await self._mark_task_as_failed(e) finally: # The consumer is dead. The ActiveTask is permanently finished. self._is_finished.set() self._request_queue.shutdown(immediate=True) + await self._event_queue_agent.close(immediate=True) logger.debug('Consumer[%s]: Finishing', self._task_id) await self._maybe_cleanup() @@ -574,53 +576,42 @@ async def subscribe( # noqa: PLR0912, PLR0915 if self._exception: raise self._exception - # Wait for next event or task completion - try: - dequeued = await asyncio.wait_for( - tapped_queue.dequeue_event(), timeout=0.1 - ) - event, updated_task = cast('Any', dequeued) + dequeued = await tapped_queue.dequeue_event() + event, updated_task = cast('Any', dequeued) + logger.debug( + 'Subscriber[%s]\nDequeued event %s\nUpdated task %s\n', + self._task_id, + event, + updated_task, + ) + if replace_status_update_with_task and isinstance( + event, TaskStatusUpdateEvent + ): logger.debug( - 'Subscriber[%s]\nDequeued event %s\nUpdated task %s\n', + 'Subscriber[%s]: Replacing TaskStatusUpdateEvent with Task: %s', self._task_id, - event, updated_task, ) - if replace_status_update_with_task and isinstance( - event, TaskStatusUpdateEvent + event = updated_task + if self._exception: + raise self._exception from None + if isinstance(event, _RequestCompleted): + if ( + request_id is not None + and event.request_id == request_id ): logger.debug( - 'Subscriber[%s]: Replacing TaskStatusUpdateEvent with Task: %s', + 'Subscriber[%s]: Request completed', self._task_id, - updated_task, ) - event = updated_task - if self._exception: - raise self._exception from None - if isinstance(event, _RequestCompleted): - if ( - request_id is not None - and event.request_id == request_id - ): - logger.debug( - 'Subscriber[%s]: Request completed', - self._task_id, - ) - return - continue - elif isinstance(event, _RequestStarted): - logger.debug( - 'Subscriber[%s]: Request started', - self._task_id, - ) - continue - except (asyncio.TimeoutError, TimeoutError): - if self._is_finished.is_set(): - if self._exception: - raise self._exception from None - break + return + continue + elif isinstance(event, _RequestStarted): + logger.debug( + 'Subscriber[%s]: Request started', + self._task_id, + ) continue - try: yield event finally: @@ -715,17 +706,20 @@ async def _mark_task_as_failed(self, exception: Exception) -> None: if self._exception is None: self._exception = exception if self._task_created.is_set(): - task = await self._task_manager.get_task() - if task is not None: - await self._event_queue_agent.enqueue_event( - TaskStatusUpdateEvent( - task_id=task.id, - context_id=task.context_id, - status=TaskStatus( - state=TaskState.TASK_STATE_FAILED, - ), + try: + task = await self._task_manager.get_task() + if task is not None: + await self._event_queue_agent.enqueue_event( + TaskStatusUpdateEvent( + task_id=task.id, + context_id=task.context_id, + status=TaskStatus( + state=TaskState.TASK_STATE_FAILED, + ), + ) ) - ) + except QueueShutDown: + pass async def get_task(self) -> Task: """Get task from db.""" From 12ce0179056db9d9ba2abdd559cb5a4bb5a20ddf Mon Sep 17 00:00:00 2001 From: Bartek Wolowiec Date: Fri, 10 Apr 2026 14:14:34 +0200 Subject: [PATCH 23/67] feat: Support AgentExectuor enqueue of a Task object. (#960) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes Task object handling when using new DefaultRequestHandlerV2. Fixes #869🦕 --- src/a2a/server/agent_execution/active_task.py | 19 ++++- tests/integration/test_end_to_end.py | 5 +- tests/integration/test_scenarios.py | 84 +++++++++++++++++++ 3 files changed, 101 insertions(+), 7 deletions(-) diff --git a/src/a2a/server/agent_execution/active_task.py b/src/a2a/server/agent_execution/active_task.py index 71e38768f..db7bb5146 100644 --- a/src/a2a/server/agent_execution/active_task.py +++ b/src/a2a/server/agent_execution/active_task.py @@ -391,11 +391,22 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912 ) if isinstance(event, Task): - new_task = event - await self._task_manager.save_task_event( - new_task + existing_task = ( + await self._task_manager.get_task() ) - # TODO: Avoid duplicated messages + if existing_task: + logger.error( + 'Task %s already exists. Ignoring task replacement.', + self._task_id, + ) + else: + await ( + self._task_manager.save_task_event( + event + ) + ) + # Initial task should already contain the message. + message_to_save = None else: new_task = ( await self._task_manager.ensure_task_id( diff --git a/tests/integration/test_end_to_end.py b/tests/integration/test_end_to_end.py index d5387a047..58dce528d 100644 --- a/tests/integration/test_end_to_end.py +++ b/tests/integration/test_end_to_end.py @@ -13,7 +13,7 @@ from a2a.server.agent_execution import AgentExecutor, RequestContext from a2a.server.events import EventQueue from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager -from a2a.server.request_handlers import GrpcHandler, LegacyRequestHandler +from a2a.server.request_handlers import GrpcHandler, DefaultRequestHandler from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes from a2a.server.routes.rest_routes import create_rest_routes from a2a.server.tasks import TaskUpdater @@ -174,8 +174,7 @@ class ClientSetup(NamedTuple): @pytest.fixture def base_e2e_setup(agent_card): task_store = InMemoryTaskStore() - # TODO(https://github.com/a2aproject/a2a-python/issues/869): Use DefaultRequestHandler once it's fixed - handler = LegacyRequestHandler( + handler = DefaultRequestHandler( agent_executor=MockAgentExecutor(), task_store=task_store, agent_card=agent_card, diff --git a/tests/integration/test_scenarios.py b/tests/integration/test_scenarios.py index 4683dc3e9..cee15bfcb 100644 --- a/tests/integration/test_scenarios.py +++ b/tests/integration/test_scenarios.py @@ -1558,3 +1558,87 @@ async def cancel( 'Bug: Task should include the published artifact' ) assert last_event.task.artifacts[0].artifact_id == 'art-1' + + +# Scenario: Enqueue Task twice +@pytest.mark.timeout(2.0) +@pytest.mark.asyncio +@pytest.mark.parametrize('use_legacy', [False, True], ids=['v2', 'legacy']) +@pytest.mark.parametrize( + 'streaming', [False, True], ids=['blocking', 'streaming'] +) +async def test_scenario_enqueue_task_twice(caplog, use_legacy, streaming): + class DoubleTaskAgent(AgentExecutor): + async def execute( + self, context: RequestContext, event_queue: EventQueue + ): + task1 = Task( + id=context.task_id, + context_id=context.context_id, + status=TaskStatus( + state=TaskState.TASK_STATE_WORKING, + message=Message(parts=[Part(text='First task')]), + ), + ) + await event_queue.enqueue_event(task1) + + # This is undefined behavior, but it should not crash or hang. + task2 = Task( + id=context.task_id, + context_id=context.context_id, + status=TaskStatus( + state=TaskState.TASK_STATE_WORKING, + message=Message(parts=[Part(text='Second task')]), + ), + ) + await event_queue.enqueue_event(task2) + + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), + ) + ) + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ): + pass + + handler = create_handler(DoubleTaskAgent(), use_legacy) + client = await create_client( + handler, agent_card=agent_card(), streaming=streaming + ) + + msg = Message( + message_id='test-msg', role=Role.ROLE_USER, parts=[Part(text='start')] + ) + + it = client.send_message( + SendMessageRequest( + message=msg, + configuration=SendMessageConfiguration(return_immediately=False), + ) + ) + events = [event async for event in it] + + (final_task,) = (await client.list_tasks(ListTasksRequest())).tasks + + if use_legacy: + assert [part.text for part in final_task.history[0].parts] == [ + 'Second task' + ] + else: + assert [part.text for part in final_task.history[0].parts] == [ + 'First task' + ] + + # Validate that new version logs with error exactly once 'Ignoring task replacement' + error_logs = [ + record.message + for record in caplog.records + if record.levelname == 'ERROR' + and 'Ignoring task replacement' in record.message + ] + assert len(error_logs) == 1 From 6b5651102326ae4c7e8936c1109a0f09693c9034 Mon Sep 17 00:00:00 2001 From: "Agent2Agent (A2A) Bot" Date: Fri, 10 Apr 2026 07:21:13 -0500 Subject: [PATCH 24/67] chore(1.0-dev): release 1.0.0-alpha.1 (#861) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit :robot: I have created a release *beep* *boop* --- ## [1.0.0-alpha.1](https://github.com/a2aproject/a2a-python/compare/v1.0.0-alpha.0...v1.0.0-alpha.1) (2026-04-10) ### ⚠ BREAKING CHANGES * **client:** make ClientConfig.push_notification_config singular ([#955](https://github.com/a2aproject/a2a-python/issues/955)) * **client:** reorganize ClientFactory API ([#947](https://github.com/a2aproject/a2a-python/issues/947)) * **server:** add build_user function to DefaultContextBuilder to allow A2A user creation customization ([#925](https://github.com/a2aproject/a2a-python/issues/925)) * **client:** remove `ClientTaskManager` and `Consumers` from client ([#916](https://github.com/a2aproject/a2a-python/issues/916)) * **server:** migrate from Application wrappers to Starlette route-based endpoints for rest ([#892](https://github.com/a2aproject/a2a-python/issues/892)) * **server:** migrate from Application wrappers to Starlette route-based endpoints for jsonrpc ([#873](https://github.com/a2aproject/a2a-python/issues/873)) ### Features * A2A Version Header validation on server side. ([#865](https://github.com/a2aproject/a2a-python/issues/865)) ([b261ceb](https://github.com/a2aproject/a2a-python/commit/b261ceb98bf46cc1e479fcdace52fef8371c8e58)) * Add GetExtendedAgentCard Support to RequestHandlers ([#919](https://github.com/a2aproject/a2a-python/issues/919)) ([2159140](https://github.com/a2aproject/a2a-python/commit/2159140b1c24fe556a41accf97a6af7f54ec6701)) * Add support for more Task Message and Artifact fields in the Vertex Task Store ([#936](https://github.com/a2aproject/a2a-python/issues/936)) ([605fa49](https://github.com/a2aproject/a2a-python/commit/605fa4913ad23539a51a3ee1f5b9ca07f24e1d2d)) * Create EventQueue interface and make tap() async. ([#914](https://github.com/a2aproject/a2a-python/issues/914)) ([9ccf99c](https://github.com/a2aproject/a2a-python/commit/9ccf99c63d4e556eadea064de6afa0b4fc4e19d6)), closes [#869](https://github.com/a2aproject/a2a-python/issues/869) * EventQueue - unify implementation between python versions ([#877](https://github.com/a2aproject/a2a-python/issues/877)) ([7437b88](https://github.com/a2aproject/a2a-python/commit/7437b88328fc71ed07e8e50f22a2eb0df4bf4201)), closes [#869](https://github.com/a2aproject/a2a-python/issues/869) * EventQueue is now a simple interface with single enqueue_event method. ([#944](https://github.com/a2aproject/a2a-python/issues/944)) ([f0e1d74](https://github.com/a2aproject/a2a-python/commit/f0e1d74802e78a4e9f4c22cbc85db104137e0cd2)) * Implementation of DefaultRequestHandlerV2 ([#933](https://github.com/a2aproject/a2a-python/issues/933)) ([462eb3c](https://github.com/a2aproject/a2a-python/commit/462eb3cb7b6070c258f5672aa3b0aa59e913037c)), closes [#869](https://github.com/a2aproject/a2a-python/issues/869) * InMemoryTaskStore creates a copy of Task by default to make it consistent with database task stores ([#887](https://github.com/a2aproject/a2a-python/issues/887)) ([8c65e84](https://github.com/a2aproject/a2a-python/commit/8c65e84fb844251ce1d8f04d26dbf465a89b9a29)), closes [#869](https://github.com/a2aproject/a2a-python/issues/869) * merge metadata of new and old artifact when append=True ([#945](https://github.com/a2aproject/a2a-python/issues/945)) ([cc094aa](https://github.com/a2aproject/a2a-python/commit/cc094aa51caba8107b63982e9b79256f7c2d331a)) * **server:** add async context manager support to EventQueue ([#743](https://github.com/a2aproject/a2a-python/issues/743)) ([f68b22f](https://github.com/a2aproject/a2a-python/commit/f68b22f0323ed4ff9267fabcf09c9d873baecc39)) * **server:** validate presence according to `google.api.field_behavior` annotations ([#870](https://github.com/a2aproject/a2a-python/issues/870)) ([4586c3e](https://github.com/a2aproject/a2a-python/commit/4586c3ec0b507d64caa3ced72d68a34ec5b37a11)) * Simplify ActiveTask.subscribe() ([#958](https://github.com/a2aproject/a2a-python/issues/958)) ([62e5e59](https://github.com/a2aproject/a2a-python/commit/62e5e59a30b11b9b493f7bf969aa13173ce51b9c)) * Support AgentExectuor enqueue of a Task object. ([#960](https://github.com/a2aproject/a2a-python/issues/960)) ([12ce017](https://github.com/a2aproject/a2a-python/commit/12ce0179056db9d9ba2abdd559cb5a4bb5a20ddf)) * Support Message-only simplified execution without creating Task ([#956](https://github.com/a2aproject/a2a-python/issues/956)) ([354fdfb](https://github.com/a2aproject/a2a-python/commit/354fdfb68dd0c7894daaac885a06dfed0ab839c8)) * Unhandled exception in AgentExecutor marks task as failed ([#943](https://github.com/a2aproject/a2a-python/issues/943)) ([4fc6b54](https://github.com/a2aproject/a2a-python/commit/4fc6b54fd26cc83d810d81f923579a1cd4853b39)) ### Bug Fixes * Add `packaging` to base dependencies ([#897](https://github.com/a2aproject/a2a-python/issues/897)) ([7a9aec7](https://github.com/a2aproject/a2a-python/commit/7a9aec7779448faa85a828d1076bcc47cda7bdbb)) * **client:** do not mutate SendMessageRequest in BaseClient.send_message ([#949](https://github.com/a2aproject/a2a-python/issues/949)) ([94537c3](https://github.com/a2aproject/a2a-python/commit/94537c382be4160332279a44d83254feeb0b8037)) * fix `athrow()` RuntimeError on streaming responses ([#912](https://github.com/a2aproject/a2a-python/issues/912)) ([ca7edc3](https://github.com/a2aproject/a2a-python/commit/ca7edc3b670538ce0f051c49f2224173f186d3f4)) * fix docstrings related to `CallContextBuilder` args in constructors and make ServerCallContext mandatory in `compat` folder ([#907](https://github.com/a2aproject/a2a-python/issues/907)) ([9cade9b](https://github.com/a2aproject/a2a-python/commit/9cade9bdadfb94f2f857ec2dc302a2c402e7f0ea)) * fix error handling for gRPC and SSE streaming ([#879](https://github.com/a2aproject/a2a-python/issues/879)) ([2b323d0](https://github.com/a2aproject/a2a-python/commit/2b323d0b191279fb5f091199aa30865299d5fcf2)) * fix JSONRPC error handling ([#957](https://github.com/a2aproject/a2a-python/issues/957)) ([6c807d5](https://github.com/a2aproject/a2a-python/commit/6c807d51c49ac294a6e3cbec34be101d4f91870d)) * fix REST error handling ([#893](https://github.com/a2aproject/a2a-python/issues/893)) ([405be3f](https://github.com/a2aproject/a2a-python/commit/405be3fa3ef8c60f730452b956879beeaecc5957)) * handle SSE errors occurred after stream started ([#894](https://github.com/a2aproject/a2a-python/issues/894)) ([3a68d8f](https://github.com/a2aproject/a2a-python/commit/3a68d8f916d96ae135748ee2b9b907f8dace4fa7)) * remove the use of deprecated types from VertexTaskStore ([#889](https://github.com/a2aproject/a2a-python/issues/889)) ([6d49122](https://github.com/a2aproject/a2a-python/commit/6d49122238a5e7d497c5d002792732446071dcb2)) * Remove unconditional SQLAlchemy dependency from SDK core ([#898](https://github.com/a2aproject/a2a-python/issues/898)) ([ab762f0](https://github.com/a2aproject/a2a-python/commit/ab762f0448911a9ac05b6e3fec0104615e0ec557)), closes [#883](https://github.com/a2aproject/a2a-python/issues/883) * remove unused import and request for FastAPI in pyproject ([#934](https://github.com/a2aproject/a2a-python/issues/934)) ([fe5de77](https://github.com/a2aproject/a2a-python/commit/fe5de77a1d457958fe14fec61b0d8aa41c5ec300)) * replace stale entry in a2a.types.__all__ with actual import name ([#902](https://github.com/a2aproject/a2a-python/issues/902)) ([05cd5e9](https://github.com/a2aproject/a2a-python/commit/05cd5e9b73b55d2863c58c13be0c7dd21d8124bb)) * wrong method name for ExtendedAgentCard endpoint in JsonRpc compat version ([#931](https://github.com/a2aproject/a2a-python/issues/931)) ([5d22186](https://github.com/a2aproject/a2a-python/commit/5d22186b8ee0f64b744512cdbe7ab6176fa97c60)) ### Documentation * add Database Migration Documentation ([#864](https://github.com/a2aproject/a2a-python/issues/864)) ([fd12dff](https://github.com/a2aproject/a2a-python/commit/fd12dffa3a7aa93816c762a155ed9b505086b924)) ### Miscellaneous Chores * release 1.0.0-alpha.1 ([a61f6d4](https://github.com/a2aproject/a2a-python/commit/a61f6d4e2e7ce1616a35c3a2ede64a4c9067048a)) ### Code Refactoring * **client:** make ClientConfig.push_notification_config singular ([#955](https://github.com/a2aproject/a2a-python/issues/955)) ([be4c5ff](https://github.com/a2aproject/a2a-python/commit/be4c5ff17a2f58e20d5d333a5e8e7bfcaa58c6c0)) * **client:** remove `ClientTaskManager` and `Consumers` from client ([#916](https://github.com/a2aproject/a2a-python/issues/916)) ([97058bb](https://github.com/a2aproject/a2a-python/commit/97058bb444ea663d77c3b62abcf2fd0c30a1a526)), closes [#734](https://github.com/a2aproject/a2a-python/issues/734) * **client:** reorganize ClientFactory API ([#947](https://github.com/a2aproject/a2a-python/issues/947)) ([01b3b2c](https://github.com/a2aproject/a2a-python/commit/01b3b2c0e196b0aab4f1f0dc22a95c09c7ee914d)) * **server:** add build_user function to DefaultContextBuilder to allow A2A user creation customization ([#925](https://github.com/a2aproject/a2a-python/issues/925)) ([2648c5e](https://github.com/a2aproject/a2a-python/commit/2648c5e50281ceb9795b10a726bd23670b363ae1)) * **server:** migrate from Application wrappers to Starlette route-based endpoints for jsonrpc ([#873](https://github.com/a2aproject/a2a-python/issues/873)) ([734d062](https://github.com/a2aproject/a2a-python/commit/734d0621dc6170d10d0cdf9c074e5ae28531fc71)) * **server:** migrate from Application wrappers to Starlette route-based endpoints for rest ([#892](https://github.com/a2aproject/a2a-python/issues/892)) ([4be2064](https://github.com/a2aproject/a2a-python/commit/4be2064b5d511e0b4617507ed0c376662688ebeb)) --- This PR was generated with [Release Please](https://github.com/googleapis/release-please). See [documentation](https://github.com/googleapis/release-please#release-please). --- .release-please-manifest.json | 2 +- CHANGELOG.md | 68 +++++++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+), 1 deletion(-) diff --git a/.release-please-manifest.json b/.release-please-manifest.json index 575c8ef05..6415ed078 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1 +1 @@ -{".":"1.0.0-alpha.0"} +{".":"1.0.0-alpha.1"} diff --git a/CHANGELOG.md b/CHANGELOG.md index 8e6162523..7e4715609 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,73 @@ # Changelog +## [1.0.0-alpha.1](https://github.com/a2aproject/a2a-python/compare/v1.0.0-alpha.0...v1.0.0-alpha.1) (2026-04-10) + + +### ⚠ BREAKING CHANGES + +* **client:** make ClientConfig.push_notification_config singular ([#955](https://github.com/a2aproject/a2a-python/issues/955)) +* **client:** reorganize ClientFactory API ([#947](https://github.com/a2aproject/a2a-python/issues/947)) +* **server:** add build_user function to DefaultContextBuilder to allow A2A user creation customization ([#925](https://github.com/a2aproject/a2a-python/issues/925)) +* **client:** remove `ClientTaskManager` and `Consumers` from client ([#916](https://github.com/a2aproject/a2a-python/issues/916)) +* **server:** migrate from Application wrappers to Starlette route-based endpoints for rest ([#892](https://github.com/a2aproject/a2a-python/issues/892)) +* **server:** migrate from Application wrappers to Starlette route-based endpoints for jsonrpc ([#873](https://github.com/a2aproject/a2a-python/issues/873)) + +### Features + +* A2A Version Header validation on server side. ([#865](https://github.com/a2aproject/a2a-python/issues/865)) ([b261ceb](https://github.com/a2aproject/a2a-python/commit/b261ceb98bf46cc1e479fcdace52fef8371c8e58)) +* Add GetExtendedAgentCard Support to RequestHandlers ([#919](https://github.com/a2aproject/a2a-python/issues/919)) ([2159140](https://github.com/a2aproject/a2a-python/commit/2159140b1c24fe556a41accf97a6af7f54ec6701)) +* Add support for more Task Message and Artifact fields in the Vertex Task Store ([#908](https://github.com/a2aproject/a2a-python/issues/908)) ([5e0dcd7](https://github.com/a2aproject/a2a-python/commit/5e0dcd798fcba16a8092b0b4c2d3d8026ca287de)) +* Add support for more Task Message and Artifact fields in the Vertex Task Store ([#936](https://github.com/a2aproject/a2a-python/issues/936)) ([605fa49](https://github.com/a2aproject/a2a-python/commit/605fa4913ad23539a51a3ee1f5b9ca07f24e1d2d)) +* Create EventQueue interface and make tap() async. ([#914](https://github.com/a2aproject/a2a-python/issues/914)) ([9ccf99c](https://github.com/a2aproject/a2a-python/commit/9ccf99c63d4e556eadea064de6afa0b4fc4e19d6)), closes [#869](https://github.com/a2aproject/a2a-python/issues/869) +* EventQueue - unify implementation between python versions ([#877](https://github.com/a2aproject/a2a-python/issues/877)) ([7437b88](https://github.com/a2aproject/a2a-python/commit/7437b88328fc71ed07e8e50f22a2eb0df4bf4201)), closes [#869](https://github.com/a2aproject/a2a-python/issues/869) +* EventQueue is now a simple interface with single enqueue_event method. ([#944](https://github.com/a2aproject/a2a-python/issues/944)) ([f0e1d74](https://github.com/a2aproject/a2a-python/commit/f0e1d74802e78a4e9f4c22cbc85db104137e0cd2)) +* Implementation of DefaultRequestHandlerV2 ([#933](https://github.com/a2aproject/a2a-python/issues/933)) ([462eb3c](https://github.com/a2aproject/a2a-python/commit/462eb3cb7b6070c258f5672aa3b0aa59e913037c)), closes [#869](https://github.com/a2aproject/a2a-python/issues/869) +* InMemoryTaskStore creates a copy of Task by default to make it consistent with database task stores ([#887](https://github.com/a2aproject/a2a-python/issues/887)) ([8c65e84](https://github.com/a2aproject/a2a-python/commit/8c65e84fb844251ce1d8f04d26dbf465a89b9a29)), closes [#869](https://github.com/a2aproject/a2a-python/issues/869) +* merge metadata of new and old artifact when append=True ([#945](https://github.com/a2aproject/a2a-python/issues/945)) ([cc094aa](https://github.com/a2aproject/a2a-python/commit/cc094aa51caba8107b63982e9b79256f7c2d331a)) +* **server:** add async context manager support to EventQueue ([#743](https://github.com/a2aproject/a2a-python/issues/743)) ([f68b22f](https://github.com/a2aproject/a2a-python/commit/f68b22f0323ed4ff9267fabcf09c9d873baecc39)) +* **server:** validate presence according to `google.api.field_behavior` annotations ([#870](https://github.com/a2aproject/a2a-python/issues/870)) ([4586c3e](https://github.com/a2aproject/a2a-python/commit/4586c3ec0b507d64caa3ced72d68a34ec5b37a11)) +* Simplify ActiveTask.subscribe() ([#958](https://github.com/a2aproject/a2a-python/issues/958)) ([62e5e59](https://github.com/a2aproject/a2a-python/commit/62e5e59a30b11b9b493f7bf969aa13173ce51b9c)) +* Support AgentExectuor enqueue of a Task object. ([#960](https://github.com/a2aproject/a2a-python/issues/960)) ([12ce017](https://github.com/a2aproject/a2a-python/commit/12ce0179056db9d9ba2abdd559cb5a4bb5a20ddf)) +* Support Message-only simplified execution without creating Task ([#956](https://github.com/a2aproject/a2a-python/issues/956)) ([354fdfb](https://github.com/a2aproject/a2a-python/commit/354fdfb68dd0c7894daaac885a06dfed0ab839c8)) +* Unhandled exception in AgentExecutor marks task as failed ([#943](https://github.com/a2aproject/a2a-python/issues/943)) ([4fc6b54](https://github.com/a2aproject/a2a-python/commit/4fc6b54fd26cc83d810d81f923579a1cd4853b39)) + + +### Bug Fixes + +* Add `packaging` to base dependencies ([#897](https://github.com/a2aproject/a2a-python/issues/897)) ([7a9aec7](https://github.com/a2aproject/a2a-python/commit/7a9aec7779448faa85a828d1076bcc47cda7bdbb)) +* **client:** do not mutate SendMessageRequest in BaseClient.send_message ([#949](https://github.com/a2aproject/a2a-python/issues/949)) ([94537c3](https://github.com/a2aproject/a2a-python/commit/94537c382be4160332279a44d83254feeb0b8037)) +* fix `athrow()` RuntimeError on streaming responses ([#912](https://github.com/a2aproject/a2a-python/issues/912)) ([ca7edc3](https://github.com/a2aproject/a2a-python/commit/ca7edc3b670538ce0f051c49f2224173f186d3f4)) +* fix docstrings related to `CallContextBuilder` args in constructors and make ServerCallContext mandatory in `compat` folder ([#907](https://github.com/a2aproject/a2a-python/issues/907)) ([9cade9b](https://github.com/a2aproject/a2a-python/commit/9cade9bdadfb94f2f857ec2dc302a2c402e7f0ea)) +* fix error handling for gRPC and SSE streaming ([#879](https://github.com/a2aproject/a2a-python/issues/879)) ([2b323d0](https://github.com/a2aproject/a2a-python/commit/2b323d0b191279fb5f091199aa30865299d5fcf2)) +* fix JSONRPC error handling ([#957](https://github.com/a2aproject/a2a-python/issues/957)) ([6c807d5](https://github.com/a2aproject/a2a-python/commit/6c807d51c49ac294a6e3cbec34be101d4f91870d)) +* fix REST error handling ([#893](https://github.com/a2aproject/a2a-python/issues/893)) ([405be3f](https://github.com/a2aproject/a2a-python/commit/405be3fa3ef8c60f730452b956879beeaecc5957)) +* handle SSE errors occurred after stream started ([#894](https://github.com/a2aproject/a2a-python/issues/894)) ([3a68d8f](https://github.com/a2aproject/a2a-python/commit/3a68d8f916d96ae135748ee2b9b907f8dace4fa7)) +* remove the use of deprecated types from VertexTaskStore ([#889](https://github.com/a2aproject/a2a-python/issues/889)) ([6d49122](https://github.com/a2aproject/a2a-python/commit/6d49122238a5e7d497c5d002792732446071dcb2)) +* Remove unconditional SQLAlchemy dependency from SDK core ([#898](https://github.com/a2aproject/a2a-python/issues/898)) ([ab762f0](https://github.com/a2aproject/a2a-python/commit/ab762f0448911a9ac05b6e3fec0104615e0ec557)), closes [#883](https://github.com/a2aproject/a2a-python/issues/883) +* remove unused import and request for FastAPI in pyproject ([#934](https://github.com/a2aproject/a2a-python/issues/934)) ([fe5de77](https://github.com/a2aproject/a2a-python/commit/fe5de77a1d457958fe14fec61b0d8aa41c5ec300)) +* replace stale entry in a2a.types.__all__ with actual import name ([#902](https://github.com/a2aproject/a2a-python/issues/902)) ([05cd5e9](https://github.com/a2aproject/a2a-python/commit/05cd5e9b73b55d2863c58c13be0c7dd21d8124bb)) +* wrong method name for ExtendedAgentCard endpoint in JsonRpc compat version ([#931](https://github.com/a2aproject/a2a-python/issues/931)) ([5d22186](https://github.com/a2aproject/a2a-python/commit/5d22186b8ee0f64b744512cdbe7ab6176fa97c60)) + + +### Documentation + +* add Database Migration Documentation ([#864](https://github.com/a2aproject/a2a-python/issues/864)) ([fd12dff](https://github.com/a2aproject/a2a-python/commit/fd12dffa3a7aa93816c762a155ed9b505086b924)) + + +### Miscellaneous Chores + +* release 1.0.0-alpha.1 ([a61f6d4](https://github.com/a2aproject/a2a-python/commit/a61f6d4e2e7ce1616a35c3a2ede64a4c9067048a)) + + +### Code Refactoring + +* **client:** make ClientConfig.push_notification_config singular ([#955](https://github.com/a2aproject/a2a-python/issues/955)) ([be4c5ff](https://github.com/a2aproject/a2a-python/commit/be4c5ff17a2f58e20d5d333a5e8e7bfcaa58c6c0)) +* **client:** remove `ClientTaskManager` and `Consumers` from client ([#916](https://github.com/a2aproject/a2a-python/issues/916)) ([97058bb](https://github.com/a2aproject/a2a-python/commit/97058bb444ea663d77c3b62abcf2fd0c30a1a526)), closes [#734](https://github.com/a2aproject/a2a-python/issues/734) +* **client:** reorganize ClientFactory API ([#947](https://github.com/a2aproject/a2a-python/issues/947)) ([01b3b2c](https://github.com/a2aproject/a2a-python/commit/01b3b2c0e196b0aab4f1f0dc22a95c09c7ee914d)) +* **server:** add build_user function to DefaultContextBuilder to allow A2A user creation customization ([#925](https://github.com/a2aproject/a2a-python/issues/925)) ([2648c5e](https://github.com/a2aproject/a2a-python/commit/2648c5e50281ceb9795b10a726bd23670b363ae1)) +* **server:** migrate from Application wrappers to Starlette route-based endpoints for jsonrpc ([#873](https://github.com/a2aproject/a2a-python/issues/873)) ([734d062](https://github.com/a2aproject/a2a-python/commit/734d0621dc6170d10d0cdf9c074e5ae28531fc71)) +* **server:** migrate from Application wrappers to Starlette route-based endpoints for rest ([#892](https://github.com/a2aproject/a2a-python/issues/892)) ([4be2064](https://github.com/a2aproject/a2a-python/commit/4be2064b5d511e0b4617507ed0c376662688ebeb)) + ## 1.0.0-alpha.0 (2026-03-17) From 57a6624d94b104ec2064a82f2334ea41caeff1ae Mon Sep 17 00:00:00 2001 From: Iva Sokolaj <102302011+sokoliva@users.noreply.github.com> Date: Tue, 14 Apr 2026 16:42:25 +0200 Subject: [PATCH 25/67] fix(samples): emit `Task(TASK_STATE_SUBMITTED)` as first streaming event (#970) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Description Updates the sample agent and CLI to correctly follow the A2A streaming event contract, where the first event in a stream must be a `Task` or a `Message` object in `TASK_STATE_SUBMITTED` state. # Changes **hello_world_agent.py** `SampleAgentExecutor.execute()` now enqueues a `Task(TASK_STATE_SUBMITTED)` object as its very first event, before any TaskUpdater calls. The initial user message is included in the Task's history field, since the consumer sets message_to_save = None upon receiving a Task event (expecting the task to carry the message itself). **cli.py** Updates `_handle_stream` to match the new event contract: The first event is now expected to be a `Message` or a `Task` (not an (event, task) tuple), and its id is used to initialize `current_task_id`. **README.md** Adds a `README.md` for the samples. # Tested ``` uv run samples/cli.py Connecting to http://127.0.0.1:41241 (preferred transport: Any) ✓ Agent Card Found: Name: Sample Agent Picked Transport: JsonRpcTransport Connected! Send a message or type /quit to exit. You: hi Task [state=TASK_STATE_SUBMITTED] TaskStatusUpdate [state=TASK_STATE_WORKING]: Processing your question... TaskArtifactUpdate [name=response]: Hello World! Nice to meet you! TaskStatusUpdate [state=TASK_STATE_COMPLETED]: --- Task Finished --- You: /quit ``` Related issue #965 🦕 --- samples/README.md | 58 +++++++++++++++++++++++++++++ samples/cli.py | 71 ++++++++++++++++++++---------------- samples/hello_world_agent.py | 12 ++++++ 3 files changed, 110 insertions(+), 31 deletions(-) create mode 100644 samples/README.md diff --git a/samples/README.md b/samples/README.md new file mode 100644 index 000000000..e61264955 --- /dev/null +++ b/samples/README.md @@ -0,0 +1,58 @@ +# A2A Python SDK — Samples + +This directory contains runnable examples demonstrating how to build and interact with an A2A-compliant agent using the Python SDK. + +## Contents + +| File | Role | Description | +|---|---|---| +| `hello_world_agent.py` | **Server** | A2A agent server | +| `cli.py` | **Client** | Interactive terminal client | + +The samples are designed to work together out of the box: the agent listens on `http://127.0.0.1:41241`, which is the default URL used by the client. +--- + +## `hello_world_agent.py` — Agent Server + +Implements an A2A agent that responds to simple greeting messages (e.g., "hello", "how are you", "bye") with text replies, simulating a 1-second processing delay. + +Demonstrates: +- Subclassing `AgentExecutor` and implementing `execute()` / `cancel()` +- Publishing streaming status updates and artifacts via `TaskUpdater` +- Exposing all three transports in both protocol versions (v1.0 and v0.3 compat) simultaneously: + - **JSON-RPC** (v1.0 and v0.3) at `http://127.0.0.1:41241/a2a/jsonrpc` + - **HTTP+JSON (REST)** (v1.0 and v0.3) at `http://127.0.0.1:41241/a2a/rest` + - **gRPC v1.0** on port `50051` + - **gRPC v0.3 (compat)** on port `50052` +- Serving the agent card at `http://127.0.0.1:41241/.well-known/agent-card.json` + +**Run:** + +```bash +uv run python samples/hello_world_agent.py +``` + +--- + +## `cli.py` — Client + +An interactive terminal client with full visibility into the streaming event flow. Each `TaskStatusUpdate` and `TaskArtifactUpdate` event is printed as it arrives. + +Features: +- Transport selection via `--transport` flag (`JSONRPC`, `HTTP+JSON`, `GRPC`) +- Session management (`context_id` persisted across messages, `task_id` per task) +- Graceful error handling for HTTP and gRPC failures + +**Run:** + +```bash +# Connect to the local hello_world_agent (default): +uv run python samples/cli.py + +# Connect to a different URL, using gRPC: +uv run python samples/cli.py --url http://192.168.1.10:41241 --transport GRPC +``` + +Then type a message like `hello` and press Enter. + +Type `/quit` or `/exit` to stop, or press `Ctrl+C`. diff --git a/samples/cli.py b/samples/cli.py index 8515fd5a9..7f72b5494 100644 --- a/samples/cli.py +++ b/samples/cli.py @@ -13,42 +13,51 @@ from a2a.types import Message, Part, Role, SendMessageRequest, TaskState -async def _handle_stream( +async def _handle_stream( # noqa: PLR0912 stream: Any, current_task_id: str | None ) -> str | None: - async for event, task in stream: - if not task: - continue + async for event in stream: + if event.HasField('message'): + print('Message:', end=' ') + for part in event.message.parts: + if part.text: + print(part.text, end=' ') + print() + return None + if not current_task_id: - current_task_id = task.id - - if event: - if event.HasField('status_update'): - state_name = TaskState.Name(event.status_update.status.state) - print(f'TaskStatusUpdate [state={state_name}]:', end=' ') - if event.status_update.status.HasField('message'): - for part in event.status_update.status.message.parts: - if part.text: - print(part.text, end=' ') - print() - - if ( - event.status_update.status.state - == TaskState.TASK_STATE_COMPLETED - ): - current_task_id = None - print('--- Task Completed ---') - - elif event.HasField('artifact_update'): - print( - f'TaskArtifactUpdate [name={event.artifact_update.artifact.name}]:', - end=' ', - ) - for part in event.artifact_update.artifact.parts: + if event.HasField('task'): + current_task_id = event.task.id + print('--- Task Started ---') + print(f'Task [state={TaskState.Name(event.task.status.state)}]') + else: + raise ValueError(f'Unexpected first event: {event}') + + if event.HasField('status_update'): + state_name = TaskState.Name(event.status_update.status.state) + print(f'TaskStatusUpdate [state={state_name}]:', end=' ') + if event.status_update.status.HasField('message'): + for part in event.status_update.status.message.parts: if part.text: print(part.text, end=' ') - print() - + print() + if state_name in ( + 'TASK_STATE_COMPLETED', + 'TASK_STATE_FAILED', + 'TASK_STATE_CANCELED', + 'TASK_STATE_REJECTED', + ): + current_task_id = None + print('--- Task Finished ---') + elif event.HasField('artifact_update'): + print( + f'TaskArtifactUpdate [name={event.artifact_update.artifact.name}]:', + end=' ', + ) + for part in event.artifact_update.artifact.parts: + if part.text: + print(part.text, end=' ') + print() return current_task_id diff --git a/samples/hello_world_agent.py b/samples/hello_world_agent.py index 8db34dc03..4c9e6f18a 100644 --- a/samples/hello_world_agent.py +++ b/samples/hello_world_agent.py @@ -27,6 +27,9 @@ AgentProvider, AgentSkill, Part, + Task, + TaskState, + TaskStatus, a2a_pb2_grpc, ) @@ -75,6 +78,15 @@ async def execute( context_id, ) + await event_queue.enqueue_event( + Task( + id=task_id, + context_id=context_id, + status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED), + history=[user_message], + ) + ) + updater = TaskUpdater( event_queue=event_queue, task_id=task_id, From 0bfec889db2f500410b0214cb826a8872bd9bcec Mon Sep 17 00:00:00 2001 From: Iva Sokolaj <102302011+sokoliva@users.noreply.github.com> Date: Tue, 14 Apr 2026 16:48:56 +0200 Subject: [PATCH 26/67] docs: update GEMINI setup (#968) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Description Updated gemini setup. It is applicable to other models as well. Example of `ai_learnings.md` entries: ``` ## 2026-04-13 — Using extend() on a str return value **Mistake**: Used `response_parts.extend(get_message_text(...))` and `response_parts.extend(get_artifact_text(...))` where both functions return `str`. `list.extend()` on a string iterates its characters, producing `['H', 'e', 'l', 'l', 'o']` instead of `['Hello']`. **Root cause**: Assumed the utility functions returned an iterable of strings rather than a single string, and did not check their signatures or run the tests before presenting the code. **Rule**: Before calling `extend()`, verify the return type of the expression. If it returns `str`, use `append()`. Run the tests after any change to aggregation logic. --- ## 2026-04-13 — Assuming streaming event order without verifying across transports **Mistake**: Added a strict check that the first streaming event must be a `Task` or `Message`, raising a `RuntimeError` otherwise. Also used `event.WhichOneof("event")` which fails because `StreamResponse` has no oneof named "event" — its fields are independent message fields. The REST transport sends a `status_update` as its first event, not a `Task`, so the guard rejected valid responses. **Root cause**: Assumed spec wording ("first event should be a Task") held across all transport implementations without testing it. Did not check the `StreamResponse` proto definition before calling `WhichOneof`. **Rule**: Before adding ordering assumptions about streaming events, verify the behaviour against every transport (JSONRPC, HTTP+JSON, GRPC). Before calling `WhichOneof`, confirm the oneof name exists in the proto. --- ## 2026-04-13 — Assuming extras without checking dev dependencies **Mistake**: Told the user that `http-server` and `grpc` extras needed to be specified explicitly in the samples README prerequisites. **Root cause**: Looked at the SDK's optional extras list and reasoned from imports in the sample files, without checking whether the dev dependency group already covered them. The dev group includes `a2a-sdk[all]`, so a plain `uv sync` installs everything. Checking the actual installed environment with one command would have revealed this immediately. **Rule**: Before writing installation instructions, verify what is already provided by the project's dev dependencies (`uv sync` with no flags). Do not recommend extra flags unless confirmed they are absent from the dev group. --- ## 2026-04-13 — Proposing unverified code **Mistake**: Proposed `_GRPC_ERROR = None` as a way to make `grpc` optional in an `except` clause. `None` is not a valid exception type in Python; the code would have crashed at runtime. **Root cause**: The fix was reasoned about at a high level ("set it to None when grpc is absent") without tracing through whether Python actually accepts `None` in an `except` tuple. No verification step was performed before presenting it to the user. **Rule**: Before presenting any code change, trace through its execution explicitly. For `except` clauses specifically: every element in the tuple must be an exception class, never `None` or any other non-exception value. --- ## 2026-04-14 — Race condition when reading state from DB in stream **Mistake**: Used `active_task.get_task()` in `on_message_send_stream` to fetch the task state for the initial response. This caused a race condition where `get_task()` returned a task state that was ahead of the stream events, leading to test failures. **Root cause**: Assumed `get_task()` would return the state corresponding to the event being processed, overlooking that the consumer loop runs independently and may have already processed subsequent events and updated the DB. **Rule**: When processing a stream of events, do not rely on reading the current state from a shared store (like DB) to represent the state at the time of a specific event. Use state snapshots passed with the event if available. ``` --- .gitignore | 1 + GEMINI.md | 17 +++++++++++++++++ docs/ai/ai_learnings.md | 19 +++++++++++++++++++ 3 files changed, 37 insertions(+) create mode 100644 docs/ai/ai_learnings.md diff --git a/.gitignore b/.gitignore index a0903bd35..bc3689e5a 100644 --- a/.gitignore +++ b/.gitignore @@ -12,6 +12,7 @@ coverage.xml spec.json docker-compose.yaml .geminiignore +docs/ai/ai_learnings.md # ITK Integration Test Artifacts itk/a2a-samples/ diff --git a/GEMINI.md b/GEMINI.md index 59ef64713..b801bd47d 100644 --- a/GEMINI.md +++ b/GEMINI.md @@ -23,3 +23,20 @@ 1. **Required Reading**: You MUST read the contents of @./docs/ai/coding_conventions.md and @./docs/ai/mandatory_checks.md at the very beginning of EVERY coding task. 2. **Initial Checklist**: Every `task.md` you create MUST include a section for **Mandatory Checks** from @./docs/ai/mandatory_checks.md. 3. **Verification Requirement**: You MUST run all mandatory checks before declaring any task finished. + +## 5. Mistake Reflection Protocol + +When you realise you have made a mistake — whether caught by the user, +by a tool, or by your own reasoning — you MUST: + +1. **Acknowledge the mistake explicitly** and explain what went wrong. +2. **Reflect on the root cause**: was it a missing check, a false + assumption, skipped verification, or a gap in the workflow? +3. **Immediately append a new entry to @./docs/ai/ai_learnings.md** + following the format defined in that file. This is not optional and + does not require user confirmation. Do it before continuing. Update user + about the changes to the workflow in the current chat. + +The goal is to treat every mistake as a signal that the workflow is +incomplete, and to improve it in place so the same mistake cannot +happen again. diff --git a/docs/ai/ai_learnings.md b/docs/ai/ai_learnings.md new file mode 100644 index 000000000..9e9a37a9f --- /dev/null +++ b/docs/ai/ai_learnings.md @@ -0,0 +1,19 @@ +> [!NOTE] for Users: +> This document is meant to be read by an AI assistant (Gemini) in order to +> learn from its mistakes and improve its behavior on this project. Use +> its findings to improve GEMINI.md setup. + +# AI Learnings + +A living record of mistakes made during this project and the rules +derived from them. Every entry must follow the format below. + +--- + +## Entry format + +**Mistake**: What went wrong. +**Root cause**: Why it happened. +**Rule**: The concrete rule added to prevent recurrence. + +--- From 3468180ac7396d453d99ce3e74cdd7f5a0afb5ab Mon Sep 17 00:00:00 2001 From: Iva Sokolaj <102302011+sokoliva@users.noreply.github.com> Date: Wed, 15 Apr 2026 15:08:24 +0200 Subject: [PATCH 27/67] feat(utils): add `display_agent_card()` utility for human-readable AgentCard inspection (#972) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Description Adds a `display_agent_card(card)` utility function to `a2a.utils` that prints a structured, human-readable summary of an `AgentCard` proto to stdout. ## Motivation The current proto text format is complete but difficult to read at a glance: name: "Sample Agent" supported_interfaces { url: "http://127.0.0.1:41241/a2a/jsonrpc" protocol_binding: "JSONRPC" protocol_version: "1.0" } ... At least four workarounds exist across `a2a-samples` for printing card contents. This provides a single, simple solution. ## Changes - `src/a2a/utils/agent_card.py` — new file with `display_agent_card(card: AgentCard) -> None` - `src/a2a/utils/__init__.py` — exports `display_agent_card` - `tests/utils/test_agent_card_display.py` — 5 unit tests including a full golden test - `samples/cli.py` — utilize the new display function ## Example output from `sample/cli.py` ``` uv run samples/cli.py Connecting to http://127.0.0.1:41241 (preferred transport: Any) ✓ Agent Card Found: ==================================================== AgentCard ==================================================== --- General --- Name : Sample Agent Description : A sample agent to test the stream functionality. Version : 1.0.0 Provider : A2A Samples (https://example.com) --- Interfaces --- [0] 127.0.0.1:50051 (GRPC 1.0) [1] 127.0.0.1:50052 (GRPC 0.3) [2] http://127.0.0.1:41241/a2a/jsonrpc (JSONRPC 1.0) [3] http://127.0.0.1:41241/a2a/jsonrpc (JSONRPC 0.3) [4] http://127.0.0.1:41241/a2a/rest (HTTP+JSON 1.0) [5] http://127.0.0.1:41241/a2a/rest (HTTP+JSON 0.3) --- Capabilities --- Streaming : True Push notifications : False Extended agent card : False --- I/O Modes --- Input : text Output : text, task-status --- Skills --- ---------------------------------------------------- ID : sample_agent Name : Sample Agent Description : Say hi. Tags : sample Example : hi ==================================================== Picked Transport: JsonRpcTransport ``` ## Notes - No breaking changes. Existing call sites are unaffected. - Optional fields (`documentation_url`, `icon_url`, `provider`) are shown only when set. - Closes #961 Fixes #961 🦕 --- samples/cli.py | 35 +++-- src/a2a/utils/__init__.py | 2 + src/a2a/utils/agent_card.py | 76 ++++++++++ tests/utils/test_agent_card_display.py | 194 +++++++++++++++++++++++++ 4 files changed, 289 insertions(+), 18 deletions(-) create mode 100644 src/a2a/utils/agent_card.py create mode 100644 tests/utils/test_agent_card_display.py diff --git a/samples/cli.py b/samples/cli.py index 7f72b5494..54b68388f 100644 --- a/samples/cli.py +++ b/samples/cli.py @@ -11,18 +11,16 @@ from a2a.client import A2ACardResolver, ClientConfig, create_client from a2a.types import Message, Part, Role, SendMessageRequest, TaskState +from a2a.utils import get_artifact_text, get_message_text +from a2a.utils.agent_card import display_agent_card -async def _handle_stream( # noqa: PLR0912 +async def _handle_stream( stream: Any, current_task_id: str | None ) -> str | None: async for event in stream: if event.HasField('message'): - print('Message:', end=' ') - for part in event.message.parts: - if part.text: - print(part.text, end=' ') - print() + print('Message:', get_message_text(event.message, delimiter=' ')) return None if not current_task_id: @@ -35,12 +33,15 @@ async def _handle_stream( # noqa: PLR0912 if event.HasField('status_update'): state_name = TaskState.Name(event.status_update.status.state) - print(f'TaskStatusUpdate [state={state_name}]:', end=' ') - if event.status_update.status.HasField('message'): - for part in event.status_update.status.message.parts: - if part.text: - print(part.text, end=' ') - print() + message_text = ( + ': ' + + get_message_text( + event.status_update.status.message, delimiter=' ' + ) + if event.status_update.status.HasField('message') + else '' + ) + print(f'TaskStatusUpdate [state={state_name}]{message_text}') if state_name in ( 'TASK_STATE_COMPLETED', 'TASK_STATE_FAILED', @@ -52,12 +53,10 @@ async def _handle_stream( # noqa: PLR0912 elif event.HasField('artifact_update'): print( f'TaskArtifactUpdate [name={event.artifact_update.artifact.name}]:', - end=' ', + get_artifact_text( + event.artifact_update.artifact, delimiter=' ' + ), ) - for part in event.artifact_update.artifact.parts: - if part.text: - print(part.text, end=' ') - print() return current_task_id @@ -86,7 +85,7 @@ async def main() -> None: resolver = A2ACardResolver(httpx_client, args.url) card = await resolver.get_agent_card() print('\n✓ Agent Card Found:') - print(f' Name: {card.name}') + display_agent_card(card) client = await create_client(card, client_config=config) diff --git a/src/a2a/utils/__init__.py b/src/a2a/utils/__init__.py index a502bfb62..1efed5794 100644 --- a/src/a2a/utils/__init__.py +++ b/src/a2a/utils/__init__.py @@ -1,6 +1,7 @@ """Utility functions for the A2A Python SDK.""" from a2a.utils import proto_utils +from a2a.utils.agent_card import display_agent_card from a2a.utils.artifact import ( get_artifact_text, new_artifact, @@ -44,6 +45,7 @@ 'build_text_artifact', 'completed_task', 'create_task_obj', + 'display_agent_card', 'get_artifact_text', 'get_data_parts', 'get_file_parts', diff --git a/src/a2a/utils/agent_card.py b/src/a2a/utils/agent_card.py new file mode 100644 index 000000000..0962e67fb --- /dev/null +++ b/src/a2a/utils/agent_card.py @@ -0,0 +1,76 @@ +"""Utility functions for inspecting AgentCard instances.""" + +from a2a.types.a2a_pb2 import AgentCard + + +def display_agent_card(card: AgentCard) -> None: + """Print a human-readable summary of an AgentCard to stdout. + + Args: + card: The AgentCard proto message to display. + """ + width = 52 + sep = '=' * width + thin = '-' * width + + lines: list[str] = [sep, 'AgentCard'.center(width), sep] + + lines += [ + '--- General ---', + f'Name : {card.name}', + f'Description : {card.description}', + f'Version : {card.version}', + ] + if card.documentation_url: + lines.append(f'Docs URL : {card.documentation_url}') + if card.icon_url: + lines.append(f'Icon URL : {card.icon_url}') + if card.HasField('provider'): + url_suffix = f' ({card.provider.url})' if card.provider.url else '' + lines.append(f'Provider : {card.provider.organization}{url_suffix}') + + lines += ['', '--- Interfaces ---'] + for i, iface in enumerate(card.supported_interfaces): + binding = f'{iface.protocol_binding} {iface.protocol_version}'.strip() + parts = [ + p + for p in [binding, f'tenant={iface.tenant}' if iface.tenant else ''] + if p + ] + suffix = f' ({", ".join(parts)})' if parts else '' + line = f' [{i}] {iface.url}{suffix}' + lines.append(line) + + lines += [ + '', + '--- Capabilities ---', + f'Streaming : {card.capabilities.streaming}', + f'Push notifications : {card.capabilities.push_notifications}', + f'Extended agent card : {card.capabilities.extended_agent_card}', + ] + + lines += [ + '', + '--- I/O Modes ---', + f'Input : {", ".join(card.default_input_modes) or "(none)"}', + f'Output : {", ".join(card.default_output_modes) or "(none)"}', + ] + + lines += ['', '--- Skills ---'] + if card.skills: + for skill in card.skills: + lines += [ + thin, + f' ID : {skill.id}', + f' Name : {skill.name}', + f' Description : {skill.description}', + f' Tags : {", ".join(skill.tags) or "(none)"}', + ] + if skill.examples: + for ex in skill.examples: + lines.append(f' Example : {ex}') + else: + lines.append(' (none)') + + lines.append(sep) + print('\n'.join(lines)) diff --git a/tests/utils/test_agent_card_display.py b/tests/utils/test_agent_card_display.py new file mode 100644 index 000000000..93dc1aad4 --- /dev/null +++ b/tests/utils/test_agent_card_display.py @@ -0,0 +1,194 @@ +"""Tests for display_agent_card utility.""" + +import pytest + +from a2a.types.a2a_pb2 import ( + AgentCapabilities, + AgentCard, + AgentInterface, + AgentProvider, + AgentSkill, +) +from a2a.utils.agent_card import display_agent_card + + +@pytest.fixture +def full_agent_card() -> AgentCard: + return AgentCard( + name='Sample Agent', + description='A sample agent.', + version='1.0.0', + documentation_url='https://docs.example.com', + icon_url='https://example.com/icon.png', + provider=AgentProvider( + organization='Example Org', url='https://example.com' + ), + supported_interfaces=[ + AgentInterface( + url='http://localhost:9999/a2a/jsonrpc', + protocol_binding='JSONRPC', + protocol_version='1.0', + ), + AgentInterface( + url='http://localhost:9999/a2a/rest', + protocol_binding='HTTP+JSON', + protocol_version='1.0', + tenant='tenant-a', + ), + ], + capabilities=AgentCapabilities( + streaming=True, + push_notifications=False, + extended_agent_card=True, + ), + default_input_modes=['text'], + default_output_modes=['text', 'task-status'], + skills=[ + AgentSkill( + id='skill-1', + name='My Skill', + description='Does something useful.', + tags=['foo', 'bar'], + examples=['Do the thing', 'Another example'], + ), + AgentSkill( + id='skill-2', + name='Other Skill', + description='Does something else.', + tags=['baz'], + ), + ], + ) + + +class TestDisplayAgentCard: + def test_full_card_output( + self, full_agent_card: AgentCard, capsys: pytest.CaptureFixture[str] + ) -> None: + """Golden test: exact output for a fully-populated card.""" + display_agent_card(full_agent_card) + assert capsys.readouterr().out == ( + '====================================================\n' + ' AgentCard \n' + '====================================================\n' + '--- General ---\n' + 'Name : Sample Agent\n' + 'Description : A sample agent.\n' + 'Version : 1.0.0\n' + 'Docs URL : https://docs.example.com\n' + 'Icon URL : https://example.com/icon.png\n' + 'Provider : Example Org (https://example.com)\n' + '\n' + '--- Interfaces ---\n' + ' [0] http://localhost:9999/a2a/jsonrpc (JSONRPC 1.0)\n' + ' [1] http://localhost:9999/a2a/rest (HTTP+JSON 1.0, tenant=tenant-a)\n' + '\n' + '--- Capabilities ---\n' + 'Streaming : True\n' + 'Push notifications : False\n' + 'Extended agent card : True\n' + '\n' + '--- I/O Modes ---\n' + 'Input : text\n' + 'Output : text, task-status\n' + '\n' + '--- Skills ---\n' + '----------------------------------------------------\n' + ' ID : skill-1\n' + ' Name : My Skill\n' + ' Description : Does something useful.\n' + ' Tags : foo, bar\n' + ' Example : Do the thing\n' + ' Example : Another example\n' + '----------------------------------------------------\n' + ' ID : skill-2\n' + ' Name : Other Skill\n' + ' Description : Does something else.\n' + ' Tags : baz\n' + '====================================================\n' + ) + + def test_empty_card_output( + self, capsys: pytest.CaptureFixture[str] + ) -> None: + """Golden test: exact output for a card with only default/empty fields. + + An empty supported_interfaces section signals a malformed card — + the bare header with no entries is intentional and visible to the user. + """ + display_agent_card(AgentCard()) + assert capsys.readouterr().out == ( + '====================================================\n' + ' AgentCard \n' + '====================================================\n' + '--- General ---\n' + 'Name : \n' + 'Description : \n' + 'Version : \n' + '\n' + '--- Interfaces ---\n' + '\n' + '--- Capabilities ---\n' + 'Streaming : False\n' + 'Push notifications : False\n' + 'Extended agent card : False\n' + '\n' + '--- I/O Modes ---\n' + 'Input : (none)\n' + 'Output : (none)\n' + '\n' + '--- Skills ---\n' + ' (none)\n' + '====================================================\n' + ) + + def test_interface_without_protocol_version_has_no_trailing_space( + self, capsys: pytest.CaptureFixture[str] + ) -> None: + """No trailing space in the binding field when protocol_version is not set.""" + card = AgentCard( + supported_interfaces=[ + AgentInterface( + url='127.0.0.1:50051', + protocol_binding='GRPC', + ) + ] + ) + display_agent_card(card) + assert ' [0] 127.0.0.1:50051 (GRPC)' in capsys.readouterr().out + + def test_interface_without_binding_or_version_has_no_parentheses( + self, capsys: pytest.CaptureFixture[str] + ) -> None: + """No parentheses when neither protocol_binding nor protocol_version are set.""" + card = AgentCard( + supported_interfaces=[AgentInterface(url='127.0.0.1:50051')] + ) + display_agent_card(card) + assert ' [0] 127.0.0.1:50051\n' in capsys.readouterr().out + + def test_provider_with_url( + self, capsys: pytest.CaptureFixture[str] + ) -> None: + """Provider shows organization and URL in parentheses when both are set.""" + card = AgentCard( + provider=AgentProvider( + organization='Example Org', + url='https://example.com', + ), + ) + display_agent_card(card) + assert ( + 'Provider : Example Org (https://example.com)' + in capsys.readouterr().out + ) + + def test_provider_without_url_has_no_empty_parentheses( + self, capsys: pytest.CaptureFixture[str] + ) -> None: + """No empty parentheses when provider URL is not set.""" + card = AgentCard(provider=AgentProvider(organization='Example Org')) + display_agent_card(card) + out = capsys.readouterr().out + assert 'Provider : Example Org' in out + assert '()' not in out From b44fa8f9d2180d56d8c95eeff8a2e9260447db36 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=A1bor=20Feh=C3=A9r?= Date: Thu, 16 Apr 2026 09:57:51 +0200 Subject: [PATCH 28/67] fix: Don't generate empty metadata change events in VertexTaskStore (#962) For #751 --- src/a2a/contrib/tasks/vertex_task_store.py | 7 ++- tests/contrib/tasks/test_vertex_task_store.py | 47 +++++++++++++++++++ 2 files changed, 53 insertions(+), 1 deletion(-) diff --git a/src/a2a/contrib/tasks/vertex_task_store.py b/src/a2a/contrib/tasks/vertex_task_store.py index 5ba9147f5..91f514af8 100644 --- a/src/a2a/contrib/tasks/vertex_task_store.py +++ b/src/a2a/contrib/tasks/vertex_task_store.py @@ -109,7 +109,12 @@ def _get_status_details_change_event( def _get_metadata_change_event( self, previous_task: Task, task: Task, event_sequence_number: int ) -> vertexai_types.TaskEvent | None: - if task.metadata != previous_task.metadata: + # We generate metadata change events if the metadata was changed. + # We don't generate events if the metadata was changed from + # one empty value to another, e.g. {} to None. + if task.metadata != previous_task.metadata and ( + task.metadata or previous_task.metadata + ): return vertexai_types.TaskEvent( event_data=vertexai_types.TaskEventData( metadata_change=vertexai_types.TaskMetadataChange( diff --git a/tests/contrib/tasks/test_vertex_task_store.py b/tests/contrib/tasks/test_vertex_task_store.py index ed99c09bb..e7d31f435 100644 --- a/tests/contrib/tasks/test_vertex_task_store.py +++ b/tests/contrib/tasks/test_vertex_task_store.py @@ -508,6 +508,53 @@ async def test_metadata_field_mapping( assert retrieved_none.metadata == {} +@pytest.mark.asyncio +async def test_metadata_empty_transitions( + vertex_store: VertexTaskStore, +) -> None: + """Test that updating metadata between {} and None does not generate events.""" + task_id = 'task-metadata-empty-test' + + # Step 1: Create task with metadata={} + task = Task( + id=task_id, + context_id='session-meta-empty', + status=TaskStatus(state=TaskState.submitted), + kind='task', + metadata={}, + ) + await vertex_store.save(task) + + full_name = f'{vertex_store._agent_engine_resource_id}/a2aTasks/{task_id}' + + # Get initial event sequence number + stored_task_before = ( + await vertex_store._client.aio.agent_engines.a2a_tasks.get(full_name) + ) + initial_seq = stored_task_before.next_event_sequence_number + + # Step 2: Update metadata to None + updated_task = task.model_copy(deep=True) + updated_task.metadata = None + await vertex_store.save(updated_task) + + # Step 3: Update back to {} + task_back = updated_task.model_copy(deep=True) + task_back.metadata = {} + await vertex_store.save(task_back) + + # Verify that retrieved task still has {} (due to mapping) + retrieved = await vertex_store.get(task_id) + assert retrieved is not None + assert retrieved.metadata == {} + + # Verify that next_event_sequence_number did NOT increase (no events generated) + stored_task_after = ( + await vertex_store._client.aio.agent_engines.a2a_tasks.get(full_name) + ) + assert stored_task_after.next_event_sequence_number == initial_seq + + @pytest.mark.asyncio async def test_update_task_status_details( vertex_store: VertexTaskStore, From b58b03ef58bd806db3accbe6dca8fc444a43bc18 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=A1bor=20Feh=C3=A9r?= Date: Thu, 16 Apr 2026 14:48:23 +0200 Subject: [PATCH 29/67] fix: Don't generate empty metadata change events in VertexTaskStore (#974) For #802 --- src/a2a/contrib/tasks/vertex_task_store.py | 7 ++- tests/contrib/tasks/test_vertex_task_store.py | 52 +++++++++++++++++++ 2 files changed, 58 insertions(+), 1 deletion(-) diff --git a/src/a2a/contrib/tasks/vertex_task_store.py b/src/a2a/contrib/tasks/vertex_task_store.py index 0457694e4..602d5c6fd 100644 --- a/src/a2a/contrib/tasks/vertex_task_store.py +++ b/src/a2a/contrib/tasks/vertex_task_store.py @@ -116,7 +116,12 @@ def _get_metadata_change_event( task: CompatTask, event_sequence_number: int, ) -> vertexai_types.TaskEvent | None: - if task.metadata != previous_task.metadata: + # We generate metadata change events if the metadata was changed. + # We don't generate events if the metadata was changed from + # one empty value to another, e.g. {} to None. + if task.metadata != previous_task.metadata and ( + task.metadata or previous_task.metadata + ): return vertexai_types.TaskEvent( event_data=vertexai_types.TaskEventData( metadata_change=vertexai_types.TaskMetadataChange( diff --git a/tests/contrib/tasks/test_vertex_task_store.py b/tests/contrib/tasks/test_vertex_task_store.py index 4be8cd4e6..c77493022 100644 --- a/tests/contrib/tasks/test_vertex_task_store.py +++ b/tests/contrib/tasks/test_vertex_task_store.py @@ -534,6 +534,58 @@ async def test_metadata_field_mapping( assert retrieved_none.metadata == {} +@pytest.mark.asyncio +async def test_metadata_empty_transitions( + vertex_store: VertexTaskStore, +) -> None: + """Test that updating metadata between {} and None does not generate events.""" + task_id = 'task-metadata-empty-test' + + # Step 1: Create task with metadata={} + task = Task( + id=task_id, + context_id='session-meta-empty', + status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED), + metadata={}, + ) + await vertex_store.save(task, ServerCallContext()) + + full_name = f'{vertex_store._agent_engine_resource_id}/a2aTasks/{task_id}' + + # Get initial event sequence number + stored_task_before = ( + await vertex_store._client.aio.agent_engines.a2a_tasks.get( + name=full_name + ) + ) + initial_seq = stored_task_before.next_event_sequence_number + + # Step 2: Update metadata to None + updated_task = Task() + updated_task.CopyFrom(task) + updated_task.metadata.Clear() + await vertex_store.save(updated_task, ServerCallContext()) + + # Step 3: Update back to {} + task_back = Task() + task_back.CopyFrom(updated_task) + task_back.metadata = {} + await vertex_store.save(task_back, ServerCallContext()) + + # Verify that retrieved task still has {} (due to mapping) + retrieved = await vertex_store.get(task_id, ServerCallContext()) + assert retrieved is not None + assert retrieved.metadata == {} + + # Verify that next_event_sequence_number did NOT increase (no events generated) + stored_task_after = ( + await vertex_store._client.aio.agent_engines.a2a_tasks.get( + name=full_name + ) + ) + assert stored_task_after.next_event_sequence_number == initial_seq + + @pytest.mark.asyncio async def test_update_task_status_details( vertex_store: VertexTaskStore, From d667e4fa55e99225eb3c02e009b426a3bc2d449d Mon Sep 17 00:00:00 2001 From: Bartek Wolowiec Date: Fri, 17 Apr 2026 08:14:50 +0200 Subject: [PATCH 30/67] docs: AgentExecutor interface documentation (#976) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes #869 🦕 --- .../server/agent_execution/agent_executor.py | 55 +++++++--- tests/integration/test_scenarios.py | 103 ++++++++++++++++++ 2 files changed, 140 insertions(+), 18 deletions(-) diff --git a/src/a2a/server/agent_execution/agent_executor.py b/src/a2a/server/agent_execution/agent_executor.py index 2da8ddfd7..1c3866047 100644 --- a/src/a2a/server/agent_execution/agent_executor.py +++ b/src/a2a/server/agent_execution/agent_executor.py @@ -23,20 +23,43 @@ async def execute( return once the agent's execution for this request is complete or yields control (e.g., enters an input-required state). - TODO: Document request lifecycle and AgentExecutor responsibilities: - - Should not close the event_queue. - - Guarantee single execution per request (no concurrent execution). - - Throwing exception will result in TaskState.TASK_STATE_ERROR (CHECK!) - - Once call is completed it should not access context or event_queue - - Before completing the call it SHOULD update task status to terminal or interrupted state. - - Explain AUTH_REQUIRED workflow. - - Explain INPUT_REQUIRED workflow. - - Explain how cancelation work (executor task will be canceled, cancel() is called, order of calls, etc) - - Explain if execute can wait for cancel and if cancel can wait for execute. - - Explain behaviour of streaming / not-immediate when execute() returns in active state. - - Possible workflows: - - Enqueue a SINGLE Message object - - Enqueue TaskStatusUpdateEvent (TASK_STATE_SUBMITTED or TASK_STATE_REJECTED) and continue with TaskStatusUpdateEvent / TaskArtifactUpdateEvent. + Request Lifecycle & AgentExecutor Responsibilities: + - **Concurrency**: The framework guarantees single execution per request; + `execute()` will not be called concurrently for the same request context. + - **Exception Handling**: Unhandled exceptions raised by `execute()` will be + caught by the framework and result in the task transitioning to + `TaskState.TASK_STATE_ERROR`. + - **Post-Completion**: Once `execute()` completes (returns or raises), the + executor must not access the `context` or `event_queue` anymore. + - **Terminal States**: Before completing the call normally, the executor + SHOULD publish a `TaskStatusUpdateEvent` to transition the task to a + terminal state (e.g., `TASK_STATE_COMPLETED`) or an interrupted state + (`TASK_STATE_INPUT_REQUIRED` or `TASK_STATE_AUTH_REQUIRED`). + - **Interrupted Workflows**: + - `TASK_STATE_INPUT_REQUIRED`: The executor publishes a `TaskStatusUpdateEvent` with + `TaskState.TASK_STATE_INPUT_REQUIRED` and returns to yield control. + The request will resume once user input is provided. + - `TASK_STATE_AUTH_REQUIRED`: There are in-bound and out-of-bound auth models. + In both scenarios, the agent publishes a `TaskStatusUpdateEvent` with + `TaskState.TASK_STATE_AUTH_REQUIRED`. + - In-bound: The agent should return from `execute()`. The framework will + call `execute()` again once the user response is received. + - Out-of-bound: The agent should not return from `execute()`. It should wait + for the out-of-band auth provider to complete the authentication and then + continue execution. + + - **Cancellation Workflow**: When a cancellation request is received, the + async task running `execute()` is cancelled (raising an `asyncio.CancelledError`), + and `cancel()` is explicitly called by the framework. + + Allowed Workflows: + - Immediate response: Enqueue a SINGLE `Message` object. + - Asynchronous/Long-running: Enqueue a `Task` object, perform work, and emit + multiple `TaskStatusUpdateEvent` / `TaskArtifactUpdateEvent` objects over time. + + Note that the framework waits with response to the send_message request with + `return_immediately=True` parameter until the first event (Message or Task) + is enqueued by AgentExecutor. Args: context: The request context containing the message, task ID, etc. @@ -53,10 +76,6 @@ async def cancel( in the context and publish a `TaskStatusUpdateEvent` with state `TaskState.TASK_STATE_CANCELED` to the `event_queue`. - TODO: Document cancelation workflow. - - What if TaskState.TASK_STATE_CANCELED is not set by cancel() ? - - How it can interact with execute() ? - Args: context: The request context containing the task ID to cancel. event_queue: The queue to publish the cancellation status update to. diff --git a/tests/integration/test_scenarios.py b/tests/integration/test_scenarios.py index cee15bfcb..c50622e5c 100644 --- a/tests/integration/test_scenarios.py +++ b/tests/integration/test_scenarios.py @@ -113,6 +113,22 @@ def agent_card(): ) +def get_task_id(event): + if event.HasField('task'): + return event.task.id + if event.HasField('status_update'): + return event.status_update.task_id + assert False, f'Event {event} has no task_id' + + +def get_task_context_id(event): + if event.HasField('task'): + return event.task.context_id + if event.HasField('status_update'): + return event.status_update.context_id + assert False, f'Event {event} has no context_id' + + def get_state(event): if event.HasField('task'): return event.task.status.state @@ -1265,6 +1281,93 @@ async def cancel( ) +# Scenario: Auth required and in channel unblocking +@pytest.mark.timeout(2.0) +@pytest.mark.asyncio +@pytest.mark.parametrize('use_legacy', [False, True], ids=['v2', 'legacy']) +@pytest.mark.parametrize( + 'streaming', [False, True], ids=['blocking', 'streaming'] +) +async def test_scenario_auth_required_in_channel(use_legacy, streaming): + class AuthAgent(AgentExecutor): + async def execute( + self, context: RequestContext, event_queue: EventQueue + ): + message = context.message + if message and message.parts and message.parts[0].text == 'start': + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + status=TaskStatus( + state=TaskState.TASK_STATE_AUTH_REQUIRED + ), + ) + ) + elif ( + message + and message.parts + and message.parts[0].text == 'credentials' + ): + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), + ) + ) + else: + raise ValueError(f'Unexpected message {message}') + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ): + pass + + handler = create_handler(AuthAgent(), use_legacy) + client = await create_client( + handler, agent_card=agent_card(), streaming=streaming + ) + + msg1 = Message( + message_id='msg-start', role=Role.ROLE_USER, parts=[Part(text='start')] + ) + + it = client.send_message( + SendMessageRequest( + message=msg1, + configuration=SendMessageConfiguration(return_immediately=False), + ) + ) + + events1 = [event async for event in it] + assert [get_state(event) for event in events1] == [ + TaskState.TASK_STATE_AUTH_REQUIRED, + ] + task_id = get_task_id(events1[0]) + context_id = get_task_context_id(events1[0]) + + # Now send another message with credentials + msg2 = Message( + task_id=task_id, + context_id=context_id, + message_id='msg-creds', + role=Role.ROLE_USER, + parts=[Part(text='credentials')], + ) + + it2 = client.send_message( + SendMessageRequest( + message=msg2, + configuration=SendMessageConfiguration(return_immediately=False), + ) + ) + + assert [get_state(event) async for event in it2] == [ + TaskState.TASK_STATE_COMPLETED, + ] + + # Scenario: Parallel subscribe attach detach # Migrated from: test_parallel_subscribe_attach_detach in test_handler_comparison @pytest.mark.timeout(5.0) From 186335925f16c3430f72577cff78e40cfa151eda Mon Sep 17 00:00:00 2001 From: kdziedzic70 Date: Fri, 17 Apr 2026 10:44:51 +0200 Subject: [PATCH 31/67] test: improved itk logging (#977) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Description New version of itk https://github.com/a2aproject/a2a-samples/releases/tag/itk-v.015-alpha improves log readabiulity for debugging by spliting the logs of individual tested agents into separate files if the `ITK_LOG_LEVEL` environmental variable is set to "DEBUG" This PR integrates the change into python's sdk CI and updates the instruction on how to set debugging mode for tests Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [x] Follow the [`CONTRIBUTING` Guide](https://github.com/a2aproject/a2a-python/blob/main/CONTRIBUTING.md). - [x] Make your Pull Request title in the specification. - Important Prefixes for [release-please](https://github.com/googleapis/release-please): - `fix:` which represents bug fixes, and correlates to a [SemVer](https://semver.org/) patch. - `feat:` represents a new feature, and correlates to a SemVer minor. - `feat!:`, or `fix!:`, `refactor!:`, etc., which represent a breaking change (indicated by the `!`) and will result in a SemVer major. - [x] Ensure the tests and linter pass (Run `bash scripts/format.sh` from the repository root to format) - [x] Appropriate docs were updated (if necessary) Fixes # 🦕 Co-authored-by: Krzysztof Dziedzic Co-authored-by: Ivan Shymko --- .github/workflows/itk.yaml | 2 +- .gitignore | 1 + itk/README.md | 21 ++++++++++++++++++++- itk/main.py | 7 +++++-- itk/run_itk.sh | 15 +++++++++++++++ 5 files changed, 42 insertions(+), 4 deletions(-) diff --git a/.github/workflows/itk.yaml b/.github/workflows/itk.yaml index 3a2c58143..f846e2d7c 100644 --- a/.github/workflows/itk.yaml +++ b/.github/workflows/itk.yaml @@ -28,4 +28,4 @@ jobs: run: bash run_itk.sh working-directory: itk env: - A2A_SAMPLES_REVISION: itk-v.0.11-alpha + A2A_SAMPLES_REVISION: itk-v.015-alpha diff --git a/.gitignore b/.gitignore index bc3689e5a..14bccd39b 100644 --- a/.gitignore +++ b/.gitignore @@ -18,3 +18,4 @@ docs/ai/ai_learnings.md itk/a2a-samples/ itk/pyproto/ itk/instruction.proto +itk/logs/ diff --git a/itk/README.md b/itk/README.md index 63ec68fad..eaa5f254a 100644 --- a/itk/README.md +++ b/itk/README.md @@ -36,7 +36,7 @@ You must set the `A2A_SAMPLES_REVISION` environment variable to specify which re Example: ```bash -export A2A_SAMPLES_REVISION=itk-v.0.11-alpha +export A2A_SAMPLES_REVISION=itk-v.015-alpha ``` ### 2. Execute Tests @@ -52,3 +52,22 @@ The script will: - Checkout the specified revision. - Build the ITK service Docker image. - Run the tests and output results. + +## Debugging + +To enable debug logging and persist logs for inspection: + +1. Set the `ITK_LOG_LEVEL` environment variable to `DEBUG`: + ```bash + export ITK_LOG_LEVEL=DEBUG + ``` +2. Run the test script: + ```bash + ./run_itk.sh + ``` + +When run in `DEBUG` mode: +- The `logs/` directory will be created in this directory (if it doesn't exist). +- The `logs/` directory will be mounted to the container. +- The test execution will produce detailed logs in `logs/` (e.g., `agent_current.log`). +- The `logs/` directory will **not** be removed during cleanup. diff --git a/itk/main.py b/itk/main.py index 7be7a5a20..5ce062fac 100644 --- a/itk/main.py +++ b/itk/main.py @@ -2,6 +2,7 @@ import asyncio import base64 import logging +import os import uuid import grpc @@ -36,7 +37,8 @@ from a2a.utils import TransportProtocol -logging.basicConfig(level=logging.INFO) +log_level = os.environ.get('ITK_LOG_LEVEL', 'INFO').upper() +logging.basicConfig(level=log_level) logger = logging.getLogger(__name__) @@ -352,8 +354,9 @@ async def main_async(http_port: int, grpc_port: int) -> None: grpc_port, ) + uvicorn_log_level = os.environ.get('ITK_LOG_LEVEL', 'INFO').lower() config = uvicorn.Config( - app, host='127.0.0.1', port=http_port, log_level='info' + app, host='127.0.0.1', port=http_port, log_level=uvicorn_log_level ) uvicorn_server = uvicorn.Server(config) diff --git a/itk/run_itk.sh b/itk/run_itk.sh index 80e96f9c2..2d9371c14 100755 --- a/itk/run_itk.sh +++ b/itk/run_itk.sh @@ -1,6 +1,9 @@ #!/bin/bash set -ex +# Set default log level +export ITK_LOG_LEVEL="${ITK_LOG_LEVEL:-INFO}" + # Initialize default exit code RESULT=1 @@ -63,9 +66,21 @@ ITK_DIR=$(pwd) # Stop existing container if any docker rm -f itk-service || true +# Create logs directory if debug +if [ "${ITK_LOG_LEVEL^^}" = "DEBUG" ]; then + mkdir -p "$ITK_DIR/logs" +fi + +DOCKER_MOUNT_LOGS="" +if [ "${ITK_LOG_LEVEL^^}" = "DEBUG" ]; then + DOCKER_MOUNT_LOGS="-v $ITK_DIR/logs:/app/logs" +fi + docker run -d --name itk-service \ -v "$A2A_PYTHON_ROOT:/app/agents/repo" \ -v "$ITK_DIR:/app/agents/repo/itk" \ + $DOCKER_MOUNT_LOGS \ + -e ITK_LOG_LEVEL="$ITK_LOG_LEVEL" \ -p 8000:8000 \ itk_service From f922ff683bac8ff8e7a495c4b02e03e86125d467 Mon Sep 17 00:00:00 2001 From: kdziedzic70 Date: Fri, 17 Apr 2026 11:46:27 +0200 Subject: [PATCH 32/67] test: force itk agent to create task before updating the status (#980) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Description Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [ ] Follow the [`CONTRIBUTING` Guide](https://github.com/a2aproject/a2a-python/blob/main/CONTRIBUTING.md). - [ ] Make your Pull Request title in the specification. - Important Prefixes for [release-please](https://github.com/googleapis/release-please): - `fix:` which represents bug fixes, and correlates to a [SemVer](https://semver.org/) patch. - `feat:` represents a new feature, and correlates to a SemVer minor. - `feat!:`, or `fix!:`, `refactor!:`, etc., which represent a breaking change (indicated by the `!`) and will result in a SemVer major. - [ ] Ensure the tests and linter pass (Run `bash scripts/format.sh` from the repository root to format) - [ ] Appropriate docs were updated (if necessary) Fixes # 🦕 Co-authored-by: Krzysztof Dziedzic --- .github/workflows/itk.yaml | 2 +- itk/README.md | 3 ++- itk/main.py | 13 ++++++++++++- 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/.github/workflows/itk.yaml b/.github/workflows/itk.yaml index f846e2d7c..ab272d0e3 100644 --- a/.github/workflows/itk.yaml +++ b/.github/workflows/itk.yaml @@ -28,4 +28,4 @@ jobs: run: bash run_itk.sh working-directory: itk env: - A2A_SAMPLES_REVISION: itk-v.015-alpha + A2A_SAMPLES_REVISION: itk-v.016-alpha diff --git a/itk/README.md b/itk/README.md index eaa5f254a..9a82d0469 100644 --- a/itk/README.md +++ b/itk/README.md @@ -35,7 +35,7 @@ podman system migrate You must set the `A2A_SAMPLES_REVISION` environment variable to specify which revision of the `a2a-samples` repository to use for testing. This can be a branch name, tag, or commit hash. Example: -```bash +``` export A2A_SAMPLES_REVISION=itk-v.015-alpha ``` @@ -58,6 +58,7 @@ The script will: To enable debug logging and persist logs for inspection: 1. Set the `ITK_LOG_LEVEL` environment variable to `DEBUG`: + ```bash export ITK_LOG_LEVEL=DEBUG ``` diff --git a/itk/main.py b/itk/main.py index 5ce062fac..6792c540a 100644 --- a/itk/main.py +++ b/itk/main.py @@ -32,7 +32,9 @@ Message, Part, SendMessageRequest, + Task, TaskState, + TaskStatus, ) from a2a.utils import TransportProtocol @@ -198,7 +200,16 @@ async def execute( context.context_id, ) - await task_updater.update_status(TaskState.TASK_STATE_SUBMITTED) + # Explicitly create the task by sending it to the queue + task = Task( + id=context.task_id, + context_id=context.context_id, + status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED), + history=[context.message] if context.message else [], + ) + async with task_updater._lock: # noqa: SLF001 + await event_queue.enqueue_event(task) + await task_updater.update_status(TaskState.TASK_STATE_WORKING) instruction = extract_instruction(context.message) From 2846be68278004196a5bf658488a883a5c4d446c Mon Sep 17 00:00:00 2001 From: Ivan Shymko Date: Fri, 17 Apr 2026 12:18:27 +0200 Subject: [PATCH 33/67] test: add extension propagation test in test_end_to_end.py (#981) --- tests/integration/test_end_to_end.py | 84 +++++++++++++++++++++++++++- 1 file changed, 81 insertions(+), 3 deletions(-) diff --git a/tests/integration/test_end_to_end.py b/tests/integration/test_end_to_end.py index 58dce528d..aea9784ad 100644 --- a/tests/integration/test_end_to_end.py +++ b/tests/integration/test_end_to_end.py @@ -5,15 +5,20 @@ import httpx import pytest import pytest_asyncio + from starlette.applications import Starlette from a2a.client.base_client import BaseClient -from a2a.client.client import ClientConfig +from a2a.client.client import ClientCallContext, ClientConfig from a2a.client.client_factory import ClientFactory +from a2a.client.service_parameters import ( + ServiceParametersFactory, + with_a2a_extensions, +) from a2a.server.agent_execution import AgentExecutor, RequestContext from a2a.server.events import EventQueue from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager -from a2a.server.request_handlers import GrpcHandler, DefaultRequestHandler +from a2a.server.request_handlers import DefaultRequestHandler, GrpcHandler from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes from a2a.server.routes.rest_routes import create_rest_routes from a2a.server.tasks import TaskUpdater @@ -21,6 +26,7 @@ from a2a.types import ( AgentCapabilities, AgentCard, + AgentExtension, AgentInterface, CancelTaskRequest, DeleteTaskPushNotificationConfigRequest, @@ -41,6 +47,12 @@ from a2a.utils.errors import InvalidParamsError +SUPPORTED_EXTENSION_URIS = [ + 'https://example.com/ext/v1', + 'https://example.com/ext/v2', +] + + def assert_message_matches(message, expected_role, expected_text): assert message.role == expected_role assert message.parts[0].text == expected_text @@ -87,6 +99,23 @@ class MockAgentExecutor(AgentExecutor): async def execute(self, context: RequestContext, event_queue: EventQueue): user_input = context.get_user_input() + # Extensions echo: activate all requested extensions and report them + # back via the Message.extensions field. + if user_input.startswith('Extensions:'): + for ext_uri in context.requested_extensions: + context.add_activated_extension(ext_uri) + await event_queue.enqueue_event( + Message( + role=Role.ROLE_AGENT, + message_id='ext-reply-1', + parts=[Part(text='extensions echoed')], + extensions=sorted( + context.call_context.activated_extensions + ), + ) + ) + return + # Direct message response (no task created). if user_input.startswith('Message:'): await event_queue.enqueue_event( @@ -142,7 +171,15 @@ def agent_card() -> AgentCard: description='Real in-memory integration testing.', version='1.0.0', capabilities=AgentCapabilities( - streaming=True, push_notifications=False + streaming=True, + push_notifications=False, + extensions=[ + AgentExtension( + uri=uri, + description=f'Test extension {uri}', + ) + for uri in SUPPORTED_EXTENSION_URIS + ], ), skills=[], default_input_modes=['text/plain'], @@ -757,3 +794,44 @@ async def test_end_to_end_direct_message_return_immediately(transport_setups): Role.ROLE_AGENT, 'Direct reply to: Message: Quick question', ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'streaming', + [ + pytest.param(False, id='blocking'), + pytest.param(True, id='streaming'), + ], +) +async def test_end_to_end_extensions_propagation(transport_setups, streaming): + """Test that extensions sent by the client reach the agent executor.""" + client = transport_setups.client + client._config.streaming = streaming + + service_params = ServiceParametersFactory.create( + [with_a2a_extensions(SUPPORTED_EXTENSION_URIS)] + ) + context = ClientCallContext(service_parameters=service_params) + + message_to_send = Message( + role=Role.ROLE_USER, + message_id='msg-ext-propagation', + parts=[Part(text='Extensions: echo')], + ) + + events = [ + event + async for event in client.send_message( + request=SendMessageRequest(message=message_to_send), + context=context, + ) + ] + + assert len(events) == 1 + response = events[0] + assert response.HasField('message') + assert_message_matches( + response.message, Role.ROLE_AGENT, 'extensions echoed' + ) + assert set(response.message.extensions) == set(SUPPORTED_EXTENSION_URIS) From 5f3ea292389cf72a25a7cf2792caceb4af45f6da Mon Sep 17 00:00:00 2001 From: Guglielmo Colombo Date: Fri, 17 Apr 2026 12:27:04 +0200 Subject: [PATCH 34/67] refactor!: extract developer helpers in helpers folder (#978) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Description Extracts developer-facing helper functions from a2a.utils into a dedicated a2a.helpers package. What changed - New a2a.helpers package with two modules: - proto_helpers.py — unified helpers for creating/inspecting Messages, Artifacts, Tasks, Events, and StreamResponses - agent_card.py — moved from utils/agent_card.py - Relocated internal functions to their actual consumers: - append_artifact_to_task -> server/tasks/task_manager.py - canonicalize_agent_card, _clean_empty -> utils/signing.py - Removed unused helpers Motivation This is the first in a series of PRs to simplify the a2a.utils structure. The goal is to stop mixing developer-facing convenience helpers with internal SDK machinery. --- samples/cli.py | 4 +- scripts/test_minimal_install.py | 5 +- src/a2a/client/__init__.py | 2 - src/a2a/client/helpers.py | 20 +- src/a2a/helpers/__init__.py | 34 +++ src/a2a/{utils => helpers}/agent_card.py | 0 src/a2a/helpers/proto_helpers.py | 214 ++++++++++++++++ src/a2a/server/agent_execution/context.py | 2 +- src/a2a/server/tasks/task_manager.py | 67 ++++- src/a2a/utils/__init__.py | 44 ---- src/a2a/utils/artifact.py | 92 ------- src/a2a/utils/helpers.py | 176 -------------- src/a2a/utils/message.py | 71 ------ src/a2a/utils/parts.py | 46 ---- src/a2a/utils/signing.py | 38 ++- src/a2a/utils/task.py | 76 +----- tests/client/test_client_helpers.py | 5 +- tests/client/transports/test_grpc_client.py | 2 +- tests/client/transports/test_rest_client.py | 24 +- tests/e2e/push_notifications/agent_app.py | 20 +- .../test_agent_card_display.py | 2 +- tests/helpers/test_proto_helpers.py | 230 ++++++++++++++++++ tests/integration/test_end_to_end.py | 5 +- .../test_default_request_handler.py | 11 +- .../test_default_request_handler_v2.py | 11 +- tests/utils/test_artifact.py | 161 ------------ tests/utils/test_helpers.py | 180 +------------- tests/utils/test_message.py | 209 ---------------- tests/utils/test_parts.py | 184 -------------- tests/utils/test_task.py | 186 +------------- 30 files changed, 638 insertions(+), 1483 deletions(-) create mode 100644 src/a2a/helpers/__init__.py rename src/a2a/{utils => helpers}/agent_card.py (100%) create mode 100644 src/a2a/helpers/proto_helpers.py delete mode 100644 src/a2a/utils/artifact.py delete mode 100644 src/a2a/utils/message.py delete mode 100644 src/a2a/utils/parts.py rename tests/{utils => helpers}/test_agent_card_display.py (99%) create mode 100644 tests/helpers/test_proto_helpers.py delete mode 100644 tests/utils/test_artifact.py delete mode 100644 tests/utils/test_message.py delete mode 100644 tests/utils/test_parts.py diff --git a/samples/cli.py b/samples/cli.py index 54b68388f..935834dd3 100644 --- a/samples/cli.py +++ b/samples/cli.py @@ -10,9 +10,9 @@ import httpx from a2a.client import A2ACardResolver, ClientConfig, create_client +from a2a.helpers import get_artifact_text, get_message_text +from a2a.helpers.agent_card import display_agent_card from a2a.types import Message, Part, Role, SendMessageRequest, TaskState -from a2a.utils import get_artifact_text, get_message_text -from a2a.utils.agent_card import display_agent_card async def _handle_stream( diff --git a/scripts/test_minimal_install.py b/scripts/test_minimal_install.py index 076df4c0f..0b29a48b6 100755 --- a/scripts/test_minimal_install.py +++ b/scripts/test_minimal_install.py @@ -50,14 +50,13 @@ 'a2a.server.tasks', 'a2a.types', 'a2a.utils', - 'a2a.utils.artifact', 'a2a.utils.constants', 'a2a.utils.error_handlers', 'a2a.utils.helpers', - 'a2a.utils.message', - 'a2a.utils.parts', 'a2a.utils.proto_utils', 'a2a.utils.task', + 'a2a.helpers.agent_card', + 'a2a.helpers.proto_helpers', ] diff --git a/src/a2a/client/__init__.py b/src/a2a/client/__init__.py index c23041f32..d33c09481 100644 --- a/src/a2a/client/__init__.py +++ b/src/a2a/client/__init__.py @@ -22,7 +22,6 @@ A2AClientTimeoutError, AgentCardResolutionError, ) -from a2a.client.helpers import create_text_message_object from a2a.client.interceptors import ClientCallInterceptor @@ -41,6 +40,5 @@ 'CredentialService', 'InMemoryContextCredentialStore', 'create_client', - 'create_text_message_object', 'minimal_agent_card', ] diff --git a/src/a2a/client/helpers.py b/src/a2a/client/helpers.py index fc7bfdbdf..f8207f03b 100644 --- a/src/a2a/client/helpers.py +++ b/src/a2a/client/helpers.py @@ -1,11 +1,10 @@ """Helper functions for the A2A client.""" from typing import Any -from uuid import uuid4 from google.protobuf.json_format import ParseDict -from a2a.types.a2a_pb2 import AgentCard, Message, Part, Role +from a2a.types.a2a_pb2 import AgentCard def parse_agent_card(agent_card_data: dict[str, Any]) -> AgentCard: @@ -111,20 +110,3 @@ def _handle_security_compatibility(agent_card_data: dict[str, Any]) -> None: new_scheme_wrapper = {mapped_name: scheme.copy()} scheme.clear() scheme.update(new_scheme_wrapper) - - -def create_text_message_object( - role: Role = Role.ROLE_USER, content: str = '' -) -> Message: - """Create a Message object containing a single text Part. - - Args: - role: The role of the message sender (user or agent). Defaults to Role.ROLE_USER. - content: The text content of the message. Defaults to an empty string. - - Returns: - A `Message` object with a new UUID message_id. - """ - return Message( - role=role, parts=[Part(text=content)], message_id=str(uuid4()) - ) diff --git a/src/a2a/helpers/__init__.py b/src/a2a/helpers/__init__.py new file mode 100644 index 000000000..c42429d43 --- /dev/null +++ b/src/a2a/helpers/__init__.py @@ -0,0 +1,34 @@ +"""Helper functions for the A2A Python SDK.""" + +from a2a.helpers.agent_card import display_agent_card +from a2a.helpers.proto_helpers import ( + get_artifact_text, + get_message_text, + get_stream_response_text, + get_text_parts, + new_artifact, + new_message, + new_task, + new_task_from_user_message, + new_text_artifact, + new_text_artifact_update_event, + new_text_message, + new_text_status_update_event, +) + + +__all__ = [ + 'display_agent_card', + 'get_artifact_text', + 'get_message_text', + 'get_stream_response_text', + 'get_text_parts', + 'new_artifact', + 'new_message', + 'new_task', + 'new_task_from_user_message', + 'new_text_artifact', + 'new_text_artifact_update_event', + 'new_text_message', + 'new_text_status_update_event', +] diff --git a/src/a2a/utils/agent_card.py b/src/a2a/helpers/agent_card.py similarity index 100% rename from src/a2a/utils/agent_card.py rename to src/a2a/helpers/agent_card.py diff --git a/src/a2a/helpers/proto_helpers.py b/src/a2a/helpers/proto_helpers.py new file mode 100644 index 000000000..79e1f739d --- /dev/null +++ b/src/a2a/helpers/proto_helpers.py @@ -0,0 +1,214 @@ +"""Unified helper functions for creating and handling A2A types.""" + +import uuid + +from collections.abc import Sequence + +from a2a.types.a2a_pb2 import ( + Artifact, + Message, + Part, + Role, + StreamResponse, + Task, + TaskArtifactUpdateEvent, + TaskState, + TaskStatus, + TaskStatusUpdateEvent, +) + + +# --- Message Helpers --- + + +def new_message( + parts: list[Part], + role: Role = Role.ROLE_AGENT, + context_id: str | None = None, + task_id: str | None = None, +) -> Message: + """Creates a new message containing a list of Parts.""" + return Message( + role=role, + parts=parts, + message_id=str(uuid.uuid4()), + task_id=task_id, + context_id=context_id, + ) + + +def new_text_message( + text: str, + context_id: str | None = None, + task_id: str | None = None, + role: Role = Role.ROLE_AGENT, +) -> Message: + """Creates a new message containing a single text Part.""" + return new_message( + parts=[Part(text=text)], + role=role, + task_id=task_id, + context_id=context_id, + ) + + +def get_message_text(message: Message, delimiter: str = '\n') -> str: + """Extracts and joins all text content from a Message's parts.""" + return delimiter.join(get_text_parts(message.parts)) + + +# --- Artifact Helpers --- + + +def new_artifact( + parts: list[Part], + name: str, + description: str | None = None, + artifact_id: str | None = None, +) -> Artifact: + """Creates a new Artifact object.""" + return Artifact( + artifact_id=artifact_id or str(uuid.uuid4()), + parts=parts, + name=name, + description=description, + ) + + +def new_text_artifact( + name: str, + text: str, + description: str | None = None, + artifact_id: str | None = None, +) -> Artifact: + """Creates a new Artifact object containing only a single text Part.""" + return new_artifact( + [Part(text=text)], + name, + description, + artifact_id=artifact_id, + ) + + +def get_artifact_text(artifact: Artifact, delimiter: str = '\n') -> str: + """Extracts and joins all text content from an Artifact's parts.""" + return delimiter.join(get_text_parts(artifact.parts)) + + +# --- Task Helpers --- + + +def new_task_from_user_message(user_message: Message) -> Task: + """Creates a new Task object from an initial user message.""" + if user_message.role != Role.ROLE_USER: + raise ValueError('Message must be from a user') + if not user_message.parts: + raise ValueError('Message parts cannot be empty') + for part in user_message.parts: + if part.HasField('text') and not part.text: + raise ValueError('Message.text cannot be empty') + + return Task( + status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED), + id=user_message.task_id or str(uuid.uuid4()), + context_id=user_message.context_id or str(uuid.uuid4()), + history=[user_message], + ) + + +def new_task( + task_id: str, + context_id: str, + state: TaskState, + artifacts: list[Artifact] | None = None, + history: list[Message] | None = None, +) -> Task: + """Creates a Task object with a specified status.""" + if history is None: + history = [] + if artifacts is None: + artifacts = [] + + return Task( + status=TaskStatus(state=state), + id=task_id, + context_id=context_id, + artifacts=artifacts, + history=history, + ) + + +# --- Part Helpers --- + + +def get_text_parts(parts: Sequence[Part]) -> list[str]: + """Extracts text content from all text Parts.""" + return [part.text for part in parts if part.HasField('text')] + + +# --- Event & Stream Helpers --- + + +def new_text_status_update_event( + task_id: str, + context_id: str, + state: TaskState, + text: str, +) -> TaskStatusUpdateEvent: + """Creates a TaskStatusUpdateEvent with a single text message.""" + return TaskStatusUpdateEvent( + task_id=task_id, + context_id=context_id, + status=TaskStatus( + state=state, + message=new_text_message( + text=text, + role=Role.ROLE_AGENT, + context_id=context_id, + task_id=task_id, + ), + ), + ) + + +def new_text_artifact_update_event( # noqa: PLR0913 + task_id: str, + context_id: str, + name: str, + text: str, + append: bool = False, + last_chunk: bool = False, + artifact_id: str | None = None, +) -> TaskArtifactUpdateEvent: + """Creates a TaskArtifactUpdateEvent with a single text artifact.""" + return TaskArtifactUpdateEvent( + task_id=task_id, + context_id=context_id, + artifact=new_text_artifact( + name=name, text=text, artifact_id=artifact_id + ), + append=append, + last_chunk=last_chunk, + ) + + +def get_stream_response_text( + response: StreamResponse, delimiter: str = '\n' +) -> str: + """Extracts text content from a StreamResponse.""" + if response.HasField('message'): + return get_message_text(response.message, delimiter) + if response.HasField('task'): + texts = [ + get_artifact_text(a, delimiter) for a in response.task.artifacts + ] + return delimiter.join(t for t in texts if t) + if response.HasField('status_update'): + if response.status_update.status.HasField('message'): + return get_message_text( + response.status_update.status.message, delimiter + ) + return '' + if response.HasField('artifact_update'): + return get_artifact_text(response.artifact_update.artifact, delimiter) + return '' diff --git a/src/a2a/server/agent_execution/context.py b/src/a2a/server/agent_execution/context.py index 1feefb1df..8b78c1045 100644 --- a/src/a2a/server/agent_execution/context.py +++ b/src/a2a/server/agent_execution/context.py @@ -1,5 +1,6 @@ from typing import Any +from a2a.helpers.proto_helpers import get_message_text from a2a.server.context import ServerCallContext from a2a.server.id_generator import ( IDGenerator, @@ -12,7 +13,6 @@ SendMessageRequest, Task, ) -from a2a.utils import get_message_text from a2a.utils.errors import InvalidParamsError diff --git a/src/a2a/server/tasks/task_manager.py b/src/a2a/server/tasks/task_manager.py index 143413d5b..e5d899c1e 100644 --- a/src/a2a/server/tasks/task_manager.py +++ b/src/a2a/server/tasks/task_manager.py @@ -4,6 +4,7 @@ from a2a.server.events.event_queue import Event from a2a.server.tasks.task_store import TaskStore from a2a.types.a2a_pb2 import ( + Artifact, Message, Task, TaskArtifactUpdateEvent, @@ -11,13 +12,77 @@ TaskStatus, TaskStatusUpdateEvent, ) -from a2a.utils import append_artifact_to_task from a2a.utils.errors import InvalidParamsError +from a2a.utils.telemetry import trace_function logger = logging.getLogger(__name__) +@trace_function() +def append_artifact_to_task(task: Task, event: TaskArtifactUpdateEvent) -> None: + """Helper method for updating a Task object with new artifact data from an event. + + Handles creating the artifacts list if it doesn't exist, adding new artifacts, + and appending parts to existing artifacts based on the `append` flag in the event. + + Args: + task: The `Task` object to modify. + event: The `TaskArtifactUpdateEvent` containing the artifact data. + """ + new_artifact_data: Artifact = event.artifact + artifact_id: str = new_artifact_data.artifact_id + append_parts: bool = event.append + + existing_artifact: Artifact | None = None + existing_artifact_list_index: int | None = None + + # Find existing artifact by its id + for i, art in enumerate(task.artifacts): + if art.artifact_id == artifact_id: + existing_artifact = art + existing_artifact_list_index = i + break + + if not append_parts: + # This represents the first chunk for this artifact index. + if existing_artifact_list_index is not None: + # Replace the existing artifact entirely with the new data + logger.debug( + 'Replacing artifact at id %s for task %s', artifact_id, task.id + ) + task.artifacts[existing_artifact_list_index].CopyFrom( + new_artifact_data + ) + else: + # Append the new artifact since no artifact with this index exists yet + logger.debug( + 'Adding new artifact with id %s for task %s', + artifact_id, + task.id, + ) + task.artifacts.append(new_artifact_data) + elif existing_artifact: + # Append new parts to the existing artifact's part list + logger.debug( + 'Appending parts to artifact id %s for task %s', + artifact_id, + task.id, + ) + existing_artifact.parts.extend(new_artifact_data.parts) + existing_artifact.metadata.update( + dict(new_artifact_data.metadata.items()) + ) + else: + # We received a chunk to append, but we don't have an existing artifact. + # we will ignore this chunk + logger.warning( + 'Received append=True for nonexistent artifact index %s in task %s. Ignoring chunk.', + artifact_id, + task.id, + ) + + class TaskManager: """Helps manage a task's lifecycle during execution of a request. diff --git a/src/a2a/utils/__init__.py b/src/a2a/utils/__init__.py index 1efed5794..04693dd0b 100644 --- a/src/a2a/utils/__init__.py +++ b/src/a2a/utils/__init__.py @@ -1,62 +1,18 @@ """Utility functions for the A2A Python SDK.""" from a2a.utils import proto_utils -from a2a.utils.agent_card import display_agent_card -from a2a.utils.artifact import ( - get_artifact_text, - new_artifact, - new_data_artifact, - new_text_artifact, -) from a2a.utils.constants import ( AGENT_CARD_WELL_KNOWN_PATH, DEFAULT_RPC_URL, TransportProtocol, ) -from a2a.utils.helpers import ( - append_artifact_to_task, - are_modalities_compatible, - build_text_artifact, - create_task_obj, -) -from a2a.utils.message import ( - get_message_text, - new_agent_parts_message, - new_agent_text_message, -) -from a2a.utils.parts import ( - get_data_parts, - get_file_parts, - get_text_parts, -) from a2a.utils.proto_utils import to_stream_response -from a2a.utils.task import ( - completed_task, - new_task, -) __all__ = [ 'AGENT_CARD_WELL_KNOWN_PATH', 'DEFAULT_RPC_URL', 'TransportProtocol', - 'append_artifact_to_task', - 'are_modalities_compatible', - 'build_text_artifact', - 'completed_task', - 'create_task_obj', - 'display_agent_card', - 'get_artifact_text', - 'get_data_parts', - 'get_file_parts', - 'get_message_text', - 'get_text_parts', - 'new_agent_parts_message', - 'new_agent_text_message', - 'new_artifact', - 'new_data_artifact', - 'new_task', - 'new_text_artifact', 'proto_utils', 'to_stream_response', ] diff --git a/src/a2a/utils/artifact.py b/src/a2a/utils/artifact.py deleted file mode 100644 index ac14087dc..000000000 --- a/src/a2a/utils/artifact.py +++ /dev/null @@ -1,92 +0,0 @@ -"""Utility functions for creating A2A Artifact objects.""" - -import uuid - -from typing import Any - -from google.protobuf.struct_pb2 import Struct, Value - -from a2a.types.a2a_pb2 import Artifact, Part -from a2a.utils.parts import get_text_parts - - -def new_artifact( - parts: list[Part], - name: str, - description: str | None = None, -) -> Artifact: - """Creates a new Artifact object. - - Args: - parts: The list of `Part` objects forming the artifact's content. - name: The human-readable name of the artifact. - description: An optional description of the artifact. - - Returns: - A new `Artifact` object with a generated artifact_id. - """ - return Artifact( - artifact_id=str(uuid.uuid4()), - parts=parts, - name=name, - description=description, - ) - - -def new_text_artifact( - name: str, - text: str, - description: str | None = None, -) -> Artifact: - """Creates a new Artifact object containing only a single text Part. - - Args: - name: The human-readable name of the artifact. - text: The text content of the artifact. - description: An optional description of the artifact. - - Returns: - A new `Artifact` object with a generated artifact_id. - """ - return new_artifact( - [Part(text=text)], - name, - description, - ) - - -def new_data_artifact( - name: str, - data: dict[str, Any], - description: str | None = None, -) -> Artifact: - """Creates a new Artifact object containing only a single data Part. - - Args: - name: The human-readable name of the artifact. - data: The structured data content of the artifact. - description: An optional description of the artifact. - - Returns: - A new `Artifact` object with a generated artifact_id. - """ - struct_data = Struct() - struct_data.update(data) - return new_artifact( - [Part(data=Value(struct_value=struct_data))], - name, - description, - ) - - -def get_artifact_text(artifact: Artifact, delimiter: str = '\n') -> str: - """Extracts and joins all text content from an Artifact's parts. - - Args: - artifact: The `Artifact` object. - delimiter: The string to use when joining text from multiple TextParts. - - Returns: - A single string containing all text content, or an empty string if no text parts are found. - """ - return delimiter.join(get_text_parts(artifact.parts)) diff --git a/src/a2a/utils/helpers.py b/src/a2a/utils/helpers.py index fe69bf26d..9a974a4c2 100644 --- a/src/a2a/utils/helpers.py +++ b/src/a2a/utils/helpers.py @@ -2,30 +2,16 @@ import functools import inspect -import json import logging from collections.abc import AsyncIterator, Awaitable, Callable from typing import Any, TypeVar, cast -from uuid import uuid4 -from google.protobuf.json_format import MessageToDict from packaging.version import InvalidVersion, Version from a2a.server.context import ServerCallContext -from a2a.types.a2a_pb2 import ( - AgentCard, - Artifact, - Part, - SendMessageRequest, - Task, - TaskArtifactUpdateEvent, - TaskState, - TaskStatus, -) from a2a.utils import constants from a2a.utils.errors import VersionNotSupportedError -from a2a.utils.telemetry import trace_function T = TypeVar('T') @@ -35,168 +21,6 @@ logger = logging.getLogger(__name__) -@trace_function() -def create_task_obj(message_send_params: SendMessageRequest) -> Task: - """Create a new task object from message send params. - - Generates UUIDs for task and context IDs if they are not already present in the message. - - Args: - message_send_params: The `SendMessageRequest` object containing the initial message. - - Returns: - A new `Task` object initialized with 'submitted' status and the input message in history. - """ - if not message_send_params.message.context_id: - message_send_params.message.context_id = str(uuid4()) - - task = Task( - id=str(uuid4()), - context_id=message_send_params.message.context_id, - status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED), - ) - task.history.append(message_send_params.message) - return task - - -@trace_function() -def append_artifact_to_task(task: Task, event: TaskArtifactUpdateEvent) -> None: - """Helper method for updating a Task object with new artifact data from an event. - - Handles creating the artifacts list if it doesn't exist, adding new artifacts, - and appending parts to existing artifacts based on the `append` flag in the event. - - Args: - task: The `Task` object to modify. - event: The `TaskArtifactUpdateEvent` containing the artifact data. - """ - new_artifact_data: Artifact = event.artifact - artifact_id: str = new_artifact_data.artifact_id - append_parts: bool = event.append - - existing_artifact: Artifact | None = None - existing_artifact_list_index: int | None = None - - # Find existing artifact by its id - for i, art in enumerate(task.artifacts): - if art.artifact_id == artifact_id: - existing_artifact = art - existing_artifact_list_index = i - break - - if not append_parts: - # This represents the first chunk for this artifact index. - if existing_artifact_list_index is not None: - # Replace the existing artifact entirely with the new data - logger.debug( - 'Replacing artifact at id %s for task %s', artifact_id, task.id - ) - task.artifacts[existing_artifact_list_index].CopyFrom( - new_artifact_data - ) - else: - # Append the new artifact since no artifact with this index exists yet - logger.debug( - 'Adding new artifact with id %s for task %s', - artifact_id, - task.id, - ) - task.artifacts.append(new_artifact_data) - elif existing_artifact: - # Append new parts to the existing artifact's part list - logger.debug( - 'Appending parts to artifact id %s for task %s', - artifact_id, - task.id, - ) - existing_artifact.parts.extend(new_artifact_data.parts) - existing_artifact.metadata.update( - dict(new_artifact_data.metadata.items()) - ) - else: - # We received a chunk to append, but we don't have an existing artifact. - # we will ignore this chunk - logger.warning( - 'Received append=True for nonexistent artifact index %s in task %s. Ignoring chunk.', - artifact_id, - task.id, - ) - - -def build_text_artifact(text: str, artifact_id: str) -> Artifact: - """Helper to create a text artifact. - - Args: - text: The text content for the artifact. - artifact_id: The ID for the artifact. - - Returns: - An `Artifact` object containing a single text Part. - """ - part = Part(text=text) - return Artifact(parts=[part], artifact_id=artifact_id) - - -def are_modalities_compatible( - server_output_modes: list[str] | None, client_output_modes: list[str] | None -) -> bool: - """Checks if server and client output modalities (MIME types) are compatible. - - Modalities are compatible if: - 1. The client specifies no preferred output modes (client_output_modes is None or empty). - 2. The server specifies no supported output modes (server_output_modes is None or empty). - 3. There is at least one common modality between the server's supported list and the client's preferred list. - - Args: - server_output_modes: A list of MIME types supported by the server/agent for output. - Can be None or empty if the server doesn't specify. - client_output_modes: A list of MIME types preferred by the client for output. - Can be None or empty if the client accepts any. - - Returns: - True if the modalities are compatible, False otherwise. - """ - if client_output_modes is None or len(client_output_modes) == 0: - return True - - if server_output_modes is None or len(server_output_modes) == 0: - return True - - return any(x in server_output_modes for x in client_output_modes) - - -def _clean_empty(d: Any) -> Any: - """Recursively remove empty strings, lists and dicts from a dictionary.""" - if isinstance(d, dict): - cleaned_dict = { - k: cleaned_v - for k, v in d.items() - if (cleaned_v := _clean_empty(v)) is not None - } - return cleaned_dict or None - if isinstance(d, list): - cleaned_list = [ - cleaned_v for v in d if (cleaned_v := _clean_empty(v)) is not None - ] - return cleaned_list or None - if isinstance(d, str) and not d: - return None - return d - - -def canonicalize_agent_card(agent_card: AgentCard) -> str: - """Canonicalizes the Agent Card JSON according to RFC 8785 (JCS).""" - card_dict = MessageToDict( - agent_card, - ) - # Remove signatures field if present - card_dict.pop('signatures', None) - - # Recursively remove empty values - cleaned_dict = _clean_empty(card_dict) - return json.dumps(cleaned_dict, separators=(',', ':'), sort_keys=True) - - async def maybe_await(value: T | Awaitable[T]) -> T: """Awaits a value if it's awaitable, otherwise simply provides it back.""" if inspect.isawaitable(value): diff --git a/src/a2a/utils/message.py b/src/a2a/utils/message.py deleted file mode 100644 index 528d952f4..000000000 --- a/src/a2a/utils/message.py +++ /dev/null @@ -1,71 +0,0 @@ -"""Utility functions for creating and handling A2A Message objects.""" - -import uuid - -from a2a.types.a2a_pb2 import ( - Message, - Part, - Role, -) -from a2a.utils.parts import get_text_parts - - -def new_agent_text_message( - text: str, - context_id: str | None = None, - task_id: str | None = None, -) -> Message: - """Creates a new agent message containing a single text Part. - - Args: - text: The text content of the message. - context_id: The context ID for the message. - task_id: The task ID for the message. - - Returns: - A new `Message` object with role 'agent'. - """ - return Message( - role=Role.ROLE_AGENT, - parts=[Part(text=text)], - message_id=str(uuid.uuid4()), - task_id=task_id, - context_id=context_id, - ) - - -def new_agent_parts_message( - parts: list[Part], - context_id: str | None = None, - task_id: str | None = None, -) -> Message: - """Creates a new agent message containing a list of Parts. - - Args: - parts: The list of `Part` objects for the message content. - context_id: The context ID for the message. - task_id: The task ID for the message. - - Returns: - A new `Message` object with role 'agent'. - """ - return Message( - role=Role.ROLE_AGENT, - parts=parts, - message_id=str(uuid.uuid4()), - task_id=task_id, - context_id=context_id, - ) - - -def get_message_text(message: Message, delimiter: str = '\n') -> str: - """Extracts and joins all text content from a Message's parts. - - Args: - message: The `Message` object. - delimiter: The string to use when joining text from multiple text Parts. - - Returns: - A single string containing all text content, or an empty string if no text parts are found. - """ - return delimiter.join(get_text_parts(message.parts)) diff --git a/src/a2a/utils/parts.py b/src/a2a/utils/parts.py deleted file mode 100644 index c9b964540..000000000 --- a/src/a2a/utils/parts.py +++ /dev/null @@ -1,46 +0,0 @@ -"""Utility functions for creating and handling A2A Parts objects.""" - -from collections.abc import Sequence -from typing import Any - -from google.protobuf.json_format import MessageToDict - -from a2a.types.a2a_pb2 import ( - Part, -) - - -def get_text_parts(parts: Sequence[Part]) -> list[str]: - """Extracts text content from all text Parts. - - Args: - parts: A sequence of `Part` objects. - - Returns: - A list of strings containing the text content from any text Parts found. - """ - return [part.text for part in parts if part.HasField('text')] - - -def get_data_parts(parts: Sequence[Part]) -> list[Any]: - """Extracts data from all data Parts in a list of Parts. - - Args: - parts: A sequence of `Part` objects. - - Returns: - A list of values containing the data from any data Parts found. - """ - return [MessageToDict(part.data) for part in parts if part.HasField('data')] - - -def get_file_parts(parts: Sequence[Part]) -> list[Part]: - """Extracts file parts from a list of Parts. - - Args: - parts: A sequence of `Part` objects. - - Returns: - A list of `Part` objects containing file data (raw or url). - """ - return [part for part in parts if part.raw or part.url] diff --git a/src/a2a/utils/signing.py b/src/a2a/utils/signing.py index 68924c8a0..aa720d159 100644 --- a/src/a2a/utils/signing.py +++ b/src/a2a/utils/signing.py @@ -3,7 +3,7 @@ from collections.abc import Callable from typing import Any, TypedDict -from a2a.utils.helpers import canonicalize_agent_card +from google.protobuf.json_format import MessageToDict try: @@ -68,7 +68,7 @@ def create_agent_card_signer( def agent_card_signer(agent_card: AgentCard) -> AgentCard: """Signs agent card.""" - canonical_payload = canonicalize_agent_card(agent_card) + canonical_payload = _canonicalize_agent_card(agent_card) payload_dict = json.loads(canonical_payload) jws_string = jwt.encode( @@ -128,7 +128,7 @@ def signature_verifier( jku = protected_header.get('jku') verification_key = key_provider(kid, jku) - canonical_payload = canonicalize_agent_card(agent_card) + canonical_payload = _canonicalize_agent_card(agent_card) encoded_payload = base64url_encode( canonical_payload.encode('utf-8') ).decode('utf-8') @@ -148,3 +148,35 @@ def signature_verifier( raise InvalidSignaturesError('No valid signature found') return signature_verifier + + +def _clean_empty(d: Any) -> Any: + """Recursively remove empty strings, lists and dicts from a dictionary.""" + if isinstance(d, dict): + cleaned_dict = { + k: cleaned_v + for k, v in d.items() + if (cleaned_v := _clean_empty(v)) is not None + } + return cleaned_dict or None + if isinstance(d, list): + cleaned_list = [ + cleaned_v for v in d if (cleaned_v := _clean_empty(v)) is not None + ] + return cleaned_list or None + if isinstance(d, str) and not d: + return None + return d + + +def _canonicalize_agent_card(agent_card: AgentCard) -> str: + """Canonicalizes the Agent Card JSON according to RFC 8785 (JCS).""" + card_dict = MessageToDict( + agent_card, + ) + # Remove signatures field if present + card_dict.pop('signatures', None) + + # Recursively remove empty values + cleaned_dict = _clean_empty(card_dict) + return json.dumps(cleaned_dict, separators=(',', ':'), sort_keys=True) diff --git a/src/a2a/utils/task.py b/src/a2a/utils/task.py index 6ff716a30..4acf54e46 100644 --- a/src/a2a/utils/task.py +++ b/src/a2a/utils/task.py @@ -1,89 +1,15 @@ """Utility functions for creating A2A Task objects.""" import binascii -import uuid from base64 import b64decode, b64encode from typing import Literal, Protocol, runtime_checkable -from a2a.types.a2a_pb2 import ( - Artifact, - Message, - Task, - TaskState, - TaskStatus, -) +from a2a.types.a2a_pb2 import Task from a2a.utils.constants import MAX_LIST_TASKS_PAGE_SIZE from a2a.utils.errors import InvalidParamsError -def new_task(request: Message) -> Task: - """Creates a new Task object from an initial user message. - - Generates task and context IDs if not provided in the message. - - Args: - request: The initial `Message` object from the user. - - Returns: - A new `Task` object initialized with 'submitted' status and the input message in history. - - Raises: - TypeError: If the message role is None. - ValueError: If the message parts are empty, if any part has empty content, or if the provided context_id is invalid. - """ - if not request.role: - raise TypeError('Message role cannot be None') - if not request.parts: - raise ValueError('Message parts cannot be empty') - for part in request.parts: - if part.HasField('text') and not part.text: - raise ValueError('Message.text cannot be empty') - - return Task( - status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED), - id=request.task_id or str(uuid.uuid4()), - context_id=request.context_id or str(uuid.uuid4()), - history=[request], - ) - - -def completed_task( - task_id: str, - context_id: str, - artifacts: list[Artifact], - history: list[Message] | None = None, -) -> Task: - """Creates a Task object in the 'completed' state. - - Useful for constructing a final Task representation when the agent - finishes and produces artifacts. - - Args: - task_id: The ID of the task. - context_id: The context ID of the task. - artifacts: A list of `Artifact` objects produced by the task. - history: An optional list of `Message` objects representing the task history. - - Returns: - A `Task` object with status set to 'completed'. - """ - if not artifacts or not all(isinstance(a, Artifact) for a in artifacts): - raise ValueError( - 'artifacts must be a non-empty list of Artifact objects' - ) - - if history is None: - history = [] - return Task( - status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), - id=task_id, - context_id=context_id, - artifacts=artifacts, - history=history, - ) - - @runtime_checkable class HistoryLengthConfig(Protocol): """Protocol for configuration arguments containing history_length field.""" diff --git a/tests/client/test_client_helpers.py b/tests/client/test_client_helpers.py index 8963eefce..0eb394f43 100644 --- a/tests/client/test_client_helpers.py +++ b/tests/client/test_client_helpers.py @@ -3,7 +3,8 @@ import json from google.protobuf.json_format import MessageToDict -from a2a.client.helpers import create_text_message_object, parse_agent_card +from a2a.client.helpers import parse_agent_card +from a2a.helpers.proto_helpers import new_text_message from a2a.server.request_handlers.response_helpers import agent_card_to_dict from a2a.types.a2a_pb2 import ( APIKeySecurityScheme, @@ -263,7 +264,7 @@ def test_parse_agent_card_security_scheme_unknown_type() -> None: def test_create_text_message_object() -> None: - msg = create_text_message_object(role=Role.ROLE_AGENT, content='Hello') + msg = new_text_message(text='Hello', role=Role.ROLE_AGENT) assert msg.role == Role.ROLE_AGENT assert len(msg.parts) == 1 assert msg.parts[0].text == 'Hello' diff --git a/tests/client/transports/test_grpc_client.py b/tests/client/transports/test_grpc_client.py index 9e81bd71e..95cca9189 100644 --- a/tests/client/transports/test_grpc_client.py +++ b/tests/client/transports/test_grpc_client.py @@ -35,7 +35,7 @@ TaskStatus, TaskStatusUpdateEvent, ) -from a2a.utils import get_text_parts +from a2a.helpers.proto_helpers import get_text_parts @pytest.fixture diff --git a/tests/client/transports/test_rest_client.py b/tests/client/transports/test_rest_client.py index e7912566e..0c9f7c30a 100644 --- a/tests/client/transports/test_rest_client.py +++ b/tests/client/transports/test_rest_client.py @@ -8,7 +8,7 @@ from google.protobuf.timestamp_pb2 import Timestamp from httpx_sse import EventSource, ServerSentEvent -from a2a.client import create_text_message_object +from a2a.helpers.proto_helpers import new_text_message from a2a.client.client import ClientCallContext from a2a.client.errors import A2AClientError from a2a.client.transports.rest import RestTransport @@ -83,7 +83,7 @@ async def test_send_message_streaming_timeout( url='http://agent.example.com/api', ) params = SendMessageRequest( - message=create_text_message_object(content='Hello stream') + message=new_text_message(text='Hello stream') ) mock_event_source = AsyncMock(spec=EventSource) mock_event_source.response = MagicMock(spec=httpx.Response) @@ -120,9 +120,7 @@ async def test_rest_mapped_errors( agent_card=mock_agent_card, url='http://agent.example.com/api', ) - params = SendMessageRequest( - message=create_text_message_object(content='Hello') - ) + params = SendMessageRequest(message=new_text_message(text='Hello')) mock_build_request = MagicMock( return_value=AsyncMock(spec=httpx.Request) @@ -172,9 +170,7 @@ async def test_send_message_with_timeout_context( agent_card=mock_agent_card, url='http://agent.example.com/api', ) - params = SendMessageRequest( - message=create_text_message_object(content='Hello') - ) + params = SendMessageRequest(message=new_text_message(text='Hello')) context = ClientCallContext(timeout=10.0) mock_build_request = MagicMock( @@ -246,9 +242,7 @@ async def test_send_message_with_default_extensions( agent_card=mock_agent_card, url='http://agent.example.com/api', ) - params = SendMessageRequest( - message=create_text_message_object(content='Hello') - ) + params = SendMessageRequest(message=new_text_message(text='Hello')) # Mock the build_request method to capture its inputs mock_build_request = MagicMock( @@ -294,7 +288,7 @@ async def test_send_message_streaming_with_new_extensions( url='http://agent.example.com/api', ) params = SendMessageRequest( - message=create_text_message_object(content='Hello stream') + message=new_text_message(text='Hello stream') ) mock_event_source = AsyncMock(spec=EventSource) @@ -343,7 +337,7 @@ async def test_send_message_streaming_server_error_propagates( url='http://agent.example.com/api', ) request = SendMessageRequest( - message=create_text_message_object(content='Error stream') + message=new_text_message(text='Error stream') ) mock_event_source = AsyncMock(spec=EventSource) @@ -524,7 +518,7 @@ class TestRestTransportTenant: 'send_message', SendMessageRequest( tenant='my-tenant', - message=create_text_message_object(content='hi'), + message=new_text_message(text='hi'), ), '/my-tenant/message:send', ), @@ -686,7 +680,7 @@ async def test_rest_get_task_prepend_empty_tenant( 'send_message_streaming', SendMessageRequest( tenant='my-tenant', - message=create_text_message_object(content='hi'), + message=new_text_message(text='hi'), ), '/my-tenant/message:stream', ), diff --git a/tests/e2e/push_notifications/agent_app.py b/tests/e2e/push_notifications/agent_app.py index 106a97cea..bc95f6c37 100644 --- a/tests/e2e/push_notifications/agent_app.py +++ b/tests/e2e/push_notifications/agent_app.py @@ -24,9 +24,9 @@ Message, Task, ) -from a2a.utils import ( - new_agent_text_message, - new_task, +from a2a.helpers.proto_helpers import ( + new_text_message, + new_task_from_user_message, ) @@ -74,7 +74,7 @@ async def invoke( or not msg.parts[0].HasField('text') ): await updater.failed( - new_agent_text_message( + new_text_message( 'Unsupported message.', task.context_id, task.id ) ) @@ -84,25 +84,23 @@ async def invoke( # Simple request-response flow. if text_message == 'Hello Agent!': await updater.complete( - new_agent_text_message('Hello User!', task.context_id, task.id) + new_text_message('Hello User!', task.context_id, task.id) ) # Flow with user input required: "How are you?" -> "Good! How are you?" -> "Good" -> "Amazing". elif text_message == 'How are you?': await updater.requires_input( - new_agent_text_message( - 'Good! How are you?', task.context_id, task.id - ) + new_text_message('Good! How are you?', task.context_id, task.id) ) elif text_message == 'Good': await updater.complete( - new_agent_text_message('Amazing', task.context_id, task.id) + new_text_message('Amazing', task.context_id, task.id) ) # Fail for unsupported messages. else: await updater.failed( - new_agent_text_message( + new_text_message( 'Unsupported message.', task.context_id, task.id ) ) @@ -124,7 +122,7 @@ async def execute( task = context.current_task if not task: - task = new_task(context.message) + task = new_task_from_user_message(context.message) await event_queue.enqueue_event(task) updater = TaskUpdater(event_queue, task.id, task.context_id) diff --git a/tests/utils/test_agent_card_display.py b/tests/helpers/test_agent_card_display.py similarity index 99% rename from tests/utils/test_agent_card_display.py rename to tests/helpers/test_agent_card_display.py index 93dc1aad4..e252a52fe 100644 --- a/tests/utils/test_agent_card_display.py +++ b/tests/helpers/test_agent_card_display.py @@ -9,7 +9,7 @@ AgentProvider, AgentSkill, ) -from a2a.utils.agent_card import display_agent_card +from a2a.helpers.agent_card import display_agent_card @pytest.fixture diff --git a/tests/helpers/test_proto_helpers.py b/tests/helpers/test_proto_helpers.py new file mode 100644 index 000000000..a4f6498ab --- /dev/null +++ b/tests/helpers/test_proto_helpers.py @@ -0,0 +1,230 @@ +"""Tests for proto helpers.""" + +import pytest +from a2a.helpers.proto_helpers import ( + new_message, + new_text_message, + get_message_text, + new_artifact, + new_text_artifact, + get_artifact_text, + new_task_from_user_message, + new_task, + get_text_parts, + new_text_status_update_event, + new_text_artifact_update_event, + get_stream_response_text, +) +from a2a.types.a2a_pb2 import ( + Part, + Role, + Message, + Artifact, + Task, + TaskState, + StreamResponse, +) + +# --- Message Helpers Tests --- + + +def test_new_message() -> None: + parts = [Part(text='hello')] + msg = new_message( + parts=parts, role=Role.ROLE_USER, context_id='ctx1', task_id='task1' + ) + assert msg.role == Role.ROLE_USER + assert msg.parts == parts + assert msg.context_id == 'ctx1' + assert msg.task_id == 'task1' + assert msg.message_id != '' + + +def test_new_text_message() -> None: + msg = new_text_message( + text='hello', context_id='ctx1', task_id='task1', role=Role.ROLE_USER + ) + assert msg.role == Role.ROLE_USER + assert len(msg.parts) == 1 + assert msg.parts[0].text == 'hello' + assert msg.context_id == 'ctx1' + assert msg.task_id == 'task1' + assert msg.message_id != '' + + +def test_get_message_text() -> None: + msg = Message(parts=[Part(text='hello'), Part(text='world')]) + assert get_message_text(msg) == 'hello\nworld' + assert get_message_text(msg, delimiter=' ') == 'hello world' + + +# --- Artifact Helpers Tests --- + + +def test_new_artifact() -> None: + parts = [Part(text='content')] + art = new_artifact(parts=parts, name='test', description='desc') + assert art.name == 'test' + assert art.description == 'desc' + assert art.parts == parts + assert art.artifact_id != '' + + +def test_new_text_artifact() -> None: + art = new_text_artifact(name='test', text='content', description='desc') + assert art.name == 'test' + assert art.description == 'desc' + assert len(art.parts) == 1 + assert art.parts[0].text == 'content' + assert art.artifact_id != '' + + +def test_new_text_artifact_with_id() -> None: + art = new_text_artifact( + name='test', text='content', description='desc', artifact_id='art1' + ) + assert art.name == 'test' + assert art.description == 'desc' + assert len(art.parts) == 1 + assert art.parts[0].text == 'content' + assert art.artifact_id == 'art1' + + +def test_get_artifact_text() -> None: + art = Artifact(parts=[Part(text='hello'), Part(text='world')]) + assert get_artifact_text(art) == 'hello\nworld' + assert get_artifact_text(art, delimiter=' ') == 'hello world' + + +# --- Task Helpers Tests --- + + +def test_new_task_from_user_message() -> None: + msg = Message( + role=Role.ROLE_USER, + parts=[Part(text='hello')], + task_id='task1', + context_id='ctx1', + ) + task = new_task_from_user_message(msg) + assert task.id == 'task1' + assert task.context_id == 'ctx1' + assert task.status.state == TaskState.TASK_STATE_SUBMITTED + assert len(task.history) == 1 + assert task.history[0] == msg + + +def test_new_task_from_user_message_empty_parts() -> None: + msg = Message(role=Role.ROLE_USER, parts=[]) + with pytest.raises(ValueError, match='Message parts cannot be empty'): + new_task_from_user_message(msg) + + +def test_new_task_from_user_message_empty_text() -> None: + msg = Message(role=Role.ROLE_USER, parts=[Part(text='')]) + with pytest.raises(ValueError, match='Message.text cannot be empty'): + new_task_from_user_message(msg) + + +def test_new_task() -> None: + task = new_task( + task_id='task1', context_id='ctx1', state=TaskState.TASK_STATE_WORKING + ) + assert task.id == 'task1' + assert task.context_id == 'ctx1' + assert task.status.state == TaskState.TASK_STATE_WORKING + assert len(task.history) == 0 + assert len(task.artifacts) == 0 + + +# --- Part Helpers Tests --- + + +def test_get_text_parts() -> None: + parts = [ + Part(text='hello'), + Part(url='http://example.com'), + Part(text='world'), + ] + assert get_text_parts(parts) == ['hello', 'world'] + + +# --- Event & Stream Helpers Tests --- + + +def test_new_text_status_update_event() -> None: + event = new_text_status_update_event( + task_id='task1', + context_id='ctx1', + state=TaskState.TASK_STATE_WORKING, + text='progress', + ) + assert event.task_id == 'task1' + assert event.context_id == 'ctx1' + assert event.status.state == TaskState.TASK_STATE_WORKING + assert event.status.message.parts[0].text == 'progress' + + +def test_new_text_artifact_update_event() -> None: + event = new_text_artifact_update_event( + task_id='task1', + context_id='ctx1', + name='test', + text='content', + append=True, + last_chunk=True, + ) + assert event.task_id == 'task1' + assert event.context_id == 'ctx1' + assert event.artifact.name == 'test' + assert event.artifact.parts[0].text == 'content' + assert event.append is True + assert event.last_chunk is True + + +def test_new_text_artifact_update_event_with_id() -> None: + event = new_text_artifact_update_event( + task_id='task1', + context_id='ctx1', + name='test', + text='content', + artifact_id='art1', + ) + assert event.task_id == 'task1' + assert event.context_id == 'ctx1' + assert event.artifact.name == 'test' + assert event.artifact.parts[0].text == 'content' + assert event.artifact.artifact_id == 'art1' + + +def test_get_stream_response_text_message() -> None: + resp = StreamResponse(message=Message(parts=[Part(text='hello')])) + assert get_stream_response_text(resp) == 'hello' + + +def test_get_stream_response_text_task() -> None: + resp = StreamResponse( + task=Task(artifacts=[Artifact(parts=[Part(text='hello')])]) + ) + assert get_stream_response_text(resp) == 'hello' + + +def test_get_stream_response_text_status_update() -> None: + resp = StreamResponse( + status_update=new_text_status_update_event( + 't', 'c', TaskState.TASK_STATE_WORKING, 'hello' + ) + ) + assert get_stream_response_text(resp) == 'hello' + + +def test_get_stream_response_text_artifact_update() -> None: + resp = StreamResponse( + artifact_update=new_text_artifact_update_event('t', 'c', 'n', 'hello') + ) + assert get_stream_response_text(resp) == 'hello' + + +def test_get_stream_response_text_empty() -> None: + resp = StreamResponse() + assert get_stream_response_text(resp) == '' diff --git a/tests/integration/test_end_to_end.py b/tests/integration/test_end_to_end.py index aea9784ad..b6cddbe4d 100644 --- a/tests/integration/test_end_to_end.py +++ b/tests/integration/test_end_to_end.py @@ -43,7 +43,8 @@ TaskState, a2a_pb2_grpc, ) -from a2a.utils import TransportProtocol, new_task +from a2a.utils import TransportProtocol +from a2a.helpers.proto_helpers import new_task_from_user_message from a2a.utils.errors import InvalidParamsError @@ -130,7 +131,7 @@ async def execute(self, context: RequestContext, event_queue: EventQueue): # Task-based response. task = context.current_task if not task: - task = new_task(context.message) + task = new_task_from_user_message(context.message) await event_queue.enqueue_event(task) task_updater = TaskUpdater( diff --git a/tests/server/request_handlers/test_default_request_handler.py b/tests/server/request_handlers/test_default_request_handler.py index 294f5aefe..5a2bf0446 100644 --- a/tests/server/request_handlers/test_default_request_handler.py +++ b/tests/server/request_handlers/test_default_request_handler.py @@ -73,7 +73,10 @@ TaskStatus, TaskStatusUpdateEvent, ) -from a2a.utils import new_agent_text_message, new_task +from a2a.helpers.proto_helpers import ( + new_text_message, + new_task_from_user_message, +) class MockAgentExecutor(AgentExecutor): @@ -254,8 +257,8 @@ async def test_on_list_tasks_applies_history_length(agent_card): """Test on_list_tasks applies history length filter.""" mock_task_store = AsyncMock(spec=TaskStore) history = [ - new_agent_text_message('Hello 1!'), - new_agent_text_message('Hello 2!'), + new_text_message('Hello 1!'), + new_text_message('Hello 2!'), ] task2 = create_sample_task(task_id='task2') task2.history.extend(history) @@ -957,7 +960,7 @@ async def execute(self, context: RequestContext, event_queue: EventQueue): assert context.message is not None, ( 'A message is required to create a new task' ) - task = new_task(context.message) # type: ignore + task = new_task_from_user_message(context.message) # type: ignore await event_queue.enqueue_event(task) updater = TaskUpdater(event_queue, task.id, task.context_id) diff --git a/tests/server/request_handlers/test_default_request_handler_v2.py b/tests/server/request_handlers/test_default_request_handler_v2.py index d48b82461..3e1568b2e 100644 --- a/tests/server/request_handlers/test_default_request_handler_v2.py +++ b/tests/server/request_handlers/test_default_request_handler_v2.py @@ -54,7 +54,10 @@ TaskState, TaskStatus, ) -from a2a.utils import new_agent_text_message, new_task +from a2a.helpers.proto_helpers import ( + new_text_message, + new_task_from_user_message, +) def create_default_agent_card(): @@ -211,8 +214,8 @@ async def test_on_list_tasks_applies_history_length(): """Test on_list_tasks applies history length filter.""" mock_task_store = AsyncMock(spec=TaskStore) history = [ - new_agent_text_message('Hello 1!'), - new_agent_text_message('Hello 2!'), + new_text_message('Hello 1!'), + new_text_message('Hello 2!'), ] task2 = create_sample_task(task_id='task2') task2.history.extend(history) @@ -274,7 +277,7 @@ async def execute(self, context: RequestContext, event_queue: EventQueue): assert context.message is not None, ( 'A message is required to create a new task' ) - task = new_task(context.message) + task = new_task_from_user_message(context.message) await event_queue.enqueue_event(task) updater = TaskUpdater(event_queue, task.id, task.context_id) try: diff --git a/tests/utils/test_artifact.py b/tests/utils/test_artifact.py deleted file mode 100644 index cbe8e9c91..000000000 --- a/tests/utils/test_artifact.py +++ /dev/null @@ -1,161 +0,0 @@ -import unittest -import uuid - -from unittest.mock import patch - -from google.protobuf.struct_pb2 import Struct - -from a2a.types.a2a_pb2 import ( - Artifact, - Part, -) -from a2a.utils.artifact import ( - get_artifact_text, - new_artifact, - new_data_artifact, - new_text_artifact, -) - - -class TestArtifact(unittest.TestCase): - @patch('uuid.uuid4') - def test_new_artifact_generates_id(self, mock_uuid4): - mock_uuid = uuid.UUID('abcdef12-1234-5678-1234-567812345678') - mock_uuid4.return_value = mock_uuid - artifact = new_artifact(parts=[], name='test_artifact') - self.assertEqual(artifact.artifact_id, str(mock_uuid)) - - def test_new_artifact_assigns_parts_name_description(self): - parts = [Part(text='Sample text')] - name = 'My Artifact' - description = 'This is a test artifact.' - artifact = new_artifact(parts=parts, name=name, description=description) - assert len(artifact.parts) == len(parts) - self.assertEqual(artifact.name, name) - self.assertEqual(artifact.description, description) - - def test_new_artifact_empty_description_if_not_provided(self): - parts = [Part(text='Another sample')] - name = 'Artifact_No_Desc' - artifact = new_artifact(parts=parts, name=name) - self.assertEqual(artifact.description, '') - - def test_new_text_artifact_creates_single_text_part(self): - text = 'This is a text artifact.' - name = 'Text_Artifact' - artifact = new_text_artifact(text=text, name=name) - self.assertEqual(len(artifact.parts), 1) - self.assertTrue(artifact.parts[0].HasField('text')) - - def test_new_text_artifact_part_contains_provided_text(self): - text = 'Hello, world!' - name = 'Greeting_Artifact' - artifact = new_text_artifact(text=text, name=name) - self.assertEqual(artifact.parts[0].text, text) - - def test_new_text_artifact_assigns_name_description(self): - text = 'Some content.' - name = 'Named_Text_Artifact' - description = 'Description for text artifact.' - artifact = new_text_artifact( - text=text, name=name, description=description - ) - self.assertEqual(artifact.name, name) - self.assertEqual(artifact.description, description) - - def test_new_data_artifact_creates_single_data_part(self): - sample_data = {'key': 'value', 'number': 123} - name = 'Data_Artifact' - artifact = new_data_artifact(data=sample_data, name=name) - self.assertEqual(len(artifact.parts), 1) - self.assertTrue(artifact.parts[0].HasField('data')) - - def test_new_data_artifact_part_contains_provided_data(self): - sample_data = {'content': 'test_data', 'is_valid': True} - name = 'Structured_Data_Artifact' - artifact = new_data_artifact(data=sample_data, name=name) - self.assertTrue(artifact.parts[0].HasField('data')) - # Compare via MessageToDict for proto Struct - from google.protobuf.json_format import MessageToDict - - self.assertEqual(MessageToDict(artifact.parts[0].data), sample_data) - - def test_new_data_artifact_assigns_name_description(self): - sample_data = {'info': 'some details'} - name = 'Named_Data_Artifact' - description = 'Description for data artifact.' - artifact = new_data_artifact( - data=sample_data, name=name, description=description - ) - self.assertEqual(artifact.name, name) - self.assertEqual(artifact.description, description) - - -class TestGetArtifactText(unittest.TestCase): - def test_get_artifact_text_single_part(self): - # Setup - artifact = Artifact( - name='test-artifact', - parts=[Part(text='Hello world')], - artifact_id='test-artifact-id', - ) - - # Exercise - result = get_artifact_text(artifact) - - # Verify - assert result == 'Hello world' - - def test_get_artifact_text_multiple_parts(self): - # Setup - artifact = Artifact( - name='test-artifact', - parts=[ - Part(text='First line'), - Part(text='Second line'), - Part(text='Third line'), - ], - artifact_id='test-artifact-id', - ) - - # Exercise - result = get_artifact_text(artifact) - - # Verify - default delimiter is newline - assert result == 'First line\nSecond line\nThird line' - - def test_get_artifact_text_custom_delimiter(self): - # Setup - artifact = Artifact( - name='test-artifact', - parts=[ - Part(text='First part'), - Part(text='Second part'), - Part(text='Third part'), - ], - artifact_id='test-artifact-id', - ) - - # Exercise - result = get_artifact_text(artifact, delimiter=' | ') - - # Verify - assert result == 'First part | Second part | Third part' - - def test_get_artifact_text_empty_parts(self): - # Setup - artifact = Artifact( - name='test-artifact', - parts=[], - artifact_id='test-artifact-id', - ) - - # Exercise - result = get_artifact_text(artifact) - - # Verify - assert result == '' - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/utils/test_helpers.py b/tests/utils/test_helpers.py index d8a85fcd9..c2c990c0d 100644 --- a/tests/utils/test_helpers.py +++ b/tests/utils/test_helpers.py @@ -22,14 +22,9 @@ TaskStatus, ) from a2a.utils.errors import UnsupportedOperationError -from a2a.utils.helpers import ( - _clean_empty, - append_artifact_to_task, - are_modalities_compatible, - build_text_artifact, - canonicalize_agent_card, - create_task_obj, -) + +from a2a.utils.signing import _clean_empty, _canonicalize_agent_card +from a2a.server.tasks.task_manager import append_artifact_to_task # --- Helper Functions --- @@ -90,62 +85,6 @@ def create_test_task( } -# Test create_task_obj -def test_create_task_obj(): - message = create_test_message() - message.context_id = 'test-context' # Set context_id to test it's preserved - send_params = SendMessageRequest(message=message) - - task = create_task_obj(send_params) - assert task.id is not None - assert task.context_id == message.context_id - assert task.status.state == TaskState.TASK_STATE_SUBMITTED - assert len(task.history) == 1 - assert task.history[0] == message - - -def test_create_task_obj_generates_context_id(): - """Test that create_task_obj generates context_id if not present and uses it for the task.""" - # Message without context_id - message_no_context_id = Message( - role=Role.ROLE_USER, - parts=[Part(text='test')], - message_id='msg-no-ctx', - task_id='task-from-msg', # Provide a task_id to differentiate from generated task.id - ) - send_params = SendMessageRequest(message=message_no_context_id) - - # Ensure message.context_id is empty initially (proto default is empty string) - assert send_params.message.context_id == '' - - known_task_uuid = uuid.UUID('11111111-1111-1111-1111-111111111111') - known_context_uuid = uuid.UUID('22222222-2222-2222-2222-222222222222') - - # Patch uuid.uuid4 to return specific UUIDs in sequence - # The first call will be for message.context_id (if empty), the second for task.id. - with patch( - 'a2a.utils.helpers.uuid4', - side_effect=[known_context_uuid, known_task_uuid], - ) as mock_uuid4: - task = create_task_obj(send_params) - - # Assert that uuid4 was called twice (once for context_id, once for task.id) - assert mock_uuid4.call_count == 2 - - # Assert that message.context_id was set to the first generated UUID - assert send_params.message.context_id == str(known_context_uuid) - - # Assert that task.context_id is the same generated UUID - assert task.context_id == str(known_context_uuid) - - # Assert that task.id is the second generated UUID - assert task.id == str(known_task_uuid) - - # Ensure the original message in history also has the updated context_id - assert len(task.history) == 1 - assert task.history[0].context_id == str(known_context_uuid) - - # Test append_artifact_to_task def test_append_artifact_to_task(): # Prepare base task @@ -243,6 +182,10 @@ def test_append_artifact_to_task(): assert len(task.artifacts[1].parts) == 1 +def build_text_artifact(text: str, artifact_id: str) -> Artifact: + return Artifact(artifact_id=artifact_id, parts=[Part(text=text)]) + + # Test build_text_artifact def test_build_text_artifact(): artifact_id = 'text_artifact' @@ -254,111 +197,6 @@ def test_build_text_artifact(): assert artifact.parts[0].text == text -# Tests for are_modalities_compatible -def test_are_modalities_compatible_client_none(): - assert ( - are_modalities_compatible( - client_output_modes=None, server_output_modes=['text/plain'] - ) - is True - ) - - -def test_are_modalities_compatible_client_empty(): - assert ( - are_modalities_compatible( - client_output_modes=[], server_output_modes=['text/plain'] - ) - is True - ) - - -def test_are_modalities_compatible_server_none(): - assert ( - are_modalities_compatible( - server_output_modes=None, client_output_modes=['text/plain'] - ) - is True - ) - - -def test_are_modalities_compatible_server_empty(): - assert ( - are_modalities_compatible( - server_output_modes=[], client_output_modes=['text/plain'] - ) - is True - ) - - -def test_are_modalities_compatible_common_mode(): - assert ( - are_modalities_compatible( - server_output_modes=['text/plain', 'application/json'], - client_output_modes=['application/json', 'image/png'], - ) - is True - ) - - -def test_are_modalities_compatible_no_common_modes(): - assert ( - are_modalities_compatible( - server_output_modes=['text/plain'], - client_output_modes=['application/json'], - ) - is False - ) - - -def test_are_modalities_compatible_exact_match(): - assert ( - are_modalities_compatible( - server_output_modes=['text/plain'], - client_output_modes=['text/plain'], - ) - is True - ) - - -def test_are_modalities_compatible_server_more_but_common(): - assert ( - are_modalities_compatible( - server_output_modes=['text/plain', 'image/jpeg'], - client_output_modes=['text/plain'], - ) - is True - ) - - -def test_are_modalities_compatible_client_more_but_common(): - assert ( - are_modalities_compatible( - server_output_modes=['text/plain'], - client_output_modes=['text/plain', 'image/jpeg'], - ) - is True - ) - - -def test_are_modalities_compatible_both_none(): - assert ( - are_modalities_compatible( - server_output_modes=None, client_output_modes=None - ) - is True - ) - - -def test_are_modalities_compatible_both_empty(): - assert ( - are_modalities_compatible( - server_output_modes=[], client_output_modes=[] - ) - is True - ) - - def test_canonicalize_agent_card(): """Test canonicalize_agent_card with defaults, optionals, and exceptions. @@ -375,7 +213,7 @@ def test_canonicalize_agent_card(): '"supportedInterfaces":[{"protocolBinding":"HTTP+JSON","url":"http://localhost"}],' '"version":"1.0.0"}' ) - result = canonicalize_agent_card(agent_card) + result = _canonicalize_agent_card(agent_card) assert result == expected_jcs @@ -390,7 +228,7 @@ def test_canonicalize_agent_card_preserves_false_capability(): ), } ) - result = canonicalize_agent_card(card) + result = _canonicalize_agent_card(card) assert '"streaming":false' in result diff --git a/tests/utils/test_message.py b/tests/utils/test_message.py deleted file mode 100644 index c90d422aa..000000000 --- a/tests/utils/test_message.py +++ /dev/null @@ -1,209 +0,0 @@ -import uuid - -from unittest.mock import patch - -from google.protobuf.struct_pb2 import Struct, Value - -from a2a.types.a2a_pb2 import ( - Message, - Part, - Role, -) -from a2a.utils.message import ( - get_message_text, - new_agent_parts_message, - new_agent_text_message, -) - - -class TestNewAgentTextMessage: - def test_new_agent_text_message_basic(self): - # Setup - text = "Hello, I'm an agent" - - # Exercise - with a fixed uuid for testing - with patch( - 'uuid.uuid4', - return_value=uuid.UUID('12345678-1234-5678-1234-567812345678'), - ): - message = new_agent_text_message(text) - - # Verify - assert message.role == Role.ROLE_AGENT - assert len(message.parts) == 1 - assert message.parts[0].text == text - assert message.message_id == '12345678-1234-5678-1234-567812345678' - assert message.task_id == '' - assert message.context_id == '' - - def test_new_agent_text_message_with_context_id(self): - # Setup - text = 'Message with context' - context_id = 'test-context-id' - - # Exercise - with patch( - 'uuid.uuid4', - return_value=uuid.UUID('12345678-1234-5678-1234-567812345678'), - ): - message = new_agent_text_message(text, context_id=context_id) - - # Verify - assert message.role == Role.ROLE_AGENT - assert message.parts[0].text == text - assert message.message_id == '12345678-1234-5678-1234-567812345678' - assert message.context_id == context_id - assert message.task_id == '' - - def test_new_agent_text_message_with_task_id(self): - # Setup - text = 'Message with task id' - task_id = 'test-task-id' - - # Exercise - with patch( - 'uuid.uuid4', - return_value=uuid.UUID('12345678-1234-5678-1234-567812345678'), - ): - message = new_agent_text_message(text, task_id=task_id) - - # Verify - assert message.role == Role.ROLE_AGENT - assert message.parts[0].text == text - assert message.message_id == '12345678-1234-5678-1234-567812345678' - assert message.task_id == task_id - assert message.context_id == '' - - def test_new_agent_text_message_with_both_ids(self): - # Setup - text = 'Message with both ids' - context_id = 'test-context-id' - task_id = 'test-task-id' - - # Exercise - with patch( - 'uuid.uuid4', - return_value=uuid.UUID('12345678-1234-5678-1234-567812345678'), - ): - message = new_agent_text_message( - text, context_id=context_id, task_id=task_id - ) - - # Verify - assert message.role == Role.ROLE_AGENT - assert message.parts[0].text == text - assert message.message_id == '12345678-1234-5678-1234-567812345678' - assert message.context_id == context_id - assert message.task_id == task_id - - def test_new_agent_text_message_empty_text(self): - # Setup - text = '' - - # Exercise - with patch( - 'uuid.uuid4', - return_value=uuid.UUID('12345678-1234-5678-1234-567812345678'), - ): - message = new_agent_text_message(text) - - # Verify - assert message.role == Role.ROLE_AGENT - assert message.parts[0].text == '' - assert message.message_id == '12345678-1234-5678-1234-567812345678' - - -class TestNewAgentPartsMessage: - def test_new_agent_parts_message(self): - """Test creating an agent message with multiple, mixed parts.""" - # Setup - data = Struct() - data.update({'product_id': 123, 'quantity': 2}) - parts = [ - Part(text='Here is some text.'), - Part(data=Value(struct_value=data)), - ] - context_id = 'ctx-multi-part' - task_id = 'task-multi-part' - - # Exercise - with patch( - 'uuid.uuid4', - return_value=uuid.UUID('abcdefab-cdef-abcd-efab-cdefabcdefab'), - ): - message = new_agent_parts_message( - parts, context_id=context_id, task_id=task_id - ) - - # Verify - assert message.role == Role.ROLE_AGENT - assert len(message.parts) == len(parts) - assert message.context_id == context_id - assert message.task_id == task_id - assert message.message_id == 'abcdefab-cdef-abcd-efab-cdefabcdefab' - - -class TestGetMessageText: - def test_get_message_text_single_part(self): - # Setup - message = Message( - role=Role.ROLE_AGENT, - parts=[Part(text='Hello world')], - message_id='test-message-id', - ) - - # Exercise - result = get_message_text(message) - - # Verify - assert result == 'Hello world' - - def test_get_message_text_multiple_parts(self): - # Setup - message = Message( - role=Role.ROLE_AGENT, - parts=[ - Part(text='First line'), - Part(text='Second line'), - Part(text='Third line'), - ], - message_id='test-message-id', - ) - - # Exercise - result = get_message_text(message) - - # Verify - default delimiter is newline - assert result == 'First line\nSecond line\nThird line' - - def test_get_message_text_custom_delimiter(self): - # Setup - message = Message( - role=Role.ROLE_AGENT, - parts=[ - Part(text='First part'), - Part(text='Second part'), - Part(text='Third part'), - ], - message_id='test-message-id', - ) - - # Exercise - result = get_message_text(message, delimiter=' | ') - - # Verify - assert result == 'First part | Second part | Third part' - - def test_get_message_text_empty_parts(self): - # Setup - message = Message( - role=Role.ROLE_AGENT, - parts=[], - message_id='test-message-id', - ) - - # Exercise - result = get_message_text(message) - - # Verify - assert result == '' diff --git a/tests/utils/test_parts.py b/tests/utils/test_parts.py deleted file mode 100644 index a7a24e225..000000000 --- a/tests/utils/test_parts.py +++ /dev/null @@ -1,184 +0,0 @@ -from google.protobuf.struct_pb2 import Struct, Value -from a2a.types.a2a_pb2 import ( - Part, -) -from a2a.utils.parts import ( - get_data_parts, - get_file_parts, - get_text_parts, -) - - -class TestGetTextParts: - def test_get_text_parts_single_text_part(self): - # Setup - parts = [Part(text='Hello world')] - - # Exercise - result = get_text_parts(parts) - - # Verify - assert result == ['Hello world'] - - def test_get_text_parts_multiple_text_parts(self): - # Setup - parts = [ - Part(text='First part'), - Part(text='Second part'), - Part(text='Third part'), - ] - - # Exercise - result = get_text_parts(parts) - - # Verify - assert result == ['First part', 'Second part', 'Third part'] - - def test_get_text_parts_empty_list(self): - # Setup - parts = [] - - # Exercise - result = get_text_parts(parts) - - # Verify - assert result == [] - - -class TestGetDataParts: - def test_get_data_parts_single_data_part(self): - # Setup - data = Struct() - data.update({'key': 'value'}) - parts = [Part(data=Value(struct_value=data))] - - # Exercise - result = get_data_parts(parts) - - # Verify - assert result == [{'key': 'value'}] - - def test_get_data_parts_multiple_data_parts(self): - # Setup - data1 = Struct() - data1.update({'key1': 'value1'}) - data2 = Struct() - data2.update({'key2': 'value2'}) - parts = [ - Part(data=Value(struct_value=data1)), - Part(data=Value(struct_value=data2)), - ] - - # Exercise - result = get_data_parts(parts) - - # Verify - assert result == [{'key1': 'value1'}, {'key2': 'value2'}] - - def test_get_data_parts_mixed_parts(self): - # Setup - data1 = Struct() - data1.update({'key1': 'value1'}) - data2 = Struct() - data2.update({'key2': 'value2'}) - parts = [ - Part(text='some text'), - Part(data=Value(struct_value=data1)), - Part(data=Value(struct_value=data2)), - ] - - # Exercise - result = get_data_parts(parts) - - # Verify - assert result == [{'key1': 'value1'}, {'key2': 'value2'}] - - def test_get_data_parts_no_data_parts(self): - # Setup - parts = [ - Part(text='some text'), - ] - - # Exercise - result = get_data_parts(parts) - - # Verify - assert result == [] - - def test_get_data_parts_empty_list(self): - # Setup - parts = [] - - # Exercise - result = get_data_parts(parts) - - # Verify - assert result == [] - - -class TestGetFileParts: - def test_get_file_parts_single_file_part(self): - # Setup - parts = [Part(url='file://path/to/file', media_type='text/plain')] - - # Exercise - result = get_file_parts(parts) - - # Verify - assert len(result) == 1 - assert result[0].url == 'file://path/to/file' - assert result[0].media_type == 'text/plain' - - def test_get_file_parts_multiple_file_parts(self): - # Setup - parts = [ - Part(url='file://path/to/file1', media_type='text/plain'), - Part(raw=b'file content', media_type='application/octet-stream'), - ] - - # Exercise - result = get_file_parts(parts) - - # Verify - assert len(result) == 2 - assert result[0].url == 'file://path/to/file1' - assert result[1].raw == b'file content' - - def test_get_file_parts_mixed_parts(self): - # Setup - parts = [ - Part(text='some text'), - Part(url='file://path/to/file', media_type='text/plain'), - ] - - # Exercise - result = get_file_parts(parts) - - # Verify - assert len(result) == 1 - assert result[0].url == 'file://path/to/file' - - def test_get_file_parts_no_file_parts(self): - # Setup - data = Struct() - data.update({'key': 'value'}) - parts = [ - Part(text='some text'), - Part(data=Value(struct_value=data)), - ] - - # Exercise - result = get_file_parts(parts) - - # Verify - assert result == [] - - def test_get_file_parts_empty_list(self): - # Setup - parts = [] - - # Exercise - result = get_file_parts(parts) - - # Verify - assert result == [] diff --git a/tests/utils/test_task.py b/tests/utils/test_task.py index 3e1f3c058..55dc8ed4f 100644 --- a/tests/utils/test_task.py +++ b/tests/utils/test_task.py @@ -14,197 +14,16 @@ GetTaskRequest, SendMessageConfiguration, ) +from a2a.helpers.proto_helpers import new_task from a2a.utils.task import ( apply_history_length, - completed_task, decode_page_token, encode_page_token, - new_task, ) from a2a.utils.errors import InvalidParamsError class TestTask(unittest.TestCase): - def test_new_task_status(self): - message = Message( - role=Role.ROLE_USER, - parts=[Part(text='test message')], - message_id=str(uuid.uuid4()), - ) - task = new_task(message) - self.assertEqual(task.status.state, TaskState.TASK_STATE_SUBMITTED) - - @patch('uuid.uuid4') - def test_new_task_generates_ids(self, mock_uuid4): - mock_uuid = uuid.UUID('12345678-1234-5678-1234-567812345678') - mock_uuid4.return_value = mock_uuid - message = Message( - role=Role.ROLE_USER, - parts=[Part(text='test message')], - message_id=str(uuid.uuid4()), - ) - task = new_task(message) - self.assertEqual(task.id, str(mock_uuid)) - self.assertEqual(task.context_id, str(mock_uuid)) - - def test_new_task_uses_provided_ids(self): - task_id = str(uuid.uuid4()) - context_id = str(uuid.uuid4()) - message = Message( - role=Role.ROLE_USER, - parts=[Part(text='test message')], - message_id=str(uuid.uuid4()), - task_id=task_id, - context_id=context_id, - ) - task = new_task(message) - self.assertEqual(task.id, task_id) - self.assertEqual(task.context_id, context_id) - - def test_new_task_initial_message_in_history(self): - message = Message( - role=Role.ROLE_USER, - parts=[Part(text='test message')], - message_id=str(uuid.uuid4()), - ) - task = new_task(message) - self.assertEqual(len(task.history), 1) - self.assertEqual(task.history[0], message) - - def test_completed_task_status(self): - task_id = str(uuid.uuid4()) - context_id = str(uuid.uuid4()) - artifacts = [ - Artifact( - artifact_id='artifact_1', - parts=[Part(text='some content')], - ) - ] - task = completed_task( - task_id=task_id, - context_id=context_id, - artifacts=artifacts, - history=[], - ) - self.assertEqual(task.status.state, TaskState.TASK_STATE_COMPLETED) - - def test_completed_task_assigns_ids_and_artifacts(self): - task_id = str(uuid.uuid4()) - context_id = str(uuid.uuid4()) - artifacts = [ - Artifact( - artifact_id='artifact_1', - parts=[Part(text='some content')], - ) - ] - task = completed_task( - task_id=task_id, - context_id=context_id, - artifacts=artifacts, - history=[], - ) - self.assertEqual(task.id, task_id) - self.assertEqual(task.context_id, context_id) - self.assertEqual(len(task.artifacts), len(artifacts)) - - def test_completed_task_empty_history_if_not_provided(self): - task_id = str(uuid.uuid4()) - context_id = str(uuid.uuid4()) - artifacts = [ - Artifact( - artifact_id='artifact_1', - parts=[Part(text='some content')], - ) - ] - task = completed_task( - task_id=task_id, context_id=context_id, artifacts=artifacts - ) - self.assertEqual(len(task.history), 0) - - def test_completed_task_uses_provided_history(self): - task_id = str(uuid.uuid4()) - context_id = str(uuid.uuid4()) - artifacts = [ - Artifact( - artifact_id='artifact_1', - parts=[Part(text='some content')], - ) - ] - history = [ - Message( - role=Role.ROLE_USER, - parts=[Part(text='Hello')], - message_id=str(uuid.uuid4()), - ), - Message( - role=Role.ROLE_AGENT, - parts=[Part(text='Hi there')], - message_id=str(uuid.uuid4()), - ), - ] - task = completed_task( - task_id=task_id, - context_id=context_id, - artifacts=artifacts, - history=history, - ) - self.assertEqual(len(task.history), len(history)) - - def test_new_task_invalid_message_empty_parts(self): - with self.assertRaises(ValueError): - new_task( - Message( - role=Role.ROLE_USER, - parts=[], - message_id=str(uuid.uuid4()), - ) - ) - - def test_new_task_invalid_message_empty_content(self): - with self.assertRaises(ValueError): - new_task( - Message( - role=Role.ROLE_USER, - parts=[Part(text='')], - message_id=str(uuid.uuid4()), - ) - ) - - def test_new_task_invalid_message_none_role(self): - # Proto messages always have a default role (ROLE_UNSPECIFIED = 0) - # Testing with unspecified role - msg = Message( - role=Role.ROLE_UNSPECIFIED, - parts=[Part(text='test message')], - message_id=str(uuid.uuid4()), - ) - with self.assertRaises((TypeError, ValueError)): - new_task(msg) - - def test_completed_task_empty_artifacts(self): - with pytest.raises( - ValueError, - match='artifacts must be a non-empty list of Artifact objects', - ): - completed_task( - task_id='task-123', - context_id='ctx-456', - artifacts=[], - history=[], - ) - - def test_completed_task_invalid_artifact_type(self): - with pytest.raises( - ValueError, - match='artifacts must be a non-empty list of Artifact objects', - ): - completed_task( - task_id='task-123', - context_id='ctx-456', - artifacts=['not an artifact'], # type: ignore[arg-type] - history=[], - ) - page_token = 'd47a95ba-0f39-4459-965b-3923cdd2ff58' encoded_page_token = 'ZDQ3YTk1YmEtMGYzOS00NDU5LTk2NWItMzkyM2NkZDJmZjU4' # base64 for 'd47a95ba-0f39-4459-965b-3923cdd2ff58' @@ -234,9 +53,10 @@ def setUp(self): for i in range(5) ] artifacts = [Artifact(artifact_id='a1', parts=[Part(text='a')])] - self.task = completed_task( + self.task = new_task( task_id='t1', context_id='c1', + state=TaskState.TASK_STATE_COMPLETED, artifacts=artifacts, history=self.history, ) From f6610fa35e1f5fbc3e7e6cd9e29a5177a538eb4e Mon Sep 17 00:00:00 2001 From: Iva Sokolaj <102302011+sokoliva@users.noreply.github.com> Date: Fri, 17 Apr 2026 15:05:29 +0200 Subject: [PATCH 35/67] docs: move `ai_learnings.md` to local-only and update `GEMINI.md` (#982) # Description: `docs/ai/ai_learnings.md` is a personal AI workflow log that should not be shared in the repository. This PR: - Removes `docs/ai/ai_learnings.md` from git tracking (file remains local, already listed in `.gitignore`) - Updates `GEMINI.md` section 5 (Mistake Reflection Protocol) to include the file description, its local-only nature, and the entry format that was previously defined in the file itself --- GEMINI.md | 18 ++++++++++++------ docs/ai/ai_learnings.md | 19 ------------------- 2 files changed, 12 insertions(+), 25 deletions(-) delete mode 100644 docs/ai/ai_learnings.md diff --git a/GEMINI.md b/GEMINI.md index b801bd47d..e6bf43b65 100644 --- a/GEMINI.md +++ b/GEMINI.md @@ -26,16 +26,22 @@ ## 5. Mistake Reflection Protocol +> [!NOTE] for Users: +> `docs/ai/ai_learnings.md` is a local-only file (excluded from git) meant to be +> read by the developer to improve AI assistant behavior on this project. Use its +> findings to improve the GEMINI.md setup. + When you realise you have made a mistake — whether caught by the user, by a tool, or by your own reasoning — you MUST: 1. **Acknowledge the mistake explicitly** and explain what went wrong. -2. **Reflect on the root cause**: was it a missing check, a false - assumption, skipped verification, or a gap in the workflow? -3. **Immediately append a new entry to @./docs/ai/ai_learnings.md** - following the format defined in that file. This is not optional and - does not require user confirmation. Do it before continuing. Update user - about the changes to the workflow in the current chat. +2. **Reflect on the root cause**: was it a missing check, a false assumption, skipped verification, or a gap in the workflow? +3. **Immediately append a new entry to `docs/ai/ai_learnings.md`** — this is not optional and does not require user confirmation. Do it before continuing, then update the user about the workflow change. + + **Entry format:** + - **Mistake**: What went wrong. + - **Root cause**: Why it happened. + - **Rule**: The concrete rule added to prevent recurrence. The goal is to treat every mistake as a signal that the workflow is incomplete, and to improve it in place so the same mistake cannot diff --git a/docs/ai/ai_learnings.md b/docs/ai/ai_learnings.md deleted file mode 100644 index 9e9a37a9f..000000000 --- a/docs/ai/ai_learnings.md +++ /dev/null @@ -1,19 +0,0 @@ -> [!NOTE] for Users: -> This document is meant to be read by an AI assistant (Gemini) in order to -> learn from its mistakes and improve its behavior on this project. Use -> its findings to improve GEMINI.md setup. - -# AI Learnings - -A living record of mistakes made during this project and the rules -derived from them. Every entry must follow the format below. - ---- - -## Entry format - -**Mistake**: What went wrong. -**Root cause**: Why it happened. -**Rule**: The concrete rule added to prevent recurrence. - ---- From b8df210b00d0f249ca68f0d814191c4205e18b35 Mon Sep 17 00:00:00 2001 From: Ivan Shymko Date: Fri, 17 Apr 2026 15:22:50 +0200 Subject: [PATCH 36/67] fix(extensions): support both header names and remove "activation" concept (#984) 1.0 spec uses `A2A-Extensions` instead of `X-A2A-Extensions` header/metadata name. It also doesn't have "activation" concept - used extensions are not propagated back via headers and should be put into message.extensions or artifact.extensions instead. 1. Support both in for compat server. 2. Send `X-A2A-Extensions` in compat transports. 3. Remove "activation". --- src/a2a/compat/v0_3/context_builders.py | 80 +++++++++ src/a2a/compat/v0_3/extension_headers.py | 27 +++ src/a2a/compat/v0_3/grpc_handler.py | 19 +-- src/a2a/compat/v0_3/grpc_transport.py | 5 +- src/a2a/compat/v0_3/jsonrpc_adapter.py | 3 +- src/a2a/compat/v0_3/jsonrpc_transport.py | 3 + src/a2a/compat/v0_3/rest_adapter.py | 3 +- src/a2a/compat/v0_3/rest_transport.py | 3 + src/a2a/extensions/common.py | 2 +- src/a2a/server/agent_execution/context.py | 10 +- src/a2a/server/context.py | 1 - .../server/request_handlers/grpc_handler.py | 15 -- src/a2a/server/routes/jsonrpc_dispatcher.py | 12 +- .../client/transports/test_jsonrpc_client.py | 4 +- tests/client/transports/test_rest_client.py | 6 +- tests/compat/v0_3/test_context_builders.py | 159 ++++++++++++++++++ tests/compat/v0_3/test_extension_headers.py | 39 +++++ tests/compat/v0_3/test_grpc_handler.py | 20 --- tests/compat/v0_3/test_grpc_transport.py | 28 +++ tests/compat/v0_3/test_jsonrpc_transport.py | 26 +++ .../test_client_server_integration.py | 4 +- tests/integration/test_end_to_end.py | 10 +- tests/server/agent_execution/test_context.py | 8 +- .../request_handlers/test_grpc_handler.py | 38 +---- .../server/routes/test_jsonrpc_dispatcher.py | 25 --- 25 files changed, 395 insertions(+), 155 deletions(-) create mode 100644 src/a2a/compat/v0_3/context_builders.py create mode 100644 src/a2a/compat/v0_3/extension_headers.py create mode 100644 tests/compat/v0_3/test_context_builders.py create mode 100644 tests/compat/v0_3/test_extension_headers.py diff --git a/src/a2a/compat/v0_3/context_builders.py b/src/a2a/compat/v0_3/context_builders.py new file mode 100644 index 000000000..2f2eec362 --- /dev/null +++ b/src/a2a/compat/v0_3/context_builders.py @@ -0,0 +1,80 @@ +"""Context builders that add v0.3 backwards-compatibility for extensions. + +The current spec uses ``A2A-Extensions`` (RFC 6648, no ``X-`` prefix). v0.3 +clients still send the old ``X-A2A-Extensions`` name, so the v0.3 compat +adapters wrap the default builders with these classes to recognize both names. +""" + +from typing import TYPE_CHECKING, Any + +import grpc + +from a2a.compat.v0_3.extension_headers import LEGACY_HTTP_EXTENSION_HEADER +from a2a.extensions.common import get_requested_extensions +from a2a.server.context import ServerCallContext + + +if TYPE_CHECKING: + from starlette.requests import Request + + from a2a.server.request_handlers.grpc_handler import ( + GrpcServerCallContextBuilder, + ) + from a2a.server.routes.common import ServerCallContextBuilder +else: + try: + from starlette.requests import Request + except ImportError: + Request = Any + + +def _get_legacy_grpc_extensions( + context: grpc.aio.ServicerContext, +) -> list[str]: + md = context.invocation_metadata() + if md is None: + return [] + lower_key = LEGACY_HTTP_EXTENSION_HEADER.lower() + return [ + e if isinstance(e, str) else e.decode('utf-8') + for k, e in md + if k.lower() == lower_key + ] + + +class V03ServerCallContextBuilder: + """Wraps a ServerCallContextBuilder to also accept the legacy header. + + Recognizes the v0.3 ``X-A2A-Extensions`` HTTP header in addition to the + spec ``A2A-Extensions``. + """ + + def __init__(self, inner: 'ServerCallContextBuilder') -> None: + self._inner = inner + + def build(self, request: 'Request') -> ServerCallContext: + """Builds a ServerCallContext, merging legacy extension headers.""" + context = self._inner.build(request) + context.requested_extensions |= get_requested_extensions( + request.headers.getlist(LEGACY_HTTP_EXTENSION_HEADER) + ) + return context + + +class V03GrpcServerCallContextBuilder: + """Wraps a GrpcServerCallContextBuilder to also accept the legacy metadata. + + Recognizes the v0.3 ``X-A2A-Extensions`` gRPC metadata key in addition to + the spec ``A2A-Extensions``. + """ + + def __init__(self, inner: 'GrpcServerCallContextBuilder') -> None: + self._inner = inner + + def build(self, context: grpc.aio.ServicerContext) -> ServerCallContext: + """Builds a ServerCallContext, merging legacy extension metadata.""" + server_context = self._inner.build(context) + server_context.requested_extensions |= get_requested_extensions( + _get_legacy_grpc_extensions(context) + ) + return server_context diff --git a/src/a2a/compat/v0_3/extension_headers.py b/src/a2a/compat/v0_3/extension_headers.py new file mode 100644 index 000000000..e1421a0b0 --- /dev/null +++ b/src/a2a/compat/v0_3/extension_headers.py @@ -0,0 +1,27 @@ +"""Shared header name constants for v0.3 extension compatibility. + +The current spec uses ``A2A-Extensions``. v0.3 used the ``X-`` prefixed +``X-A2A-Extensions`` form. v0.3 compat servers and clients accept/emit both +names so they can interoperate with peers that only know the legacy one. +""" + +from a2a.client.service_parameters import ServiceParameters +from a2a.extensions.common import HTTP_EXTENSION_HEADER + + +LEGACY_HTTP_EXTENSION_HEADER = f'X-{HTTP_EXTENSION_HEADER}' + + +def add_legacy_extension_header(parameters: ServiceParameters) -> None: + """Mirrors the ``A2A-Extensions`` parameter under its legacy name in-place. + + Used by v0.3 compat client transports so that requests can be understood + by older v0.3 servers that only recognize ``X-A2A-Extensions``. + """ + if ( + HTTP_EXTENSION_HEADER in parameters + and LEGACY_HTTP_EXTENSION_HEADER not in parameters + ): + parameters[LEGACY_HTTP_EXTENSION_HEADER] = parameters[ + HTTP_EXTENSION_HEADER + ] diff --git a/src/a2a/compat/v0_3/grpc_handler.py b/src/a2a/compat/v0_3/grpc_handler.py index 23d1f831d..b7bec26ea 100644 --- a/src/a2a/compat/v0_3/grpc_handler.py +++ b/src/a2a/compat/v0_3/grpc_handler.py @@ -17,8 +17,8 @@ from a2a.compat.v0_3 import ( types as types_v03, ) +from a2a.compat.v0_3.context_builders import V03GrpcServerCallContextBuilder from a2a.compat.v0_3.request_handler import RequestHandler03 -from a2a.extensions.common import HTTP_EXTENSION_HEADER from a2a.server.context import ServerCallContext from a2a.server.request_handlers.grpc_handler import ( _ERROR_CODE_MAP, @@ -51,7 +51,7 @@ def __init__( DefaultCallContextBuilder is used. """ self.handler03 = RequestHandler03(request_handler=request_handler) - self._context_builder = ( + self._context_builder = V03GrpcServerCallContextBuilder( context_builder or DefaultGrpcServerCallContextBuilder() ) @@ -65,7 +65,6 @@ async def _handle_unary( try: server_context = self._context_builder.build(context) result = await handler_func(server_context) - self._set_extension_metadata(context, server_context) except A2AError as e: await self.abort_context(e, context) else: @@ -82,7 +81,6 @@ async def _handle_stream( server_context = self._context_builder.build(context) async for item in handler_func(server_context): yield item - self._set_extension_metadata(context, server_context) except A2AError as e: await self.abort_context(e, context) @@ -120,19 +118,6 @@ async def abort_context( f'Unknown error type: {error}', ) - def _set_extension_metadata( - self, - context: grpc.aio.ServicerContext, - server_context: ServerCallContext, - ) -> None: - if server_context.activated_extensions: - context.set_trailing_metadata( - [ - (HTTP_EXTENSION_HEADER.lower(), e) - for e in sorted(server_context.activated_extensions) - ] - ) - async def SendMessage( self, request: a2a_v0_3_pb2.SendMessageRequest, diff --git a/src/a2a/compat/v0_3/grpc_transport.py b/src/a2a/compat/v0_3/grpc_transport.py index 32ce7f27b..95314e3f1 100644 --- a/src/a2a/compat/v0_3/grpc_transport.py +++ b/src/a2a/compat/v0_3/grpc_transport.py @@ -30,6 +30,7 @@ from a2a.compat.v0_3 import ( types as types_v03, ) +from a2a.compat.v0_3.extension_headers import add_legacy_extension_header from a2a.types import a2a_pb2 from a2a.utils.constants import PROTOCOL_VERSION_0_3, VERSION_HEADER from a2a.utils.telemetry import SpanKind, trace_class @@ -361,7 +362,9 @@ def _get_grpc_metadata( metadata = [(VERSION_HEADER.lower(), PROTOCOL_VERSION_0_3)] if context and context.service_parameters: - for key, value in context.service_parameters.items(): + params = dict(context.service_parameters) + add_legacy_extension_header(params) + for key, value in params.items(): metadata.append((key.lower(), value)) return metadata diff --git a/src/a2a/compat/v0_3/jsonrpc_adapter.py b/src/a2a/compat/v0_3/jsonrpc_adapter.py index baa2bcda8..8b4b26a79 100644 --- a/src/a2a/compat/v0_3/jsonrpc_adapter.py +++ b/src/a2a/compat/v0_3/jsonrpc_adapter.py @@ -24,6 +24,7 @@ _package_starlette_installed = False from a2a.compat.v0_3 import types as types_v03 +from a2a.compat.v0_3.context_builders import V03ServerCallContextBuilder from a2a.compat.v0_3.request_handler import RequestHandler03 from a2a.server.context import ServerCallContext from a2a.server.jsonrpc_models import ( @@ -70,7 +71,7 @@ def __init__( self.handler = RequestHandler03( request_handler=http_handler, ) - self._context_builder = ( + self._context_builder = V03ServerCallContextBuilder( context_builder or DefaultServerCallContextBuilder() ) diff --git a/src/a2a/compat/v0_3/jsonrpc_transport.py b/src/a2a/compat/v0_3/jsonrpc_transport.py index 557a63a16..caccd2811 100644 --- a/src/a2a/compat/v0_3/jsonrpc_transport.py +++ b/src/a2a/compat/v0_3/jsonrpc_transport.py @@ -19,6 +19,7 @@ ) from a2a.compat.v0_3 import conversions from a2a.compat.v0_3 import types as types_v03 +from a2a.compat.v0_3.extension_headers import add_legacy_extension_header from a2a.types.a2a_pb2 import ( AgentCard, CancelTaskRequest, @@ -424,6 +425,7 @@ async def _send_stream_request( http_kwargs = get_http_args(context) http_kwargs.setdefault('headers', {}) http_kwargs['headers'][VERSION_HEADER.lower()] = PROTOCOL_VERSION_0_3 + add_legacy_extension_header(http_kwargs['headers']) async for sse_data in send_http_stream_request( self.httpx_client, @@ -485,6 +487,7 @@ async def _send_request( http_kwargs = get_http_args(context) http_kwargs.setdefault('headers', {}) http_kwargs['headers'][VERSION_HEADER.lower()] = PROTOCOL_VERSION_0_3 + add_legacy_extension_header(http_kwargs['headers']) request = self.httpx_client.build_request( 'POST', diff --git a/src/a2a/compat/v0_3/rest_adapter.py b/src/a2a/compat/v0_3/rest_adapter.py index a2a9b56ee..38687054f 100644 --- a/src/a2a/compat/v0_3/rest_adapter.py +++ b/src/a2a/compat/v0_3/rest_adapter.py @@ -31,6 +31,7 @@ _package_starlette_installed = False +from a2a.compat.v0_3.context_builders import V03ServerCallContextBuilder from a2a.compat.v0_3.rest_handler import REST03Handler from a2a.server.routes.common import ( DefaultServerCallContextBuilder, @@ -60,7 +61,7 @@ def __init__( context_builder: 'ServerCallContextBuilder | None' = None, ): self.handler = REST03Handler(request_handler=http_handler) - self._context_builder = ( + self._context_builder = V03ServerCallContextBuilder( context_builder or DefaultServerCallContextBuilder() ) diff --git a/src/a2a/compat/v0_3/rest_transport.py b/src/a2a/compat/v0_3/rest_transport.py index 0ba38538d..bcaed2949 100644 --- a/src/a2a/compat/v0_3/rest_transport.py +++ b/src/a2a/compat/v0_3/rest_transport.py @@ -25,6 +25,7 @@ from a2a.compat.v0_3 import ( types as types_v03, ) +from a2a.compat.v0_3.extension_headers import add_legacy_extension_header from a2a.types.a2a_pb2 import ( AgentCard, CancelTaskRequest, @@ -380,6 +381,7 @@ async def _send_stream_request( http_kwargs = get_http_args(context) http_kwargs.setdefault('headers', {}) http_kwargs['headers'][VERSION_HEADER.lower()] = PROTOCOL_VERSION_0_3 + add_legacy_extension_header(http_kwargs['headers']) async for sse_data in send_http_stream_request( self.httpx_client, @@ -414,6 +416,7 @@ async def _execute_request( http_kwargs = get_http_args(context) http_kwargs.setdefault('headers', {}) http_kwargs['headers'][VERSION_HEADER.lower()] = PROTOCOL_VERSION_0_3 + add_legacy_extension_header(http_kwargs['headers']) request = self.httpx_client.build_request( method, diff --git a/src/a2a/extensions/common.py b/src/a2a/extensions/common.py index 0595216ed..06ccf8f40 100644 --- a/src/a2a/extensions/common.py +++ b/src/a2a/extensions/common.py @@ -1,7 +1,7 @@ from a2a.types.a2a_pb2 import AgentCard, AgentExtension -HTTP_EXTENSION_HEADER = 'X-A2A-Extensions' +HTTP_EXTENSION_HEADER = 'A2A-Extensions' def get_requested_extensions(values: list[str]) -> set[str]: diff --git a/src/a2a/server/agent_execution/context.py b/src/a2a/server/agent_execution/context.py index 8b78c1045..5fcdf8697 100644 --- a/src/a2a/server/agent_execution/context.py +++ b/src/a2a/server/agent_execution/context.py @@ -151,14 +151,6 @@ def metadata(self) -> dict[str, Any]: return dict(self._params.metadata) return {} - def add_activated_extension(self, uri: str) -> None: - """Add an extension to the set of activated extensions for this request. - - This causes the extension to be indicated back to the client in the - response. - """ - self._call_context.activated_extensions.add(uri) - @property def tenant(self) -> str: """The tenant associated with this request.""" @@ -166,7 +158,7 @@ def tenant(self) -> str: @property def requested_extensions(self) -> set[str]: - """Extensions that the client requested to activate.""" + """Extensions that the client requested for this interaction.""" return self._call_context.requested_extensions def _check_or_generate_task_id(self) -> None: diff --git a/src/a2a/server/context.py b/src/a2a/server/context.py index 6196a69d6..833ca44c4 100644 --- a/src/a2a/server/context.py +++ b/src/a2a/server/context.py @@ -23,4 +23,3 @@ class ServerCallContext(BaseModel): user: User = Field(default_factory=UnauthenticatedUser) tenant: str = Field(default='') requested_extensions: set[str] = Field(default_factory=set) - activated_extensions: set[str] = Field(default_factory=set) diff --git a/src/a2a/server/request_handlers/grpc_handler.py b/src/a2a/server/request_handlers/grpc_handler.py index 2ccfa9bdd..8cd421e93 100644 --- a/src/a2a/server/request_handlers/grpc_handler.py +++ b/src/a2a/server/request_handlers/grpc_handler.py @@ -135,7 +135,6 @@ async def _handle_unary( try: server_context = self._build_call_context(context, request) result = await handler_func(server_context) - self._set_extension_metadata(context, server_context) except A2AError as e: await self.abort_context(e, context) else: @@ -153,7 +152,6 @@ async def _handle_stream( server_context = self._build_call_context(context, request) async for item in handler_func(server_context): yield item - self._set_extension_metadata(context, server_context) except A2AError as e: await self.abort_context(e, context) @@ -422,19 +420,6 @@ async def abort_context( f'Unknown error type: {error}', ) - def _set_extension_metadata( - self, - context: grpc.aio.ServicerContext, - server_context: ServerCallContext, - ) -> None: - if server_context.activated_extensions: - context.set_trailing_metadata( - [ - (HTTP_EXTENSION_HEADER.lower(), e) - for e in sorted(server_context.activated_extensions) - ] - ) - def _build_call_context( self, context: grpc.aio.ServicerContext, diff --git a/src/a2a/server/routes/jsonrpc_dispatcher.py b/src/a2a/server/routes/jsonrpc_dispatcher.py index 60620081a..3dc94488a 100644 --- a/src/a2a/server/routes/jsonrpc_dispatcher.py +++ b/src/a2a/server/routes/jsonrpc_dispatcher.py @@ -11,9 +11,6 @@ from jsonrpc.jsonrpc2 import JSONRPC20Request, JSONRPC20Response from a2a.compat.v0_3.jsonrpc_adapter import JSONRPC03Adapter -from a2a.extensions.common import ( - HTTP_EXTENSION_HEADER, -) from a2a.server.context import ServerCallContext from a2a.server.events import Event from a2a.server.jsonrpc_models import ( @@ -570,9 +567,6 @@ def _create_response( Returns: A Starlette JSONResponse or EventSourceResponse. """ - headers = {} - if exts := context.activated_extensions: - headers[HTTP_EXTENSION_HEADER] = ', '.join(sorted(exts)) if isinstance(handler_result, AsyncGenerator): # Result is a stream of dict objects async def event_generator( @@ -603,9 +597,7 @@ async def event_generator( 'data': json.dumps(error_response), } - return EventSourceResponse( - event_generator(handler_result), headers=headers - ) + return EventSourceResponse(event_generator(handler_result)) # handler_result is a dict (JSON-RPC response) - return JSONResponse(handler_result, headers=headers) + return JSONResponse(handler_result) diff --git a/tests/client/transports/test_jsonrpc_client.py b/tests/client/transports/test_jsonrpc_client.py index 1339bb8af..b005c2e05 100644 --- a/tests/client/transports/test_jsonrpc_client.py +++ b/tests/client/transports/test_jsonrpc_client.py @@ -545,7 +545,7 @@ async def test_extensions_added_to_request( from a2a.client.client import ClientCallContext context = ClientCallContext( - service_parameters={'X-A2A-Extensions': 'https://example.com/ext1'} + service_parameters={'A2A-Extensions': 'https://example.com/ext1'} ) await transport.send_message(request, context=context) @@ -555,7 +555,7 @@ async def test_extensions_added_to_request( call_args = mock_httpx_client.build_request.call_args # Extensions should be in the kwargs assert ( - call_args[1].get('headers', {}).get('X-A2A-Extensions') + call_args[1].get('headers', {}).get('A2A-Extensions') == 'https://example.com/ext1' ) diff --git a/tests/client/transports/test_rest_client.py b/tests/client/transports/test_rest_client.py index 0c9f7c30a..1e9398181 100644 --- a/tests/client/transports/test_rest_client.py +++ b/tests/client/transports/test_rest_client.py @@ -257,7 +257,7 @@ async def test_send_message_with_default_extensions( context = ClientCallContext( service_parameters={ - 'X-A2A-Extensions': 'https://example.com/test-ext/v1,https://example.com/test-ext/v2' + 'A2A-Extensions': 'https://example.com/test-ext/v1,https://example.com/test-ext/v2' } ) await client.send_message(request=params, context=context) @@ -281,7 +281,7 @@ async def test_send_message_streaming_with_new_extensions( mock_httpx_client: AsyncMock, mock_agent_card: MagicMock, ): - """Test X-A2A-Extensions header in send_message_streaming.""" + """Test A2A-Extensions header in send_message_streaming.""" client = RestTransport( httpx_client=mock_httpx_client, agent_card=mock_agent_card, @@ -303,7 +303,7 @@ async def test_send_message_streaming_with_new_extensions( context = ClientCallContext( service_parameters={ - 'X-A2A-Extensions': 'https://example.com/test-ext/v2' + 'A2A-Extensions': 'https://example.com/test-ext/v2' } ) diff --git a/tests/compat/v0_3/test_context_builders.py b/tests/compat/v0_3/test_context_builders.py new file mode 100644 index 000000000..1b711f52f --- /dev/null +++ b/tests/compat/v0_3/test_context_builders.py @@ -0,0 +1,159 @@ +from unittest.mock import AsyncMock, MagicMock + +import grpc + +from starlette.datastructures import Headers + +from a2a.compat.v0_3.context_builders import ( + V03GrpcServerCallContextBuilder, + V03ServerCallContextBuilder, +) +from a2a.compat.v0_3.extension_headers import LEGACY_HTTP_EXTENSION_HEADER +from a2a.extensions.common import HTTP_EXTENSION_HEADER +from a2a.server.context import ServerCallContext +from a2a.server.request_handlers.grpc_handler import ( + DefaultGrpcServerCallContextBuilder, +) +from a2a.server.routes.common import DefaultServerCallContextBuilder + + +def _make_mock_request(headers=None): + request = MagicMock() + request.scope = {} + request.headers = Headers(headers or {}) + return request + + +def _make_mock_grpc_context(metadata: list[tuple[str, str]]) -> AsyncMock: + context = AsyncMock(spec=grpc.aio.ServicerContext) + context.invocation_metadata.return_value = grpc.aio.Metadata(*metadata) + return context + + +class TestV03ServerCallContextBuilder: + def test_legacy_header_only(self): + request = _make_mock_request( + headers={LEGACY_HTTP_EXTENSION_HEADER: 'legacy-ext'} + ) + builder = V03ServerCallContextBuilder(DefaultServerCallContextBuilder()) + + ctx = builder.build(request) + + assert isinstance(ctx, ServerCallContext) + assert ctx.requested_extensions == {'legacy-ext'} + + def test_spec_header_only(self): + request = _make_mock_request( + headers={HTTP_EXTENSION_HEADER: 'spec-ext'} + ) + builder = V03ServerCallContextBuilder(DefaultServerCallContextBuilder()) + + ctx = builder.build(request) + + assert ctx.requested_extensions == {'spec-ext'} + + def test_both_headers_merged(self): + request = _make_mock_request( + headers={ + HTTP_EXTENSION_HEADER: 'spec-ext', + LEGACY_HTTP_EXTENSION_HEADER: 'legacy-ext', + } + ) + builder = V03ServerCallContextBuilder(DefaultServerCallContextBuilder()) + + ctx = builder.build(request) + + assert ctx.requested_extensions == {'spec-ext', 'legacy-ext'} + + def test_legacy_header_comma_separated(self): + request = _make_mock_request( + headers={LEGACY_HTTP_EXTENSION_HEADER: 'foo, bar'} + ) + builder = V03ServerCallContextBuilder(DefaultServerCallContextBuilder()) + + ctx = builder.build(request) + + assert ctx.requested_extensions == {'foo', 'bar'} + + def test_no_extensions(self): + request = _make_mock_request() + builder = V03ServerCallContextBuilder(DefaultServerCallContextBuilder()) + + ctx = builder.build(request) + + assert ctx.requested_extensions == set() + + +class TestV03GrpcServerCallContextBuilder: + def test_legacy_metadata_only(self): + context = _make_mock_grpc_context( + [(LEGACY_HTTP_EXTENSION_HEADER.lower(), 'legacy-ext')] + ) + builder = V03GrpcServerCallContextBuilder( + DefaultGrpcServerCallContextBuilder() + ) + + ctx = builder.build(context) + + assert isinstance(ctx, ServerCallContext) + assert ctx.requested_extensions == {'legacy-ext'} + + def test_spec_metadata_only(self): + context = _make_mock_grpc_context( + [(HTTP_EXTENSION_HEADER.lower(), 'spec-ext')] + ) + builder = V03GrpcServerCallContextBuilder( + DefaultGrpcServerCallContextBuilder() + ) + + ctx = builder.build(context) + + assert ctx.requested_extensions == {'spec-ext'} + + def test_both_metadata_merged(self): + context = _make_mock_grpc_context( + [ + (HTTP_EXTENSION_HEADER.lower(), 'spec-ext'), + (LEGACY_HTTP_EXTENSION_HEADER.lower(), 'legacy-ext'), + ] + ) + builder = V03GrpcServerCallContextBuilder( + DefaultGrpcServerCallContextBuilder() + ) + + ctx = builder.build(context) + + assert ctx.requested_extensions == {'spec-ext', 'legacy-ext'} + + def test_legacy_metadata_comma_separated(self): + context = _make_mock_grpc_context( + [(LEGACY_HTTP_EXTENSION_HEADER.lower(), 'foo, bar')] + ) + builder = V03GrpcServerCallContextBuilder( + DefaultGrpcServerCallContextBuilder() + ) + + ctx = builder.build(context) + + assert ctx.requested_extensions == {'foo', 'bar'} + + def test_no_extensions(self): + context = _make_mock_grpc_context([]) + builder = V03GrpcServerCallContextBuilder( + DefaultGrpcServerCallContextBuilder() + ) + + ctx = builder.build(context) + + assert ctx.requested_extensions == set() + + def test_no_metadata(self): + context = AsyncMock(spec=grpc.aio.ServicerContext) + context.invocation_metadata.return_value = None + builder = V03GrpcServerCallContextBuilder( + DefaultGrpcServerCallContextBuilder() + ) + + ctx = builder.build(context) + + assert ctx.requested_extensions == set() diff --git a/tests/compat/v0_3/test_extension_headers.py b/tests/compat/v0_3/test_extension_headers.py new file mode 100644 index 000000000..d5abbdfcc --- /dev/null +++ b/tests/compat/v0_3/test_extension_headers.py @@ -0,0 +1,39 @@ +from a2a.compat.v0_3.extension_headers import ( + LEGACY_HTTP_EXTENSION_HEADER, + add_legacy_extension_header, +) +from a2a.extensions.common import HTTP_EXTENSION_HEADER + + +def test_legacy_header_constant_value(): + assert LEGACY_HTTP_EXTENSION_HEADER == 'X-A2A-Extensions' + + +def test_mirrors_spec_header_under_legacy_name(): + params = {HTTP_EXTENSION_HEADER: 'foo,bar'} + + add_legacy_extension_header(params) + + assert params == { + HTTP_EXTENSION_HEADER: 'foo,bar', + LEGACY_HTTP_EXTENSION_HEADER: 'foo,bar', + } + + +def test_no_op_when_spec_header_absent(): + params = {'Other': 'value'} + + add_legacy_extension_header(params) + + assert params == {'Other': 'value'} + + +def test_does_not_overwrite_existing_legacy_header(): + params = { + HTTP_EXTENSION_HEADER: 'spec', + LEGACY_HTTP_EXTENSION_HEADER: 'legacy-original', + } + + add_legacy_extension_header(params) + + assert params[LEGACY_HTTP_EXTENSION_HEADER] == 'legacy-original' diff --git a/tests/compat/v0_3/test_grpc_handler.py b/tests/compat/v0_3/test_grpc_handler.py index 75c6421e8..fbd74f29f 100644 --- a/tests/compat/v0_3/test_grpc_handler.py +++ b/tests/compat/v0_3/test_grpc_handler.py @@ -7,8 +7,6 @@ a2a_v0_3_pb2, grpc_handler as compat_grpc_handler, ) -from a2a.extensions.common import HTTP_EXTENSION_HEADER -from a2a.server.context import ServerCallContext from a2a.server.request_handlers import RequestHandler from a2a.types import a2a_pb2 from a2a.utils.errors import TaskNotFoundError, InvalidParamsError @@ -506,21 +504,3 @@ async def test_extract_task_and_config_id_invalid( ): with pytest.raises(InvalidParamsError): handler._extract_task_and_config_id('invalid-name') - - -@pytest.mark.asyncio -async def test_handle_unary_extension_metadata( - handler: compat_grpc_handler.CompatGrpcHandler, - mock_request_handler: AsyncMock, - mock_grpc_context: AsyncMock, -) -> None: - async def mock_func(server_context: ServerCallContext): - server_context.activated_extensions.add('ext-1') - return a2a_pb2.Task() - - await handler._handle_unary(mock_grpc_context, mock_func, a2a_pb2.Task()) - - expected_metadata = [(HTTP_EXTENSION_HEADER.lower(), 'ext-1')] - mock_grpc_context.set_trailing_metadata.assert_called_once_with( - expected_metadata - ) diff --git a/tests/compat/v0_3/test_grpc_transport.py b/tests/compat/v0_3/test_grpc_transport.py index ba1e6af3d..402a57000 100644 --- a/tests/compat/v0_3/test_grpc_transport.py +++ b/tests/compat/v0_3/test_grpc_transport.py @@ -2,6 +2,7 @@ import pytest +from a2a.client.client import ClientCallContext from a2a.client.optionals import Channel from a2a.compat.v0_3 import a2a_v0_3_pb2 from a2a.compat.v0_3.grpc_transport import CompatGrpcTransport @@ -38,3 +39,30 @@ async def test_compat_grpc_transport_send_message_response_msg_parsing(): assert isinstance(response, SendMessageResponse) assert response.HasField('message') assert response.message.message_id == 'msg-123' + + +def test_compat_grpc_transport_mirrors_extension_metadata(): + """Compat gRPC client must also emit the legacy x-a2a-extensions metadata + so that v0.3 servers (which only know that name) understand the request.""" + transport = CompatGrpcTransport( + channel=AsyncMock(spec=Channel), agent_card=None + ) + context = ClientCallContext( + service_parameters={'A2A-Extensions': 'foo,bar'} + ) + + metadata = dict(transport._get_grpc_metadata(context)) + + assert metadata['a2a-extensions'] == 'foo,bar' + assert metadata['x-a2a-extensions'] == 'foo,bar' + + +def test_compat_grpc_transport_no_extension_metadata(): + transport = CompatGrpcTransport( + channel=AsyncMock(spec=Channel), agent_card=None + ) + + metadata = dict(transport._get_grpc_metadata(None)) + + assert 'a2a-extensions' not in metadata + assert 'x-a2a-extensions' not in metadata diff --git a/tests/compat/v0_3/test_jsonrpc_transport.py b/tests/compat/v0_3/test_jsonrpc_transport.py index 50b33e162..70291f005 100644 --- a/tests/compat/v0_3/test_jsonrpc_transport.py +++ b/tests/compat/v0_3/test_jsonrpc_transport.py @@ -539,3 +539,29 @@ async def test_compat_jsonrpc_transport_send_request( mock_send_http_request.assert_called_once_with( transport.httpx_client, mock_request, transport._handle_http_error ) + + +@pytest.mark.asyncio +@patch('a2a.compat.v0_3.jsonrpc_transport.send_http_request') +async def test_compat_jsonrpc_transport_mirrors_extension_header( + mock_send_http_request, transport +): + """Compat client must also emit the legacy X-A2A-Extensions header so + that v0.3 servers (which only know that name) understand the request.""" + from a2a.client.client import ClientCallContext + + mock_send_http_request.return_value = {'result': {'ok': True}} + transport.httpx_client.build_request.return_value = httpx.Request( + 'POST', 'http://example.com' + ) + + context = ClientCallContext( + service_parameters={'A2A-Extensions': 'foo,bar'} + ) + + await transport._send_request({'some': 'data'}, context=context) + + _, kwargs = transport.httpx_client.build_request.call_args + headers = kwargs['headers'] + assert headers['A2A-Extensions'] == 'foo,bar' + assert headers['X-A2A-Extensions'] == 'foo,bar' diff --git a/tests/integration/test_client_server_integration.py b/tests/integration/test_client_server_integration.py index 1ac8a7162..76da2e20f 100644 --- a/tests/integration/test_client_server_integration.py +++ b/tests/integration/test_client_server_integration.py @@ -675,9 +675,9 @@ async def test_json_transport_base_client_send_message_with_extensions( call_args[1] if len(call_args) > 1 else call_kwargs.get('context') ) service_params = getattr(called_context, 'service_parameters', {}) - assert 'X-A2A-Extensions' in service_params + assert 'A2A-Extensions' in service_params assert ( - service_params['X-A2A-Extensions'] + service_params['A2A-Extensions'] == 'https://example.com/test-ext/v1,https://example.com/test-ext/v2' ) diff --git a/tests/integration/test_end_to_end.py b/tests/integration/test_end_to_end.py index b6cddbe4d..dcd016b48 100644 --- a/tests/integration/test_end_to_end.py +++ b/tests/integration/test_end_to_end.py @@ -100,19 +100,15 @@ class MockAgentExecutor(AgentExecutor): async def execute(self, context: RequestContext, event_queue: EventQueue): user_input = context.get_user_input() - # Extensions echo: activate all requested extensions and report them - # back via the Message.extensions field. + # Extensions echo: report the requested extensions back to the client + # via the Message.extensions field. if user_input.startswith('Extensions:'): - for ext_uri in context.requested_extensions: - context.add_activated_extension(ext_uri) await event_queue.enqueue_event( Message( role=Role.ROLE_AGENT, message_id='ext-reply-1', parts=[Part(text='extensions echoed')], - extensions=sorted( - context.call_context.activated_extensions - ), + extensions=sorted(context.requested_extensions), ) ) return diff --git a/tests/server/agent_execution/test_context.py b/tests/server/agent_execution/test_context.py index 7ec612986..dce780f58 100644 --- a/tests/server/agent_execution/test_context.py +++ b/tests/server/agent_execution/test_context.py @@ -322,14 +322,8 @@ def test_init_with_context_id_and_existing_context_id_match( assert context.current_task == mock_task def test_extension_handling(self) -> None: - """Test extension handling in RequestContext.""" + """Test that requested_extensions is exposed via RequestContext.""" call_context = ServerCallContext(requested_extensions={'foo', 'bar'}) context = RequestContext(call_context=call_context) assert context.requested_extensions == {'foo', 'bar'} - - context.add_activated_extension('foo') - assert call_context.activated_extensions == {'foo'} - - context.add_activated_extension('baz') - assert call_context.activated_extensions == {'foo', 'baz'} diff --git a/tests/server/request_handlers/test_grpc_handler.py b/tests/server/request_handlers/test_grpc_handler.py index 2b1a37385..d140d3d7b 100644 --- a/tests/server/request_handlers/test_grpc_handler.py +++ b/tests/server/request_handlers/test_grpc_handler.py @@ -421,19 +421,11 @@ async def test_send_message_with_extensions( (HTTP_EXTENSION_HEADER.lower(), 'foo'), (HTTP_EXTENSION_HEADER.lower(), 'bar'), ) - - def side_effect(request, context: ServerCallContext): - context.activated_extensions.add('foo') - context.activated_extensions.add('baz') - return types.Task( - id='task-1', - context_id='ctx-1', - status=types.TaskStatus( - state=types.TaskState.TASK_STATE_COMPLETED - ), - ) - - mock_request_handler.on_message_send.side_effect = side_effect + mock_request_handler.on_message_send.return_value = types.Task( + id='task-1', + context_id='ctx-1', + status=types.TaskStatus(state=types.TaskState.TASK_STATE_COMPLETED), + ) await grpc_handler.SendMessage( a2a_pb2.SendMessageRequest(), mock_grpc_context @@ -444,15 +436,6 @@ def side_effect(request, context: ServerCallContext): assert isinstance(call_context, ServerCallContext) assert call_context.requested_extensions == {'foo', 'bar'} - mock_grpc_context.set_trailing_metadata.assert_called_once() - called_metadata = ( - mock_grpc_context.set_trailing_metadata.call_args.args[0] - ) - assert set(called_metadata) == { - (HTTP_EXTENSION_HEADER.lower(), 'foo'), - (HTTP_EXTENSION_HEADER.lower(), 'baz'), - } - async def test_send_message_with_comma_separated_extensions( self, grpc_handler: GrpcHandler, @@ -490,8 +473,6 @@ async def test_send_streaming_message_with_extensions( ) async def side_effect(request, context: ServerCallContext): - context.activated_extensions.add('foo') - context.activated_extensions.add('baz') yield types.Task( id='task-1', context_id='ctx-1', @@ -517,15 +498,6 @@ async def side_effect(request, context: ServerCallContext): assert isinstance(call_context, ServerCallContext) assert call_context.requested_extensions == {'foo', 'bar'} - mock_grpc_context.set_trailing_metadata.assert_called_once() - called_metadata = ( - mock_grpc_context.set_trailing_metadata.call_args.args[0] - ) - assert set(called_metadata) == { - (HTTP_EXTENSION_HEADER.lower(), 'foo'), - (HTTP_EXTENSION_HEADER.lower(), 'baz'), - } - @pytest.mark.asyncio class TestTenantExtraction: diff --git a/tests/server/routes/test_jsonrpc_dispatcher.py b/tests/server/routes/test_jsonrpc_dispatcher.py index 15d3349cd..7ce73eb2e 100644 --- a/tests/server/routes/test_jsonrpc_dispatcher.py +++ b/tests/server/routes/test_jsonrpc_dispatcher.py @@ -169,31 +169,6 @@ def test_method_added_to_call_context_state(self, client, mock_handler): call_context = mock_handler.on_message_send.call_args[0][1] assert call_context.state['method'] == 'SendMessage' - def test_response_with_activated_extensions(self, client, mock_handler): - def side_effect(request, context: ServerCallContext): - context.activated_extensions.add('foo') - context.activated_extensions.add('baz') - return Message( - message_id='test', - role=Role.ROLE_AGENT, - parts=[Part(text='response message')], - ) - - mock_handler.on_message_send.side_effect = side_effect - - response = client.post( - '/', - json=_make_send_message_request(), - ) - response.raise_for_status() - - assert response.status_code == 200 - assert HTTP_EXTENSION_HEADER in response.headers - assert set(response.headers[HTTP_EXTENSION_HEADER].split(', ')) == { - 'foo', - 'baz', - } - class TestJsonRpcDispatcherTenant: def test_tenant_extraction_from_params(self, client, mock_handler): From f4a0bcdf68107c95e6c0a5e6696e4a7d6e01a03f Mon Sep 17 00:00:00 2001 From: Bartek Wolowiec Date: Fri, 17 Apr 2026 15:38:49 +0200 Subject: [PATCH 37/67] feat!: Raise errors on invalid AgentExecutor behavior. (#979) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes #869 🦕 --------- Co-authored-by: Ivan Shymko --- src/a2a/server/agent_execution/active_task.py | 34 +- .../cross_version/client_server/server_0_3.py | 10 +- .../cross_version/client_server/server_1_0.py | 8 +- .../integration/test_copying_observability.py | 9 +- tests/integration/test_scenarios.py | 615 ++++++++++++++---- .../agent_execution/test_active_task.py | 200 +----- .../test_default_request_handler_v2.py | 22 +- 7 files changed, 566 insertions(+), 332 deletions(-) diff --git a/src/a2a/server/agent_execution/active_task.py b/src/a2a/server/agent_execution/active_task.py index db7bb5146..5479a38c1 100644 --- a/src/a2a/server/agent_execution/active_task.py +++ b/src/a2a/server/agent_execution/active_task.py @@ -36,6 +36,7 @@ TaskStatusUpdateEvent, ) from a2a.utils.errors import ( + InvalidAgentResponseError, InvalidParamsError, TaskNotFoundError, ) @@ -370,13 +371,12 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912 elif isinstance(event, Message): if task_mode is not None: if task_mode: - logger.error( - 'Received Message() object in task mode.' - ) - else: - logger.error( - 'Multiple Message() objects received.' + raise InvalidAgentResponseError( + 'Received Message object in task mode. Use TaskStatusUpdateEvent or TaskArtifactUpdateEvent instead.' ) + raise InvalidAgentResponseError( + 'Multiple Message objects received.' + ) task_mode = False logger.debug( 'Consumer[%s]: Setting result to Message: %s', @@ -385,9 +385,8 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912 ) else: if task_mode is False: - logger.error( - 'Received %s in message mode.', - type(event).__name__, + raise InvalidAgentResponseError( + f'Received {type(event).__name__} in message mode. Use Task with TaskStatusUpdateEvent and TaskArtifactUpdateEvent instead.' ) if isinstance(event, Task): @@ -408,6 +407,18 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912 # Initial task should already contain the message. message_to_save = None else: + if ( + isinstance(event, TaskStatusUpdateEvent) + and not self._task_created.is_set() + ): + task = ( + await self._task_manager.get_task() + ) + if task is None: + raise InvalidAgentResponseError( + f'Agent should enqueue Task before {type(event).__name__} event' + ) + new_task = ( await self._task_manager.ensure_task_id( self._task_id, @@ -434,8 +445,6 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912 if not isinstance(event, Task): await self._task_manager.process(event) - self._task_created.set() - # Check for AUTH_REQUIRED or INPUT_REQUIRED or TERMINAL states new_task = await self._task_manager.get_task() if new_task is None: @@ -496,6 +505,9 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912 await self._push_sender.send_notification( self._task_id, event ) + + self._task_created.set() + finally: if new_task is not None: new_task_copy = Task() diff --git a/tests/integration/cross_version/client_server/server_0_3.py b/tests/integration/cross_version/client_server/server_0_3.py index 7bd5f7e75..875cbb1ca 100644 --- a/tests/integration/cross_version/client_server/server_0_3.py +++ b/tests/integration/cross_version/client_server/server_0_3.py @@ -38,7 +38,7 @@ from starlette.requests import Request from starlette.concurrency import iterate_in_threadpool import time - +from a2a.utils.task import new_task from server_common import CustomLoggingMiddleware @@ -48,12 +48,18 @@ def __init__(self): async def execute(self, context: RequestContext, event_queue: EventQueue): print(f'SERVER: execute called for task {context.task_id}') + + task = new_task(context.message) + task.id = context.task_id + task.context_id = context.context_id + task.status.state = TaskState.working + await event_queue.enqueue_event(task) + task_updater = TaskUpdater( event_queue, context.task_id, context.context_id, ) - await task_updater.update_status(TaskState.submitted) await task_updater.update_status(TaskState.working) text = '' diff --git a/tests/integration/cross_version/client_server/server_1_0.py b/tests/integration/cross_version/client_server/server_1_0.py index e11b1d69d..06f7e5e97 100644 --- a/tests/integration/cross_version/client_server/server_1_0.py +++ b/tests/integration/cross_version/client_server/server_1_0.py @@ -28,6 +28,7 @@ from a2a.utils import TransportProtocol from server_common import CustomLoggingMiddleware from google.protobuf.struct_pb2 import Struct, Value +from a2a.helpers.proto_helpers import new_task_from_user_message class MockAgentExecutor(AgentExecutor): @@ -36,12 +37,17 @@ def __init__(self): async def execute(self, context: RequestContext, event_queue: EventQueue): print(f'SERVER: execute called for task {context.task_id}') + task = new_task_from_user_message(context.message) + task.id = context.task_id + task.context_id = context.context_id + task.status.state = TaskState.TASK_STATE_WORKING + await event_queue.enqueue_event(task) + task_updater = TaskUpdater( event_queue, context.task_id, context.context_id, ) - await task_updater.update_status(TaskState.TASK_STATE_SUBMITTED) await task_updater.update_status(TaskState.TASK_STATE_WORKING) text = '' diff --git a/tests/integration/test_copying_observability.py b/tests/integration/test_copying_observability.py index d5171097a..bc23b4696 100644 --- a/tests/integration/test_copying_observability.py +++ b/tests/integration/test_copying_observability.py @@ -25,6 +25,7 @@ SendMessageRequest, TaskState, ) +from a2a.helpers.proto_helpers import new_task_from_user_message from a2a.utils import TransportProtocol @@ -42,6 +43,12 @@ async def execute(self, context: RequestContext, event_queue: EventQueue): if user_input == 'Init task': # Explicitly save status change to ensure task exists with some state + task = new_task_from_user_message(context.message) + task.id = context.task_id + task.context_id = context.context_id + task.status.state = TaskState.TASK_STATE_WORKING + await event_queue.enqueue_event(task) + await task_updater.update_status( TaskState.TASK_STATE_WORKING, message=task_updater.new_agent_message( @@ -153,6 +160,7 @@ async def test_mutation_observability(agent_card: AgentCard, use_copying: bool): ] event = events[-1] + assert event.HasField('status_update') task_id = event.status_update.task_id # 2. Second message to mutate it @@ -162,7 +170,6 @@ async def test_mutation_observability(agent_card: AgentCard, use_copying: bool): task_id=task_id, parts=[Part(text='Update task without saving it')], ) - _ = [ event async for event in client.send_message( diff --git a/tests/integration/test_scenarios.py b/tests/integration/test_scenarios.py index c50622e5c..6070a672f 100644 --- a/tests/integration/test_scenarios.py +++ b/tests/integration/test_scenarios.py @@ -1,5 +1,6 @@ import asyncio import collections +import contextlib import logging from typing import Any @@ -46,11 +47,13 @@ TaskStatus, TaskStatusUpdateEvent, ) +from a2a.helpers.proto_helpers import new_task_from_user_message from a2a.utils import TransportProtocol from a2a.utils.errors import ( InvalidParamsError, TaskNotCancelableError, TaskNotFoundError, + InvalidAgentResponseError, ) @@ -246,13 +249,9 @@ class DummyAgentExecutor(AgentExecutor): async def execute( self, context: RequestContext, event_queue: EventQueue ): - await event_queue.enqueue_event( - TaskStatusUpdateEvent( - task_id=context.task_id, - context_id=context.context_id, - status=TaskStatus(state=TaskState.TASK_STATE_WORKING), - ) - ) + task = new_task_from_user_message(context.message) + task.status.state = TaskState.TASK_STATE_WORKING + await event_queue.enqueue_event(task) await event_queue.enqueue_event( TaskStatusUpdateEvent( task_id=context.task_id, @@ -277,7 +276,11 @@ async def cancel( event async for event in client.send_message(SendMessageRequest(message=msg)) ] - assert [event.status_update.status.state for event in events] == [ + task, status_update = events + assert task.HasField('task') + assert status_update.HasField('status_update') + + assert [get_state(event) for event in events] == [ TaskState.TASK_STATE_WORKING, TaskState.TASK_STATE_COMPLETED, ] @@ -291,13 +294,9 @@ class DummyAgentExecutor(AgentExecutor): async def execute( self, context: RequestContext, event_queue: EventQueue ): - await event_queue.enqueue_event( - TaskStatusUpdateEvent( - task_id=context.task_id, - context_id=context.context_id, - status=TaskStatus(state=TaskState.TASK_STATE_WORKING), - ) - ) + task = new_task_from_user_message(context.message) + task.status.state = TaskState.TASK_STATE_WORKING + await event_queue.enqueue_event(task) await event_queue.enqueue_event( TaskStatusUpdateEvent( task_id=context.task_id, @@ -350,13 +349,9 @@ class DummyAgentExecutor(AgentExecutor): async def execute( self, context: RequestContext, event_queue: EventQueue ): - await event_queue.enqueue_event( - TaskStatusUpdateEvent( - task_id=context.task_id, - context_id=context.context_id, - status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), - ) - ) + task = new_task_from_user_message(context.message) + task.status.state = TaskState.TASK_STATE_COMPLETED + await event_queue.enqueue_event(task) async def cancel( self, context: RequestContext, event_queue: EventQueue @@ -393,11 +388,9 @@ async def cancel( (event,) = [event async for event in it] if streaming: - assert event.HasField('status_update') - task_id = event.status_update.task_id - assert ( - event.status_update.status.state == TaskState.TASK_STATE_COMPLETED - ) + assert event.HasField('task') + task_id = event.task.id + validate_state(event, TaskState.TASK_STATE_COMPLETED) else: assert event.HasField('task') task_id = event.task.id @@ -485,13 +478,9 @@ class ErrorAfterAgent(AgentExecutor): async def execute( self, context: RequestContext, event_queue: EventQueue ): - await event_queue.enqueue_event( - TaskStatusUpdateEvent( - task_id=context.task_id, - context_id=context.context_id, - status=TaskStatus(state=TaskState.TASK_STATE_WORKING), - ) - ) + task = new_task_from_user_message(context.message) + task.status.state = TaskState.TASK_STATE_WORKING + await event_queue.enqueue_event(task) started_event.set() await continue_event.wait() raise ValueError('TEST_ERROR_IN_EXECUTE') @@ -515,7 +504,7 @@ async def cancel( if streaming: res = await it.__anext__() - assert res.status_update.status.state == TaskState.TASK_STATE_WORKING + validate_state(res, TaskState.TASK_STATE_WORKING) continue_event.set() else: @@ -554,13 +543,9 @@ class ErrorCancelAgent(AgentExecutor): async def execute( self, context: RequestContext, event_queue: EventQueue ): - await event_queue.enqueue_event( - TaskStatusUpdateEvent( - task_id=context.task_id, - context_id=context.context_id, - status=TaskStatus(state=TaskState.TASK_STATE_WORKING), - ) - ) + task = new_task_from_user_message(context.message) + task.status.state = TaskState.TASK_STATE_WORKING + await event_queue.enqueue_event(task) started_event.set() await hang_event.wait() @@ -614,13 +599,9 @@ class ErrorAfterAgent(AgentExecutor): async def execute( self, context: RequestContext, event_queue: EventQueue ): - await event_queue.enqueue_event( - TaskStatusUpdateEvent( - task_id=context.task_id, - context_id=context.context_id, - status=TaskStatus(state=TaskState.TASK_STATE_WORKING), - ) - ) + task = new_task_from_user_message(context.message) + task.status.state = TaskState.TASK_STATE_WORKING + await event_queue.enqueue_event(task) started_event.set() await continue_event.wait() raise ValueError('TEST_ERROR_IN_EXECUTE') @@ -744,13 +725,9 @@ class DummyCancelAgent(AgentExecutor): async def execute( self, context: RequestContext, event_queue: EventQueue ): - await event_queue.enqueue_event( - TaskStatusUpdateEvent( - task_id=context.task_id, - context_id=context.context_id, - status=TaskStatus(state=TaskState.TASK_STATE_WORKING), - ) - ) + task = new_task_from_user_message(context.message) + task.status.state = TaskState.TASK_STATE_WORKING + await event_queue.enqueue_event(task) started_event.set() await hang_event.wait() @@ -812,13 +789,9 @@ class ComplexAgent(AgentExecutor): async def execute( self, context: RequestContext, event_queue: EventQueue ): - await event_queue.enqueue_event( - TaskStatusUpdateEvent( - task_id=context.task_id, - context_id=context.context_id, - status=TaskStatus(state=TaskState.TASK_STATE_WORKING), - ) - ) + task = new_task_from_user_message(context.message) + task.status.state = TaskState.TASK_STATE_WORKING + await event_queue.enqueue_event(task) started_event.set() await working_event.wait() @@ -931,13 +904,9 @@ async def execute( ) return - await event_queue.enqueue_event( - TaskStatusUpdateEvent( - task_id=context.task_id, - context_id=context.context_id, - status=TaskStatus(state=TaskState.TASK_STATE_WORKING), - ) - ) + task = new_task_from_user_message(context.message) + task.status.state = TaskState.TASK_STATE_WORKING + await event_queue.enqueue_event(task) started_event.set() await continue_event.wait() await event_queue.enqueue_event( @@ -1059,13 +1028,9 @@ class ImmediateAgent(AgentExecutor): async def execute( self, context: RequestContext, event_queue: EventQueue ): - await event_queue.enqueue_event( - TaskStatusUpdateEvent( - task_id=context.task_id, - context_id=context.context_id, - status=TaskStatus(state=TaskState.TASK_STATE_WORKING), - ) - ) + task = new_task_from_user_message(context.message) + task.status.state = TaskState.TASK_STATE_WORKING + await event_queue.enqueue_event(task) await event_queue.enqueue_event( TaskStatusUpdateEvent( task_id=context.task_id, @@ -1120,27 +1085,17 @@ async def execute( ): message = context.message if message and message.parts and message.parts[0].text == 'start': - await event_queue.enqueue_event( - TaskStatusUpdateEvent( - task_id=context.task_id, - context_id=context.context_id, - status=TaskStatus( - state=TaskState.TASK_STATE_INPUT_REQUIRED - ), - ) - ) + task = new_task_from_user_message(message) + task.status.state = TaskState.TASK_STATE_INPUT_REQUIRED + await event_queue.enqueue_event(task) elif ( message and message.parts and message.parts[0].text == 'here is input' ): - await event_queue.enqueue_event( - TaskStatusUpdateEvent( - task_id=context.task_id, - context_id=context.context_id, - status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), - ) - ) + task = new_task_from_user_message(message) + task.status.state = TaskState.TASK_STATE_COMPLETED + await event_queue.enqueue_event(task) else: raise ValueError('Unexpected message') @@ -1209,13 +1164,9 @@ class AuthAgent(AgentExecutor): async def execute( self, context: RequestContext, event_queue: EventQueue ): - await event_queue.enqueue_event( - TaskStatusUpdateEvent( - task_id=context.task_id, - context_id=context.context_id, - status=TaskStatus(state=TaskState.TASK_STATE_WORKING), - ) - ) + task = new_task_from_user_message(context.message) + task.status.state = TaskState.TASK_STATE_WORKING + await event_queue.enqueue_event(task) await event_queue.enqueue_event( TaskStatusUpdateEvent( task_id=context.task_id, @@ -1295,15 +1246,9 @@ async def execute( ): message = context.message if message and message.parts and message.parts[0].text == 'start': - await event_queue.enqueue_event( - TaskStatusUpdateEvent( - task_id=context.task_id, - context_id=context.context_id, - status=TaskStatus( - state=TaskState.TASK_STATE_AUTH_REQUIRED - ), - ) - ) + task = new_task_from_user_message(message) + task.status.state = TaskState.TASK_STATE_AUTH_REQUIRED + await event_queue.enqueue_event(task) elif ( message and message.parts @@ -1316,6 +1261,7 @@ async def execute( status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), ) ) + else: raise ValueError(f'Unexpected message {message}') @@ -1380,13 +1326,9 @@ class EmitAgent(AgentExecutor): async def execute( self, context: RequestContext, event_queue: EventQueue ): - await event_queue.enqueue_event( - TaskStatusUpdateEvent( - task_id=context.task_id, - context_id=context.context_id, - status=TaskStatus(state=TaskState.TASK_STATE_WORKING), - ) - ) + task = new_task_from_user_message(context.message) + task.status.state = TaskState.TASK_STATE_WORKING + await event_queue.enqueue_event(task) phases = [ ('trigger_phase_1', 'artifact_1'), @@ -1602,6 +1544,9 @@ class ArtifactAgent(AgentExecutor): async def execute( self, context: RequestContext, event_queue: EventQueue ): + task = new_task_from_user_message(context.message) + task.status.state = TaskState.TASK_STATE_WORKING + await event_queue.enqueue_event(task) await event_queue.enqueue_event( TaskArtifactUpdateEvent( task_id=context.task_id, @@ -1724,7 +1669,7 @@ async def cancel( configuration=SendMessageConfiguration(return_immediately=False), ) ) - events = [event async for event in it] + _ = [event async for event in it] (final_task,) = (await client.list_tasks(ListTasksRequest())).tasks @@ -1744,4 +1689,440 @@ async def cancel( if record.levelname == 'ERROR' and 'Ignoring task replacement' in record.message ] + assert len(error_logs) == 1 + + +# Scenario: Task restoration - terminal state +@pytest.mark.timeout(2.0) +@pytest.mark.asyncio +@pytest.mark.parametrize('use_legacy', [False, True], ids=['v2', 'legacy']) +@pytest.mark.parametrize( + 'streaming', [False, True], ids=['blocking', 'streaming'] +) +@pytest.mark.parametrize( + 'subscribe_first', + [False, True], + ids=['no_subscribe_first', 'subscribe_first'], +) +async def test_restore_task_terminal_state( + use_legacy, streaming, subscribe_first +): + class TerminalAgent(AgentExecutor): + async def execute( + self, context: RequestContext, event_queue: EventQueue + ): + task = new_task_from_user_message(context.message) + task.status.state = TaskState.TASK_STATE_COMPLETED + await event_queue.enqueue_event(task) + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ): + pass + + task_store = InMemoryTaskStore() + handler1 = create_handler( + TerminalAgent(), use_legacy, task_store=task_store + ) + client1 = await create_client( + handler1, agent_card=agent_card(), streaming=streaming + ) + + msg = Message( + message_id='test-msg-1', role=Role.ROLE_USER, parts=[Part(text='start')] + ) + it1 = client1.send_message( + SendMessageRequest( + message=msg, + configuration=SendMessageConfiguration(return_immediately=False), + ) + ) + events1 = [event async for event in it1] + task_id = get_task_id(events1[-1]) + + await wait_for_state( + client1, task_id, expected_states={TaskState.TASK_STATE_COMPLETED} + ) + + # Restore task in a new handler (simulating server restart) + handler2 = create_handler( + TerminalAgent(), use_legacy, task_store=task_store + ) + client2 = await create_client( + handler2, agent_card=agent_card(), streaming=streaming + ) + + restored_task = await client2.get_task(GetTaskRequest(id=task_id)) + assert restored_task.status.state == TaskState.TASK_STATE_COMPLETED + + if subscribe_first and streaming: + with pytest.raises( + Exception, + match=r'terminal state', + ): + async for _ in client2.subscribe( + SubscribeToTaskRequest(id=task_id) + ): + pass + + msg2 = Message( + task_id=task_id, + message_id='test-msg-2', + role=Role.ROLE_USER, + parts=[Part(text='message to completed task')], + ) + + with pytest.raises(Exception, match=r'terminal state'): + async for _ in client2.send_message(SendMessageRequest(message=msg2)): + pass + + if streaming: + with pytest.raises( + Exception, + match=r'terminal state', + ): + async for _ in client2.subscribe( + SubscribeToTaskRequest(id=task_id) + ): + pass + + +# Scenario: Task restoration - user input required state +@pytest.mark.timeout(2.0) +@pytest.mark.asyncio +@pytest.mark.parametrize('use_legacy', [False, True], ids=['v2', 'legacy']) +@pytest.mark.parametrize( + 'streaming', [False, True], ids=['blocking', 'streaming'] +) +@pytest.mark.parametrize( + 'subscribe_mode', + ['none', 'drop', 'listen'], + ids=['no_sub', 'sub_drop', 'sub_listen'], +) +async def test_restore_task_input_required_state( + use_legacy, streaming, subscribe_mode +): + class InputAgent(AgentExecutor): + async def execute( + self, context: RequestContext, event_queue: EventQueue + ): + message = context.message + if message and message.parts and message.parts[0].text == 'start': + task = new_task_from_user_message(message) + task.status.state = TaskState.TASK_STATE_INPUT_REQUIRED + await event_queue.enqueue_event(task) + elif message and message.parts and message.parts[0].text == 'input': + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), + ) + ) + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ): + pass + + task_store = InMemoryTaskStore() + handler1 = create_handler(InputAgent(), use_legacy, task_store=task_store) + client1 = await create_client( + handler1, agent_card=agent_card(), streaming=streaming + ) + + msg1 = Message( + message_id='test-msg-1', role=Role.ROLE_USER, parts=[Part(text='start')] + ) + it1 = client1.send_message( + SendMessageRequest( + message=msg1, + configuration=SendMessageConfiguration(return_immediately=False), + ) + ) + events1 = [event async for event in it1] + + task_id = get_task_id(events1[-1]) + context_id = get_task_context_id(events1[-1]) + + await wait_for_state( + client1, task_id, expected_states={TaskState.TASK_STATE_INPUT_REQUIRED} + ) + + # Restore task in a new handler (simulating server restart) + handler2 = create_handler(InputAgent(), use_legacy, task_store=task_store) + client2 = await create_client( + handler2, agent_card=agent_card(), streaming=streaming + ) + + restored_task = await client2.get_task(GetTaskRequest(id=task_id)) + assert restored_task.status.state == TaskState.TASK_STATE_INPUT_REQUIRED + + # Subscription logic based on mode + listen_task = None + if streaming: + if subscribe_mode == 'drop': + # Subscribing and dropping immediately (cancelling the generator) + async for _ in client2.subscribe( + SubscribeToTaskRequest(id=task_id) + ): + break + elif subscribe_mode == 'listen': + sub_started_event = asyncio.Event() + + async def listen_to_end(): + res = [] + async for ev in client2.subscribe( + SubscribeToTaskRequest(id=task_id) + ): + res.append(ev) + sub_started_event.set() + return res + + listen_task = asyncio.create_task(listen_to_end()) + # Wait for subscription to establish and yield the initial task event + await asyncio.wait_for(sub_started_event.wait(), timeout=1.0) + + msg2 = Message( + task_id=task_id, + context_id=context_id, + message_id='test-msg-2', + role=Role.ROLE_USER, + parts=[Part(text='input')], + ) + + it2 = client2.send_message( + SendMessageRequest( + message=msg2, + configuration=SendMessageConfiguration(return_immediately=False), + ) + ) + events2 = [event async for event in it2] + + if streaming: + assert ( + events2[-1].status_update.status.state + == TaskState.TASK_STATE_COMPLETED + ) + else: + assert events2[-1].task.status.state == TaskState.TASK_STATE_COMPLETED + + if listen_task: + if use_legacy and streaming: + # Error: Legacy handler does not properly manage subscriptions for restored tasks + with pytest.raises(TaskNotFoundError): + await listen_task + else: + listen_events = await listen_task + # The first event is the initial task state (INPUT_REQUIRED), the last should be COMPLETED + assert ( + get_state(listen_events[-1]) == TaskState.TASK_STATE_COMPLETED + ) + + final_task = await client2.get_task(GetTaskRequest(id=task_id)) + assert final_task.status.state == TaskState.TASK_STATE_COMPLETED + + +# Scenario 20: Create initial task with new_task +@pytest.mark.timeout(2.0) +@pytest.mark.asyncio +@pytest.mark.parametrize('use_legacy', [False, True], ids=['v2', 'legacy']) +@pytest.mark.parametrize( + 'streaming', [False, True], ids=['blocking', 'streaming'] +) +@pytest.mark.parametrize('initial_task_type', ['new_task', 'status_update']) +async def test_scenario_initial_task_types( + use_legacy, streaming, initial_task_type +): + started_event = asyncio.Event() + continue_event = asyncio.Event() + + class InitialTaskAgent(AgentExecutor): + async def execute( + self, context: RequestContext, event_queue: EventQueue + ): + if initial_task_type == 'new_task': + # Create with new_task + task = new_task_from_user_message(context.message) + task.status.state = TaskState.TASK_STATE_WORKING + await event_queue.enqueue_event(task) + else: + # Create with status update (illegal in v2) + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + status=TaskStatus(state=TaskState.TASK_STATE_WORKING), + ) + ) + + started_event.set() + await continue_event.wait() + + await event_queue.enqueue_event( + TaskArtifactUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + artifact=Artifact( + artifact_id='art-1', parts=[Part(text='artifact data')] + ), + ) + ) + + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), + ) + ) + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ): + pass + + handler = create_handler(InitialTaskAgent(), use_legacy) + client = await create_client( + handler, agent_card=agent_card(), streaming=streaming + ) + + msg = Message( + message_id='test-msg', role=Role.ROLE_USER, parts=[Part(text='start')] + ) + + it = client.send_message( + SendMessageRequest( + message=msg, + configuration=SendMessageConfiguration( + return_immediately=streaming + ), + ) + ) + + if streaming: + if initial_task_type == 'status_update' and not use_legacy: + with pytest.raises( + InvalidAgentResponseError, + match='Agent should enqueue Task before TaskStatusUpdateEvent event', + ): + await it.__anext__() + + # End of the test. + return + + res = await it.__anext__() + if initial_task_type == 'status_update' and use_legacy: + # First message has to be a Task. + assert res.HasField('status_update') + + # End of the test. + return + + assert res.HasField('task') + task_id = get_task_id(res) + + await asyncio.wait_for(started_event.wait(), timeout=1.0) + + # Start subscription + sub = client.subscribe(SubscribeToTaskRequest(id=task_id)) + + # first subscriber receives current task state (WORKING) + first_event = await sub.__anext__() + assert first_event.HasField('task') + + continue_event.set() + + events = [first_event] + [event async for event in sub] + else: + # blocking + async def release_agent(): + await started_event.wait() + continue_event.set() + + release_task = asyncio.create_task(release_agent()) + if initial_task_type == 'status_update' and not use_legacy: + with pytest.raises( + InvalidAgentResponseError, + match='Agent should enqueue Task before TaskStatusUpdateEvent event', + ): + events = [event async for event in it] + # End of the test + return + else: + events = [event async for event in it] + await release_task + + if streaming: + task, artifact_update, status_update = events + assert task.HasField('task') + validate_state(task, TaskState.TASK_STATE_WORKING) + assert artifact_update.artifact_update.artifact.artifact_id == 'art-1' + assert status_update.HasField('status_update') + validate_state(status_update, TaskState.TASK_STATE_COMPLETED) + else: + (task,) = events + assert task.HasField('task') + validate_state(task, TaskState.TASK_STATE_COMPLETED) + (artifact,) = task.task.artifacts + assert artifact.artifact_id == 'art-1' + task_id = task.task.id + + (final_task_from_list,) = ( + await client.list_tasks(ListTasksRequest(include_artifacts=True)) + ).tasks + assert len(final_task_from_list.artifacts) > 0 + assert final_task_from_list.artifacts[0].artifact_id == 'art-1' + + final_task = await client.get_task(GetTaskRequest(id=task_id)) + assert final_task.status.state == TaskState.TASK_STATE_COMPLETED + assert len(final_task.artifacts) > 0 + assert final_task.artifacts[0].artifact_id == 'art-1' + + +# Scenario 23: Invalid Agent Response - Task followed by Message +@pytest.mark.asyncio +@pytest.mark.parametrize('use_legacy', [False, True], ids=['v2', 'legacy']) +@pytest.mark.parametrize( + 'streaming', [False, True], ids=['blocking', 'streaming'] +) +async def test_scenario_23_invalid_response_task_message(use_legacy, streaming): + class TaskMessageAgent(AgentExecutor): + async def execute( + self, context: RequestContext, event_queue: EventQueue + ): + await event_queue.enqueue_event( + new_task_from_user_message(context.message) + ) + await event_queue.enqueue_event( + Message(message_id='m1', parts=[Part(text='m1')]) + ) + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ): + pass + + handler = create_handler(TaskMessageAgent(), use_legacy) + client = await create_client( + handler, agent_card=agent_card(), streaming=streaming + ) + + msg = Message( + message_id='test-msg', role=Role.ROLE_USER, parts=[Part(text='start')] + ) + + it = client.send_message(SendMessageRequest(message=msg)) + + if use_legacy: + # Legacy: no error. + async for _ in it: + pass + else: + with pytest.raises( + InvalidAgentResponseError, + match='Received Message object in task mode', + ): + async for _ in it: + pass diff --git a/tests/server/agent_execution/test_active_task.py b/tests/server/agent_execution/test_active_task.py index 3a4a24ff6..6e477186b 100644 --- a/tests/server/agent_execution/test_active_task.py +++ b/tests/server/agent_execution/test_active_task.py @@ -19,6 +19,8 @@ TaskState, TaskStatus, TaskStatusUpdateEvent, + Role, + Part, ) from a2a.utils.errors import InvalidParamsError @@ -71,51 +73,6 @@ async def active_task( push_sender=push_sender, ) - @pytest.mark.asyncio - async def test_active_task_lifecycle( - self, - active_task: ActiveTask, - agent_executor: Mock, - request_context: Mock, - task_manager: Mock, - ) -> None: - """Test the basic lifecycle of an ActiveTask.""" - - async def execute_mock(req, q): - await q.enqueue_event(Message(message_id='m1')) - await q.enqueue_event( - Task( - id='test-task-id', - status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), - ) - ) - - agent_executor.execute = AsyncMock(side_effect=execute_mock) - task_manager.get_task.side_effect = [ - Task( - id='test-task-id', - status=TaskStatus(state=TaskState.TASK_STATE_WORKING), - ) - ] + [ - Task( - id='test-task-id', - status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), - ) - ] * 10 - - await active_task.enqueue_request(request_context) - await active_task.start( - call_context=ServerCallContext(), create_task_if_missing=True - ) - - # Wait for the task to finish - events = [e async for e in active_task.subscribe()] - result = next(e for e in events if isinstance(e, Message)) - - assert isinstance(result, Message) - assert result.message_id == 'm1' - assert active_task.task_id == 'test-task-id' - @pytest.mark.asyncio async def test_active_task_already_started( self, active_task: ActiveTask, request_context: Mock @@ -132,36 +89,6 @@ async def test_active_task_already_started( ) assert active_task._producer_task is not None - @pytest.mark.asyncio - async def test_active_task_subscribe( - self, - active_task: ActiveTask, - agent_executor: Mock, - request_context: Mock, - ) -> None: - """Test subscribing to events from an ActiveTask.""" - - async def execute_mock(req, q): - await q.enqueue_event(Message(message_id='m1')) - await q.enqueue_event(Message(message_id='m2')) - - agent_executor.execute = AsyncMock(side_effect=execute_mock) - - await active_task.enqueue_request(request_context) - await active_task.start( - call_context=ServerCallContext(), create_task_if_missing=True - ) - - events = [] - async for event in active_task.subscribe(): - events.append(event) - if len(events) == 2: - break - - assert len(events) == 2 - assert events[0].message_id == 'm1' - assert events[1].message_id == 'm2' - @pytest.mark.asyncio async def test_active_task_cancel( self, @@ -355,59 +282,6 @@ async def execute_mock(req, q): push_sender.send_notification.assert_called() - @pytest.mark.asyncio - async def test_active_task_cleanup( - self, - agent_executor: Mock, - task_manager: Mock, - request_context: Mock, - ) -> None: - """Test that the cleanup callback is called.""" - on_cleanup = Mock() - active_task = ActiveTask( - agent_executor=agent_executor, - task_id='test-task-id', - task_manager=task_manager, - on_cleanup=on_cleanup, - ) - - async def execute_mock(req, q): - await q.enqueue_event(Message(message_id='m1')) - await q.enqueue_event( - Task( - id='test-task-id', - status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), - ) - ) - - agent_executor.execute = AsyncMock(side_effect=execute_mock) - task_manager.get_task.side_effect = [ - Task( - id='test-task-id', - status=TaskStatus(state=TaskState.TASK_STATE_WORKING), - ) - ] + [ - Task( - id='test-task-id', - status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), - ) - ] * 10 - - await active_task.start( - call_context=ServerCallContext(), create_task_if_missing=True - ) - - async for _ in active_task.subscribe(request=request_context): - pass - - # Wait for consumer thread to finish and call cleanup - for _ in range(20): - if on_cleanup.called: - break - await asyncio.sleep(0.05) - - on_cleanup.assert_called_once_with(active_task) - @pytest.mark.asyncio async def test_active_task_consumer_failure( self, @@ -894,76 +768,6 @@ async def test_active_task_maybe_cleanup_not_finished( await active_task._maybe_cleanup() on_cleanup.assert_not_called() - @pytest.mark.asyncio - async def test_active_task_maybe_cleanup_with_subscribers( - self, - agent_executor: Mock, - task_manager: Mock, - push_sender: Mock, - request_context: Mock, - ) -> None: - """Test that cleanup is not called if there are subscribers.""" - on_cleanup = Mock() - active_task = ActiveTask( - agent_executor=agent_executor, - task_id='test-task-id', - task_manager=task_manager, - push_sender=push_sender, - on_cleanup=on_cleanup, - ) - - # Mock execute to finish immediately - async def execute_mock(req, q): - await q.enqueue_event(Message(message_id='m1')) - await q.enqueue_event( - Task( - id='test-task-id', - status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), - ) - ) - - agent_executor.execute = AsyncMock(side_effect=execute_mock) - task_manager.get_task.side_effect = [ - Task( - id='test-task-id', - status=TaskStatus(state=TaskState.TASK_STATE_WORKING), - ) - ] + [ - Task( - id='test-task-id', - status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), - ) - ] * 10 - - # 1. Start a subscriber before task finishes - gen = active_task.subscribe() - # Start the generator to increment reference count - msg_task = asyncio.create_task(gen.__anext__()) - - # 2. Start the task and wait for it to finish - await active_task.start( - call_context=ServerCallContext(), create_task_if_missing=True - ) - - async for _ in active_task.subscribe(request=request_context): - pass - - # Give the consumer loop a moment to set _is_finished - await asyncio.sleep(0.1) - - # Ensure we got the message - assert (await msg_task).message_id == 'm1' - - # At this point, task is finished, but we still have a subscriber (gen). - # _maybe_cleanup was called by consumer loop, but should have done nothing. - on_cleanup.assert_not_called() - - # 3. Close the subscriber - await gen.aclose() - - # Now cleanup should be triggered - on_cleanup.assert_called_once_with(active_task) - @pytest.mark.asyncio async def test_active_task_subscribe_exception_already_set( self, active_task: ActiveTask diff --git a/tests/server/request_handlers/test_default_request_handler_v2.py b/tests/server/request_handlers/test_default_request_handler_v2.py index 3e1568b2e..fda1ab960 100644 --- a/tests/server/request_handlers/test_default_request_handler_v2.py +++ b/tests/server/request_handlers/test_default_request_handler_v2.py @@ -53,6 +53,7 @@ TaskPushNotificationConfig, TaskState, TaskStatus, + TaskStatusUpdateEvent, ) from a2a.helpers.proto_helpers import ( new_text_message, @@ -71,11 +72,17 @@ def create_default_agent_card(): class MockAgentExecutor(AgentExecutor): async def execute(self, context: RequestContext, event_queue: EventQueue): + if context.message: + await event_queue.enqueue_event( + new_task_from_user_message(context.message) + ) + task_updater = TaskUpdater( event_queue, str(context.task_id or ''), str(context.context_id or ''), ) + async for i in self._run(): parts = [Part(text=f'Event {i}')] try: @@ -572,8 +579,15 @@ async def consume_stream(): elapsed = time.perf_counter() - start assert len(events) == 3 assert elapsed < 0.5 - texts = [p.text for e in events for p in e.status.message.parts] - assert texts == ['Event 0', 'Event 1', 'Event 2'] + task, event0, event1 = events + assert isinstance(task, Task) + assert task.history[0].parts[0].text == 'How are you?' + + assert isinstance(event0, TaskStatusUpdateEvent) + assert event0.status.message.parts[0].text == 'Event 0' + + assert isinstance(event1, TaskStatusUpdateEvent) + assert event1.status.message.parts[0].text == 'Event 1' @pytest.mark.asyncio @@ -954,6 +968,10 @@ class HelloWorldAgentExecutor(AgentExecutor): async def execute( self, context: RequestContext, event_queue: EventQueue ) -> None: + if context.message: + await event_queue.enqueue_event( + new_task_from_user_message(context.message) + ) updater = TaskUpdater( event_queue, task_id=context.task_id or str(uuid.uuid4()), From c87e87c76c004c73c9d6b9bd8cacfd4e590598e6 Mon Sep 17 00:00:00 2001 From: Guglielmo Colombo Date: Fri, 17 Apr 2026 15:42:24 +0200 Subject: [PATCH 38/67] refactor!: clean up of folder structure (#983) # Description Refactors internal helpers modules so that the helpers name is used exclusively for the a2a.helpers package (customer-facing convenience functions). - Move a2a.client.helpers into a2a.client.card_resolver -- parse_agent_card and its backward-compat shims are implementation details of card resolution - Rename a2a.utils.helpers to a2a.utils.version_validator to reflect its actual content --- scripts/test_minimal_install.py | 3 +- src/a2a/client/card_resolver.py | 108 ++- src/a2a/client/helpers.py | 112 --- src/a2a/compat/v0_3/jsonrpc_adapter.py | 2 +- src/a2a/compat/v0_3/rest_handler.py | 4 +- .../default_request_handler.py | 7 +- .../default_request_handler_v2.py | 7 +- src/a2a/server/routes/agent_card_routes.py | 6 +- src/a2a/server/routes/jsonrpc_dispatcher.py | 2 +- src/a2a/server/routes/rest_dispatcher.py | 2 +- .../{helpers.py => version_validator.py} | 10 +- tests/client/test_card_resolver.py | 701 +++++++++++++++++- tests/client/test_client_helpers.py | 696 ----------------- .../test_cross_version_card_validation.py | 2 +- .../test_client_server_integration.py | 11 +- tests/server/tasks/test_task_manager.py | 97 +++ tests/server/test_integration.py | 4 +- tests/utils/test_helpers.py | 312 -------- tests/utils/test_signing.py | 108 +++ ...lidation.py => test_version_validation.py} | 2 +- 20 files changed, 1038 insertions(+), 1158 deletions(-) delete mode 100644 src/a2a/client/helpers.py rename src/a2a/utils/{helpers.py => version_validator.py} (94%) delete mode 100644 tests/client/test_client_helpers.py delete mode 100644 tests/utils/test_helpers.py rename tests/utils/{test_helpers_validation.py => test_version_validation.py} (98%) diff --git a/scripts/test_minimal_install.py b/scripts/test_minimal_install.py index 0b29a48b6..84e3ee3fc 100755 --- a/scripts/test_minimal_install.py +++ b/scripts/test_minimal_install.py @@ -38,7 +38,6 @@ 'a2a.client.client', 'a2a.client.client_factory', 'a2a.client.errors', - 'a2a.client.helpers', 'a2a.client.interceptors', 'a2a.client.optionals', 'a2a.client.transports', @@ -52,7 +51,7 @@ 'a2a.utils', 'a2a.utils.constants', 'a2a.utils.error_handlers', - 'a2a.utils.helpers', + 'a2a.utils.version_validator', 'a2a.utils.proto_utils', 'a2a.utils.task', 'a2a.helpers.agent_card', diff --git a/src/a2a/client/card_resolver.py b/src/a2a/client/card_resolver.py index 6d98a5361..815916014 100644 --- a/src/a2a/client/card_resolver.py +++ b/src/a2a/client/card_resolver.py @@ -6,10 +6,9 @@ import httpx -from google.protobuf.json_format import ParseError +from google.protobuf.json_format import ParseDict, ParseError from a2a.client.errors import AgentCardResolutionError -from a2a.client.helpers import parse_agent_card from a2a.types.a2a_pb2 import ( AgentCard, ) @@ -19,6 +18,111 @@ logger = logging.getLogger(__name__) +def parse_agent_card(agent_card_data: dict[str, Any]) -> AgentCard: + """Parse AgentCard JSON dictionary and handle backward compatibility.""" + _handle_extended_card_compatibility(agent_card_data) + _handle_connection_fields_compatibility(agent_card_data) + _handle_security_compatibility(agent_card_data) + + return ParseDict(agent_card_data, AgentCard(), ignore_unknown_fields=True) + + +def _handle_extended_card_compatibility( + agent_card_data: dict[str, Any], +) -> None: + """Map legacy supportsAuthenticatedExtendedCard to capabilities.""" + if agent_card_data.pop('supportsAuthenticatedExtendedCard', None): + capabilities = agent_card_data.setdefault('capabilities', {}) + if 'extendedAgentCard' not in capabilities: + capabilities['extendedAgentCard'] = True + + +def _handle_connection_fields_compatibility( + agent_card_data: dict[str, Any], +) -> None: + """Map legacy connection and transport fields to supportedInterfaces.""" + main_url = agent_card_data.pop('url', None) + main_transport = agent_card_data.pop('preferredTransport', 'JSONRPC') + version = agent_card_data.pop('protocolVersion', '0.3.0') + additional_interfaces = ( + agent_card_data.pop('additionalInterfaces', None) or [] + ) + + if 'supportedInterfaces' not in agent_card_data and main_url: + supported_interfaces = [] + supported_interfaces.append( + { + 'url': main_url, + 'protocolBinding': main_transport, + 'protocolVersion': version, + } + ) + supported_interfaces.extend( + { + 'url': iface.get('url'), + 'protocolBinding': iface.get('transport'), + 'protocolVersion': version, + } + for iface in additional_interfaces + ) + agent_card_data['supportedInterfaces'] = supported_interfaces + + +def _map_legacy_security( + sec_list: list[dict[str, list[str]]], +) -> list[dict[str, Any]]: + """Convert a legacy security requirement list into the 1.0.0 Protobuf format.""" + return [ + { + 'schemes': { + scheme_name: {'list': scopes} + for scheme_name, scopes in sec_dict.items() + } + } + for sec_dict in sec_list + ] + + +def _handle_security_compatibility(agent_card_data: dict[str, Any]) -> None: + """Map legacy security requirements and schemas to their 1.0.0 Protobuf equivalents.""" + legacy_security = agent_card_data.pop('security', None) + if ( + 'securityRequirements' not in agent_card_data + and legacy_security is not None + ): + agent_card_data['securityRequirements'] = _map_legacy_security( + legacy_security + ) + + for skill in agent_card_data.get('skills', []): + legacy_skill_sec = skill.pop('security', None) + if 'securityRequirements' not in skill and legacy_skill_sec is not None: + skill['securityRequirements'] = _map_legacy_security( + legacy_skill_sec + ) + + security_schemes = agent_card_data.get('securitySchemes', {}) + if security_schemes: + type_mapping = { + 'apiKey': 'apiKeySecurityScheme', + 'http': 'httpAuthSecurityScheme', + 'oauth2': 'oauth2SecurityScheme', + 'openIdConnect': 'openIdConnectSecurityScheme', + 'mutualTLS': 'mtlsSecurityScheme', + } + for scheme in security_schemes.values(): + scheme_type = scheme.pop('type', None) + if scheme_type in type_mapping: + # Map legacy 'in' to modern 'location' + if scheme_type == 'apiKey' and 'in' in scheme: + scheme['location'] = scheme.pop('in') + + mapped_name = type_mapping[scheme_type] + new_scheme_wrapper = {mapped_name: scheme.copy()} + scheme.clear() + scheme.update(new_scheme_wrapper) + + class A2ACardResolver: """Agent Card resolver.""" diff --git a/src/a2a/client/helpers.py b/src/a2a/client/helpers.py deleted file mode 100644 index f8207f03b..000000000 --- a/src/a2a/client/helpers.py +++ /dev/null @@ -1,112 +0,0 @@ -"""Helper functions for the A2A client.""" - -from typing import Any - -from google.protobuf.json_format import ParseDict - -from a2a.types.a2a_pb2 import AgentCard - - -def parse_agent_card(agent_card_data: dict[str, Any]) -> AgentCard: - """Parse AgentCard JSON dictionary and handle backward compatibility.""" - _handle_extended_card_compatibility(agent_card_data) - _handle_connection_fields_compatibility(agent_card_data) - _handle_security_compatibility(agent_card_data) - - return ParseDict(agent_card_data, AgentCard(), ignore_unknown_fields=True) - - -def _handle_extended_card_compatibility( - agent_card_data: dict[str, Any], -) -> None: - """Map legacy supportsAuthenticatedExtendedCard to capabilities.""" - if agent_card_data.pop('supportsAuthenticatedExtendedCard', None): - capabilities = agent_card_data.setdefault('capabilities', {}) - if 'extendedAgentCard' not in capabilities: - capabilities['extendedAgentCard'] = True - - -def _handle_connection_fields_compatibility( - agent_card_data: dict[str, Any], -) -> None: - """Map legacy connection and transport fields to supportedInterfaces.""" - main_url = agent_card_data.pop('url', None) - main_transport = agent_card_data.pop('preferredTransport', 'JSONRPC') - version = agent_card_data.pop('protocolVersion', '0.3.0') - additional_interfaces = ( - agent_card_data.pop('additionalInterfaces', None) or [] - ) - - if 'supportedInterfaces' not in agent_card_data and main_url: - supported_interfaces = [] - supported_interfaces.append( - { - 'url': main_url, - 'protocolBinding': main_transport, - 'protocolVersion': version, - } - ) - supported_interfaces.extend( - { - 'url': iface.get('url'), - 'protocolBinding': iface.get('transport'), - 'protocolVersion': version, - } - for iface in additional_interfaces - ) - agent_card_data['supportedInterfaces'] = supported_interfaces - - -def _map_legacy_security( - sec_list: list[dict[str, list[str]]], -) -> list[dict[str, Any]]: - """Convert a legacy security requirement list into the 1.0.0 Protobuf format.""" - return [ - { - 'schemes': { - scheme_name: {'list': scopes} - for scheme_name, scopes in sec_dict.items() - } - } - for sec_dict in sec_list - ] - - -def _handle_security_compatibility(agent_card_data: dict[str, Any]) -> None: - """Map legacy security requirements and schemas to their 1.0.0 Protobuf equivalents.""" - legacy_security = agent_card_data.pop('security', None) - if ( - 'securityRequirements' not in agent_card_data - and legacy_security is not None - ): - agent_card_data['securityRequirements'] = _map_legacy_security( - legacy_security - ) - - for skill in agent_card_data.get('skills', []): - legacy_skill_sec = skill.pop('security', None) - if 'securityRequirements' not in skill and legacy_skill_sec is not None: - skill['securityRequirements'] = _map_legacy_security( - legacy_skill_sec - ) - - security_schemes = agent_card_data.get('securitySchemes', {}) - if security_schemes: - type_mapping = { - 'apiKey': 'apiKeySecurityScheme', - 'http': 'httpAuthSecurityScheme', - 'oauth2': 'oauth2SecurityScheme', - 'openIdConnect': 'openIdConnectSecurityScheme', - 'mutualTLS': 'mtlsSecurityScheme', - } - for scheme in security_schemes.values(): - scheme_type = scheme.pop('type', None) - if scheme_type in type_mapping: - # Map legacy 'in' to modern 'location' - if scheme_type == 'apiKey' and 'in' in scheme: - scheme['location'] = scheme.pop('in') - - mapped_name = type_mapping[scheme_type] - new_scheme_wrapper = {mapped_name: scheme.copy()} - scheme.clear() - scheme.update(new_scheme_wrapper) diff --git a/src/a2a/compat/v0_3/jsonrpc_adapter.py b/src/a2a/compat/v0_3/jsonrpc_adapter.py index 8b4b26a79..580034e9b 100644 --- a/src/a2a/compat/v0_3/jsonrpc_adapter.py +++ b/src/a2a/compat/v0_3/jsonrpc_adapter.py @@ -41,7 +41,7 @@ ServerCallContextBuilder, ) from a2a.utils import constants -from a2a.utils.helpers import validate_version +from a2a.utils.version_validator import validate_version logger = logging.getLogger(__name__) diff --git a/src/a2a/compat/v0_3/rest_handler.py b/src/a2a/compat/v0_3/rest_handler.py index 0c64506cb..bd5fcd2e6 100644 --- a/src/a2a/compat/v0_3/rest_handler.py +++ b/src/a2a/compat/v0_3/rest_handler.py @@ -28,10 +28,8 @@ from a2a.compat.v0_3.request_handler import RequestHandler03 from a2a.server.context import ServerCallContext from a2a.utils import constants -from a2a.utils.helpers import ( - validate_version, -) from a2a.utils.telemetry import SpanKind, trace_class +from a2a.utils.version_validator import validate_version logger = logging.getLogger(__name__) diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index fea5184d6..e803b567f 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -58,7 +58,6 @@ TaskNotFoundError, UnsupportedOperationError, ) -from a2a.utils.helpers import maybe_await from a2a.utils.task import ( apply_history_length, validate_history_length, @@ -100,7 +99,7 @@ def __init__( # noqa: PLR0913 request_context_builder: RequestContextBuilder | None = None, extended_agent_card: AgentCard | None = None, extended_card_modifier: Callable[ - [AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard + [AgentCard, ServerCallContext], Awaitable[AgentCard] ] | None = None, ) -> None: @@ -695,8 +694,8 @@ async def on_get_extended_agent_card( raise ExtendedAgentCardNotConfiguredError if self.extended_card_modifier: - return await maybe_await( - self.extended_card_modifier(extended_card, context) + extended_card = await self.extended_card_modifier( + extended_card, context ) return extended_card diff --git a/src/a2a/server/request_handlers/default_request_handler_v2.py b/src/a2a/server/request_handlers/default_request_handler_v2.py index 1a8464687..c0c6b5445 100644 --- a/src/a2a/server/request_handlers/default_request_handler_v2.py +++ b/src/a2a/server/request_handlers/default_request_handler_v2.py @@ -47,7 +47,6 @@ TaskNotCancelableError, TaskNotFoundError, ) -from a2a.utils.helpers import maybe_await from a2a.utils.task import ( apply_history_length, validate_history_length, @@ -93,7 +92,7 @@ def __init__( # noqa: PLR0913 request_context_builder: RequestContextBuilder | None = None, extended_agent_card: AgentCard | None = None, extended_card_modifier: Callable[ - [AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard + [AgentCard, ServerCallContext], Awaitable[AgentCard] ] | None = None, ) -> None: @@ -467,8 +466,8 @@ async def on_get_extended_agent_card( raise ExtendedAgentCardNotConfiguredError if self.extended_card_modifier: - return await maybe_await( - self.extended_card_modifier(extended_card, context) + extended_card = await self.extended_card_modifier( + extended_card, context ) return extended_card diff --git a/src/a2a/server/routes/agent_card_routes.py b/src/a2a/server/routes/agent_card_routes.py index 9b850ff4f..924a3d9dc 100644 --- a/src/a2a/server/routes/agent_card_routes.py +++ b/src/a2a/server/routes/agent_card_routes.py @@ -26,13 +26,11 @@ from a2a.server.request_handlers.response_helpers import agent_card_to_dict from a2a.types.a2a_pb2 import AgentCard from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH -from a2a.utils.helpers import maybe_await def create_agent_card_routes( agent_card: AgentCard, - card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] - | None = None, + card_modifier: Callable[[AgentCard], Awaitable[AgentCard]] | None = None, card_url: str = AGENT_CARD_WELL_KNOWN_PATH, ) -> list['Route']: """Creates the Starlette Route for the A2A protocol agent card endpoint.""" @@ -45,7 +43,7 @@ def create_agent_card_routes( async def _get_agent_card(request: Request) -> Response: card_to_serve = agent_card if card_modifier: - card_to_serve = await maybe_await(card_modifier(card_to_serve)) + card_to_serve = await card_modifier(card_to_serve) return JSONResponse(agent_card_to_dict(card_to_serve)) return [ diff --git a/src/a2a/server/routes/jsonrpc_dispatcher.py b/src/a2a/server/routes/jsonrpc_dispatcher.py index 3dc94488a..cb4e93bf1 100644 --- a/src/a2a/server/routes/jsonrpc_dispatcher.py +++ b/src/a2a/server/routes/jsonrpc_dispatcher.py @@ -49,8 +49,8 @@ TaskNotFoundError, UnsupportedOperationError, ) -from a2a.utils.helpers import validate_version from a2a.utils.telemetry import SpanKind, trace_class +from a2a.utils.version_validator import validate_version INTERNAL_ERROR_CODE = -32603 diff --git a/src/a2a/server/routes/rest_dispatcher.py b/src/a2a/server/routes/rest_dispatcher.py index 8af384893..adbdba96e 100644 --- a/src/a2a/server/routes/rest_dispatcher.py +++ b/src/a2a/server/routes/rest_dispatcher.py @@ -28,8 +28,8 @@ InvalidRequestError, TaskNotFoundError, ) -from a2a.utils.helpers import validate_version from a2a.utils.telemetry import SpanKind, trace_class +from a2a.utils.version_validator import validate_version if TYPE_CHECKING: diff --git a/src/a2a/utils/helpers.py b/src/a2a/utils/version_validator.py similarity index 94% rename from src/a2a/utils/helpers.py rename to src/a2a/utils/version_validator.py index 9a974a4c2..4a776c27e 100644 --- a/src/a2a/utils/helpers.py +++ b/src/a2a/utils/version_validator.py @@ -4,7 +4,7 @@ import inspect import logging -from collections.abc import AsyncIterator, Awaitable, Callable +from collections.abc import AsyncIterator, Callable from typing import Any, TypeVar, cast from packaging.version import InvalidVersion, Version @@ -14,20 +14,12 @@ from a2a.utils.errors import VersionNotSupportedError -T = TypeVar('T') F = TypeVar('F', bound=Callable[..., Any]) logger = logging.getLogger(__name__) -async def maybe_await(value: T | Awaitable[T]) -> T: - """Awaits a value if it's awaitable, otherwise simply provides it back.""" - if inspect.isawaitable(value): - return await value - return value - - def validate_version(expected_version: str) -> Callable[[F], F]: """Decorator that validates the A2A-Version header in the request context. diff --git a/tests/client/test_card_resolver.py b/tests/client/test_card_resolver.py index 9a684a4ac..ff60632ad 100644 --- a/tests/client/test_card_resolver.py +++ b/tests/client/test_card_resolver.py @@ -1,13 +1,35 @@ +import copy +import difflib import json import logging - from unittest.mock import AsyncMock, MagicMock, Mock +from google.protobuf.json_format import MessageToDict import httpx import pytest from a2a.client import A2ACardResolver, AgentCardResolutionError +from a2a.client.card_resolver import parse_agent_card +from a2a.server.request_handlers.response_helpers import agent_card_to_dict from a2a.types import AgentCard +from a2a.types.a2a_pb2 import ( + APIKeySecurityScheme, + AgentCapabilities, + AgentCardSignature, + AgentInterface, + AgentProvider, + AgentSkill, + AuthorizationCodeOAuthFlow, + HTTPAuthSecurityScheme, + MutualTlsSecurityScheme, + OAuth2SecurityScheme, + OAuthFlows, + OpenIdConnectSecurityScheme, + Role, + SecurityRequirement, + SecurityScheme, + StringList, +) from a2a.utils import AGENT_CARD_WELL_KNOWN_PATH @@ -388,3 +410,680 @@ async def test_get_agent_card_with_signature_verifier( ) mock_verifier.assert_called_once_with(agent_card) + + +class TestParseAgentCard: + """Tests for parse_agent_card function.""" + + @staticmethod + def _assert_agent_card_diff( + original_data: dict, serialized_data: dict + ) -> None: + """Helper to assert that the re-serialized 1.0.0 JSON payload contains all original 0.3.0 data (no dropped fields).""" + original_json_str = json.dumps(original_data, indent=2, sort_keys=True) + serialized_json_str = json.dumps( + serialized_data, indent=2, sort_keys=True + ) + + diff_lines = list( + difflib.unified_diff( + original_json_str.splitlines(), + serialized_json_str.splitlines(), + lineterm='', + ) + ) + + removed_lines = [] + for line in diff_lines: + if line.startswith('-') and not line.startswith('---'): + removed_lines.append(line) + + if removed_lines: + error_msg = ( + 'Re-serialization dropped fields from the original payload:\n' + + '\n'.join(removed_lines) + ) + raise AssertionError(error_msg) + + def test_parse_agent_card_legacy_support(self) -> None: + data = { + 'name': 'Legacy Agent', + 'description': 'Legacy Description', + 'version': '1.0', + 'supportsAuthenticatedExtendedCard': True, + } + card = parse_agent_card(data) + assert card.name == 'Legacy Agent' + assert card.capabilities.extended_agent_card is True + # Ensure it's popped from the dict + assert 'supportsAuthenticatedExtendedCard' not in data + + def test_parse_agent_card_new_support(self) -> None: + data = { + 'name': 'New Agent', + 'description': 'New Description', + 'version': '1.0', + 'capabilities': {'extendedAgentCard': True}, + } + card = parse_agent_card(data) + assert card.name == 'New Agent' + assert card.capabilities.extended_agent_card is True + + def test_parse_agent_card_no_support(self) -> None: + data = { + 'name': 'No Support Agent', + 'description': 'No Support Description', + 'version': '1.0', + 'capabilities': {'extendedAgentCard': False}, + } + card = parse_agent_card(data) + assert card.name == 'No Support Agent' + assert card.capabilities.extended_agent_card is False + + def test_parse_agent_card_both_legacy_and_new(self) -> None: + data = { + 'name': 'Mixed Agent', + 'description': 'Mixed Description', + 'version': '1.0', + 'supportsAuthenticatedExtendedCard': True, + 'capabilities': {'streaming': True}, + } + card = parse_agent_card(data) + assert card.name == 'Mixed Agent' + assert card.capabilities.streaming is True + assert card.capabilities.extended_agent_card is True + + def test_parse_typical_030_agent_card(self) -> None: + data = { + 'additionalInterfaces': [ + { + 'transport': 'GRPC', + 'url': 'http://agent.example.com/api/grpc', + } + ], + 'capabilities': {'streaming': True}, + 'defaultInputModes': ['text/plain'], + 'defaultOutputModes': ['application/json'], + 'description': 'A typical agent from 0.3.0', + 'name': 'Typical Agent 0.3', + 'preferredTransport': 'JSONRPC', + 'protocolVersion': '0.3.0', + 'security': [{'test_oauth': ['read', 'write']}], + 'securitySchemes': { + 'test_oauth': { + 'description': 'OAuth2 authentication', + 'flows': { + 'authorizationCode': { + 'authorizationUrl': 'http://auth.example.com', + 'scopes': { + 'read': 'Read access', + 'write': 'Write access', + }, + 'tokenUrl': 'http://token.example.com', + } + }, + 'type': 'oauth2', + } + }, + 'skills': [ + { + 'description': 'The first skill', + 'id': 'skill-1', + 'name': 'Skill 1', + 'security': [{'test_oauth': ['read']}], + 'tags': ['example'], + } + ], + 'supportsAuthenticatedExtendedCard': True, + 'url': 'http://agent.example.com/api', + 'version': '1.0', + } + original_data = copy.deepcopy(data) + card = parse_agent_card(data) + + expected_card = AgentCard( + name='Typical Agent 0.3', + description='A typical agent from 0.3.0', + version='1.0', + capabilities=AgentCapabilities( + extended_agent_card=True, streaming=True + ), + default_input_modes=['text/plain'], + default_output_modes=['application/json'], + supported_interfaces=[ + AgentInterface( + url='http://agent.example.com/api', + protocol_binding='JSONRPC', + protocol_version='0.3.0', + ), + AgentInterface( + url='http://agent.example.com/api/grpc', + protocol_binding='GRPC', + protocol_version='0.3.0', + ), + ], + security_requirements=[ + SecurityRequirement( + schemes={'test_oauth': StringList(list=['read', 'write'])} + ) + ], + security_schemes={ + 'test_oauth': SecurityScheme( + oauth2_security_scheme=OAuth2SecurityScheme( + description='OAuth2 authentication', + flows=OAuthFlows( + authorization_code=AuthorizationCodeOAuthFlow( + authorization_url='http://auth.example.com', + token_url='http://token.example.com', + scopes={ + 'read': 'Read access', + 'write': 'Write access', + }, + ) + ), + ) + ) + }, + skills=[ + AgentSkill( + id='skill-1', + name='Skill 1', + description='The first skill', + tags=['example'], + security_requirements=[ + SecurityRequirement( + schemes={'test_oauth': StringList(list=['read'])} + ) + ], + ) + ], + ) + + assert card == expected_card + + # Serialize back to JSON and compare + serialized_data = agent_card_to_dict(card) + + self._assert_agent_card_diff(original_data, serialized_data) + assert 'preferredTransport' in serialized_data + + # Re-parse from the serialized payload and verify identical to original parsing + re_parsed_card = parse_agent_card(copy.deepcopy(serialized_data)) + assert re_parsed_card == card + + def test_parse_agent_card_security_scheme_without_in(self) -> None: + data = { + 'name': 'API Key Agent', + 'description': 'API Key without in param', + 'version': '1.0', + 'securitySchemes': { + 'test_api_key': {'type': 'apiKey', 'name': 'X-API-KEY'} + }, + } + card = parse_agent_card(data) + assert 'test_api_key' in card.security_schemes + assert ( + card.security_schemes['test_api_key'].api_key_security_scheme.name + == 'X-API-KEY' + ) + assert ( + card.security_schemes[ + 'test_api_key' + ].api_key_security_scheme.location + == '' + ) + + def test_parse_agent_card_security_scheme_unknown_type(self) -> None: + data = { + 'name': 'Unknown Scheme Agent', + 'description': 'Has unknown scheme type', + 'version': '1.0', + 'securitySchemes': { + 'test_unknown': { + 'type': 'someFutureType', + 'future_prop': 'value', + }, + 'test_missing_type': {'prop': 'value'}, + }, + } + card = parse_agent_card(data) + assert 'test_unknown' in card.security_schemes + assert not card.security_schemes['test_unknown'].WhichOneof('scheme') + + assert 'test_missing_type' in card.security_schemes + assert not card.security_schemes['test_missing_type'].WhichOneof( + 'scheme' + ) + + def test_parse_030_agent_card_route_planner(self) -> None: + data = { + 'protocolVersion': '0.3', + 'name': 'GeoSpatial Route Planner Agent', + 'description': 'Provides advanced route planning.', + 'url': 'https://georoute-agent.example.com/a2a/v1', + 'preferredTransport': 'JSONRPC', + 'additionalInterfaces': [ + { + 'url': 'https://georoute-agent.example.com/a2a/v1', + 'transport': 'JSONRPC', + }, + { + 'url': 'https://georoute-agent.example.com/a2a/grpc', + 'transport': 'GRPC', + }, + { + 'url': 'https://georoute-agent.example.com/a2a/json', + 'transport': 'HTTP+JSON', + }, + ], + 'provider': { + 'organization': 'Example Geo Services Inc.', + 'url': 'https://www.examplegeoservices.com', + }, + 'iconUrl': 'https://georoute-agent.example.com/icon.png', + 'version': '1.2.0', + 'documentationUrl': 'https://docs.examplegeoservices.com/georoute-agent/api', + 'supportsAuthenticatedExtendedCard': True, + 'capabilities': { + 'streaming': True, + 'pushNotifications': True, + 'stateTransitionHistory': False, + }, + 'securitySchemes': { + 'google': { + 'type': 'openIdConnect', + 'openIdConnectUrl': 'https://accounts.google.com/.well-known/openid-configuration', + } + }, + 'security': [{'google': ['openid', 'profile', 'email']}], + 'defaultInputModes': ['application/json', 'text/plain'], + 'defaultOutputModes': ['application/json', 'image/png'], + 'skills': [ + { + 'id': 'route-optimizer-traffic', + 'name': 'Traffic-Aware Route Optimizer', + 'description': 'Calculates the optimal driving route between two or more locations, taking into account real-time traffic conditions, road closures, and user preferences (e.g., avoid tolls, prefer highways).', + 'tags': [ + 'maps', + 'routing', + 'navigation', + 'directions', + 'traffic', + ], + 'examples': [ + "Plan a route from '1600 Amphitheatre Parkway, Mountain View, CA' to 'San Francisco International Airport' avoiding tolls.", + '{"origin": {"lat": 37.422, "lng": -122.084}, "destination": {"lat": 37.7749, "lng": -122.4194}, "preferences": ["avoid_ferries"]}', + ], + 'inputModes': ['application/json', 'text/plain'], + 'outputModes': [ + 'application/json', + 'application/vnd.geo+json', + 'text/html', + ], + 'security': [ + {'example': []}, + {'google': ['openid', 'profile', 'email']}, + ], + }, + { + 'id': 'custom-map-generator', + 'name': 'Personalized Map Generator', + 'description': 'Creates custom map images or interactive map views based on user-defined points of interest, routes, and style preferences. Can overlay data layers.', + 'tags': [ + 'maps', + 'customization', + 'visualization', + 'cartography', + ], + 'examples': [ + 'Generate a map of my upcoming road trip with all planned stops highlighted.', + 'Show me a map visualizing all coffee shops within a 1-mile radius of my current location.', + ], + 'inputModes': ['application/json'], + 'outputModes': [ + 'image/png', + 'image/jpeg', + 'application/json', + 'text/html', + ], + }, + ], + 'signatures': [ + { + 'protected': 'eyJhbGciOiJFUzI1NiIsInR5cCI6IkpPU0UiLCJraWQiOiJrZXktMSIsImprdSI6Imh0dHBzOi8vZXhhbXBsZS5jb20vYWdlbnQvandrcy5qc29uIn0', + 'signature': 'QFdkNLNszlGj3z3u0YQGt_T9LixY3qtdQpZmsTdDHDe3fXV9y9-B3m2-XgCpzuhiLt8E0tV6HXoZKHv4GtHgKQ', + } + ], + } + + original_data = copy.deepcopy(data) + card = parse_agent_card(data) + + expected_card = AgentCard( + name='GeoSpatial Route Planner Agent', + description='Provides advanced route planning.', + version='1.2.0', + documentation_url='https://docs.examplegeoservices.com/georoute-agent/api', + icon_url='https://georoute-agent.example.com/icon.png', + provider=AgentProvider( + organization='Example Geo Services Inc.', + url='https://www.examplegeoservices.com', + ), + capabilities=AgentCapabilities( + extended_agent_card=True, + streaming=True, + push_notifications=True, + ), + default_input_modes=['application/json', 'text/plain'], + default_output_modes=['application/json', 'image/png'], + supported_interfaces=[ + AgentInterface( + url='https://georoute-agent.example.com/a2a/v1', + protocol_binding='JSONRPC', + protocol_version='0.3', + ), + AgentInterface( + url='https://georoute-agent.example.com/a2a/v1', + protocol_binding='JSONRPC', + protocol_version='0.3', + ), + AgentInterface( + url='https://georoute-agent.example.com/a2a/grpc', + protocol_binding='GRPC', + protocol_version='0.3', + ), + AgentInterface( + url='https://georoute-agent.example.com/a2a/json', + protocol_binding='HTTP+JSON', + protocol_version='0.3', + ), + ], + security_requirements=[ + SecurityRequirement( + schemes={ + 'google': StringList( + list=['openid', 'profile', 'email'] + ) + } + ) + ], + security_schemes={ + 'google': SecurityScheme( + open_id_connect_security_scheme=OpenIdConnectSecurityScheme( + open_id_connect_url='https://accounts.google.com/.well-known/openid-configuration' + ) + ) + }, + skills=[ + AgentSkill( + id='route-optimizer-traffic', + name='Traffic-Aware Route Optimizer', + description='Calculates the optimal driving route between two or more locations, taking into account real-time traffic conditions, road closures, and user preferences (e.g., avoid tolls, prefer highways).', + tags=[ + 'maps', + 'routing', + 'navigation', + 'directions', + 'traffic', + ], + examples=[ + "Plan a route from '1600 Amphitheatre Parkway, Mountain View, CA' to 'San Francisco International Airport' avoiding tolls.", + '{"origin": {"lat": 37.422, "lng": -122.084}, "destination": {"lat": 37.7749, "lng": -122.4194}, "preferences": ["avoid_ferries"]}', + ], + input_modes=['application/json', 'text/plain'], + output_modes=[ + 'application/json', + 'application/vnd.geo+json', + 'text/html', + ], + security_requirements=[ + SecurityRequirement(schemes={'example': StringList()}), + SecurityRequirement( + schemes={ + 'google': StringList( + list=['openid', 'profile', 'email'] + ) + } + ), + ], + ), + AgentSkill( + id='custom-map-generator', + name='Personalized Map Generator', + description='Creates custom map images or interactive map views based on user-defined points of interest, routes, and style preferences. Can overlay data layers.', + tags=[ + 'maps', + 'customization', + 'visualization', + 'cartography', + ], + examples=[ + 'Generate a map of my upcoming road trip with all planned stops highlighted.', + 'Show me a map visualizing all coffee shops within a 1-mile radius of my current location.', + ], + input_modes=['application/json'], + output_modes=[ + 'image/png', + 'image/jpeg', + 'application/json', + 'text/html', + ], + ), + ], + signatures=[ + AgentCardSignature( + protected='eyJhbGciOiJFUzI1NiIsInR5cCI6IkpPU0UiLCJraWQiOiJrZXktMSIsImprdSI6Imh0dHBzOi8vZXhhbXBsZS5jb20vYWdlbnQvandrcy5qc29uIn0', + signature='QFdkNLNszlGj3z3u0YQGt_T9LixY3qtdQpZmsTdDHDe3fXV9y9-B3m2-XgCpzuhiLt8E0tV6HXoZKHv4GtHgKQ', + ) + ], + ) + + assert card == expected_card + serialized_data = agent_card_to_dict(card) + del original_data['capabilities']['stateTransitionHistory'] + self._assert_agent_card_diff(original_data, serialized_data) + re_parsed_card = parse_agent_card(copy.deepcopy(serialized_data)) + assert re_parsed_card == card + + def test_parse_complex_030_agent_card(self) -> None: + data = { + 'additionalInterfaces': [ + { + 'transport': 'GRPC', + 'url': 'http://complex.agent.example.com/grpc', + }, + { + 'transport': 'JSONRPC', + 'url': 'http://complex.agent.example.com/jsonrpc', + }, + ], + 'capabilities': {'pushNotifications': True, 'streaming': True}, + 'defaultInputModes': ['text/plain', 'application/json'], + 'defaultOutputModes': ['application/json', 'image/png'], + 'description': 'A very complex agent from 0.3.0', + 'name': 'Complex Agent 0.3', + 'preferredTransport': 'HTTP+JSON', + 'protocolVersion': '0.3.0', + 'security': [ + {'test_oauth': ['read', 'write'], 'test_api_key': []}, + {'test_http': []}, + {'test_oidc': ['openid', 'profile']}, + {'test_mtls': []}, + ], + 'securitySchemes': { + 'test_oauth': { + 'description': 'OAuth2 authentication', + 'flows': { + 'authorizationCode': { + 'authorizationUrl': 'http://auth.example.com', + 'scopes': { + 'read': 'Read access', + 'write': 'Write access', + }, + 'tokenUrl': 'http://token.example.com', + } + }, + 'type': 'oauth2', + }, + 'test_api_key': { + 'description': 'API Key auth', + 'in': 'header', + 'name': 'X-API-KEY', + 'type': 'apiKey', + }, + 'test_http': { + 'bearerFormat': 'JWT', + 'description': 'HTTP Basic auth', + 'scheme': 'basic', + 'type': 'http', + }, + 'test_oidc': { + 'description': 'OIDC Auth', + 'openIdConnectUrl': 'https://example.com/.well-known/openid-configuration', + 'type': 'openIdConnect', + }, + 'test_mtls': {'description': 'mTLS Auth', 'type': 'mutualTLS'}, + }, + 'skills': [ + { + 'description': 'The first complex skill', + 'id': 'skill-1', + 'inputModes': ['application/json'], + 'name': 'Complex Skill 1', + 'outputModes': ['application/json'], + 'security': [{'test_api_key': []}], + 'tags': ['example', 'complex'], + }, + { + 'description': 'The second complex skill', + 'id': 'skill-2', + 'name': 'Complex Skill 2', + 'security': [{'test_oidc': ['openid']}], + 'tags': ['example2'], + }, + ], + 'supportsAuthenticatedExtendedCard': True, + 'url': 'http://complex.agent.example.com/api', + 'version': '1.5.2', + } + original_data = copy.deepcopy(data) + card = parse_agent_card(data) + + expected_card = AgentCard( + name='Complex Agent 0.3', + description='A very complex agent from 0.3.0', + version='1.5.2', + capabilities=AgentCapabilities( + extended_agent_card=True, + streaming=True, + push_notifications=True, + ), + default_input_modes=['text/plain', 'application/json'], + default_output_modes=['application/json', 'image/png'], + supported_interfaces=[ + AgentInterface( + url='http://complex.agent.example.com/api', + protocol_binding='HTTP+JSON', + protocol_version='0.3.0', + ), + AgentInterface( + url='http://complex.agent.example.com/grpc', + protocol_binding='GRPC', + protocol_version='0.3.0', + ), + AgentInterface( + url='http://complex.agent.example.com/jsonrpc', + protocol_binding='JSONRPC', + protocol_version='0.3.0', + ), + ], + security_requirements=[ + SecurityRequirement( + schemes={ + 'test_oauth': StringList(list=['read', 'write']), + 'test_api_key': StringList(), + } + ), + SecurityRequirement(schemes={'test_http': StringList()}), + SecurityRequirement( + schemes={ + 'test_oidc': StringList(list=['openid', 'profile']) + } + ), + SecurityRequirement(schemes={'test_mtls': StringList()}), + ], + security_schemes={ + 'test_oauth': SecurityScheme( + oauth2_security_scheme=OAuth2SecurityScheme( + description='OAuth2 authentication', + flows=OAuthFlows( + authorization_code=AuthorizationCodeOAuthFlow( + authorization_url='http://auth.example.com', + token_url='http://token.example.com', + scopes={ + 'read': 'Read access', + 'write': 'Write access', + }, + ) + ), + ) + ), + 'test_api_key': SecurityScheme( + api_key_security_scheme=APIKeySecurityScheme( + description='API Key auth', + location='header', + name='X-API-KEY', + ) + ), + 'test_http': SecurityScheme( + http_auth_security_scheme=HTTPAuthSecurityScheme( + description='HTTP Basic auth', + scheme='basic', + bearer_format='JWT', + ) + ), + 'test_oidc': SecurityScheme( + open_id_connect_security_scheme=OpenIdConnectSecurityScheme( + description='OIDC Auth', + open_id_connect_url='https://example.com/.well-known/openid-configuration', + ) + ), + 'test_mtls': SecurityScheme( + mtls_security_scheme=MutualTlsSecurityScheme( + description='mTLS Auth' + ) + ), + }, + skills=[ + AgentSkill( + id='skill-1', + name='Complex Skill 1', + description='The first complex skill', + tags=['example', 'complex'], + input_modes=['application/json'], + output_modes=['application/json'], + security_requirements=[ + SecurityRequirement( + schemes={'test_api_key': StringList()} + ) + ], + ), + AgentSkill( + id='skill-2', + name='Complex Skill 2', + description='The second complex skill', + tags=['example2'], + security_requirements=[ + SecurityRequirement( + schemes={'test_oidc': StringList(list=['openid'])} + ) + ], + ), + ], + ) + + assert card == expected_card + serialized_data = agent_card_to_dict(card) + self._assert_agent_card_diff(original_data, serialized_data) + re_parsed_card = parse_agent_card(copy.deepcopy(serialized_data)) + assert re_parsed_card == card diff --git a/tests/client/test_client_helpers.py b/tests/client/test_client_helpers.py deleted file mode 100644 index 0eb394f43..000000000 --- a/tests/client/test_client_helpers.py +++ /dev/null @@ -1,696 +0,0 @@ -import copy -import difflib -import json -from google.protobuf.json_format import MessageToDict - -from a2a.client.helpers import parse_agent_card -from a2a.helpers.proto_helpers import new_text_message -from a2a.server.request_handlers.response_helpers import agent_card_to_dict -from a2a.types.a2a_pb2 import ( - APIKeySecurityScheme, - AgentCapabilities, - AgentCard, - AgentCardSignature, - AgentInterface, - AgentProvider, - AgentSkill, - AuthorizationCodeOAuthFlow, - HTTPAuthSecurityScheme, - MutualTlsSecurityScheme, - OAuth2SecurityScheme, - OAuthFlows, - OpenIdConnectSecurityScheme, - Role, - SecurityRequirement, - SecurityScheme, - StringList, -) - - -def test_parse_agent_card_legacy_support() -> None: - data = { - 'name': 'Legacy Agent', - 'description': 'Legacy Description', - 'version': '1.0', - 'supportsAuthenticatedExtendedCard': True, - } - card = parse_agent_card(data) - assert card.name == 'Legacy Agent' - assert card.capabilities.extended_agent_card is True - # Ensure it's popped from the dict - assert 'supportsAuthenticatedExtendedCard' not in data - - -def test_parse_agent_card_new_support() -> None: - data = { - 'name': 'New Agent', - 'description': 'New Description', - 'version': '1.0', - 'capabilities': {'extendedAgentCard': True}, - } - card = parse_agent_card(data) - assert card.name == 'New Agent' - assert card.capabilities.extended_agent_card is True - - -def test_parse_agent_card_no_support() -> None: - data = { - 'name': 'No Support Agent', - 'description': 'No Support Description', - 'version': '1.0', - 'capabilities': {'extendedAgentCard': False}, - } - card = parse_agent_card(data) - assert card.name == 'No Support Agent' - assert card.capabilities.extended_agent_card is False - - -def test_parse_agent_card_both_legacy_and_new() -> None: - data = { - 'name': 'Mixed Agent', - 'description': 'Mixed Description', - 'version': '1.0', - 'supportsAuthenticatedExtendedCard': True, - 'capabilities': {'streaming': True}, - } - card = parse_agent_card(data) - assert card.name == 'Mixed Agent' - assert card.capabilities.streaming is True - assert card.capabilities.extended_agent_card is True - - -def _assert_agent_card_diff(original_data: dict, serialized_data: dict) -> None: - """Helper to assert that the re-serialized 1.0.0 JSON payload contains all original 0.3.0 data (no dropped fields).""" - original_json_str = json.dumps(original_data, indent=2, sort_keys=True) - serialized_json_str = json.dumps(serialized_data, indent=2, sort_keys=True) - - diff_lines = list( - difflib.unified_diff( - original_json_str.splitlines(), - serialized_json_str.splitlines(), - lineterm='', - ) - ) - - removed_lines = [] - for line in diff_lines: - if line.startswith('-') and not line.startswith('---'): - removed_lines.append(line) - - if removed_lines: - error_msg = ( - 'Re-serialization dropped fields from the original payload:\n' - + '\n'.join(removed_lines) - ) - raise AssertionError(error_msg) - - -def test_parse_typical_030_agent_card() -> None: - data = { - 'additionalInterfaces': [ - {'transport': 'GRPC', 'url': 'http://agent.example.com/api/grpc'} - ], - 'capabilities': {'streaming': True}, - 'defaultInputModes': ['text/plain'], - 'defaultOutputModes': ['application/json'], - 'description': 'A typical agent from 0.3.0', - 'name': 'Typical Agent 0.3', - 'preferredTransport': 'JSONRPC', - 'protocolVersion': '0.3.0', - 'security': [{'test_oauth': ['read', 'write']}], - 'securitySchemes': { - 'test_oauth': { - 'description': 'OAuth2 authentication', - 'flows': { - 'authorizationCode': { - 'authorizationUrl': 'http://auth.example.com', - 'scopes': { - 'read': 'Read access', - 'write': 'Write access', - }, - 'tokenUrl': 'http://token.example.com', - } - }, - 'type': 'oauth2', - } - }, - 'skills': [ - { - 'description': 'The first skill', - 'id': 'skill-1', - 'name': 'Skill 1', - 'security': [{'test_oauth': ['read']}], - 'tags': ['example'], - } - ], - 'supportsAuthenticatedExtendedCard': True, - 'url': 'http://agent.example.com/api', - 'version': '1.0', - } - original_data = copy.deepcopy(data) - card = parse_agent_card(data) - - expected_card = AgentCard( - name='Typical Agent 0.3', - description='A typical agent from 0.3.0', - version='1.0', - capabilities=AgentCapabilities( - extended_agent_card=True, streaming=True - ), - default_input_modes=['text/plain'], - default_output_modes=['application/json'], - supported_interfaces=[ - AgentInterface( - url='http://agent.example.com/api', - protocol_binding='JSONRPC', - protocol_version='0.3.0', - ), - AgentInterface( - url='http://agent.example.com/api/grpc', - protocol_binding='GRPC', - protocol_version='0.3.0', - ), - ], - security_requirements=[ - SecurityRequirement( - schemes={'test_oauth': StringList(list=['read', 'write'])} - ) - ], - security_schemes={ - 'test_oauth': SecurityScheme( - oauth2_security_scheme=OAuth2SecurityScheme( - description='OAuth2 authentication', - flows=OAuthFlows( - authorization_code=AuthorizationCodeOAuthFlow( - authorization_url='http://auth.example.com', - token_url='http://token.example.com', - scopes={ - 'read': 'Read access', - 'write': 'Write access', - }, - ) - ), - ) - ) - }, - skills=[ - AgentSkill( - id='skill-1', - name='Skill 1', - description='The first skill', - tags=['example'], - security_requirements=[ - SecurityRequirement( - schemes={'test_oauth': StringList(list=['read'])} - ) - ], - ) - ], - ) - - assert card == expected_card - - # Serialize back to JSON and compare - serialized_data = agent_card_to_dict(card) - - _assert_agent_card_diff(original_data, serialized_data) - assert 'preferredTransport' in serialized_data - - # Re-parse from the serialized payload and verify identical to original parsing - re_parsed_card = parse_agent_card(copy.deepcopy(serialized_data)) - assert re_parsed_card == card - - -def test_parse_agent_card_security_scheme_without_in() -> None: - data = { - 'name': 'API Key Agent', - 'description': 'API Key without in param', - 'version': '1.0', - 'securitySchemes': { - 'test_api_key': {'type': 'apiKey', 'name': 'X-API-KEY'} - }, - } - card = parse_agent_card(data) - assert 'test_api_key' in card.security_schemes - assert ( - card.security_schemes['test_api_key'].api_key_security_scheme.name - == 'X-API-KEY' - ) - assert ( - card.security_schemes['test_api_key'].api_key_security_scheme.location - == '' - ) - - -def test_parse_agent_card_security_scheme_unknown_type() -> None: - data = { - 'name': 'Unknown Scheme Agent', - 'description': 'Has unknown scheme type', - 'version': '1.0', - 'securitySchemes': { - 'test_unknown': {'type': 'someFutureType', 'future_prop': 'value'}, - 'test_missing_type': {'prop': 'value'}, - }, - } - card = parse_agent_card(data) - # the ParseDict ignore_unknown_fields=True handles the unknown fields. - # Because there is no mapping logic for 'someFutureType', the Protobuf - # creates an empty SecurityScheme message under those keys. - assert 'test_unknown' in card.security_schemes - assert not card.security_schemes['test_unknown'].WhichOneof('scheme') - - assert 'test_missing_type' in card.security_schemes - assert not card.security_schemes['test_missing_type'].WhichOneof('scheme') - - -def test_create_text_message_object() -> None: - msg = new_text_message(text='Hello', role=Role.ROLE_AGENT) - assert msg.role == Role.ROLE_AGENT - assert len(msg.parts) == 1 - assert msg.parts[0].text == 'Hello' - assert msg.message_id != '' - - -def test_parse_030_agent_card_route_planner() -> None: - data = { - 'protocolVersion': '0.3', - 'name': 'GeoSpatial Route Planner Agent', - 'description': 'Provides advanced route planning.', - 'url': 'https://georoute-agent.example.com/a2a/v1', - 'preferredTransport': 'JSONRPC', - 'additionalInterfaces': [ - { - 'url': 'https://georoute-agent.example.com/a2a/v1', - 'transport': 'JSONRPC', - }, - { - 'url': 'https://georoute-agent.example.com/a2a/grpc', - 'transport': 'GRPC', - }, - { - 'url': 'https://georoute-agent.example.com/a2a/json', - 'transport': 'HTTP+JSON', - }, - ], - 'provider': { - 'organization': 'Example Geo Services Inc.', - 'url': 'https://www.examplegeoservices.com', - }, - 'iconUrl': 'https://georoute-agent.example.com/icon.png', - 'version': '1.2.0', - 'documentationUrl': 'https://docs.examplegeoservices.com/georoute-agent/api', - 'supportsAuthenticatedExtendedCard': True, - 'capabilities': { - 'streaming': True, - 'pushNotifications': True, - 'stateTransitionHistory': False, - }, - 'securitySchemes': { - 'google': { - 'type': 'openIdConnect', - 'openIdConnectUrl': 'https://accounts.google.com/.well-known/openid-configuration', - } - }, - 'security': [{'google': ['openid', 'profile', 'email']}], - 'defaultInputModes': ['application/json', 'text/plain'], - 'defaultOutputModes': ['application/json', 'image/png'], - 'skills': [ - { - 'id': 'route-optimizer-traffic', - 'name': 'Traffic-Aware Route Optimizer', - 'description': 'Calculates the optimal driving route between two or more locations, taking into account real-time traffic conditions, road closures, and user preferences (e.g., avoid tolls, prefer highways).', - 'tags': [ - 'maps', - 'routing', - 'navigation', - 'directions', - 'traffic', - ], - 'examples': [ - "Plan a route from '1600 Amphitheatre Parkway, Mountain View, CA' to 'San Francisco International Airport' avoiding tolls.", - '{"origin": {"lat": 37.422, "lng": -122.084}, "destination": {"lat": 37.7749, "lng": -122.4194}, "preferences": ["avoid_ferries"]}', - ], - 'inputModes': ['application/json', 'text/plain'], - 'outputModes': [ - 'application/json', - 'application/vnd.geo+json', - 'text/html', - ], - 'security': [ - {'example': []}, - {'google': ['openid', 'profile', 'email']}, - ], - }, - { - 'id': 'custom-map-generator', - 'name': 'Personalized Map Generator', - 'description': 'Creates custom map images or interactive map views based on user-defined points of interest, routes, and style preferences. Can overlay data layers.', - 'tags': [ - 'maps', - 'customization', - 'visualization', - 'cartography', - ], - 'examples': [ - 'Generate a map of my upcoming road trip with all planned stops highlighted.', - 'Show me a map visualizing all coffee shops within a 1-mile radius of my current location.', - ], - 'inputModes': ['application/json'], - 'outputModes': [ - 'image/png', - 'image/jpeg', - 'application/json', - 'text/html', - ], - }, - ], - 'signatures': [ - { - 'protected': 'eyJhbGciOiJFUzI1NiIsInR5cCI6IkpPU0UiLCJraWQiOiJrZXktMSIsImprdSI6Imh0dHBzOi8vZXhhbXBsZS5jb20vYWdlbnQvandrcy5qc29uIn0', - 'signature': 'QFdkNLNszlGj3z3u0YQGt_T9LixY3qtdQpZmsTdDHDe3fXV9y9-B3m2-XgCpzuhiLt8E0tV6HXoZKHv4GtHgKQ', - } - ], - } - - original_data = copy.deepcopy(data) - card = parse_agent_card(data) - - expected_card = AgentCard( - name='GeoSpatial Route Planner Agent', - description='Provides advanced route planning.', - version='1.2.0', - documentation_url='https://docs.examplegeoservices.com/georoute-agent/api', - icon_url='https://georoute-agent.example.com/icon.png', - provider=AgentProvider( - organization='Example Geo Services Inc.', - url='https://www.examplegeoservices.com', - ), - capabilities=AgentCapabilities( - extended_agent_card=True, streaming=True, push_notifications=True - ), - default_input_modes=['application/json', 'text/plain'], - default_output_modes=['application/json', 'image/png'], - supported_interfaces=[ - AgentInterface( - url='https://georoute-agent.example.com/a2a/v1', - protocol_binding='JSONRPC', - protocol_version='0.3', - ), - AgentInterface( - url='https://georoute-agent.example.com/a2a/v1', - protocol_binding='JSONRPC', - protocol_version='0.3', - ), - AgentInterface( - url='https://georoute-agent.example.com/a2a/grpc', - protocol_binding='GRPC', - protocol_version='0.3', - ), - AgentInterface( - url='https://georoute-agent.example.com/a2a/json', - protocol_binding='HTTP+JSON', - protocol_version='0.3', - ), - ], - security_requirements=[ - SecurityRequirement( - schemes={ - 'google': StringList(list=['openid', 'profile', 'email']) - } - ) - ], - security_schemes={ - 'google': SecurityScheme( - open_id_connect_security_scheme=OpenIdConnectSecurityScheme( - open_id_connect_url='https://accounts.google.com/.well-known/openid-configuration' - ) - ) - }, - skills=[ - AgentSkill( - id='route-optimizer-traffic', - name='Traffic-Aware Route Optimizer', - description='Calculates the optimal driving route between two or more locations, taking into account real-time traffic conditions, road closures, and user preferences (e.g., avoid tolls, prefer highways).', - tags=['maps', 'routing', 'navigation', 'directions', 'traffic'], - examples=[ - "Plan a route from '1600 Amphitheatre Parkway, Mountain View, CA' to 'San Francisco International Airport' avoiding tolls.", - '{"origin": {"lat": 37.422, "lng": -122.084}, "destination": {"lat": 37.7749, "lng": -122.4194}, "preferences": ["avoid_ferries"]}', - ], - input_modes=['application/json', 'text/plain'], - output_modes=[ - 'application/json', - 'application/vnd.geo+json', - 'text/html', - ], - security_requirements=[ - SecurityRequirement(schemes={'example': StringList()}), - SecurityRequirement( - schemes={ - 'google': StringList( - list=['openid', 'profile', 'email'] - ) - } - ), - ], - ), - AgentSkill( - id='custom-map-generator', - name='Personalized Map Generator', - description='Creates custom map images or interactive map views based on user-defined points of interest, routes, and style preferences. Can overlay data layers.', - tags=['maps', 'customization', 'visualization', 'cartography'], - examples=[ - 'Generate a map of my upcoming road trip with all planned stops highlighted.', - 'Show me a map visualizing all coffee shops within a 1-mile radius of my current location.', - ], - input_modes=['application/json'], - output_modes=[ - 'image/png', - 'image/jpeg', - 'application/json', - 'text/html', - ], - ), - ], - signatures=[ - AgentCardSignature( - protected='eyJhbGciOiJFUzI1NiIsInR5cCI6IkpPU0UiLCJraWQiOiJrZXktMSIsImprdSI6Imh0dHBzOi8vZXhhbXBsZS5jb20vYWdlbnQvandrcy5qc29uIn0', - signature='QFdkNLNszlGj3z3u0YQGt_T9LixY3qtdQpZmsTdDHDe3fXV9y9-B3m2-XgCpzuhiLt8E0tV6HXoZKHv4GtHgKQ', - ) - ], - ) - - assert card == expected_card - - # Serialize back to JSON and compare - serialized_data = agent_card_to_dict(card) - - # Remove deprecated stateTransitionHistory before diffing - del original_data['capabilities']['stateTransitionHistory'] - - _assert_agent_card_diff(original_data, serialized_data) - - # Re-parse from the serialized payload and verify identical to original parsing - re_parsed_card = parse_agent_card(copy.deepcopy(serialized_data)) - assert re_parsed_card == card - - -def test_parse_complex_030_agent_card() -> None: - data = { - 'additionalInterfaces': [ - { - 'transport': 'GRPC', - 'url': 'http://complex.agent.example.com/grpc', - }, - { - 'transport': 'JSONRPC', - 'url': 'http://complex.agent.example.com/jsonrpc', - }, - ], - 'capabilities': {'pushNotifications': True, 'streaming': True}, - 'defaultInputModes': ['text/plain', 'application/json'], - 'defaultOutputModes': ['application/json', 'image/png'], - 'description': 'A very complex agent from 0.3.0', - 'name': 'Complex Agent 0.3', - 'preferredTransport': 'HTTP+JSON', - 'protocolVersion': '0.3.0', - 'security': [ - {'test_oauth': ['read', 'write'], 'test_api_key': []}, - {'test_http': []}, - {'test_oidc': ['openid', 'profile']}, - {'test_mtls': []}, - ], - 'securitySchemes': { - 'test_oauth': { - 'description': 'OAuth2 authentication', - 'flows': { - 'authorizationCode': { - 'authorizationUrl': 'http://auth.example.com', - 'scopes': { - 'read': 'Read access', - 'write': 'Write access', - }, - 'tokenUrl': 'http://token.example.com', - } - }, - 'type': 'oauth2', - }, - 'test_api_key': { - 'description': 'API Key auth', - 'in': 'header', - 'name': 'X-API-KEY', - 'type': 'apiKey', - }, - 'test_http': { - 'bearerFormat': 'JWT', - 'description': 'HTTP Basic auth', - 'scheme': 'basic', - 'type': 'http', - }, - 'test_oidc': { - 'description': 'OIDC Auth', - 'openIdConnectUrl': 'https://example.com/.well-known/openid-configuration', - 'type': 'openIdConnect', - }, - 'test_mtls': {'description': 'mTLS Auth', 'type': 'mutualTLS'}, - }, - 'skills': [ - { - 'description': 'The first complex skill', - 'id': 'skill-1', - 'inputModes': ['application/json'], - 'name': 'Complex Skill 1', - 'outputModes': ['application/json'], - 'security': [{'test_api_key': []}], - 'tags': ['example', 'complex'], - }, - { - 'description': 'The second complex skill', - 'id': 'skill-2', - 'name': 'Complex Skill 2', - 'security': [{'test_oidc': ['openid']}], - 'tags': ['example2'], - }, - ], - 'supportsAuthenticatedExtendedCard': True, - 'url': 'http://complex.agent.example.com/api', - 'version': '1.5.2', - } - original_data = copy.deepcopy(data) - card = parse_agent_card(data) - - expected_card = AgentCard( - name='Complex Agent 0.3', - description='A very complex agent from 0.3.0', - version='1.5.2', - capabilities=AgentCapabilities( - extended_agent_card=True, streaming=True, push_notifications=True - ), - default_input_modes=['text/plain', 'application/json'], - default_output_modes=['application/json', 'image/png'], - supported_interfaces=[ - AgentInterface( - url='http://complex.agent.example.com/api', - protocol_binding='HTTP+JSON', - protocol_version='0.3.0', - ), - AgentInterface( - url='http://complex.agent.example.com/grpc', - protocol_binding='GRPC', - protocol_version='0.3.0', - ), - AgentInterface( - url='http://complex.agent.example.com/jsonrpc', - protocol_binding='JSONRPC', - protocol_version='0.3.0', - ), - ], - security_requirements=[ - SecurityRequirement( - schemes={ - 'test_oauth': StringList(list=['read', 'write']), - 'test_api_key': StringList(), - } - ), - SecurityRequirement(schemes={'test_http': StringList()}), - SecurityRequirement( - schemes={'test_oidc': StringList(list=['openid', 'profile'])} - ), - SecurityRequirement(schemes={'test_mtls': StringList()}), - ], - security_schemes={ - 'test_oauth': SecurityScheme( - oauth2_security_scheme=OAuth2SecurityScheme( - description='OAuth2 authentication', - flows=OAuthFlows( - authorization_code=AuthorizationCodeOAuthFlow( - authorization_url='http://auth.example.com', - token_url='http://token.example.com', - scopes={ - 'read': 'Read access', - 'write': 'Write access', - }, - ) - ), - ) - ), - 'test_api_key': SecurityScheme( - api_key_security_scheme=APIKeySecurityScheme( - description='API Key auth', - location='header', - name='X-API-KEY', - ) - ), - 'test_http': SecurityScheme( - http_auth_security_scheme=HTTPAuthSecurityScheme( - description='HTTP Basic auth', - scheme='basic', - bearer_format='JWT', - ) - ), - 'test_oidc': SecurityScheme( - open_id_connect_security_scheme=OpenIdConnectSecurityScheme( - description='OIDC Auth', - open_id_connect_url='https://example.com/.well-known/openid-configuration', - ) - ), - 'test_mtls': SecurityScheme( - mtls_security_scheme=MutualTlsSecurityScheme( - description='mTLS Auth' - ) - ), - }, - skills=[ - AgentSkill( - id='skill-1', - name='Complex Skill 1', - description='The first complex skill', - tags=['example', 'complex'], - input_modes=['application/json'], - output_modes=['application/json'], - security_requirements=[ - SecurityRequirement(schemes={'test_api_key': StringList()}) - ], - ), - AgentSkill( - id='skill-2', - name='Complex Skill 2', - description='The second complex skill', - tags=['example2'], - security_requirements=[ - SecurityRequirement( - schemes={'test_oidc': StringList(list=['openid'])} - ) - ], - ), - ], - ) - - assert card == expected_card - - # Serialize back to JSON and compare - serialized_data = agent_card_to_dict(card) - _assert_agent_card_diff(original_data, serialized_data) - - # Re-parse from the serialized payload and verify identical to original parsing - re_parsed_card = parse_agent_card(copy.deepcopy(serialized_data)) - assert re_parsed_card == card diff --git a/tests/integration/cross_version/test_cross_version_card_validation.py b/tests/integration/cross_version/test_cross_version_card_validation.py index 85379c3a3..25972b075 100644 --- a/tests/integration/cross_version/test_cross_version_card_validation.py +++ b/tests/integration/cross_version/test_cross_version_card_validation.py @@ -18,7 +18,7 @@ SecurityScheme, StringList, ) -from a2a.client.helpers import parse_agent_card +from a2a.client.card_resolver import parse_agent_card from google.protobuf.json_format import MessageToDict, ParseDict diff --git a/tests/integration/test_client_server_integration.py b/tests/integration/test_client_server_integration.py index 76da2e20f..1711ac810 100644 --- a/tests/integration/test_client_server_integration.py +++ b/tests/integration/test_client_server_integration.py @@ -709,8 +709,11 @@ async def test_json_transport_get_signed_base_card( }, ) + async def async_signer(card: AgentCard) -> AgentCard: + return signer(card) + agent_card_routes = create_agent_card_routes( - agent_card=agent_card, card_url='/', card_modifier=signer + agent_card=agent_card, card_url='/', card_modifier=async_signer ) jsonrpc_routes = create_jsonrpc_routes( request_handler=mock_request_handler, rpc_url='/' @@ -863,8 +866,12 @@ async def get_extended_agent_card_mock_3(*args, **kwargs): mock_request_handler.on_get_extended_agent_card.side_effect = ( get_extended_agent_card_mock_3 # type: ignore[union-attr] ) + + async def async_signer(card: AgentCard) -> AgentCard: + return signer(card) + agent_card_routes = create_agent_card_routes( - agent_card=agent_card, card_url='/', card_modifier=signer + agent_card=agent_card, card_url='/', card_modifier=async_signer ) jsonrpc_routes = create_jsonrpc_routes( request_handler=mock_request_handler, rpc_url='/' diff --git a/tests/server/tasks/test_task_manager.py b/tests/server/tasks/test_task_manager.py index bdfbf525c..eba8d2f14 100644 --- a/tests/server/tasks/test_task_manager.py +++ b/tests/server/tasks/test_task_manager.py @@ -6,6 +6,7 @@ from a2a.auth.user import User from a2a.server.context import ServerCallContext from a2a.server.tasks import TaskManager +from a2a.server.tasks.task_manager import append_artifact_to_task from a2a.types.a2a_pb2 import ( Artifact, Message, @@ -345,3 +346,99 @@ async def test_save_task_event_no_task_existing( assert saved_task.status.state == TaskState.TASK_STATE_COMPLETED assert task_manager_without_id.task_id == 'event-task-id' assert task_manager_without_id.context_id == 'some-context' + + +def test_append_artifact_to_task(): + # Prepare base task + task = create_minimal_task() + assert task.id == 'task-abc' + assert task.context_id == 'session-xyz' + assert task.status.state == TaskState.TASK_STATE_SUBMITTED + assert len(task.history) == 0 # proto repeated fields are empty, not None + assert len(task.artifacts) == 0 + + # Prepare appending artifact and event + artifact_1 = Artifact( + artifact_id='artifact-123', parts=[Part(text='Hello')] + ) + append_event_1 = TaskArtifactUpdateEvent( + artifact=artifact_1, append=False, task_id='123', context_id='123' + ) + + # Test adding a new artifact (not appending) + append_artifact_to_task(task, append_event_1) + assert len(task.artifacts) == 1 + assert task.artifacts[0].artifact_id == 'artifact-123' + assert task.artifacts[0].name == '' # proto default for string + assert len(task.artifacts[0].parts) == 1 + assert task.artifacts[0].parts[0].text == 'Hello' + + # Test replacing the artifact + artifact_2 = Artifact( + artifact_id='artifact-123', + name='updated name', + parts=[Part(text='Updated')], + metadata={'existing_key': 'existing_value'}, + ) + append_event_2 = TaskArtifactUpdateEvent( + artifact=artifact_2, append=False, task_id='123', context_id='123' + ) + append_artifact_to_task(task, append_event_2) + assert len(task.artifacts) == 1 # Should still have one artifact + assert task.artifacts[0].artifact_id == 'artifact-123' + assert task.artifacts[0].name == 'updated name' + assert len(task.artifacts[0].parts) == 1 + assert task.artifacts[0].parts[0].text == 'Updated' + assert task.artifacts[0].metadata['existing_key'] == 'existing_value' + + # Test appending parts to an existing artifact + artifact_with_parts = Artifact( + artifact_id='artifact-123', + parts=[Part(text='Part 2')], + metadata={'new_key': 'new_value'}, + ) + append_event_3 = TaskArtifactUpdateEvent( + artifact=artifact_with_parts, + append=True, + task_id='123', + context_id='123', + ) + append_artifact_to_task(task, append_event_3) + assert len(task.artifacts[0].parts) == 2 + assert task.artifacts[0].parts[0].text == 'Updated' + assert task.artifacts[0].parts[1].text == 'Part 2' + assert task.artifacts[0].metadata['existing_key'] == 'existing_value' + assert task.artifacts[0].metadata['new_key'] == 'new_value' + + # Test adding another new artifact + another_artifact_with_parts = Artifact( + artifact_id='new_artifact', + parts=[Part(text='new artifact Part 1')], + ) + append_event_4 = TaskArtifactUpdateEvent( + artifact=another_artifact_with_parts, + append=False, + task_id='123', + context_id='123', + ) + append_artifact_to_task(task, append_event_4) + assert len(task.artifacts) == 2 + assert task.artifacts[0].artifact_id == 'artifact-123' + assert task.artifacts[1].artifact_id == 'new_artifact' + assert len(task.artifacts[0].parts) == 2 + assert len(task.artifacts[1].parts) == 1 + + # Test appending part to a task that does not have a matching artifact + non_existing_artifact_with_parts = Artifact( + artifact_id='artifact-456', parts=[Part(text='Part 1')] + ) + append_event_5 = TaskArtifactUpdateEvent( + artifact=non_existing_artifact_with_parts, + append=True, + task_id='123', + context_id='123', + ) + append_artifact_to_task(task, append_event_5) + assert len(task.artifacts) == 2 + assert len(task.artifacts[0].parts) == 2 + assert len(task.artifacts[1].parts) == 1 diff --git a/tests/server/test_integration.py b/tests/server/test_integration.py index ddab2661a..56663e7e9 100644 --- a/tests/server/test_integration.py +++ b/tests/server/test_integration.py @@ -775,7 +775,7 @@ def test_dynamic_agent_card_modifier_sync( ): """Test that a synchronous card_modifier dynamically alters the public agent card.""" - def modifier(card: AgentCard) -> AgentCard: + async def modifier(card: AgentCard) -> AgentCard: modified_card = AgentCard() modified_card.CopyFrom(card) modified_card.name = 'Dynamically Modified Agent' @@ -818,7 +818,7 @@ def test_fastapi_dynamic_agent_card_modifier_sync( ): """Test that a synchronous card_modifier dynamically alters the public agent card for FastAPI.""" - def modifier(card: AgentCard) -> AgentCard: + async def modifier(card: AgentCard) -> AgentCard: modified_card = AgentCard() modified_card.CopyFrom(card) modified_card.name = 'Dynamically Modified Agent' diff --git a/tests/utils/test_helpers.py b/tests/utils/test_helpers.py deleted file mode 100644 index c2c990c0d..000000000 --- a/tests/utils/test_helpers.py +++ /dev/null @@ -1,312 +0,0 @@ -import uuid - -from typing import Any -from unittest.mock import patch - -import pytest - -from a2a.types import ( - AgentCapabilities, - AgentCard, - AgentCardSignature, - AgentInterface, - AgentSkill, - Artifact, - Message, - Part, - Role, - SendMessageRequest, - Task, - TaskArtifactUpdateEvent, - TaskState, - TaskStatus, -) -from a2a.utils.errors import UnsupportedOperationError - -from a2a.utils.signing import _clean_empty, _canonicalize_agent_card -from a2a.server.tasks.task_manager import append_artifact_to_task - - -# --- Helper Functions --- -def create_test_message( - role: Role = Role.ROLE_USER, - text: str = 'Hello', - message_id: str = 'msg-123', -) -> Message: - return Message( - role=role, - parts=[Part(text=text)], - message_id=message_id, - ) - - -def create_test_task( - task_id: str = 'task-abc', - context_id: str = 'session-xyz', -) -> Task: - return Task( - id=task_id, - context_id=context_id, - status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED), - ) - - -SAMPLE_AGENT_CARD: dict[str, Any] = { - 'name': 'Test Agent', - 'description': 'A test agent', - 'supported_interfaces': [ - AgentInterface( - url='http://localhost', - protocol_binding='HTTP+JSON', - ) - ], - 'version': '1.0.0', - 'capabilities': AgentCapabilities( - streaming=None, - push_notifications=True, - ), - 'default_input_modes': ['text/plain'], - 'default_output_modes': ['text/plain'], - 'documentation_url': None, - 'icon_url': '', - 'skills': [ - AgentSkill( - id='skill1', - name='Test Skill', - description='A test skill', - tags=['test'], - ) - ], - 'signatures': [ - AgentCardSignature( - protected='protected_header', signature='test_signature' - ) - ], -} - - -# Test append_artifact_to_task -def test_append_artifact_to_task(): - # Prepare base task - task = create_test_task() - assert task.id == 'task-abc' - assert task.context_id == 'session-xyz' - assert task.status.state == TaskState.TASK_STATE_SUBMITTED - assert len(task.history) == 0 # proto repeated fields are empty, not None - assert len(task.artifacts) == 0 - - # Prepare appending artifact and event - artifact_1 = Artifact( - artifact_id='artifact-123', parts=[Part(text='Hello')] - ) - append_event_1 = TaskArtifactUpdateEvent( - artifact=artifact_1, append=False, task_id='123', context_id='123' - ) - - # Test adding a new artifact (not appending) - append_artifact_to_task(task, append_event_1) - assert len(task.artifacts) == 1 - assert task.artifacts[0].artifact_id == 'artifact-123' - assert task.artifacts[0].name == '' # proto default for string - assert len(task.artifacts[0].parts) == 1 - assert task.artifacts[0].parts[0].text == 'Hello' - - # Test replacing the artifact - artifact_2 = Artifact( - artifact_id='artifact-123', - name='updated name', - parts=[Part(text='Updated')], - metadata={'existing_key': 'existing_value'}, - ) - append_event_2 = TaskArtifactUpdateEvent( - artifact=artifact_2, append=False, task_id='123', context_id='123' - ) - append_artifact_to_task(task, append_event_2) - assert len(task.artifacts) == 1 # Should still have one artifact - assert task.artifacts[0].artifact_id == 'artifact-123' - assert task.artifacts[0].name == 'updated name' - assert len(task.artifacts[0].parts) == 1 - assert task.artifacts[0].parts[0].text == 'Updated' - assert task.artifacts[0].metadata['existing_key'] == 'existing_value' - - # Test appending parts to an existing artifact - artifact_with_parts = Artifact( - artifact_id='artifact-123', - parts=[Part(text='Part 2')], - metadata={'new_key': 'new_value'}, - ) - append_event_3 = TaskArtifactUpdateEvent( - artifact=artifact_with_parts, - append=True, - task_id='123', - context_id='123', - ) - append_artifact_to_task(task, append_event_3) - assert len(task.artifacts[0].parts) == 2 - assert task.artifacts[0].parts[0].text == 'Updated' - assert task.artifacts[0].parts[1].text == 'Part 2' - assert task.artifacts[0].metadata['existing_key'] == 'existing_value' - assert task.artifacts[0].metadata['new_key'] == 'new_value' - - # Test adding another new artifact - another_artifact_with_parts = Artifact( - artifact_id='new_artifact', - parts=[Part(text='new artifact Part 1')], - ) - append_event_4 = TaskArtifactUpdateEvent( - artifact=another_artifact_with_parts, - append=False, - task_id='123', - context_id='123', - ) - append_artifact_to_task(task, append_event_4) - assert len(task.artifacts) == 2 - assert task.artifacts[0].artifact_id == 'artifact-123' - assert task.artifacts[1].artifact_id == 'new_artifact' - assert len(task.artifacts[0].parts) == 2 - assert len(task.artifacts[1].parts) == 1 - - # Test appending part to a task that does not have a matching artifact - non_existing_artifact_with_parts = Artifact( - artifact_id='artifact-456', parts=[Part(text='Part 1')] - ) - append_event_5 = TaskArtifactUpdateEvent( - artifact=non_existing_artifact_with_parts, - append=True, - task_id='123', - context_id='123', - ) - append_artifact_to_task(task, append_event_5) - assert len(task.artifacts) == 2 - assert len(task.artifacts[0].parts) == 2 - assert len(task.artifacts[1].parts) == 1 - - -def build_text_artifact(text: str, artifact_id: str) -> Artifact: - return Artifact(artifact_id=artifact_id, parts=[Part(text=text)]) - - -# Test build_text_artifact -def test_build_text_artifact(): - artifact_id = 'text_artifact' - text = 'This is a sample text' - artifact = build_text_artifact(text, artifact_id) - - assert artifact.artifact_id == artifact_id - assert len(artifact.parts) == 1 - assert artifact.parts[0].text == text - - -def test_canonicalize_agent_card(): - """Test canonicalize_agent_card with defaults, optionals, and exceptions. - - - extensions is omitted as it's not set and optional. - - protocolVersion is included because it's always added by canonicalize_agent_card. - - signatures should be omitted. - """ - agent_card = AgentCard(**SAMPLE_AGENT_CARD) - expected_jcs = ( - '{"capabilities":{"pushNotifications":true},' - '"defaultInputModes":["text/plain"],"defaultOutputModes":["text/plain"],' - '"description":"A test agent","name":"Test Agent",' - '"skills":[{"description":"A test skill","id":"skill1","name":"Test Skill","tags":["test"]}],' - '"supportedInterfaces":[{"protocolBinding":"HTTP+JSON","url":"http://localhost"}],' - '"version":"1.0.0"}' - ) - result = _canonicalize_agent_card(agent_card) - assert result == expected_jcs - - -def test_canonicalize_agent_card_preserves_false_capability(): - """Regression #692: streaming=False must not be stripped from canonical JSON.""" - card = AgentCard( - **{ - **SAMPLE_AGENT_CARD, - 'capabilities': AgentCapabilities( - streaming=False, - push_notifications=True, - ), - } - ) - result = _canonicalize_agent_card(card) - assert '"streaming":false' in result - - -@pytest.mark.parametrize( - 'input_val', - [ - pytest.param({'a': ''}, id='empty-string'), - pytest.param({'a': []}, id='empty-list'), - pytest.param({'a': {}}, id='empty-dict'), - pytest.param({'a': {'b': []}}, id='nested-empty'), - pytest.param({'a': '', 'b': [], 'c': {}}, id='all-empties'), - pytest.param({'a': {'b': {'c': ''}}}, id='deeply-nested'), - ], -) -def test_clean_empty_removes_empties(input_val): - """_clean_empty removes empty strings, lists, and dicts recursively.""" - assert _clean_empty(input_val) is None - - -def test_clean_empty_top_level_list_becomes_none(): - """Top-level list that becomes empty after cleaning should return None.""" - assert _clean_empty(['', {}, []]) is None - - -@pytest.mark.parametrize( - 'input_val,expected', - [ - pytest.param({'retries': 0}, {'retries': 0}, id='int-zero'), - pytest.param({'enabled': False}, {'enabled': False}, id='bool-false'), - pytest.param({'score': 0.0}, {'score': 0.0}, id='float-zero'), - pytest.param([0, 1, 2], [0, 1, 2], id='zero-in-list'), - pytest.param([False, True], [False, True], id='false-in-list'), - pytest.param( - {'config': {'max_retries': 0, 'name': 'agent'}}, - {'config': {'max_retries': 0, 'name': 'agent'}}, - id='nested-zero', - ), - ], -) -def test_clean_empty_preserves_falsy_values(input_val, expected): - """_clean_empty preserves legitimate falsy values (0, False, 0.0).""" - assert _clean_empty(input_val) == expected - - -@pytest.mark.parametrize( - 'input_val,expected', - [ - pytest.param( - {'count': 0, 'label': '', 'items': []}, - {'count': 0}, - id='falsy-with-empties', - ), - pytest.param( - {'a': 0, 'b': 'hello', 'c': False, 'd': ''}, - {'a': 0, 'b': 'hello', 'c': False}, - id='mixed-types', - ), - pytest.param( - {'name': 'agent', 'retries': 0, 'tags': [], 'desc': ''}, - {'name': 'agent', 'retries': 0}, - id='realistic-mixed', - ), - ], -) -def test_clean_empty_mixed(input_val, expected): - """_clean_empty handles mixed empty and falsy values correctly.""" - assert _clean_empty(input_val) == expected - - -def test_clean_empty_does_not_mutate_input(): - """_clean_empty should not mutate the original input object.""" - original = {'a': '', 'b': 1, 'c': {'d': ''}} - original_copy = { - 'a': '', - 'b': 1, - 'c': {'d': ''}, - } - - _clean_empty(original) - - assert original == original_copy diff --git a/tests/utils/test_signing.py b/tests/utils/test_signing.py index 162f28e28..2a09943fe 100644 --- a/tests/utils/test_signing.py +++ b/tests/utils/test_signing.py @@ -178,3 +178,111 @@ def test_signer_and_verifier_asymmetric(sample_agent_card: AgentCard): ) with pytest.raises(signing.InvalidSignaturesError): verifier_wrong_key(signed_card) + + +def test_canonicalize_agent_card(sample_agent_card: AgentCard): + """Test canonicalize_agent_card with defaults, optionals, and exceptions. + + - extensions is omitted as it's not set and optional. + - protocolVersion is included because it's always added by canonicalize_agent_card. + - signatures should be omitted. + """ + expected_jcs = ( + '{"capabilities":{"pushNotifications":true},' + '"defaultInputModes":["text/plain"],"defaultOutputModes":["text/plain"],' + '"description":"A test agent","name":"Test Agent",' + '"skills":[{"description":"A test skill","id":"skill1","name":"Test Skill","tags":["test"]}],' + '"supportedInterfaces":[{"protocolBinding":"HTTP+JSON","url":"http://localhost"}],' + '"version":"1.0.0"}' + ) + result = signing._canonicalize_agent_card(sample_agent_card) + assert result == expected_jcs + + +def test_canonicalize_agent_card_preserves_false_capability( + sample_agent_card: AgentCard, +): + """Regression #692: streaming=False must not be stripped from canonical JSON.""" + sample_agent_card.capabilities.streaming = False + result = signing._canonicalize_agent_card(sample_agent_card) + assert '"streaming":false' in result + + +@pytest.mark.parametrize( + 'input_val', + [ + pytest.param({'a': ''}, id='empty-string'), + pytest.param({'a': []}, id='empty-list'), + pytest.param({'a': {}}, id='empty-dict'), + pytest.param({'a': {'b': []}}, id='nested-empty'), + pytest.param({'a': '', 'b': [], 'c': {}}, id='all-empties'), + pytest.param({'a': {'b': {'c': ''}}}, id='deeply-nested'), + ], +) +def test_clean_empty_removes_empties(input_val): + """_clean_empty removes empty strings, lists, and dicts recursively.""" + assert signing._clean_empty(input_val) is None + + +def test_clean_empty_top_level_list_becomes_none(): + """Top-level list that becomes empty after cleaning should return None.""" + assert signing._clean_empty(['', {}, []]) is None + + +@pytest.mark.parametrize( + 'input_val,expected', + [ + pytest.param({'retries': 0}, {'retries': 0}, id='int-zero'), + pytest.param({'enabled': False}, {'enabled': False}, id='bool-false'), + pytest.param({'score': 0.0}, {'score': 0.0}, id='float-zero'), + pytest.param([0, 1, 2], [0, 1, 2], id='zero-in-list'), + pytest.param([False, True], [False, True], id='false-in-list'), + pytest.param( + {'config': {'max_retries': 0, 'name': 'agent'}}, + {'config': {'max_retries': 0, 'name': 'agent'}}, + id='nested-zero', + ), + ], +) +def test_clean_empty_preserves_falsy_values(input_val, expected): + """_clean_empty preserves legitimate falsy values (0, False, 0.0).""" + assert signing._clean_empty(input_val) == expected + + +@pytest.mark.parametrize( + 'input_val,expected', + [ + pytest.param( + {'count': 0, 'label': '', 'items': []}, + {'count': 0}, + id='falsy-with-empties', + ), + pytest.param( + {'a': 0, 'b': 'hello', 'c': False, 'd': ''}, + {'a': 0, 'b': 'hello', 'c': False}, + id='mixed-types', + ), + pytest.param( + {'name': 'agent', 'retries': 0, 'tags': [], 'desc': ''}, + {'name': 'agent', 'retries': 0}, + id='realistic-mixed', + ), + ], +) +def test_clean_empty_mixed(input_val, expected): + """_clean_empty handles mixed empty and falsy values correctly.""" + assert signing._clean_empty(input_val) == expected + + +def test_clean_empty_does_not_mutate_input(): + """_clean_empty should not mutate the original input object.""" + original = {'a': '', 'b': 1, 'c': {'d': ''}} + original_copy = { + 'a': '', + 'b': 1, + 'c': {'d': ''}, + } + + signing._clean_empty(original) + + assert original == original_copy diff --git a/tests/utils/test_helpers_validation.py b/tests/utils/test_version_validation.py similarity index 98% rename from tests/utils/test_helpers_validation.py rename to tests/utils/test_version_validation.py index 571f8ae9b..b2ae0594e 100644 --- a/tests/utils/test_helpers_validation.py +++ b/tests/utils/test_version_validation.py @@ -6,7 +6,7 @@ from a2a.server.context import ServerCallContext from a2a.utils import constants from a2a.utils.errors import VersionNotSupportedError -from a2a.utils.helpers import validate_version +from a2a.utils.version_validator import validate_version class TestHandler: From 934b59536756641076dc9ad407da4b891d774074 Mon Sep 17 00:00:00 2001 From: "Agent2Agent (A2A) Bot" Date: Fri, 17 Apr 2026 08:47:49 -0500 Subject: [PATCH 39/67] chore(1.0-dev): release 1.0.0-alpha.2 (#971) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit :robot: I have created a release *beep* *boop* --- ## [1.0.0-alpha.2](https://github.com/a2aproject/a2a-python/compare/v1.0.0-alpha.1...v1.0.0-alpha.2) (2026-04-17) ### ⚠ BREAKING CHANGES * clean helpers and utils folders structure ([#983](https://github.com/a2aproject/a2a-python/issues/983)) * Raise errors on invalid AgentExecutor behavior. ([#979](https://github.com/a2aproject/a2a-python/issues/979)) * extract developer helpers in helpers folder ([#978](https://github.com/a2aproject/a2a-python/issues/978)) ### Features * Raise errors on invalid AgentExecutor behavior. ([#979](https://github.com/a2aproject/a2a-python/issues/979)) ([f4a0bcd](https://github.com/a2aproject/a2a-python/commit/f4a0bcdf68107c95e6c0a5e6696e4a7d6e01a03f)) * **utils:** add `display_agent_card()` utility for human-readable AgentCard inspection ([#972](https://github.com/a2aproject/a2a-python/issues/972)) ([3468180](https://github.com/a2aproject/a2a-python/commit/3468180ac7396d453d99ce3e74cdd7f5a0afb5ab)) ### Bug Fixes * Don't generate empty metadata change events in VertexTaskStore ([#974](https://github.com/a2aproject/a2a-python/issues/974)) ([b58b03e](https://github.com/a2aproject/a2a-python/commit/b58b03ef58bd806db3accbe6dca8fc444a43bc18)), closes [#802](https://github.com/a2aproject/a2a-python/issues/802) * **extensions:** support both header names and remove "activation" concept ([#984](https://github.com/a2aproject/a2a-python/issues/984)) ([b8df210](https://github.com/a2aproject/a2a-python/commit/b8df210b00d0f249ca68f0d814191c4205e18b35)) ### Documentation * AgentExecutor interface documentation ([#976](https://github.com/a2aproject/a2a-python/issues/976)) ([d667e4f](https://github.com/a2aproject/a2a-python/commit/d667e4fa55e99225eb3c02e009b426a3bc2d449d)) * move `ai_learnings.md` to local-only and update `GEMINI.md` ([#982](https://github.com/a2aproject/a2a-python/issues/982)) ([f6610fa](https://github.com/a2aproject/a2a-python/commit/f6610fa35e1f5fbc3e7e6cd9e29a5177a538eb4e)) ### Code Refactoring * clean helpers and utils folders structure ([#983](https://github.com/a2aproject/a2a-python/issues/983)) ([c87e87c](https://github.com/a2aproject/a2a-python/commit/c87e87c76c004c73c9d6b9bd8cacfd4e590598e6)) * extract developer helpers in helpers folder ([#978](https://github.com/a2aproject/a2a-python/issues/978)) ([5f3ea29](https://github.com/a2aproject/a2a-python/commit/5f3ea292389cf72a25a7cf2792caceb4af45f6da)) --- This PR was generated with [Release Please](https://github.com/googleapis/release-please). See [documentation](https://github.com/googleapis/release-please#release-please). --- .release-please-manifest.json | 2 +- CHANGELOG.md | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/.release-please-manifest.json b/.release-please-manifest.json index 6415ed078..68a1b65c2 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1 +1 @@ -{".":"1.0.0-alpha.1"} +{".":"1.0.0-alpha.2"} diff --git a/CHANGELOG.md b/CHANGELOG.md index 7e4715609..7e3297eac 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,37 @@ # Changelog +## [1.0.0-alpha.2](https://github.com/a2aproject/a2a-python/compare/v1.0.0-alpha.1...v1.0.0-alpha.2) (2026-04-17) + + +### ⚠ BREAKING CHANGES + +* clean helpers and utils folders structure ([#983](https://github.com/a2aproject/a2a-python/issues/983)) +* Raise errors on invalid AgentExecutor behavior. ([#979](https://github.com/a2aproject/a2a-python/issues/979)) +* extract developer helpers in helpers folder ([#978](https://github.com/a2aproject/a2a-python/issues/978)) + +### Features + +* Raise errors on invalid AgentExecutor behavior. ([#979](https://github.com/a2aproject/a2a-python/issues/979)) ([f4a0bcd](https://github.com/a2aproject/a2a-python/commit/f4a0bcdf68107c95e6c0a5e6696e4a7d6e01a03f)) +* **utils:** add `display_agent_card()` utility for human-readable AgentCard inspection ([#972](https://github.com/a2aproject/a2a-python/issues/972)) ([3468180](https://github.com/a2aproject/a2a-python/commit/3468180ac7396d453d99ce3e74cdd7f5a0afb5ab)) + + +### Bug Fixes + +* Don't generate empty metadata change events in VertexTaskStore ([#974](https://github.com/a2aproject/a2a-python/issues/974)) ([b58b03e](https://github.com/a2aproject/a2a-python/commit/b58b03ef58bd806db3accbe6dca8fc444a43bc18)), closes [#802](https://github.com/a2aproject/a2a-python/issues/802) +* **extensions:** support both header names and remove "activation" concept ([#984](https://github.com/a2aproject/a2a-python/issues/984)) ([b8df210](https://github.com/a2aproject/a2a-python/commit/b8df210b00d0f249ca68f0d814191c4205e18b35)) + + +### Documentation + +* AgentExecutor interface documentation ([#976](https://github.com/a2aproject/a2a-python/issues/976)) ([d667e4f](https://github.com/a2aproject/a2a-python/commit/d667e4fa55e99225eb3c02e009b426a3bc2d449d)) +* move `ai_learnings.md` to local-only and update `GEMINI.md` ([#982](https://github.com/a2aproject/a2a-python/issues/982)) ([f6610fa](https://github.com/a2aproject/a2a-python/commit/f6610fa35e1f5fbc3e7e6cd9e29a5177a538eb4e)) + + +### Code Refactoring + +* clean helpers and utils folders structure ([#983](https://github.com/a2aproject/a2a-python/issues/983)) ([c87e87c](https://github.com/a2aproject/a2a-python/commit/c87e87c76c004c73c9d6b9bd8cacfd4e590598e6)) +* extract developer helpers in helpers folder ([#978](https://github.com/a2aproject/a2a-python/issues/978)) ([5f3ea29](https://github.com/a2aproject/a2a-python/commit/5f3ea292389cf72a25a7cf2792caceb4af45f6da)) + ## [1.0.0-alpha.1](https://github.com/a2aproject/a2a-python/compare/v1.0.0-alpha.0...v1.0.0-alpha.1) (2026-04-10) From e1d0e7a72e2b9633be0b76c952f6c2e6fe11e3e5 Mon Sep 17 00:00:00 2001 From: Ivan Shymko Date: Fri, 17 Apr 2026 17:45:02 +0200 Subject: [PATCH 40/67] fix: update `with_a2a_extensions` to append instead of overwriting (#985) Existing extensions are kept, enables better modularity of service parameters updates by (for instance) multiple interceptors. --- src/a2a/client/service_parameters.py | 22 +++++----- tests/client/test_service_parameters.py | 53 +++++++++++++++++++++++++ 2 files changed, 66 insertions(+), 9 deletions(-) create mode 100644 tests/client/test_service_parameters.py diff --git a/src/a2a/client/service_parameters.py b/src/a2a/client/service_parameters.py index cef250807..39fe79ce1 100644 --- a/src/a2a/client/service_parameters.py +++ b/src/a2a/client/service_parameters.py @@ -1,7 +1,10 @@ from collections.abc import Callable from typing import TypeAlias -from a2a.extensions.common import HTTP_EXTENSION_HEADER +from a2a.extensions.common import ( + HTTP_EXTENSION_HEADER, + get_requested_extensions, +) ServiceParameters: TypeAlias = dict[str, str] @@ -44,17 +47,18 @@ def create_from( def with_a2a_extensions(extensions: list[str]) -> ServiceParametersUpdate: - """Create a ServiceParametersUpdate that adds A2A extensions. + """Create a ServiceParametersUpdate that merges A2A extension URIs. - Args: - extensions: List of extension strings. - - Returns: - A function that updates ServiceParameters with the extensions header. + Unions the supplied URIs with any already present in the A2A-Extensions + parameter, deduplicating and emitting them in sorted order. Repeated + calls accumulate rather than overwrite. """ def update(parameters: ServiceParameters) -> None: - if extensions: - parameters[HTTP_EXTENSION_HEADER] = ','.join(extensions) + if not extensions: + return + existing = parameters.get(HTTP_EXTENSION_HEADER, '') + merged = sorted(get_requested_extensions([existing, *extensions])) + parameters[HTTP_EXTENSION_HEADER] = ','.join(merged) return update diff --git a/tests/client/test_service_parameters.py b/tests/client/test_service_parameters.py new file mode 100644 index 000000000..fbabd9719 --- /dev/null +++ b/tests/client/test_service_parameters.py @@ -0,0 +1,53 @@ +"""Tests for a2a.client.service_parameters module.""" + +from a2a.client.service_parameters import ( + ServiceParametersFactory, + with_a2a_extensions, +) +from a2a.extensions.common import HTTP_EXTENSION_HEADER + + +def test_with_a2a_extensions_merges_dedupes_and_sorts(): + """Repeated calls accumulate; duplicates collapse; output is sorted.""" + parameters = ServiceParametersFactory.create( + [ + with_a2a_extensions(['ext-c', 'ext-a']), + with_a2a_extensions(['ext-b', 'ext-a']), + ] + ) + + assert parameters[HTTP_EXTENSION_HEADER] == 'ext-a,ext-b,ext-c' + + +def test_with_a2a_extensions_merges_existing_header_value(): + """Pre-existing comma-separated header values are parsed and merged.""" + parameters = ServiceParametersFactory.create_from( + {HTTP_EXTENSION_HEADER: 'ext-a, ext-b'}, + [with_a2a_extensions(['ext-c'])], + ) + + assert parameters[HTTP_EXTENSION_HEADER] == 'ext-a,ext-b,ext-c' + + +def test_with_a2a_extensions_empty_is_noop(): + """An empty extensions list leaves the header untouched / absent.""" + parameters = ServiceParametersFactory.create( + [ + with_a2a_extensions(['ext-a']), + with_a2a_extensions([]), + ] + ) + + assert parameters[HTTP_EXTENSION_HEADER] == 'ext-a' + assert HTTP_EXTENSION_HEADER not in ServiceParametersFactory.create( + [with_a2a_extensions([])] + ) + + +def test_with_a2a_extensions_normalizes_input_strings(): + """Input strings are split on commas and stripped, like header values.""" + parameters = ServiceParametersFactory.create( + [with_a2a_extensions(['ext-a, ext-b', ' ext-c '])] + ) + + assert parameters[HTTP_EXTENSION_HEADER] == 'ext-a,ext-b,ext-c' From 25e2a7d620524a2325744b7a559662b6c6d24c48 Mon Sep 17 00:00:00 2001 From: "Agent2Agent (A2A) Bot" Date: Fri, 17 Apr 2026 10:47:32 -0500 Subject: [PATCH 41/67] chore(1.0-dev): release 1.0.0-alpha.3 (#986) :robot: I have created a release *beep* *boop* --- ## [1.0.0-alpha.3](https://github.com/a2aproject/a2a-python/compare/v1.0.0-alpha.2...v1.0.0-alpha.3) (2026-04-17) ### Bug Fixes * update `with_a2a_extensions` to append instead of overwriting ([#985](https://github.com/a2aproject/a2a-python/issues/985)) ([e1d0e7a](https://github.com/a2aproject/a2a-python/commit/e1d0e7a72e2b9633be0b76c952f6c2e6fe11e3e5)) --- This PR was generated with [Release Please](https://github.com/googleapis/release-please). See [documentation](https://github.com/googleapis/release-please#release-please). --- .release-please-manifest.json | 2 +- CHANGELOG.md | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/.release-please-manifest.json b/.release-please-manifest.json index 68a1b65c2..160cadc01 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1 +1 @@ -{".":"1.0.0-alpha.2"} +{".":"1.0.0-alpha.3"} diff --git a/CHANGELOG.md b/CHANGELOG.md index 7e3297eac..33ca3f9d2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +## [1.0.0-alpha.3](https://github.com/a2aproject/a2a-python/compare/v1.0.0-alpha.2...v1.0.0-alpha.3) (2026-04-17) + + +### Bug Fixes + +* update `with_a2a_extensions` to append instead of overwriting ([#985](https://github.com/a2aproject/a2a-python/issues/985)) ([e1d0e7a](https://github.com/a2aproject/a2a-python/commit/e1d0e7a72e2b9633be0b76c952f6c2e6fe11e3e5)) + ## [1.0.0-alpha.2](https://github.com/a2aproject/a2a-python/compare/v1.0.0-alpha.1...v1.0.0-alpha.2) (2026-04-17) From d77cd68f5e69b0ffccaca5e3deab4c1a397cfe9c Mon Sep 17 00:00:00 2001 From: Ivan Shymko Date: Mon, 20 Apr 2026 11:10:25 +0200 Subject: [PATCH 42/67] fix: rely on agent executor implementation for stream termination (#988) `active_task.py` already contains agent executor behavior validation, do not terminate the stream so that those errors can be raised, tests are updated to cover invalid behavior conditions. --- .../default_request_handler_v2.py | 17 +- .../test_default_request_handler_v2.py | 167 ++++++++++++++++++ 2 files changed, 180 insertions(+), 4 deletions(-) diff --git a/src/a2a/server/request_handlers/default_request_handler_v2.py b/src/a2a/server/request_handlers/default_request_handler_v2.py index c0c6b5445..ecdc0cfef 100644 --- a/src/a2a/server/request_handlers/default_request_handler_v2.py +++ b/src/a2a/server/request_handlers/default_request_handler_v2.py @@ -271,11 +271,17 @@ async def on_message_send( # noqa: D102 ): self._validate_task_id_match(task_id, event.id) result = event + # DO break here as it's "return_immediately". + # AgentExecutor will continue to run in the background. break if isinstance(event, Message): result = event - break + # Do NOT break here as Message is supposed to be the only + # event in "Message-only" interaction. + # ActiveTask consumer (see active_task.py) validates the event + # stream and raises InvalidAgentResponseError if more events are + # pushed after a Message. if result is None: logger.debug('Missing result for task %s', request_context.task_id) @@ -311,15 +317,18 @@ async def on_message_send_stream( # noqa: D102 request=request_context, include_initial_task=False, ): + # Do NOT break here as we rely on AgentExecutor to yield control. + # ActiveTask consumer (see active_task.py) validates the event + # stream and raises InvalidAgentResponseError on misbehaving agents: + # - an event after a Message + # - Message after entering task mode + # - an event after a terminal state if isinstance(event, Task): self._validate_task_id_match(task_id, event.id) yield apply_history_length(event, params.configuration) else: yield event - if isinstance(event, Message): - break - @validate_request_params @validate( lambda self: self._agent_card.capabilities.push_notifications, diff --git a/tests/server/request_handlers/test_default_request_handler_v2.py b/tests/server/request_handlers/test_default_request_handler_v2.py index fda1ab960..e35b8f720 100644 --- a/tests/server/request_handlers/test_default_request_handler_v2.py +++ b/tests/server/request_handlers/test_default_request_handler_v2.py @@ -28,6 +28,7 @@ ) from a2a.types import ( InternalError, + InvalidAgentResponseError, InvalidParamsError, TaskNotFoundError, PushNotificationNotSupportedError, @@ -1244,3 +1245,169 @@ async def test_on_message_send_with_push_notification(): push_store.set_info.assert_awaited_once_with( result.id, push_config, context ) + + +class MultipleMessagesAgentExecutor(AgentExecutor): + """Misbehaving agent that yields more than one Message.""" + + async def execute(self, context: RequestContext, event_queue: EventQueue): + await event_queue.enqueue_event( + new_text_message('first', role=Role.ROLE_AGENT) + ) + await event_queue.enqueue_event( + new_text_message('second', role=Role.ROLE_AGENT) + ) + + async def cancel(self, context: RequestContext, event_queue: EventQueue): + pass + + +class MessageAfterTaskEventAgentExecutor(AgentExecutor): + """Misbehaving agent that yields a task-mode event then a Message.""" + + async def execute(self, context: RequestContext, event_queue: EventQueue): + task = new_task_from_user_message(context.message) + await event_queue.enqueue_event(task) + updater = TaskUpdater(event_queue, task.id, task.context_id) + await updater.update_status(TaskState.TASK_STATE_WORKING) + await event_queue.enqueue_event( + new_text_message('stray message', role=Role.ROLE_AGENT) + ) + + async def cancel(self, context: RequestContext, event_queue: EventQueue): + pass + + +class TaskEventAfterMessageAgentExecutor(AgentExecutor): + """Misbehaving agent that yields a Message and then a task-mode event.""" + + async def execute(self, context: RequestContext, event_queue: EventQueue): + await event_queue.enqueue_event( + new_text_message('only message', role=Role.ROLE_AGENT) + ) + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=str(context.task_id or ''), + context_id=str(context.context_id or ''), + status=TaskStatus(state=TaskState.TASK_STATE_WORKING), + ) + ) + + async def cancel(self, context: RequestContext, event_queue: EventQueue): + pass + + +class EventAfterTerminalStateAgentExecutor(AgentExecutor): + """Misbehaving agent that yields an event after reaching a terminal state.""" + + async def execute(self, context: RequestContext, event_queue: EventQueue): + task = new_task_from_user_message(context.message) + await event_queue.enqueue_event(task) + updater = TaskUpdater(event_queue, task.id, task.context_id) + await updater.complete() + await event_queue.enqueue_event( + new_text_message('after terminal', role=Role.ROLE_AGENT) + ) + + async def cancel(self, context: RequestContext, event_queue: EventQueue): + pass + + +@pytest.mark.asyncio +@pytest.mark.timeout(1) +async def test_on_message_send_stream_rejects_multiple_messages(): + """Stream surfaces InvalidAgentResponseError when the agent yields a + second Message after the first one (see comment in on_message_send_stream).""" + request_handler = DefaultRequestHandlerV2( + agent_executor=MultipleMessagesAgentExecutor(), + task_store=InMemoryTaskStore(), + agent_card=create_default_agent_card(), + ) + params = SendMessageRequest( + message=Message( + role=Role.ROLE_USER, + message_id='msg_multi_stream', + parts=[Part(text='Hi')], + ) + ) + with pytest.raises(InvalidAgentResponseError, match='Multiple Message'): + async for _ in request_handler.on_message_send_stream( + params, create_server_call_context() + ): + pass + + +@pytest.mark.asyncio +@pytest.mark.timeout(1) +async def test_on_message_send_stream_rejects_message_after_task_event(): + """Stream surfaces InvalidAgentResponseError when the agent yields a + Message after entering task mode (see comment in on_message_send_stream).""" + request_handler = DefaultRequestHandlerV2( + agent_executor=MessageAfterTaskEventAgentExecutor(), + task_store=InMemoryTaskStore(), + agent_card=create_default_agent_card(), + ) + params = SendMessageRequest( + message=Message( + role=Role.ROLE_USER, + message_id='msg_after_task_stream', + parts=[Part(text='Hi')], + ) + ) + with pytest.raises( + InvalidAgentResponseError, match='Message object in task mode' + ): + async for _ in request_handler.on_message_send_stream( + params, create_server_call_context() + ): + pass + + +@pytest.mark.asyncio +@pytest.mark.timeout(1) +async def test_on_message_send_stream_rejects_task_event_after_message(): + """Stream surfaces InvalidAgentResponseError when the agent yields a + task-mode event after a Message (see comment in on_message_send_stream).""" + request_handler = DefaultRequestHandlerV2( + agent_executor=TaskEventAfterMessageAgentExecutor(), + task_store=InMemoryTaskStore(), + agent_card=create_default_agent_card(), + ) + params = SendMessageRequest( + message=Message( + role=Role.ROLE_USER, + message_id='msg_then_task_stream', + parts=[Part(text='Hi')], + ) + ) + with pytest.raises(InvalidAgentResponseError, match='in message mode'): + async for _ in request_handler.on_message_send_stream( + params, create_server_call_context() + ): + pass + + +@pytest.mark.asyncio +@pytest.mark.timeout(1) +async def test_on_message_send_stream_rejects_event_after_terminal_state(): + """Stream surfaces InvalidAgentResponseError when the agent yields an event + after reaching a terminal state (see comment in on_message_send_stream).""" + request_handler = DefaultRequestHandlerV2( + agent_executor=EventAfterTerminalStateAgentExecutor(), + task_store=InMemoryTaskStore(), + agent_card=create_default_agent_card(), + ) + params = SendMessageRequest( + message=Message( + role=Role.ROLE_USER, + message_id='msg_after_terminal_stream', + parts=[Part(text='Hi')], + ) + ) + with pytest.raises( + InvalidAgentResponseError, match='Message object in task mode' + ): + async for _ in request_handler.on_message_send_stream( + params, create_server_call_context() + ): + pass From 6d0080cccedf6a76dd4c6e898f34fc3a4f89e3ef Mon Sep 17 00:00:00 2001 From: Ivan Shymko Date: Mon, 20 Apr 2026 11:48:34 +0200 Subject: [PATCH 43/67] test: add E2E smoke test for the sample (#991) 1. Fix gRPC setup. 1. Add E2E test with subprocess. --- samples/cli.py | 4 +- samples/hello_world_agent.py | 16 ++- tests/integration/test_samples_smoke.py | 134 ++++++++++++++++++++++++ 3 files changed, 152 insertions(+), 2 deletions(-) create mode 100644 tests/integration/test_samples_smoke.py diff --git a/samples/cli.py b/samples/cli.py index 935834dd3..beff26aa9 100644 --- a/samples/cli.py +++ b/samples/cli.py @@ -73,7 +73,9 @@ async def main() -> None: ) args = parser.parse_args() - config = ClientConfig() + config = ClientConfig( + grpc_channel_factory=grpc.aio.insecure_channel, + ) if args.transport: config.supported_protocol_bindings = [args.transport] diff --git a/samples/hello_world_agent.py b/samples/hello_world_agent.py index 4c9e6f18a..a6e589ac0 100644 --- a/samples/hello_world_agent.py +++ b/samples/hello_world_agent.py @@ -1,3 +1,4 @@ +import argparse import asyncio import contextlib import logging @@ -257,5 +258,18 @@ async def serve( if __name__ == '__main__': logging.basicConfig(level=logging.INFO) + parser = argparse.ArgumentParser(description='Sample A2A agent server') + parser.add_argument('--host', default='127.0.0.1') + parser.add_argument('--port', type=int, default=41241) + parser.add_argument('--grpc-port', type=int, default=50051) + parser.add_argument('--compat-grpc-port', type=int, default=50052) + args = parser.parse_args() with contextlib.suppress(KeyboardInterrupt): - asyncio.run(serve()) + asyncio.run( + serve( + host=args.host, + port=args.port, + grpc_port=args.grpc_port, + compat_grpc_port=args.compat_grpc_port, + ) + ) diff --git a/tests/integration/test_samples_smoke.py b/tests/integration/test_samples_smoke.py new file mode 100644 index 000000000..fcb49a003 --- /dev/null +++ b/tests/integration/test_samples_smoke.py @@ -0,0 +1,134 @@ +"""End-to-end smoke test for `samples/hello_world_agent.py` and `samples/cli.py`. + +Boots the sample agent as a subprocess on free ports, then runs the sample CLI +against it once per supported transport, asserting the expected greeting reply +flows through. +""" + +from __future__ import annotations + +import asyncio +import socket +import sys + +from pathlib import Path +from typing import TYPE_CHECKING + +import httpx +import pytest +import pytest_asyncio + + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator + + +REPO_ROOT = Path(__file__).resolve().parents[2] +SAMPLES_DIR = REPO_ROOT / 'samples' +AGENT_SCRIPT = SAMPLES_DIR / 'hello_world_agent.py' +CLI_SCRIPT = SAMPLES_DIR / 'cli.py' + +STARTUP_TIMEOUT_S = 30.0 +CLI_TIMEOUT_S = 30.0 +EXPECTED_REPLY = 'Hello World! Nice to meet you!' + + +def _free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(('127.0.0.1', 0)) + return sock.getsockname()[1] + + +async def _wait_for_agent_card(url: str) -> None: + deadline = asyncio.get_running_loop().time() + STARTUP_TIMEOUT_S + async with httpx.AsyncClient(timeout=2.0) as client: + while asyncio.get_running_loop().time() < deadline: + try: + response = await client.get(url) + if response.status_code == 200: + return + except httpx.RequestError: + pass + await asyncio.sleep(0.2) + raise TimeoutError(f'Agent did not become ready at {url}') + + +@pytest_asyncio.fixture +async def running_sample_agent() -> AsyncGenerator[str, None]: + """Start `hello_world_agent.py` as a subprocess on free ports.""" + host = '127.0.0.1' + http_port = _free_port() + grpc_port = _free_port() + compat_grpc_port = _free_port() + base_url = f'http://{host}:{http_port}' + + proc = await asyncio.create_subprocess_exec( + sys.executable, + str(AGENT_SCRIPT), + '--host', + host, + '--port', + str(http_port), + '--grpc-port', + str(grpc_port), + '--compat-grpc-port', + str(compat_grpc_port), + cwd=str(REPO_ROOT), + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.STDOUT, + ) + + try: + await _wait_for_agent_card(f'{base_url}/.well-known/agent-card.json') + yield base_url + finally: + if proc.returncode is None: + proc.terminate() + try: + await asyncio.wait_for(proc.wait(), timeout=10.0) + except asyncio.TimeoutError: + proc.kill() + await proc.wait() + + +async def _run_cli(base_url: str, transport: str) -> str: + """Run `cli.py --transport `, send `hello`, return combined output.""" + proc = await asyncio.create_subprocess_exec( + sys.executable, + str(CLI_SCRIPT), + '--url', + base_url, + '--transport', + transport, + cwd=str(REPO_ROOT), + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.STDOUT, + ) + try: + stdout, _ = await asyncio.wait_for( + proc.communicate(b'hello\n/quit\n'), + timeout=CLI_TIMEOUT_S, + ) + except asyncio.TimeoutError: + proc.kill() + await proc.wait() + raise + output = stdout.decode('utf-8', errors='replace') + assert proc.returncode == 0, ( + f'CLI exited with {proc.returncode} for transport {transport!r}.\n' + f'Output:\n{output}' + ) + return output + + +@pytest.mark.asyncio +@pytest.mark.parametrize('transport', ['JSONRPC', 'HTTP+JSON', 'GRPC']) +async def test_cli_against_sample_agent( + running_sample_agent: str, transport: str +) -> None: + """The CLI should successfully exchange a greeting over each transport.""" + output = await _run_cli(running_sample_agent, transport) + + assert 'TASK_STATE_COMPLETED' in output, output + assert EXPECTED_REPLY in output, output From 1b6c3f1ae886f52c130d1e6803dd67892afb6d65 Mon Sep 17 00:00:00 2001 From: Ivan Shymko Date: Mon, 20 Apr 2026 13:09:26 +0200 Subject: [PATCH 44/67] ci: prepare release please for 1.0 (#995) - Remove custom config. - Remove 1.0-dev branch from triggers. --- .github/workflows/release-please.yml | 4 ---- .release-please-manifest.json | 1 - release-please-config.json | 9 --------- 3 files changed, 14 deletions(-) delete mode 100644 .release-please-manifest.json delete mode 100644 release-please-config.json diff --git a/.github/workflows/release-please.yml b/.github/workflows/release-please.yml index 98ac3bf2d..8f4e5102f 100644 --- a/.github/workflows/release-please.yml +++ b/.github/workflows/release-please.yml @@ -2,7 +2,6 @@ on: push: branches: - main - - 1.0-dev permissions: contents: write @@ -17,6 +16,3 @@ jobs: - uses: googleapis/release-please-action@16a9c90856f42705d54a6fda1823352bdc62cf38 # v4 with: token: ${{ secrets.A2A_BOT_PAT }} - target-branch: ${{ github.ref_name }} - config-file: release-please-config.json - manifest-file: .release-please-manifest.json diff --git a/.release-please-manifest.json b/.release-please-manifest.json deleted file mode 100644 index 160cadc01..000000000 --- a/.release-please-manifest.json +++ /dev/null @@ -1 +0,0 @@ -{".":"1.0.0-alpha.3"} diff --git a/release-please-config.json b/release-please-config.json deleted file mode 100644 index 063b8435a..000000000 --- a/release-please-config.json +++ /dev/null @@ -1,9 +0,0 @@ -{ - "release-type": "python", - "prerelease": true, - "prerelease-type": "alpha", - "versioning": "prerelease", - "packages": { - ".": {} - } -} From 367536b2a4c911c8223d14fa27cd8c703aecfdf7 Mon Sep 17 00:00:00 2001 From: Ivan Shymko Date: Mon, 20 Apr 2026 13:13:21 +0200 Subject: [PATCH 45/67] ci: fix release please (#996) Set `release-type` as we removed configs in #995. --- .github/workflows/release-please.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/release-please.yml b/.github/workflows/release-please.yml index 8f4e5102f..1668691e8 100644 --- a/.github/workflows/release-please.yml +++ b/.github/workflows/release-please.yml @@ -16,3 +16,4 @@ jobs: - uses: googleapis/release-please-action@16a9c90856f42705d54a6fda1823352bdc62cf38 # v4 with: token: ${{ secrets.A2A_BOT_PAT }} + release-type: python From 530ec37f4c4580095c2411e40740ca0186fd1240 Mon Sep 17 00:00:00 2001 From: Ivan Shymko Date: Mon, 20 Apr 2026 11:18:04 +0000 Subject: [PATCH 46/67] chore: release 1.0.0 Release-As: 1.0.0 From 7fce2ada1eb331e230925993758e8c7663da9a13 Mon Sep 17 00:00:00 2001 From: Ivan Shymko Date: Mon, 20 Apr 2026 16:03:46 +0200 Subject: [PATCH 47/67] chore!: remove Vertex AI Task Store integration (#999) It can not be used at the moment, will be readded after 1.0 release. --- README.md | 1 - pyproject.toml | 2 - src/a2a/contrib/__init__.py | 0 src/a2a/contrib/tasks/__init__.py | 0 .../contrib/tasks/vertex_task_converter.py | 312 --------- src/a2a/contrib/tasks/vertex_task_store.py | 263 ------- tests/contrib/__init__.py | 0 tests/contrib/tasks/__init__.py | 0 tests/contrib/tasks/fake_vertex_client.py | 143 ---- tests/contrib/tasks/run_vertex_tests.sh | 17 - .../tasks/test_vertex_task_converter.py | 485 ------------- tests/contrib/tasks/test_vertex_task_store.py | 654 ------------------ uv.lock | 310 +-------- 13 files changed, 1 insertion(+), 2186 deletions(-) delete mode 100644 src/a2a/contrib/__init__.py delete mode 100644 src/a2a/contrib/tasks/__init__.py delete mode 100644 src/a2a/contrib/tasks/vertex_task_converter.py delete mode 100644 src/a2a/contrib/tasks/vertex_task_store.py delete mode 100644 tests/contrib/__init__.py delete mode 100644 tests/contrib/tasks/__init__.py delete mode 100644 tests/contrib/tasks/fake_vertex_client.py delete mode 100755 tests/contrib/tasks/run_vertex_tests.sh delete mode 100644 tests/contrib/tasks/test_vertex_task_converter.py delete mode 100644 tests/contrib/tasks/test_vertex_task_store.py diff --git a/README.md b/README.md index 8ac1cfef4..b7a60fe3b 100644 --- a/README.md +++ b/README.md @@ -68,7 +68,6 @@ Install the core SDK and any desired extras using your preferred package manager | **gRPC Support** | `uv add "a2a-sdk[grpc]"` | `pip install "a2a-sdk[grpc]"` | | **OpenTelemetry Tracing**| `uv add "a2a-sdk[telemetry]"` | `pip install "a2a-sdk[telemetry]"` | | **Encryption** | `uv add "a2a-sdk[encryption]"` | `pip install "a2a-sdk[encryption]"` | -| **Vertex AI Task Store** | `uv add "a2a-sdk[vertex]"` | `pip install "a2a-sdk[vertex]"` | | | | | | **Database Drivers** | | | | **PostgreSQL** | `uv add "a2a-sdk[postgresql]"` | `pip install "a2a-sdk[postgresql]"` | diff --git a/pyproject.toml b/pyproject.toml index 724749865..abaa9f1ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,7 +43,6 @@ mysql = ["sqlalchemy[asyncio,aiomysql]>=2.0.0"] signing = ["PyJWT>=2.0.0"] sqlite = ["sqlalchemy[asyncio,aiosqlite]>=2.0.0"] db-cli = ["alembic>=1.14.0"] -vertex = ["google-cloud-aiplatform>=1.140.0"] sql = ["a2a-sdk[postgresql,mysql,sqlite]"] @@ -55,7 +54,6 @@ all = [ "a2a-sdk[telemetry]", "a2a-sdk[signing]", "a2a-sdk[db-cli]", - "a2a-sdk[vertex]", ] [project.urls] diff --git a/src/a2a/contrib/__init__.py b/src/a2a/contrib/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/a2a/contrib/tasks/__init__.py b/src/a2a/contrib/tasks/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/a2a/contrib/tasks/vertex_task_converter.py b/src/a2a/contrib/tasks/vertex_task_converter.py deleted file mode 100644 index 9441d2153..000000000 --- a/src/a2a/contrib/tasks/vertex_task_converter.py +++ /dev/null @@ -1,312 +0,0 @@ -try: - from google.genai import types as genai_types - from vertexai import types as vertexai_types -except ImportError as e: - raise ImportError( - 'vertex_task_converter requires vertexai. ' - 'Install with: ' - "'pip install a2a-sdk[vertex]'" - ) from e - -import base64 -import json - -from dataclasses import dataclass -from typing import Any - -from a2a.compat.v0_3.types import ( - Artifact, - DataPart, - FilePart, - FileWithBytes, - FileWithUri, - Message, - Part, - Role, - Task, - TaskState, - TaskStatus, - TextPart, -) - - -_ORIGINAL_METADATA_KEY = 'originalMetadata' -_EXTENSIONS_KEY = 'extensions' -_REFERENCE_TASK_IDS_KEY = 'referenceTaskIds' -_PART_METADATA_KEY = 'partMetadata' -_METADATA_VERSION_KEY = '__vertex_compat_v' -_METADATA_VERSION_NUMBER = 1.0 - -_DATA_PART_MIME_TYPE = 'application/x-a2a-datapart' - - -_TO_SDK_TASK_STATE = { - vertexai_types.A2aTaskState.STATE_UNSPECIFIED: TaskState.unknown, - vertexai_types.A2aTaskState.SUBMITTED: TaskState.submitted, - vertexai_types.A2aTaskState.WORKING: TaskState.working, - vertexai_types.A2aTaskState.COMPLETED: TaskState.completed, - vertexai_types.A2aTaskState.CANCELLED: TaskState.canceled, - vertexai_types.A2aTaskState.FAILED: TaskState.failed, - vertexai_types.A2aTaskState.REJECTED: TaskState.rejected, - vertexai_types.A2aTaskState.INPUT_REQUIRED: TaskState.input_required, - vertexai_types.A2aTaskState.AUTH_REQUIRED: TaskState.auth_required, -} - -_SDK_TO_STORED_TASK_STATE = {v: k for k, v in _TO_SDK_TASK_STATE.items()} - - -def to_sdk_task_state(stored_state: vertexai_types.A2aTaskState) -> TaskState: - """Converts a proto A2aTask.State to a TaskState enum.""" - return _TO_SDK_TASK_STATE.get(stored_state, TaskState.unknown) - - -def to_stored_task_state(task_state: TaskState) -> vertexai_types.A2aTaskState: - """Converts a TaskState enum to a proto A2aTask.State enum value.""" - return _SDK_TO_STORED_TASK_STATE.get( - task_state, vertexai_types.A2aTaskState.STATE_UNSPECIFIED - ) - - -def to_stored_metadata( - original_metadata: dict[str, Any] | None, - extensions: list[str] | None, - reference_task_ids: list[str] | None, - parts: list[Part], -) -> dict[str, Any]: - """Packs original metadata, extensions, and part types/metadata into a storage dictionary.""" - metadata: dict[str, Any] = {_METADATA_VERSION_KEY: _METADATA_VERSION_NUMBER} - if original_metadata: - metadata[_ORIGINAL_METADATA_KEY] = original_metadata - if extensions: - metadata[_EXTENSIONS_KEY] = extensions - if reference_task_ids: - metadata[_REFERENCE_TASK_IDS_KEY] = reference_task_ids - - metadata[_PART_METADATA_KEY] = [part.root.metadata for part in parts] - - return metadata - - -@dataclass -class _UnpackedMetadata: - original_metadata: dict[str, Any] | None = None - extensions: list[str] | None = None - reference_task_ids: list[str] | None = None - part_metadata: list[dict[str, Any] | None] | None = None - - -def to_sdk_metadata( - stored_metadata: dict[str, Any] | None, -) -> _UnpackedMetadata: - """Unpacks metadata, extensions, and part types/metadata from a storage dictionary.""" - if not stored_metadata: - return _UnpackedMetadata() - - version = stored_metadata.get(_METADATA_VERSION_KEY) - if version is None: - return _UnpackedMetadata(original_metadata=stored_metadata) - if version > _METADATA_VERSION_NUMBER: - raise ValueError(f'Unsupported metadata version: {version}') - - return _UnpackedMetadata( - original_metadata=stored_metadata.get(_ORIGINAL_METADATA_KEY), - extensions=stored_metadata.get(_EXTENSIONS_KEY), - reference_task_ids=stored_metadata.get(_REFERENCE_TASK_IDS_KEY), - part_metadata=stored_metadata.get(_PART_METADATA_KEY), - ) - - -def to_stored_part(part: Part) -> genai_types.Part: - """Converts a SDK Part to a proto Part.""" - if isinstance(part.root, TextPart): - return genai_types.Part(text=part.root.text) - if isinstance(part.root, DataPart): - data_bytes = json.dumps(part.root.data).encode('utf-8') - return genai_types.Part( - inline_data=genai_types.Blob( - mime_type=_DATA_PART_MIME_TYPE, data=data_bytes - ) - ) - if isinstance(part.root, FilePart): - file_content = part.root.file - if isinstance(file_content, FileWithBytes): - decoded_bytes = base64.b64decode(file_content.bytes) - return genai_types.Part( - inline_data=genai_types.Blob( - mime_type=file_content.mime_type or '', data=decoded_bytes - ) - ) - if isinstance(file_content, FileWithUri): - return genai_types.Part( - file_data=genai_types.FileData( - mime_type=file_content.mime_type or '', - file_uri=file_content.uri, - ) - ) - raise ValueError(f'Unsupported part type: {type(part.root)}') - - -def to_sdk_part( - stored_part: genai_types.Part, - part_metadata: dict[str, Any] | None = None, -) -> Part: - """Converts a proto Part to a SDK Part.""" - if stored_part.text: - return Part( - root=TextPart(text=stored_part.text, metadata=part_metadata) - ) - if stored_part.inline_data: - mime_type = stored_part.inline_data.mime_type - if mime_type == _DATA_PART_MIME_TYPE: - data_dict = json.loads(stored_part.inline_data.data or b'{}') - return Part(root=DataPart(data=data_dict, metadata=part_metadata)) - - encoded_bytes = base64.b64encode( - stored_part.inline_data.data or b'' - ).decode('utf-8') - return Part( - root=FilePart( - file=FileWithBytes( - mime_type=mime_type, - bytes=encoded_bytes, - ), - metadata=part_metadata, - ) - ) - if stored_part.file_data and stored_part.file_data.file_uri: - return Part( - root=FilePart( - file=FileWithUri( - mime_type=stored_part.file_data.mime_type, - uri=stored_part.file_data.file_uri or '', - ), - metadata=part_metadata, - ) - ) - - raise ValueError(f'Unsupported part: {stored_part}') - - -def to_stored_artifact(artifact: Artifact) -> vertexai_types.TaskArtifact: - """Converts a SDK Artifact to a proto TaskArtifact.""" - return vertexai_types.TaskArtifact( - artifact_id=artifact.artifact_id, - display_name=artifact.name, - description=artifact.description, - parts=[to_stored_part(part) for part in artifact.parts], - metadata=to_stored_metadata( - original_metadata=artifact.metadata, - extensions=artifact.extensions, - reference_task_ids=None, - parts=artifact.parts, - ), - ) - - -def to_sdk_artifact(stored_artifact: vertexai_types.TaskArtifact) -> Artifact: - """Converts a proto TaskArtifact to a SDK Artifact.""" - unpacked_meta = to_sdk_metadata(stored_artifact.metadata) - part_metadata_list = unpacked_meta.part_metadata or [] - - parts = [] - for i, part in enumerate(stored_artifact.parts or []): - meta: dict[str, Any] | None = None - if i < len(part_metadata_list): - meta = part_metadata_list[i] - parts.append(to_sdk_part(part, part_metadata=meta)) - - return Artifact( - artifact_id=stored_artifact.artifact_id, - name=stored_artifact.display_name, - description=stored_artifact.description, - extensions=unpacked_meta.extensions, - metadata=unpacked_meta.original_metadata, - parts=parts, - ) - - -def to_stored_message( - message: Message | None, -) -> vertexai_types.TaskMessage | None: - """Converts a SDK Message to a proto Message.""" - if not message: - return None - role = message.role.value if message.role else '' - return vertexai_types.TaskMessage( - message_id=message.message_id, - role=role, - parts=[to_stored_part(part) for part in message.parts], - metadata=to_stored_metadata( - original_metadata=message.metadata, - extensions=message.extensions, - reference_task_ids=message.reference_task_ids, - parts=message.parts, - ), - ) - - -def to_sdk_message( - stored_msg: vertexai_types.TaskMessage | None, -) -> Message | None: - """Converts a proto Message to a SDK Message.""" - if not stored_msg: - return None - unpacked_meta = to_sdk_metadata(stored_msg.metadata) - part_metadata_list = unpacked_meta.part_metadata or [] - - parts = [] - for i, part in enumerate(stored_msg.parts or []): - part_metadata: dict[str, Any] | None = None - if i < len(part_metadata_list): - part_metadata = part_metadata_list[i] - parts.append(to_sdk_part(part, part_metadata=part_metadata)) - - return Message( - message_id=stored_msg.message_id, - role=Role(stored_msg.role), - extensions=unpacked_meta.extensions, - reference_task_ids=unpacked_meta.reference_task_ids, - metadata=unpacked_meta.original_metadata, - parts=parts, - ) - - -def to_stored_task(task: Task) -> vertexai_types.A2aTask: - """Converts a SDK Task to a proto A2aTask.""" - return vertexai_types.A2aTask( - context_id=task.context_id, - metadata=task.metadata, - state=to_stored_task_state(task.status.state), - status_details=vertexai_types.TaskStatusDetails( - task_message=to_stored_message(task.status.message) - ) - if task.status.message - else None, - output=vertexai_types.TaskOutput( - artifacts=[ - to_stored_artifact(artifact) - for artifact in task.artifacts or [] - ] - ), - ) - - -def to_sdk_task(a2a_task: vertexai_types.A2aTask) -> Task: - """Converts a proto A2aTask to a SDK Task.""" - msg: Message | None = None - if a2a_task.status_details and a2a_task.status_details.task_message: - msg = to_sdk_message(a2a_task.status_details.task_message) - - return Task( - id=a2a_task.name.split('/')[-1], - context_id=a2a_task.context_id, - status=TaskStatus(state=to_sdk_task_state(a2a_task.state), message=msg), - metadata=a2a_task.metadata or {}, - artifacts=[ - to_sdk_artifact(artifact) - for artifact in a2a_task.output.artifacts or [] - ] - if a2a_task.output - else [], - history=[], - ) diff --git a/src/a2a/contrib/tasks/vertex_task_store.py b/src/a2a/contrib/tasks/vertex_task_store.py deleted file mode 100644 index 602d5c6fd..000000000 --- a/src/a2a/contrib/tasks/vertex_task_store.py +++ /dev/null @@ -1,263 +0,0 @@ -import logging - - -try: - import vertexai - - from google.genai import errors as genai_errors - from vertexai import types as vertexai_types -except ImportError as e: - raise ImportError( - 'VertexTaskStore requires vertexai. ' - 'Install with: ' - "'pip install a2a-sdk[vertex]'" - ) from e - -from a2a.compat.v0_3.conversions import to_compat_task, to_core_task -from a2a.compat.v0_3.types import Task as CompatTask -from a2a.contrib.tasks import vertex_task_converter -from a2a.server.context import ServerCallContext -from a2a.server.tasks.task_store import TaskStore -from a2a.types.a2a_pb2 import ListTasksRequest, ListTasksResponse, Task - - -logger = logging.getLogger(__name__) - - -class VertexTaskStore(TaskStore): - """Implementation of TaskStore using Vertex AI Agent Engine Task Store. - - Stores task objects in Vertex AI Agent Engine Task Store. - """ - - def __init__( - self, - client: vertexai.Client, # type: ignore - agent_engine_resource_id: str, - ) -> None: - """Initializes the VertexTaskStore. - - Args: - client: The Vertex AI client. - agent_engine_resource_id: The resource ID of the agent engine. - """ - self._client = client - self._agent_engine_resource_id = agent_engine_resource_id - - async def save(self, task: Task, context: ServerCallContext) -> None: - """Saves or updates a task in the store.""" - compat_task = to_compat_task(task) - previous_task = await self._get_stored_task(compat_task.id) - if previous_task is None: - await self._create(compat_task) - else: - await self._update(previous_task, compat_task) - - async def _create(self, sdk_task: CompatTask) -> None: - stored_task = vertex_task_converter.to_stored_task(sdk_task) - await self._client.aio.agent_engines.a2a_tasks.create( - name=self._agent_engine_resource_id, - a2a_task_id=sdk_task.id, - config=vertexai_types.CreateAgentEngineTaskConfig( - context_id=stored_task.context_id, - metadata=stored_task.metadata, - output=stored_task.output, - ), - ) - - def _get_status_change_event( - self, - previous_task: CompatTask, - task: CompatTask, - event_sequence_number: int, - ) -> vertexai_types.TaskEvent | None: - if task.status.state != previous_task.status.state: - return vertexai_types.TaskEvent( - event_data=vertexai_types.TaskEventData( - state_change=vertexai_types.TaskStateChange( - new_state=vertex_task_converter.to_stored_task_state( - task.status.state - ), - ), - ), - event_sequence_number=event_sequence_number, - ) - return None - - def _get_status_details_change_event( - self, - previous_task: CompatTask, - task: CompatTask, - event_sequence_number: int, - ) -> vertexai_types.TaskEvent | None: - if task.status.message != previous_task.status.message: - status_details = ( - vertexai_types.TaskStatusDetails( - task_message=vertex_task_converter.to_stored_message( - task.status.message - ) - ) - if task.status.message - else vertexai_types.TaskStatusDetails() - ) - return vertexai_types.TaskEvent( - event_data=vertexai_types.TaskEventData( - status_details_change=vertexai_types.TaskStatusDetailsChange( - new_task_status=status_details, - ), - ), - event_sequence_number=event_sequence_number, - ) - return None - - def _get_metadata_change_event( - self, - previous_task: CompatTask, - task: CompatTask, - event_sequence_number: int, - ) -> vertexai_types.TaskEvent | None: - # We generate metadata change events if the metadata was changed. - # We don't generate events if the metadata was changed from - # one empty value to another, e.g. {} to None. - if task.metadata != previous_task.metadata and ( - task.metadata or previous_task.metadata - ): - return vertexai_types.TaskEvent( - event_data=vertexai_types.TaskEventData( - metadata_change=vertexai_types.TaskMetadataChange( - new_metadata=task.metadata, - ) - ), - event_sequence_number=event_sequence_number, - ) - return None - - def _get_artifacts_change_event( - self, - previous_task: CompatTask, - task: CompatTask, - event_sequence_number: int, - ) -> vertexai_types.TaskEvent | None: - if task.artifacts != previous_task.artifacts: - task_artifact_change = vertexai_types.TaskArtifactChange() - event = vertexai_types.TaskEvent( - event_data=vertexai_types.TaskEventData( - output_change=vertexai_types.TaskOutputChange( - task_artifact_change=task_artifact_change - ) - ), - event_sequence_number=event_sequence_number, - ) - task_artifacts = ( - {artifact.artifact_id: artifact for artifact in task.artifacts} - if task.artifacts - else {} - ) - previous_task_artifacts = ( - { - artifact.artifact_id: artifact - for artifact in previous_task.artifacts - } - if previous_task.artifacts - else {} - ) - for artifact in previous_task_artifacts.values(): - if artifact.artifact_id not in task_artifacts: - if not task_artifact_change.deleted_artifact_ids: - task_artifact_change.deleted_artifact_ids = [] - task_artifact_change.deleted_artifact_ids.append( - artifact.artifact_id - ) - for artifact in task_artifacts.values(): - if artifact.artifact_id not in previous_task_artifacts: - if not task_artifact_change.added_artifacts: - task_artifact_change.added_artifacts = [] - task_artifact_change.added_artifacts.append( - vertex_task_converter.to_stored_artifact(artifact) - ) - elif artifact != previous_task_artifacts[artifact.artifact_id]: - if not task_artifact_change.updated_artifacts: - task_artifact_change.updated_artifacts = [] - task_artifact_change.updated_artifacts.append( - vertex_task_converter.to_stored_artifact(artifact) - ) - if task_artifact_change != vertexai_types.TaskArtifactChange(): - return event - return None - - async def _update( - self, previous_stored_task: vertexai_types.A2aTask, task: CompatTask - ) -> None: - previous_task = vertex_task_converter.to_sdk_task(previous_stored_task) - events = [] - event_sequence_number = previous_stored_task.next_event_sequence_number - - status_event = self._get_status_change_event( - previous_task, task, event_sequence_number - ) - if status_event: - events.append(status_event) - event_sequence_number += 1 - - status_details_event = self._get_status_details_change_event( - previous_task, task, event_sequence_number - ) - if status_details_event: - events.append(status_details_event) - event_sequence_number += 1 - - metadata_event = self._get_metadata_change_event( - previous_task, task, event_sequence_number - ) - if metadata_event: - events.append(metadata_event) - event_sequence_number += 1 - - artifacts_event = self._get_artifacts_change_event( - previous_task, task, event_sequence_number - ) - if artifacts_event: - events.append(artifacts_event) - event_sequence_number += 1 - - if not events: - return - await self._client.aio.agent_engines.a2a_tasks.events.append( - name=self._agent_engine_resource_id + '/a2aTasks/' + task.id, - task_events=events, - ) - - async def _get_stored_task( - self, task_id: str - ) -> vertexai_types.A2aTask | None: - try: - a2a_task = await self._client.aio.agent_engines.a2a_tasks.get( - name=self._agent_engine_resource_id + '/a2aTasks/' + task_id, - ) - except genai_errors.APIError as e: - if e.status == 'NOT_FOUND': - logger.debug('Task %s not found in store.', task_id) - return None - raise - return a2a_task - - async def get( - self, task_id: str, context: ServerCallContext - ) -> Task | None: - """Retrieves a task from the database by ID.""" - a2a_task = await self._get_stored_task(task_id) - if a2a_task is None: - return None - return to_core_task(vertex_task_converter.to_sdk_task(a2a_task)) - - async def list( - self, - params: ListTasksRequest, - context: ServerCallContext, - ) -> ListTasksResponse: - """Retrieves a list of tasks from the store.""" - raise NotImplementedError - - async def delete(self, task_id: str, context: ServerCallContext) -> None: - """The backend doesn't support deleting tasks, so this is not implemented.""" - raise NotImplementedError diff --git a/tests/contrib/__init__.py b/tests/contrib/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/contrib/tasks/__init__.py b/tests/contrib/tasks/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/contrib/tasks/fake_vertex_client.py b/tests/contrib/tasks/fake_vertex_client.py deleted file mode 100644 index 8a4a86903..000000000 --- a/tests/contrib/tasks/fake_vertex_client.py +++ /dev/null @@ -1,143 +0,0 @@ -"""Fake Vertex AI Client implementations for testing.""" - -import copy - -from google.genai import errors as genai_errors -from vertexai import types as vertexai_types - - -class FakeAgentEnginesA2aTasksEventsClient: - def __init__(self, parent_client): - self.parent_client = parent_client - - async def append( - self, name: str, task_events: list[vertexai_types.TaskEvent] - ) -> None: - task = self.parent_client.tasks.get(name) - if not task: - raise genai_errors.APIError( - code=404, - response_json={ - 'error': { - 'status': 'NOT_FOUND', - 'message': 'Task not found', - } - }, - ) - - task = copy.deepcopy(task) - if ( - not hasattr(task, 'next_event_sequence_number') - or not task.next_event_sequence_number - ): - task.next_event_sequence_number = 0 - - for event in task_events: - data = event.event_data - if getattr(data, 'state_change', None): - task.state = getattr(data.state_change, 'new_state', task.state) - if getattr(data, 'status_details_change', None): - task.status_details = getattr( - data.status_details_change, - 'new_task_status', - getattr(task, 'status_details', None), - ) - if getattr(data, 'metadata_change', None): - task.metadata = getattr( - data.metadata_change, 'new_metadata', task.metadata - ) - if getattr(data, 'output_change', None): - change = getattr( - data.output_change, 'task_artifact_change', None - ) - if not change: - continue - if not getattr(task, 'output', None): - task.output = vertexai_types.TaskOutput() - - current_artifacts = ( - list(task.output.artifacts) - if getattr(task.output, 'artifacts', None) - else [] - ) - - deleted_ids = getattr(change, 'deleted_artifact_ids', []) or [] - if deleted_ids: - current_artifacts = [ - a - for a in current_artifacts - if a.artifact_id not in deleted_ids - ] - - added = getattr(change, 'added_artifacts', []) or [] - if added: - current_artifacts.extend(added) - - updated = getattr(change, 'updated_artifacts', []) or [] - if updated: - updated_map = {a.artifact_id: a for a in updated} - current_artifacts = [ - updated_map.get(a.artifact_id, a) - for a in current_artifacts - ] - - try: - del task.output.artifacts[:] - task.output.artifacts.extend(current_artifacts) - except Exception: - task.output.artifacts = current_artifacts - task.next_event_sequence_number += 1 - - self.parent_client.tasks[name] = task - - -class FakeAgentEnginesA2aTasksClient: - def __init__(self): - self.tasks: dict[str, vertexai_types.A2aTask] = {} - self.events = FakeAgentEnginesA2aTasksEventsClient(self) - - async def create( - self, - name: str, - a2a_task_id: str, - config: vertexai_types.CreateAgentEngineTaskConfig, - ) -> vertexai_types.A2aTask: - full_name = f'{name}/a2aTasks/{a2a_task_id}' - task = vertexai_types.A2aTask( - name=full_name, - context_id=config.context_id, - metadata=config.metadata, - output=config.output, - state=vertexai_types.State.SUBMITTED, - ) - task.next_event_sequence_number = 1 - self.tasks[full_name] = task - return task - - async def get(self, name: str) -> vertexai_types.A2aTask: - if name not in self.tasks: - raise genai_errors.APIError( - code=404, - response_json={ - 'error': { - 'status': 'NOT_FOUND', - 'message': 'Task not found', - } - }, - ) - return copy.deepcopy(self.tasks[name]) - - -class FakeAgentEnginesClient: - def __init__(self): - self.a2a_tasks = FakeAgentEnginesA2aTasksClient() - - -class FakeAioClient: - def __init__(self): - self.agent_engines = FakeAgentEnginesClient() - - -class FakeVertexClient: - def __init__(self): - self.aio = FakeAioClient() diff --git a/tests/contrib/tasks/run_vertex_tests.sh b/tests/contrib/tasks/run_vertex_tests.sh deleted file mode 100755 index 12c0395d2..000000000 --- a/tests/contrib/tasks/run_vertex_tests.sh +++ /dev/null @@ -1,17 +0,0 @@ -#!/bin/bash -set -e - -for var in VERTEX_PROJECT VERTEX_LOCATION VERTEX_BASE_URL VERTEX_API_VERSION; do - if [ -z "${!var}" ]; then - echo "Error: Environment variable $var is undefined or empty." >&2 - exit 1 - fi -done - -PYTEST_ARGS=("$@") - -echo "Running Vertex tests..." - -cd $(git rev-parse --show-toplevel) - -uv run pytest -v "${PYTEST_ARGS[@]}" tests/contrib/tasks/test_vertex_task_store.py tests/contrib/tasks/test_vertex_task_converter.py diff --git a/tests/contrib/tasks/test_vertex_task_converter.py b/tests/contrib/tasks/test_vertex_task_converter.py deleted file mode 100644 index 3d260c599..000000000 --- a/tests/contrib/tasks/test_vertex_task_converter.py +++ /dev/null @@ -1,485 +0,0 @@ -import base64 - -import pytest - - -pytest.importorskip( - 'vertexai', reason='Vertex Task Converter tests require vertexai' -) -from vertexai import types as vertexai_types -from google.genai import types as genai_types -from a2a.contrib.tasks.vertex_task_converter import ( - _DATA_PART_MIME_TYPE, - to_sdk_artifact, - to_sdk_message, - to_sdk_part, - to_sdk_task, - to_sdk_task_state, - to_stored_artifact, - to_stored_message, - to_stored_part, - to_stored_task, - to_stored_task_state, -) -from a2a.compat.v0_3.types import ( - Artifact, - DataPart, - FilePart, - FileWithBytes, - FileWithUri, - Message, - Part, - Role, - Task, - TaskState, - TaskStatus, - TextPart, -) - - -def test_to_sdk_task_state() -> None: - assert ( - to_sdk_task_state(vertexai_types.A2aTaskState.STATE_UNSPECIFIED) - == TaskState.unknown - ) - assert ( - to_sdk_task_state(vertexai_types.A2aTaskState.SUBMITTED) - == TaskState.submitted - ) - assert ( - to_sdk_task_state(vertexai_types.A2aTaskState.WORKING) - == TaskState.working - ) - assert ( - to_sdk_task_state(vertexai_types.A2aTaskState.COMPLETED) - == TaskState.completed - ) - assert ( - to_sdk_task_state(vertexai_types.A2aTaskState.CANCELLED) - == TaskState.canceled - ) - assert ( - to_sdk_task_state(vertexai_types.A2aTaskState.FAILED) - == TaskState.failed - ) - assert ( - to_sdk_task_state(vertexai_types.A2aTaskState.REJECTED) - == TaskState.rejected - ) - assert ( - to_sdk_task_state(vertexai_types.A2aTaskState.INPUT_REQUIRED) - == TaskState.input_required - ) - assert ( - to_sdk_task_state(vertexai_types.A2aTaskState.AUTH_REQUIRED) - == TaskState.auth_required - ) - assert to_sdk_task_state(999) == TaskState.unknown # type: ignore - - -def test_to_stored_task_state() -> None: - assert ( - to_stored_task_state(TaskState.unknown) - == vertexai_types.A2aTaskState.STATE_UNSPECIFIED - ) - assert ( - to_stored_task_state(TaskState.submitted) - == vertexai_types.A2aTaskState.SUBMITTED - ) - assert ( - to_stored_task_state(TaskState.working) - == vertexai_types.A2aTaskState.WORKING - ) - assert ( - to_stored_task_state(TaskState.completed) - == vertexai_types.A2aTaskState.COMPLETED - ) - assert ( - to_stored_task_state(TaskState.canceled) - == vertexai_types.A2aTaskState.CANCELLED - ) - assert ( - to_stored_task_state(TaskState.failed) - == vertexai_types.A2aTaskState.FAILED - ) - assert ( - to_stored_task_state(TaskState.rejected) - == vertexai_types.A2aTaskState.REJECTED - ) - assert ( - to_stored_task_state(TaskState.input_required) - == vertexai_types.A2aTaskState.INPUT_REQUIRED - ) - assert ( - to_stored_task_state(TaskState.auth_required) - == vertexai_types.A2aTaskState.AUTH_REQUIRED - ) - - -def test_to_stored_part_text() -> None: - sdk_part = Part(root=TextPart(text='hello world')) - stored_part = to_stored_part(sdk_part) - assert stored_part.text == 'hello world' - assert not stored_part.inline_data - assert not stored_part.file_data - - -def test_to_stored_part_data() -> None: - sdk_part = Part(root=DataPart(data={'key': 'value'})) - stored_part = to_stored_part(sdk_part) - assert stored_part.inline_data is not None - assert stored_part.inline_data.mime_type == _DATA_PART_MIME_TYPE - assert stored_part.inline_data.data == b'{"key": "value"}' - - -def test_to_stored_part_file_bytes() -> None: - encoded_b64 = base64.b64encode(b'test data').decode('utf-8') - sdk_part = Part( - root=FilePart( - file=FileWithBytes( - bytes=encoded_b64, - mime_type='text/plain', - ) - ) - ) - stored_part = to_stored_part(sdk_part) - assert stored_part.inline_data is not None - assert stored_part.inline_data.mime_type == 'text/plain' - assert stored_part.inline_data.data == b'test data' - - -def test_to_stored_part_file_uri() -> None: - sdk_part = Part( - root=FilePart( - file=FileWithUri( - uri='gs://test-bucket/file.txt', - mime_type='text/plain', - ) - ) - ) - stored_part = to_stored_part(sdk_part) - assert stored_part.file_data is not None - assert stored_part.file_data.mime_type == 'text/plain' - assert stored_part.file_data.file_uri == 'gs://test-bucket/file.txt' - - -def test_to_stored_part_unsupported() -> None: - class BadPart: - pass - - part = Part(root=TextPart(text='t')) - part.root = BadPart() # type: ignore - with pytest.raises(ValueError, match='Unsupported part type'): - to_stored_part(part) - - -def test_to_sdk_part_text() -> None: - stored_part = genai_types.Part(text='hello back') - sdk_part = to_sdk_part(stored_part) - assert isinstance(sdk_part.root, TextPart) - assert sdk_part.root.text == 'hello back' - - -def test_to_sdk_part_inline_data() -> None: - stored_part = genai_types.Part( - inline_data=genai_types.Blob( - mime_type='application/json', - data=b'{"key": "val"}', - ) - ) - sdk_part = to_sdk_part(stored_part) - assert isinstance(sdk_part.root, FilePart) - assert isinstance(sdk_part.root.file, FileWithBytes) - expected_b64 = base64.b64encode(b'{"key": "val"}').decode('utf-8') - assert sdk_part.root.file.mime_type == 'application/json' - assert sdk_part.root.file.bytes == expected_b64 - - -def test_to_sdk_part_inline_data_datapart() -> None: - stored_part = genai_types.Part( - inline_data=genai_types.Blob( - mime_type=_DATA_PART_MIME_TYPE, - data=b'{"key": "val"}', - ) - ) - sdk_part = to_sdk_part(stored_part) - assert isinstance(sdk_part.root, DataPart) - assert sdk_part.root.data == {'key': 'val'} - - -def test_to_sdk_part_file_data() -> None: - stored_part = genai_types.Part( - file_data=genai_types.FileData( - mime_type='image/jpeg', - file_uri='gs://bucket/image.jpg', - ) - ) - sdk_part = to_sdk_part(stored_part) - assert isinstance(sdk_part.root, FilePart) - assert isinstance(sdk_part.root.file, FileWithUri) - assert sdk_part.root.file.mime_type == 'image/jpeg' - assert sdk_part.root.file.uri == 'gs://bucket/image.jpg' - - -def test_to_sdk_part_unsupported() -> None: - stored_part = genai_types.Part() - with pytest.raises(ValueError, match='Unsupported part:'): - to_sdk_part(stored_part) - - -def test_to_stored_artifact() -> None: - sdk_artifact = Artifact( - artifact_id='art-123', - parts=[Part(root=TextPart(text='part_1'))], - ) - stored_artifact = to_stored_artifact(sdk_artifact) - assert stored_artifact.artifact_id == 'art-123' - assert len(stored_artifact.parts) == 1 - assert stored_artifact.parts[0].text == 'part_1' - - -def test_to_sdk_artifact() -> None: - stored_artifact = vertexai_types.TaskArtifact( - artifact_id='art-456', - parts=[genai_types.Part(text='part_2')], - ) - sdk_artifact = to_sdk_artifact(stored_artifact) - assert sdk_artifact.artifact_id == 'art-456' - assert len(sdk_artifact.parts) == 1 - assert isinstance(sdk_artifact.parts[0].root, TextPart) - assert sdk_artifact.parts[0].root.text == 'part_2' - - -def test_to_stored_task() -> None: - sdk_task = Task( - id='task-1', - context_id='ctx-1', - status=TaskStatus(state=TaskState.working), - metadata={'foo': 'bar'}, - artifacts=[ - Artifact( - artifact_id='art-1', - parts=[Part(root=TextPart(text='stuff'))], - ) - ], - history=[], - ) - stored_task = to_stored_task(sdk_task) - assert stored_task.context_id == 'ctx-1' - assert stored_task.metadata == {'foo': 'bar'} - assert stored_task.state == vertexai_types.A2aTaskState.WORKING - assert stored_task.output is not None - assert stored_task.output.artifacts is not None - assert len(stored_task.output.artifacts) == 1 - assert stored_task.output.artifacts[0].artifact_id == 'art-1' - - -def test_to_sdk_task() -> None: - stored_task = vertexai_types.A2aTask( - name='projects/123/locations/us-central1/agentEngines/456/tasks/task-2', - context_id='ctx-2', - state=vertexai_types.A2aTaskState.COMPLETED, - metadata={'a': 'b'}, - output=vertexai_types.TaskOutput( - artifacts=[ - vertexai_types.TaskArtifact( - artifact_id='art-2', - parts=[genai_types.Part(text='result')], - ) - ] - ), - ) - sdk_task = to_sdk_task(stored_task) - assert sdk_task.id == 'task-2' - assert sdk_task.context_id == 'ctx-2' - assert sdk_task.status.state == TaskState.completed - assert sdk_task.metadata == {'a': 'b'} - assert sdk_task.history == [] - assert sdk_task.artifacts is not None - assert len(sdk_task.artifacts) == 1 - assert sdk_task.artifacts[0].artifact_id == 'art-2' - assert isinstance(sdk_task.artifacts[0].parts[0].root, TextPart) - assert sdk_task.artifacts[0].parts[0].root.text == 'result' - - -def test_to_sdk_task_no_output() -> None: - stored_task = vertexai_types.A2aTask( - name='tasks/task-3', - context_id='ctx-3', - state=vertexai_types.A2aTaskState.SUBMITTED, - metadata=None, - ) - sdk_task = to_sdk_task(stored_task) - assert sdk_task.id == 'task-3' - assert sdk_task.metadata == {} - assert sdk_task.artifacts == [] - - -def test_sdk_task_state_conversion_round_trip() -> None: - for state in TaskState: - stored_state = to_stored_task_state(state) - round_trip_state = to_sdk_task_state(stored_state) - assert round_trip_state == state - - -def test_sdk_part_text_conversion_round_trip() -> None: - sdk_part = Part(root=TextPart(text='hello world')) - stored_part = to_stored_part(sdk_part) - round_trip_sdk_part = to_sdk_part(stored_part) - assert round_trip_sdk_part == sdk_part - - -def test_sdk_part_data_conversion_round_trip() -> None: - sdk_part = Part(root=DataPart(data={'key': 'value'})) - stored_part = to_stored_part(sdk_part) - round_trip_sdk_part = to_sdk_part(stored_part, part_metadata=None) - - assert round_trip_sdk_part == sdk_part - - -def test_sdk_part_file_bytes_conversion_round_trip() -> None: - encoded_b64 = base64.b64encode(b'test data').decode('utf-8') - sdk_part = Part( - root=FilePart( - file=FileWithBytes( - bytes=encoded_b64, - mime_type='text/plain', - ) - ) - ) - stored_part = to_stored_part(sdk_part) - round_trip_sdk_part = to_sdk_part(stored_part) - assert round_trip_sdk_part == sdk_part - - -def test_sdk_part_file_uri_conversion_round_trip() -> None: - sdk_part = Part( - root=FilePart( - file=FileWithUri( - uri='gs://test-bucket/file.txt', - mime_type='text/plain', - ) - ) - ) - stored_part = to_stored_part(sdk_part) - round_trip_sdk_part = to_sdk_part(stored_part) - assert round_trip_sdk_part == sdk_part - - -def test_sdk_task_conversion_round_trip() -> None: - sdk_task = Task( - id='task-1', - context_id='ctx-1', - status=TaskStatus(state=TaskState.working), - metadata={'foo': 'bar'}, - artifacts=[ - Artifact( - artifact_id='art-1', - parts=[Part(root=TextPart(text='stuff'))], - ) - ], - history=[ - # History is not yet implemented and later will be supported - # via events. - ], - ) - stored_task = to_stored_task(sdk_task) - # Simulate Vertex storing the ID in the fully qualified resource name. - # The task ID during creation gets appended to the parent name. - stored_task.name = ( - f'projects/p/locations/l/agentEngines/e/tasks/{sdk_task.id}' - ) - - round_trip_sdk_task = to_sdk_task(stored_task) - - assert round_trip_sdk_task.id == sdk_task.id - assert round_trip_sdk_task.context_id == sdk_task.context_id - assert round_trip_sdk_task.status == sdk_task.status - assert round_trip_sdk_task.metadata == sdk_task.metadata - assert round_trip_sdk_task.artifacts == sdk_task.artifacts - assert round_trip_sdk_task.history == [] - - -def test_stored_artifact_conversion_round_trip() -> None: - """Test converting an Artifact to TaskArtifact and back restores everything.""" - original_artifact = Artifact( - artifact_id='art123', - name='My cool artifact', - description='A very interesting description', - extensions=['ext1', 'ext2'], - metadata={'custom': 'value'}, - parts=[ - Part( - root=TextPart( - text='hello', metadata={'part_meta': 'hello_meta'} - ) - ), - Part(root=DataPart(data={'foo': 'bar'})), # no metadata - ], - ) - - stored = to_stored_artifact(original_artifact) - assert isinstance(stored, vertexai_types.TaskArtifact) - - # ensure it was populated correctly - assert stored.display_name == 'My cool artifact' - assert stored.description == 'A very interesting description' - assert stored.metadata['__vertex_compat_v'] == 1.0 - - restored_artifact = to_sdk_artifact(stored) - - assert restored_artifact.artifact_id == original_artifact.artifact_id - assert restored_artifact.name == original_artifact.name - assert restored_artifact.description == original_artifact.description - assert restored_artifact.extensions == original_artifact.extensions - assert restored_artifact.metadata == original_artifact.metadata - - assert len(restored_artifact.parts) == 2 - assert isinstance(restored_artifact.parts[0].root, TextPart) - assert restored_artifact.parts[0].root.text == 'hello' - assert restored_artifact.parts[0].root.metadata == { - 'part_meta': 'hello_meta' - } - - assert isinstance(restored_artifact.parts[1].root, DataPart) - assert restored_artifact.parts[1].root.data == {'foo': 'bar'} - assert restored_artifact.parts[1].root.metadata is None - - -def test_stored_message_conversion_round_trip() -> None: - """Test converting a Message to TaskMessage and back restores everything.""" - original_message = Message( - message_id='msg456', - role=Role.agent, - reference_task_ids=['tsk2', 'tsk3'], - extensions=['ext_msg'], - metadata={'msg_meta': 42}, - parts=[ - Part(root=TextPart(text='message text')), - ], - ) - - stored = to_stored_message(original_message) - assert stored is not None - assert isinstance(stored, vertexai_types.TaskMessage) - - assert stored.message_id == 'msg456' - assert stored.role == 'agent' - assert stored.metadata['__vertex_compat_v'] == 1.0 - - restored_message = to_sdk_message(stored) - assert restored_message is not None - - assert restored_message.message_id == original_message.message_id - assert restored_message.role == original_message.role - assert ( - restored_message.reference_task_ids - == original_message.reference_task_ids - ) - assert restored_message.extensions == original_message.extensions - assert restored_message.metadata == original_message.metadata - - assert len(restored_message.parts) == 1 - assert isinstance(restored_message.parts[0].root, TextPart) - assert restored_message.parts[0].root.text == 'message text' - assert restored_message.parts[0].root.metadata is None diff --git a/tests/contrib/tasks/test_vertex_task_store.py b/tests/contrib/tasks/test_vertex_task_store.py deleted file mode 100644 index c77493022..000000000 --- a/tests/contrib/tasks/test_vertex_task_store.py +++ /dev/null @@ -1,654 +0,0 @@ -""" -Tests for the VertexTaskStore. - -These tests can be run with a real or fake Vertex AI Agent Engine as a backend. -The real ones are skipped by default unless the necessary environment variables\ -are set, which prevents them from failing in GitHub Actions. - -To run these tests locally, you can use the provided script: - ./run_vertex_tests.sh - -The following environment variables are required for the real backend: - VERTEX_PROJECT="your-project" \ - VERTEX_LOCATION="your-location" \ - VERTEX_BASE_URL="your-base-url" \ - VERTEX_API_VERSION="your-api-version" \ -""" - -import os - -from collections.abc import AsyncGenerator - -import pytest -import pytest_asyncio - -from .fake_vertex_client import FakeVertexClient - - -# Skip the entire test module if vertexai is not installed -pytest.importorskip( - 'vertexai', reason='Vertex Task Store tests require vertexai' -) -import vertexai - - -# Skip the real backend tests if required environment variables are not set -missing_env_vars = not all( - os.environ.get(var) - for var in [ - 'VERTEX_PROJECT', - 'VERTEX_LOCATION', - 'VERTEX_BASE_URL', - 'VERTEX_API_VERSION', - ] -) - - -@pytest.fixture( - scope='module', - params=[ - 'fake', - pytest.param( - 'real', - marks=pytest.mark.skipif( - missing_env_vars, - reason='Missing required environment variables for real Vertex Task Store.', - ), - ), - ], -) -def backend_type(request) -> str: - return request.param - - -from a2a.contrib.tasks.vertex_task_store import VertexTaskStore -from a2a.server.context import ServerCallContext -from a2a.types.a2a_pb2 import ( - Artifact, - Message, - Part, - Role, - Task, - TaskState, - TaskStatus, -) - - -# Minimal Task object for testing -MINIMAL_TASK_OBJ = Task( - id='task-abc', - context_id='session-xyz', - status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED), -) -MINIMAL_TASK_OBJ.metadata['test_key'] = 'test_value' - - -from collections.abc import Generator - - -@pytest.fixture(scope='module') -def agent_engine_resource_id(backend_type: str) -> Generator[str, None, None]: - """ - Module-scoped fixture that creates and deletes a single Agent Engine - for all the tests. For fake backend, it yields a mock resource. - """ - if backend_type == 'fake': - yield 'projects/mock-project/locations/mock-location/agentEngines/mock-engine' - return - - project = os.environ.get('VERTEX_PROJECT') - location = os.environ.get('VERTEX_LOCATION') - base_url = os.environ.get('VERTEX_BASE_URL') - - client = vertexai.Client(project=project, location=location) - client._api_client._http_options.base_url = base_url - - agent_engine = client.agent_engines.create() - yield agent_engine.api_resource.name - agent_engine.delete() - - -@pytest_asyncio.fixture -async def vertex_store( - backend_type: str, - agent_engine_resource_id: str, -) -> AsyncGenerator[VertexTaskStore, None]: - """ - Function-scoped fixture providing a fresh VertexTaskStore per test, - reusing the module-scoped engine. Uses fake client for 'fake' backend. - """ - if backend_type == 'fake': - client = FakeVertexClient() - else: - project = os.environ.get('VERTEX_PROJECT') - location = os.environ.get('VERTEX_LOCATION') - base_url = os.environ.get('VERTEX_BASE_URL') - api_version = os.environ.get('VERTEX_API_VERSION') - - client = vertexai.Client(project=project, location=location) - client._api_client._http_options.base_url = base_url - client._api_client._http_options.api_version = api_version - - store = VertexTaskStore( - client=client, # type: ignore - agent_engine_resource_id=agent_engine_resource_id, - ) - yield store - - -@pytest.mark.asyncio -async def test_save_task(vertex_store: VertexTaskStore) -> None: - """Test saving a task to the VertexTaskStore.""" - # Ensure unique ID for parameterized tests if needed, or rely on table isolation - task_to_save = Task() - task_to_save.CopyFrom(MINIMAL_TASK_OBJ) - task_to_save.id = 'save-test-task-2' - await vertex_store.save(task_to_save, ServerCallContext()) - - retrieved_task = await vertex_store.get( - task_to_save.id, ServerCallContext() - ) - assert retrieved_task is not None - assert retrieved_task.id == task_to_save.id - - assert retrieved_task == task_to_save - - -@pytest.mark.asyncio -async def test_get_task(vertex_store: VertexTaskStore) -> None: - """Test retrieving a task from the VertexTaskStore.""" - task_id = 'get-test-task-1' - task_to_save = Task() - task_to_save.CopyFrom(MINIMAL_TASK_OBJ) - task_to_save.id = task_id - await vertex_store.save(task_to_save, ServerCallContext()) - - retrieved_task = await vertex_store.get( - task_to_save.id, ServerCallContext() - ) - assert retrieved_task is not None - assert retrieved_task.id == task_to_save.id - assert retrieved_task.context_id == task_to_save.context_id - assert retrieved_task.status.state == TaskState.TASK_STATE_SUBMITTED - - -@pytest.mark.asyncio -async def test_get_nonexistent_task( - vertex_store: VertexTaskStore, -) -> None: - """Test retrieving a nonexistent task.""" - retrieved_task = await vertex_store.get( - 'nonexistent-task-id', ServerCallContext() - ) - assert retrieved_task is None - - -@pytest.mark.asyncio -async def test_save_and_get_detailed_task( - vertex_store: VertexTaskStore, -) -> None: - """Test saving and retrieving a task with more fields populated.""" - task_id = 'detailed-task-test-vertex' - test_task = Task( - id=task_id, - context_id='test-session-1', - status=TaskStatus( - state=TaskState.TASK_STATE_SUBMITTED, - ), - artifacts=[ - Artifact( - artifact_id='artifact-1', - parts=[Part(text='hello')], - ) - ], - ) - test_task.metadata['key1'] = 'value1' - test_task.metadata['key2'] = 123 - - await vertex_store.save(test_task, ServerCallContext()) - retrieved_task = await vertex_store.get(test_task.id, ServerCallContext()) - - assert retrieved_task is not None - assert retrieved_task.id == test_task.id - assert retrieved_task.context_id == test_task.context_id - assert retrieved_task.status.state == TaskState.TASK_STATE_SUBMITTED - assert retrieved_task.metadata['key1'] == 'value1' - assert retrieved_task.metadata['key2'] == 123 - assert retrieved_task.artifacts == test_task.artifacts - - -@pytest.mark.asyncio -async def test_update_task_status_and_metadata( - vertex_store: VertexTaskStore, -) -> None: - """Test updating an existing task.""" - task_id = 'update-test-task-1' - original_task = Task( - id=task_id, - context_id='session-update', - status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED), - artifacts=[], - history=[], - ) - await vertex_store.save(original_task, ServerCallContext()) - - retrieved_before_update = await vertex_store.get( - task_id, ServerCallContext() - ) - assert retrieved_before_update is not None - assert ( - retrieved_before_update.status.state == TaskState.TASK_STATE_SUBMITTED - ) - assert retrieved_before_update.metadata == {} - - updated_task = Task() - updated_task.CopyFrom(original_task) - updated_task.status.state = TaskState.TASK_STATE_COMPLETED - updated_task.status.timestamp.FromJsonString('2023-01-02T11:00:00Z') - updated_task.metadata.update({'update_key': 'update_value'}) - - await vertex_store.save(updated_task, ServerCallContext()) - - retrieved_after_update = await vertex_store.get( - task_id, ServerCallContext() - ) - assert retrieved_after_update is not None - assert retrieved_after_update.status.state == TaskState.TASK_STATE_COMPLETED - assert retrieved_after_update.metadata == {'update_key': 'update_value'} - - -@pytest.mark.asyncio -async def test_update_task_add_artifact(vertex_store: VertexTaskStore) -> None: - """Test updating an existing task by adding an artifact.""" - task_id = 'update-test-task-2' - original_task = Task( - id=task_id, - context_id='session-update', - status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED), - artifacts=[ - Artifact( - artifact_id='artifact-1', - parts=[Part(text='hello')], - ) - ], - history=[], - ) - await vertex_store.save(original_task, ServerCallContext()) - - retrieved_before_update = await vertex_store.get( - task_id, ServerCallContext() - ) - assert retrieved_before_update is not None - assert ( - retrieved_before_update.status.state == TaskState.TASK_STATE_SUBMITTED - ) - assert retrieved_before_update.metadata == {} - - updated_task = Task() - updated_task.CopyFrom(original_task) - updated_task.status.state = TaskState.TASK_STATE_WORKING - updated_task.status.timestamp.FromJsonString('2023-01-02T11:00:00Z') - - updated_task.artifacts.append( - Artifact( - artifact_id='artifact-2', - parts=[Part(text='world')], - ) - ) - - await vertex_store.save(updated_task, ServerCallContext()) - - retrieved_after_update = await vertex_store.get( - task_id, ServerCallContext() - ) - assert retrieved_after_update is not None - assert retrieved_after_update.status.state == TaskState.TASK_STATE_WORKING - - assert retrieved_after_update.artifacts == [ - Artifact( - artifact_id='artifact-1', - parts=[Part(text='hello')], - ), - Artifact( - artifact_id='artifact-2', - parts=[Part(text='world')], - ), - ] - - -@pytest.mark.asyncio -async def test_update_task_update_artifact( - vertex_store: VertexTaskStore, -) -> None: - """Test updating an existing task by changing an artifact.""" - task_id = 'update-test-task-3' - original_task = Task( - id=task_id, - context_id='session-update', - status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED), - artifacts=[ - Artifact( - artifact_id='artifact-1', - parts=[Part(text='hello')], - ), - Artifact( - artifact_id='artifact-2', - parts=[Part(text='world')], - ), - ], - history=[], - ) - await vertex_store.save(original_task, ServerCallContext()) - - retrieved_before_update = await vertex_store.get( - task_id, ServerCallContext() - ) - assert retrieved_before_update is not None - assert ( - retrieved_before_update.status.state == TaskState.TASK_STATE_SUBMITTED - ) - assert retrieved_before_update.metadata == {} - - updated_task = Task() - updated_task.CopyFrom(original_task) - updated_task.status.state = TaskState.TASK_STATE_WORKING - updated_task.status.timestamp.FromJsonString('2023-01-02T11:00:00Z') - - updated_task.artifacts[0].parts[0].text = 'ahoy' - - await vertex_store.save(updated_task, ServerCallContext()) - - retrieved_after_update = await vertex_store.get( - task_id, ServerCallContext() - ) - assert retrieved_after_update is not None - assert retrieved_after_update.status.state == TaskState.TASK_STATE_WORKING - - assert retrieved_after_update.artifacts == [ - Artifact( - artifact_id='artifact-1', - parts=[Part(text='ahoy')], - ), - Artifact( - artifact_id='artifact-2', - parts=[Part(text='world')], - ), - ] - - -@pytest.mark.asyncio -async def test_update_task_delete_artifact( - vertex_store: VertexTaskStore, -) -> None: - """Test updating an existing task by deleting an artifact.""" - task_id = 'update-test-task-4' - original_task = Task( - id=task_id, - context_id='session-update', - status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED), - artifacts=[ - Artifact( - artifact_id='artifact-1', - parts=[Part(text='hello')], - ), - Artifact( - artifact_id='artifact-2', - parts=[Part(text='world')], - ), - ], - history=[], - ) - await vertex_store.save(original_task, ServerCallContext()) - - retrieved_before_update = await vertex_store.get( - task_id, ServerCallContext() - ) - assert retrieved_before_update is not None - assert ( - retrieved_before_update.status.state == TaskState.TASK_STATE_SUBMITTED - ) - assert retrieved_before_update.metadata == {} - - updated_task = Task() - updated_task.CopyFrom(original_task) - updated_task.status.state = TaskState.TASK_STATE_WORKING - updated_task.status.timestamp.FromJsonString('2023-01-02T11:00:00Z') - - del updated_task.artifacts[1] - - await vertex_store.save(updated_task, ServerCallContext()) - - retrieved_after_update = await vertex_store.get( - task_id, ServerCallContext() - ) - assert retrieved_after_update is not None - assert retrieved_after_update.status.state == TaskState.TASK_STATE_WORKING - - assert retrieved_after_update.artifacts == [ - Artifact( - artifact_id='artifact-1', - parts=[Part(text='hello')], - ) - ] - - -@pytest.mark.asyncio -async def test_metadata_field_mapping( - vertex_store: VertexTaskStore, -) -> None: - """Test that metadata field is correctly mapped between the core types and vertex. - - This test verifies: - 1. Metadata can be None - 2. Metadata can be a simple dict - 3. Metadata can contain nested structures - 4. Metadata is correctly saved and retrieved - 5. The mapping between task.metadata and task_metadata column works - """ - # Test 1: Task with no metadata (None) - task_no_metadata = Task( - id='task-metadata-test-1', - context_id='session-meta-1', - status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED), - ) - await vertex_store.save(task_no_metadata, ServerCallContext()) - retrieved_no_metadata = await vertex_store.get( - 'task-metadata-test-1', ServerCallContext() - ) - assert retrieved_no_metadata is not None - assert retrieved_no_metadata.metadata == {} - - # Test 2: Task with simple metadata - simple_metadata = {'key': 'value', 'number': 42, 'boolean': True} - task_simple_metadata = Task( - id='task-metadata-test-2', - context_id='session-meta-2', - status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED), - metadata=simple_metadata, - ) - await vertex_store.save(task_simple_metadata, ServerCallContext()) - retrieved_simple = await vertex_store.get( - 'task-metadata-test-2', ServerCallContext() - ) - assert retrieved_simple is not None - assert retrieved_simple.metadata == simple_metadata - - # Test 3: Task with complex nested metadata - complex_metadata = { - 'level1': { - 'level2': { - 'level3': ['a', 'b', 'c'], - 'numeric': 3.14159, - }, - 'array': [1, 2, {'nested': 'value'}], - }, - 'special_chars': 'Hello\nWorld\t!', - 'unicode': '🚀 Unicode test 你好', - 'null_value': None, - } - task_complex_metadata = Task( - id='task-metadata-test-3', - context_id='session-meta-3', - status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED), - metadata=complex_metadata, - ) - await vertex_store.save(task_complex_metadata, ServerCallContext()) - retrieved_complex = await vertex_store.get( - 'task-metadata-test-3', ServerCallContext() - ) - assert retrieved_complex is not None - assert retrieved_complex.metadata == complex_metadata - - # Test 4: Update metadata from None to dict - task_update_metadata = Task( - id='task-metadata-test-4', - context_id='session-meta-4', - status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED), - ) - await vertex_store.save(task_update_metadata, ServerCallContext()) - - # Update metadata - task_update_metadata.metadata.Clear() - task_update_metadata.metadata.update( - {'updated': True, 'timestamp': '2024-01-01'} - ) - await vertex_store.save(task_update_metadata, ServerCallContext()) - - retrieved_updated = await vertex_store.get( - 'task-metadata-test-4', ServerCallContext() - ) - assert retrieved_updated is not None - assert retrieved_updated.metadata == { - 'updated': True, - 'timestamp': '2024-01-01', - } - - # Test 5: Update metadata from dict to None - task_update_metadata.metadata.Clear() - await vertex_store.save(task_update_metadata, ServerCallContext()) - - retrieved_none = await vertex_store.get( - 'task-metadata-test-4', ServerCallContext() - ) - assert retrieved_none is not None - assert retrieved_none.metadata == {} - - -@pytest.mark.asyncio -async def test_metadata_empty_transitions( - vertex_store: VertexTaskStore, -) -> None: - """Test that updating metadata between {} and None does not generate events.""" - task_id = 'task-metadata-empty-test' - - # Step 1: Create task with metadata={} - task = Task( - id=task_id, - context_id='session-meta-empty', - status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED), - metadata={}, - ) - await vertex_store.save(task, ServerCallContext()) - - full_name = f'{vertex_store._agent_engine_resource_id}/a2aTasks/{task_id}' - - # Get initial event sequence number - stored_task_before = ( - await vertex_store._client.aio.agent_engines.a2a_tasks.get( - name=full_name - ) - ) - initial_seq = stored_task_before.next_event_sequence_number - - # Step 2: Update metadata to None - updated_task = Task() - updated_task.CopyFrom(task) - updated_task.metadata.Clear() - await vertex_store.save(updated_task, ServerCallContext()) - - # Step 3: Update back to {} - task_back = Task() - task_back.CopyFrom(updated_task) - task_back.metadata = {} - await vertex_store.save(task_back, ServerCallContext()) - - # Verify that retrieved task still has {} (due to mapping) - retrieved = await vertex_store.get(task_id, ServerCallContext()) - assert retrieved is not None - assert retrieved.metadata == {} - - # Verify that next_event_sequence_number did NOT increase (no events generated) - stored_task_after = ( - await vertex_store._client.aio.agent_engines.a2a_tasks.get( - name=full_name - ) - ) - assert stored_task_after.next_event_sequence_number == initial_seq - - -@pytest.mark.asyncio -async def test_update_task_status_details( - vertex_store: VertexTaskStore, -) -> None: - """Test updating an existing task by changing the status details (message) with part metadata.""" - task_id = 'update-test-task-status-details' - original_task = Task( - id=task_id, - context_id='session-update', - status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED), - metadata=None, - artifacts=[], - history=[], - ) - await vertex_store.save(original_task, ServerCallContext()) - - retrieved_before_update = await vertex_store.get( - task_id, ServerCallContext() - ) - assert retrieved_before_update is not None - assert ( - retrieved_before_update.status.state == TaskState.TASK_STATE_SUBMITTED - ) - - updated_task = Task() - updated_task.CopyFrom(original_task) - updated_task.status.state = TaskState.TASK_STATE_FAILED - updated_task.status.timestamp.FromJsonString('2023-01-02T11:00:00Z') - updated_task.status.message.CopyFrom( - Message( - message_id='msg-error-1', - role=Role.ROLE_AGENT, - parts=[ - Part( - text='Task failed due to an unknown error', - metadata={'error_code': 'UNKNOWN', 'retryable': False}, - ) - ], - ) - ) - - await vertex_store.save(updated_task, ServerCallContext()) - - retrieved_after_update = await vertex_store.get( - task_id, ServerCallContext() - ) - assert retrieved_after_update is not None - assert retrieved_after_update.status.state == TaskState.TASK_STATE_FAILED - assert retrieved_after_update.status.message is not None - assert retrieved_after_update.status.message.message_id == 'msg-error-1' - assert retrieved_after_update.status.message.role == Role.ROLE_AGENT - assert len(retrieved_after_update.status.message.parts) == 1 - - part = retrieved_after_update.status.message.parts[0] - assert part.text == 'Task failed due to an unknown error' - assert part.metadata == {'error_code': 'UNKNOWN', 'retryable': False} - - # Also test clearing the message - cleared_task = Task() - cleared_task.CopyFrom(updated_task) - cleared_task.status.ClearField('message') - - await vertex_store.save(cleared_task, ServerCallContext()) - retrieved_cleared = await vertex_store.get(task_id, ServerCallContext()) - assert retrieved_cleared is not None - assert not retrieved_cleared.status.HasField('message') diff --git a/uv.lock b/uv.lock index dc87a7b6d..0a1a7e13e 100644 --- a/uv.lock +++ b/uv.lock @@ -27,7 +27,6 @@ dependencies = [ all = [ { name = "alembic" }, { name = "cryptography" }, - { name = "google-cloud-aiplatform" }, { name = "grpcio" }, { name = "grpcio-reflection" }, { name = "grpcio-status" }, @@ -74,9 +73,6 @@ telemetry = [ { name = "opentelemetry-api" }, { name = "opentelemetry-sdk" }, ] -vertex = [ - { name = "google-cloud-aiplatform" }, -] [package.dev-dependencies] dev = [ @@ -109,8 +105,6 @@ requires-dist = [ { name = "cryptography", marker = "extra == 'encryption'", specifier = ">=43.0.0" }, { name = "culsans", marker = "python_full_version < '3.13'", specifier = ">=0.11.0" }, { name = "google-api-core", specifier = ">=1.26.0" }, - { name = "google-cloud-aiplatform", marker = "extra == 'all'", specifier = ">=1.140.0" }, - { name = "google-cloud-aiplatform", marker = "extra == 'vertex'", specifier = ">=1.140.0" }, { name = "googleapis-common-protos", specifier = ">=1.70.0" }, { name = "grpcio", marker = "extra == 'all'", specifier = ">=1.60" }, { name = "grpcio", marker = "extra == 'grpc'", specifier = ">=1.60" }, @@ -146,7 +140,7 @@ requires-dist = [ { name = "starlette", marker = "extra == 'all'" }, { name = "starlette", marker = "extra == 'http-server'" }, ] -provides-extras = ["all", "db-cli", "encryption", "grpc", "http-server", "mysql", "postgresql", "signing", "sql", "sqlite", "telemetry", "vertex"] +provides-extras = ["all", "db-cli", "encryption", "grpc", "http-server", "mysql", "postgresql", "signing", "sql", "sqlite", "telemetry"] [package.metadata.requires-dev] dev = [ @@ -765,24 +759,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/33/6b/e0547afaf41bf2c42e52430072fa5658766e3d65bd4b03a563d1b6336f57/distlib-0.4.0-py2.py3-none-any.whl", hash = "sha256:9659f7d87e46584a30b5780e43ac7a2143098441670ff0a49d5f9034c54a6c16", size = 469047, upload-time = "2025-07-17T16:51:58.613Z" }, ] -[[package]] -name = "distro" -version = "1.9.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/fc/f8/98eea607f65de6527f8a2e8885fc8015d3e6f5775df186e443e0964a11c3/distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed", size = 60722, upload-time = "2023-12-24T09:54:32.31Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/12/b3/231ffd4ab1fc9d679809f356cebee130ac7daa00d6d6f3206dd4fd137e9e/distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2", size = 20277, upload-time = "2023-12-24T09:54:30.421Z" }, -] - -[[package]] -name = "docstring-parser" -version = "0.17.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/b2/9d/c3b43da9515bd270df0f80548d9944e389870713cc1fe2b8fb35fe2bcefd/docstring_parser-0.17.0.tar.gz", hash = "sha256:583de4a309722b3315439bb31d64ba3eebada841f2e2cee23b99df001434c912", size = 27442, upload-time = "2025-07-21T07:35:01.868Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/55/e2/2537ebcff11c1ee1ff17d8d0b6f4db75873e3b0fb32c2d4a2ee31ecb310a/docstring_parser-0.17.0-py3-none-any.whl", hash = "sha256:cf2569abd23dce8099b300f9b4fa8191e9582dda731fd533daf54c4551658708", size = 36896, upload-time = "2025-07-21T07:35:00.684Z" }, -] - [[package]] name = "dunamai" version = "1.26.0" @@ -857,12 +833,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/45/27/09c33d67f7e0dcf06d7ac17d196594e66989299374bfb0d4331d1038e76b/google_api_core-2.30.0-py3-none-any.whl", hash = "sha256:80be49ee937ff9aba0fd79a6eddfde35fe658b9953ab9b79c57dd7061afa8df5", size = 173288, upload-time = "2026-02-18T20:28:10.367Z" }, ] -[package.optional-dependencies] -grpc = [ - { name = "grpcio" }, - { name = "grpcio-status" }, -] - [[package]] name = "google-auth" version = "2.49.1" @@ -876,167 +846,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e9/eb/c6c2478d8a8d633460be40e2a8a6f8f429171997a35a96f81d3b680dec83/google_auth-2.49.1-py3-none-any.whl", hash = "sha256:195ebe3dca18eddd1b3db5edc5189b76c13e96f29e73043b923ebcf3f1a860f7", size = 240737, upload-time = "2026-03-12T19:30:53.159Z" }, ] -[package.optional-dependencies] -requests = [ - { name = "requests" }, -] - -[[package]] -name = "google-cloud-aiplatform" -version = "1.141.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "docstring-parser" }, - { name = "google-api-core", extra = ["grpc"] }, - { name = "google-auth" }, - { name = "google-cloud-bigquery" }, - { name = "google-cloud-resource-manager" }, - { name = "google-cloud-storage" }, - { name = "google-genai" }, - { name = "packaging" }, - { name = "proto-plus" }, - { name = "protobuf" }, - { name = "pydantic" }, - { name = "typing-extensions" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/ac/dc/1209c7aab43bd7233cf631165a3b1b4284d22fc7fe7387c66228d07868ab/google_cloud_aiplatform-1.141.0.tar.gz", hash = "sha256:e3b1cdb28865dd862aac9c685dfc5ac076488705aba0a5354016efadcddd59c6", size = 10152688, upload-time = "2026-03-10T22:20:08.692Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/6a/fc/428af69a69ff2e477e7f5e12d227b31fe5790f1a8234aacd54297f49c836/google_cloud_aiplatform-1.141.0-py2.py3-none-any.whl", hash = "sha256:6bd25b4d514c40b8181ca703e1b313ad6d0454ab8006fc9907fb3e9f672f31d1", size = 8358409, upload-time = "2026-03-10T22:20:04.871Z" }, -] - -[[package]] -name = "google-cloud-bigquery" -version = "3.40.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "google-api-core", extra = ["grpc"] }, - { name = "google-auth" }, - { name = "google-cloud-core" }, - { name = "google-resumable-media" }, - { name = "packaging" }, - { name = "python-dateutil" }, - { name = "requests" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/11/0c/153ee546c288949fcc6794d58811ab5420f3ecad5fa7f9e73f78d9512a6e/google_cloud_bigquery-3.40.1.tar.gz", hash = "sha256:75afcfb6e007238fe1deefb2182105249321145ff921784fe7b1de2b4ba24506", size = 511761, upload-time = "2026-02-12T18:44:18.958Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/7c/f5/081cf5b90adfe524ae0d671781b0d497a75a0f2601d075af518828e22d8f/google_cloud_bigquery-3.40.1-py3-none-any.whl", hash = "sha256:9082a6b8193aba87bed6a2c79cf1152b524c99bb7e7ac33a785e333c09eac868", size = 262018, upload-time = "2026-02-12T18:44:16.913Z" }, -] - -[[package]] -name = "google-cloud-core" -version = "2.5.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "google-api-core" }, - { name = "google-auth" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/a6/03/ef0bc99d0e0faf4fdbe67ac445e18cdaa74824fd93cd069e7bb6548cb52d/google_cloud_core-2.5.0.tar.gz", hash = "sha256:7c1b7ef5c92311717bd05301aa1a91ffbc565673d3b0b4163a52d8413a186963", size = 36027, upload-time = "2025-10-29T23:17:39.513Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/89/20/bfa472e327c8edee00f04beecc80baeddd2ab33ee0e86fd7654da49d45e9/google_cloud_core-2.5.0-py3-none-any.whl", hash = "sha256:67d977b41ae6c7211ee830c7912e41003ea8194bff15ae7d72fd6f51e57acabc", size = 29469, upload-time = "2025-10-29T23:17:38.548Z" }, -] - -[[package]] -name = "google-cloud-resource-manager" -version = "1.16.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "google-api-core", extra = ["grpc"] }, - { name = "google-auth" }, - { name = "grpc-google-iam-v1" }, - { name = "grpcio" }, - { name = "proto-plus" }, - { name = "protobuf" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/4e/7f/db00b2820475793a52958dc55fe9ec2eb8e863546e05fcece9b921f86ebe/google_cloud_resource_manager-1.16.0.tar.gz", hash = "sha256:cc938f87cc36c2672f062b1e541650629e0d954c405a4dac35ceedee70c267c3", size = 459840, upload-time = "2026-01-15T13:04:07.726Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/94/ff/4b28bcc791d9d7e4ac8fea00fbd90ccb236afda56746a3b4564d2ae45df3/google_cloud_resource_manager-1.16.0-py3-none-any.whl", hash = "sha256:fb9a2ad2b5053c508e1c407ac31abfd1a22e91c32876c1892830724195819a28", size = 400218, upload-time = "2026-01-15T13:02:47.378Z" }, -] - -[[package]] -name = "google-cloud-storage" -version = "3.10.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "google-api-core" }, - { name = "google-auth" }, - { name = "google-cloud-core" }, - { name = "google-crc32c" }, - { name = "google-resumable-media" }, - { name = "requests" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/7a/e3/747759eebc72e420c25903d6bc231d0ceb110b66ac7e6ee3f350417152cd/google_cloud_storage-3.10.0.tar.gz", hash = "sha256:1aeebf097c27d718d84077059a28d7e87f136f3700212215f1ceeae1d1c5d504", size = 17309829, upload-time = "2026-03-18T15:54:11.875Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/29/e2/d58442f4daee5babd9255cf492a1f3d114357164072f8339a22a3ad460a2/google_cloud_storage-3.10.0-py3-none-any.whl", hash = "sha256:0072e7783b201e45af78fd9779894cdb6bec2bf922ee932f3fcc16f8bce9b9a3", size = 324382, upload-time = "2026-03-18T15:54:10.091Z" }, -] - -[[package]] -name = "google-crc32c" -version = "1.8.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/03/41/4b9c02f99e4c5fb477122cd5437403b552873f014616ac1d19ac8221a58d/google_crc32c-1.8.0.tar.gz", hash = "sha256:a428e25fb7691024de47fecfbff7ff957214da51eddded0da0ae0e0f03a2cf79", size = 14192, upload-time = "2025-12-16T00:35:25.142Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/95/ac/6f7bc93886a823ab545948c2dd48143027b2355ad1944c7cf852b338dc91/google_crc32c-1.8.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:0470b8c3d73b5f4e3300165498e4cf25221c7eb37f1159e221d1825b6df8a7ff", size = 31296, upload-time = "2025-12-16T00:19:07.261Z" }, - { url = "https://files.pythonhosted.org/packages/f7/97/a5accde175dee985311d949cfcb1249dcbb290f5ec83c994ea733311948f/google_crc32c-1.8.0-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:119fcd90c57c89f30040b47c211acee231b25a45d225e3225294386f5d258288", size = 30870, upload-time = "2025-12-16T00:29:17.669Z" }, - { url = "https://files.pythonhosted.org/packages/3d/63/bec827e70b7a0d4094e7476f863c0dbd6b5f0f1f91d9c9b32b76dcdfeb4e/google_crc32c-1.8.0-cp310-cp310-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:6f35aaffc8ccd81ba3162443fabb920e65b1f20ab1952a31b13173a67811467d", size = 33214, upload-time = "2025-12-16T00:40:19.618Z" }, - { url = "https://files.pythonhosted.org/packages/63/bc/11b70614df04c289128d782efc084b9035ef8466b3d0a8757c1b6f5cf7ac/google_crc32c-1.8.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:864abafe7d6e2c4c66395c1eb0fe12dc891879769b52a3d56499612ca93b6092", size = 33589, upload-time = "2025-12-16T00:40:20.7Z" }, - { url = "https://files.pythonhosted.org/packages/3e/00/a08a4bc24f1261cc5b0f47312d8aebfbe4b53c2e6307f1b595605eed246b/google_crc32c-1.8.0-cp310-cp310-win_amd64.whl", hash = "sha256:db3fe8eaf0612fc8b20fa21a5f25bd785bc3cd5be69f8f3412b0ac2ffd49e733", size = 34437, upload-time = "2025-12-16T00:35:19.437Z" }, - { url = "https://files.pythonhosted.org/packages/5d/ef/21ccfaab3d5078d41efe8612e0ed0bfc9ce22475de074162a91a25f7980d/google_crc32c-1.8.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:014a7e68d623e9a4222d663931febc3033c5c7c9730785727de2a81f87d5bab8", size = 31298, upload-time = "2025-12-16T00:20:32.241Z" }, - { url = "https://files.pythonhosted.org/packages/c5/b8/f8413d3f4b676136e965e764ceedec904fe38ae8de0cdc52a12d8eb1096e/google_crc32c-1.8.0-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:86cfc00fe45a0ac7359e5214a1704e51a99e757d0272554874f419f79838c5f7", size = 30872, upload-time = "2025-12-16T00:33:58.785Z" }, - { url = "https://files.pythonhosted.org/packages/f6/fd/33aa4ec62b290477181c55bb1c9302c9698c58c0ce9a6ab4874abc8b0d60/google_crc32c-1.8.0-cp311-cp311-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:19b40d637a54cb71e0829179f6cb41835f0fbd9e8eb60552152a8b52c36cbe15", size = 33243, upload-time = "2025-12-16T00:40:21.46Z" }, - { url = "https://files.pythonhosted.org/packages/71/03/4820b3bd99c9653d1a5210cb32f9ba4da9681619b4d35b6a052432df4773/google_crc32c-1.8.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:17446feb05abddc187e5441a45971b8394ea4c1b6efd88ab0af393fd9e0a156a", size = 33608, upload-time = "2025-12-16T00:40:22.204Z" }, - { url = "https://files.pythonhosted.org/packages/7c/43/acf61476a11437bf9733fb2f70599b1ced11ec7ed9ea760fdd9a77d0c619/google_crc32c-1.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:71734788a88f551fbd6a97be9668a0020698e07b2bf5b3aa26a36c10cdfb27b2", size = 34439, upload-time = "2025-12-16T00:35:20.458Z" }, - { url = "https://files.pythonhosted.org/packages/e9/5f/7307325b1198b59324c0fa9807cafb551afb65e831699f2ce211ad5c8240/google_crc32c-1.8.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:4b8286b659c1335172e39563ab0a768b8015e88e08329fa5321f774275fc3113", size = 31300, upload-time = "2025-12-16T00:21:56.723Z" }, - { url = "https://files.pythonhosted.org/packages/21/8e/58c0d5d86e2220e6a37befe7e6a94dd2f6006044b1a33edf1ff6d9f7e319/google_crc32c-1.8.0-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:2a3dc3318507de089c5384cc74d54318401410f82aa65b2d9cdde9d297aca7cb", size = 30867, upload-time = "2025-12-16T00:38:31.302Z" }, - { url = "https://files.pythonhosted.org/packages/ce/a9/a780cc66f86335a6019f557a8aaca8fbb970728f0efd2430d15ff1beae0e/google_crc32c-1.8.0-cp312-cp312-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:14f87e04d613dfa218d6135e81b78272c3b904e2a7053b841481b38a7d901411", size = 33364, upload-time = "2025-12-16T00:40:22.96Z" }, - { url = "https://files.pythonhosted.org/packages/21/3f/3457ea803db0198c9aaca2dd373750972ce28a26f00544b6b85088811939/google_crc32c-1.8.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cb5c869c2923d56cb0c8e6bcdd73c009c36ae39b652dbe46a05eb4ef0ad01454", size = 33740, upload-time = "2025-12-16T00:40:23.96Z" }, - { url = "https://files.pythonhosted.org/packages/df/c0/87c2073e0c72515bb8733d4eef7b21548e8d189f094b5dad20b0ecaf64f6/google_crc32c-1.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:3cc0c8912038065eafa603b238abf252e204accab2a704c63b9e14837a854962", size = 34437, upload-time = "2025-12-16T00:35:21.395Z" }, - { url = "https://files.pythonhosted.org/packages/d1/db/000f15b41724589b0e7bc24bc7a8967898d8d3bc8caf64c513d91ef1f6c0/google_crc32c-1.8.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:3ebb04528e83b2634857f43f9bb8ef5b2bbe7f10f140daeb01b58f972d04736b", size = 31297, upload-time = "2025-12-16T00:23:20.709Z" }, - { url = "https://files.pythonhosted.org/packages/d7/0d/8ebed0c39c53a7e838e2a486da8abb0e52de135f1b376ae2f0b160eb4c1a/google_crc32c-1.8.0-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:450dc98429d3e33ed2926fc99ee81001928d63460f8538f21a5d6060912a8e27", size = 30867, upload-time = "2025-12-16T00:43:14.628Z" }, - { url = "https://files.pythonhosted.org/packages/ce/42/b468aec74a0354b34c8cbf748db20d6e350a68a2b0912e128cabee49806c/google_crc32c-1.8.0-cp313-cp313-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:3b9776774b24ba76831609ffbabce8cdf6fa2bd5e9df37b594221c7e333a81fa", size = 33344, upload-time = "2025-12-16T00:40:24.742Z" }, - { url = "https://files.pythonhosted.org/packages/1c/e8/b33784d6fc77fb5062a8a7854e43e1e618b87d5ddf610a88025e4de6226e/google_crc32c-1.8.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:89c17d53d75562edfff86679244830599ee0a48efc216200691de8b02ab6b2b8", size = 33694, upload-time = "2025-12-16T00:40:25.505Z" }, - { url = "https://files.pythonhosted.org/packages/92/b1/d3cbd4d988afb3d8e4db94ca953df429ed6db7282ed0e700d25e6c7bfc8d/google_crc32c-1.8.0-cp313-cp313-win_amd64.whl", hash = "sha256:57a50a9035b75643996fbf224d6661e386c7162d1dfdab9bc4ca790947d1007f", size = 34435, upload-time = "2025-12-16T00:35:22.107Z" }, - { url = "https://files.pythonhosted.org/packages/21/88/8ecf3c2b864a490b9e7010c84fd203ec8cf3b280651106a3a74dd1b0ca72/google_crc32c-1.8.0-cp314-cp314-macosx_12_0_arm64.whl", hash = "sha256:e6584b12cb06796d285d09e33f63309a09368b9d806a551d8036a4207ea43697", size = 31301, upload-time = "2025-12-16T00:24:48.527Z" }, - { url = "https://files.pythonhosted.org/packages/36/c6/f7ff6c11f5ca215d9f43d3629163727a272eabc356e5c9b2853df2bfe965/google_crc32c-1.8.0-cp314-cp314-macosx_12_0_x86_64.whl", hash = "sha256:f4b51844ef67d6cf2e9425983274da75f18b1597bb2c998e1c0a0e8d46f8f651", size = 30868, upload-time = "2025-12-16T00:48:12.163Z" }, - { url = "https://files.pythonhosted.org/packages/56/15/c25671c7aad70f8179d858c55a6ae8404902abe0cdcf32a29d581792b491/google_crc32c-1.8.0-cp314-cp314-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:b0d1a7afc6e8e4635564ba8aa5c0548e3173e41b6384d7711a9123165f582de2", size = 33381, upload-time = "2025-12-16T00:40:26.268Z" }, - { url = "https://files.pythonhosted.org/packages/42/fa/f50f51260d7b0ef5d4898af122d8a7ec5a84e2984f676f746445f783705f/google_crc32c-1.8.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8b3f68782f3cbd1bce027e48768293072813469af6a61a86f6bb4977a4380f21", size = 33734, upload-time = "2025-12-16T00:40:27.028Z" }, - { url = "https://files.pythonhosted.org/packages/08/a5/7b059810934a09fb3ccb657e0843813c1fee1183d3bc2c8041800374aa2c/google_crc32c-1.8.0-cp314-cp314-win_amd64.whl", hash = "sha256:d511b3153e7011a27ab6ee6bb3a5404a55b994dc1a7322c0b87b29606d9790e2", size = 34878, upload-time = "2025-12-16T00:35:23.142Z" }, - { url = "https://files.pythonhosted.org/packages/52/c5/c171e4d8c44fec1422d801a6d2e5d7ddabd733eeda505c79730ee9607f07/google_crc32c-1.8.0-pp311-pypy311_pp73-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:87fa445064e7db928226b2e6f0d5304ab4cd0339e664a4e9a25029f384d9bb93", size = 28615, upload-time = "2025-12-16T00:40:29.298Z" }, - { url = "https://files.pythonhosted.org/packages/9c/97/7d75fe37a7a6ed171a2cf17117177e7aab7e6e0d115858741b41e9dd4254/google_crc32c-1.8.0-pp311-pypy311_pp73-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:f639065ea2042d5c034bf258a9f085eaa7af0cd250667c0635a3118e8f92c69c", size = 28800, upload-time = "2025-12-16T00:40:30.322Z" }, -] - -[[package]] -name = "google-genai" -version = "1.68.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "anyio" }, - { name = "distro" }, - { name = "google-auth", extra = ["requests"] }, - { name = "httpx" }, - { name = "pydantic" }, - { name = "requests" }, - { name = "sniffio" }, - { name = "tenacity" }, - { name = "typing-extensions" }, - { name = "websockets" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/9c/2c/f059982dbcb658cc535c81bbcbe7e2c040d675f4b563b03cdb01018a4bc3/google_genai-1.68.0.tar.gz", hash = "sha256:ac30c0b8bc630f9372993a97e4a11dae0e36f2e10d7c55eacdca95a9fa14ca96", size = 511285, upload-time = "2026-03-18T01:03:18.243Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/84/de/7d3ee9c94b74c3578ea4f88d45e8de9405902f857932334d81e89bce3dfa/google_genai-1.68.0-py3-none-any.whl", hash = "sha256:a1bc9919c0e2ea2907d1e319b65471d3d6d58c54822039a249fe1323e4178d15", size = 750912, upload-time = "2026-03-18T01:03:15.983Z" }, -] - -[[package]] -name = "google-resumable-media" -version = "2.8.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "google-crc32c" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/64/d7/520b62a35b23038ff005e334dba3ffc75fcf583bee26723f1fd8fd4b6919/google_resumable_media-2.8.0.tar.gz", hash = "sha256:f1157ed8b46994d60a1bc432544db62352043113684d4e030ee02e77ebe9a1ae", size = 2163265, upload-time = "2025-11-17T15:38:06.659Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/1f/0b/93afde9cfe012260e9fe1522f35c9b72d6ee222f316586b1f23ecf44d518/google_resumable_media-2.8.0-py3-none-any.whl", hash = "sha256:dd14a116af303845a8d932ddae161a26e86cc229645bc98b39f026f9b1717582", size = 81340, upload-time = "2025-11-17T15:38:05.594Z" }, -] - [[package]] name = "googleapis-common-protos" version = "1.73.0" @@ -1049,11 +858,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/69/28/23eea8acd65972bbfe295ce3666b28ac510dfcb115fac089d3edb0feb00a/googleapis_common_protos-1.73.0-py3-none-any.whl", hash = "sha256:dfdaaa2e860f242046be561e6d6cb5c5f1541ae02cfbcb034371aadb2942b4e8", size = 297578, upload-time = "2026-03-06T21:52:33.933Z" }, ] -[package.optional-dependencies] -grpc = [ - { name = "grpcio" }, -] - [[package]] name = "greenlet" version = "3.3.2" @@ -1114,20 +918,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/29/4b/45d90626aef8e65336bed690106d1382f7a43665e2249017e9527df8823b/greenlet-3.3.2-cp314-cp314t-win_amd64.whl", hash = "sha256:c04c5e06ec3e022cbfe2cd4a846e1d4e50087444f875ff6d2c2ad8445495cf1a", size = 237086, upload-time = "2026-02-20T20:20:45.786Z" }, ] -[[package]] -name = "grpc-google-iam-v1" -version = "0.14.3" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "googleapis-common-protos", extra = ["grpc"] }, - { name = "grpcio" }, - { name = "protobuf" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/76/1e/1011451679a983f2f5c6771a1682542ecb027776762ad031fd0d7129164b/grpc_google_iam_v1-0.14.3.tar.gz", hash = "sha256:879ac4ef33136c5491a6300e27575a9ec760f6cdf9a2518798c1b8977a5dc389", size = 23745, upload-time = "2025-10-15T21:14:53.318Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/4a/bd/330a1bbdb1afe0b96311249e699b6dc9cfc17916394fd4503ac5aca2514b/grpc_google_iam_v1-0.14.3-py3-none-any.whl", hash = "sha256:7a7f697e017a067206a3dfef44e4c634a34d3dee135fe7d7a4613fe3e59217e6", size = 32690, upload-time = "2025-10-15T21:14:51.72Z" }, -] - [[package]] name = "grpcio" version = "1.78.0" @@ -2067,18 +1857,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ca/31/d4e37e9e550c2b92a9cbc2e4d0b7420a27224968580b5a447f420847c975/pytest_xdist-3.8.0-py3-none-any.whl", hash = "sha256:202ca578cfeb7370784a8c33d6d05bc6e13b4f25b5053c30a152269fd10f0b88", size = 46396, upload-time = "2025-07-01T13:30:56.632Z" }, ] -[[package]] -name = "python-dateutil" -version = "2.9.0.post0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "six" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/66/c0/0c8b6ad9f17a802ee498c46e004a0eb49bc148f2fd230864601a86dcf6db/python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3", size = 342432, upload-time = "2024-03-01T18:36:20.211Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427", size = 229892, upload-time = "2024-03-01T18:36:18.57Z" }, -] - [[package]] name = "python-discovery" version = "1.2.0" @@ -2217,15 +1995,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9d/76/f789f7a86709c6b087c5a2f52f911838cad707cc613162401badc665acfe/setuptools-82.0.1-py3-none-any.whl", hash = "sha256:a59e362652f08dcd477c78bb6e7bd9d80a7995bc73ce773050228a348ce2e5bb", size = 1006223, upload-time = "2026-03-09T12:47:15.026Z" }, ] -[[package]] -name = "six" -version = "1.17.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/94/e7/b2c673351809dca68a0e064b6af791aa332cf192da575fd474ed7d6f16a2/six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81", size = 34031, upload-time = "2024-12-04T17:35:28.174Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274", size = 11050, upload-time = "2024-12-04T17:35:26.475Z" }, -] - [[package]] name = "sniffio" version = "1.3.1" @@ -2348,15 +2117,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/81/0d/13d1d239a25cbfb19e740db83143e95c772a1fe10202dda4b76792b114dd/starlette-0.52.1-py3-none-any.whl", hash = "sha256:0029d43eb3d273bc4f83a08720b4912ea4b071087a3b48db01b7c839f7954d74", size = 74272, upload-time = "2026-01-18T13:34:09.188Z" }, ] -[[package]] -name = "tenacity" -version = "9.1.4" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/47/c6/ee486fd809e357697ee8a44d3d69222b344920433d3b6666ccd9b374630c/tenacity-9.1.4.tar.gz", hash = "sha256:adb31d4c263f2bd041081ab33b498309a57c77f9acf2db65aadf0898179cf93a", size = 49413, upload-time = "2026-02-07T10:45:33.841Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/d7/c1/eb8f9debc45d3b7918a32ab756658a0904732f75e555402972246b0b8e71/tenacity-9.1.4-py3-none-any.whl", hash = "sha256:6095a360c919085f28c6527de529e76a06ad89b23659fa881ae0649b867a9d55", size = 28926, upload-time = "2026-02-07T10:45:32.24Z" }, -] - [[package]] name = "tomli" version = "2.4.0" @@ -2543,74 +2303,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c6/59/7d02447a55b2e55755011a647479041bc92a82e143f96a8195cb33bd0a1c/virtualenv-21.2.0-py3-none-any.whl", hash = "sha256:1bd755b504931164a5a496d217c014d098426cddc79363ad66ac78125f9d908f", size = 5825084, upload-time = "2026-03-09T17:24:35.378Z" }, ] -[[package]] -name = "websockets" -version = "16.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/04/24/4b2031d72e840ce4c1ccb255f693b15c334757fc50023e4db9537080b8c4/websockets-16.0.tar.gz", hash = "sha256:5f6261a5e56e8d5c42a4497b364ea24d94d9563e8fbd44e78ac40879c60179b5", size = 179346, upload-time = "2026-01-10T09:23:47.181Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/20/74/221f58decd852f4b59cc3354cccaf87e8ef695fede361d03dc9a7396573b/websockets-16.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:04cdd5d2d1dacbad0a7bf36ccbcd3ccd5a30ee188f2560b7a62a30d14107b31a", size = 177343, upload-time = "2026-01-10T09:22:21.28Z" }, - { url = "https://files.pythonhosted.org/packages/19/0f/22ef6107ee52ab7f0b710d55d36f5a5d3ef19e8a205541a6d7ffa7994e5a/websockets-16.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8ff32bb86522a9e5e31439a58addbb0166f0204d64066fb955265c4e214160f0", size = 175021, upload-time = "2026-01-10T09:22:22.696Z" }, - { url = "https://files.pythonhosted.org/packages/10/40/904a4cb30d9b61c0e278899bf36342e9b0208eb3c470324a9ecbaac2a30f/websockets-16.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:583b7c42688636f930688d712885cf1531326ee05effd982028212ccc13e5957", size = 175320, upload-time = "2026-01-10T09:22:23.94Z" }, - { url = "https://files.pythonhosted.org/packages/9d/2f/4b3ca7e106bc608744b1cdae041e005e446124bebb037b18799c2d356864/websockets-16.0-cp310-cp310-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:7d837379b647c0c4c2355c2499723f82f1635fd2c26510e1f587d89bc2199e72", size = 183815, upload-time = "2026-01-10T09:22:25.469Z" }, - { url = "https://files.pythonhosted.org/packages/86/26/d40eaa2a46d4302becec8d15b0fc5e45bdde05191e7628405a19cf491ccd/websockets-16.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:df57afc692e517a85e65b72e165356ed1df12386ecb879ad5693be08fac65dde", size = 185054, upload-time = "2026-01-10T09:22:27.101Z" }, - { url = "https://files.pythonhosted.org/packages/b0/ba/6500a0efc94f7373ee8fefa8c271acdfd4dca8bd49a90d4be7ccabfc397e/websockets-16.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:2b9f1e0d69bc60a4a87349d50c09a037a2607918746f07de04df9e43252c77a3", size = 184565, upload-time = "2026-01-10T09:22:28.293Z" }, - { url = "https://files.pythonhosted.org/packages/04/b4/96bf2cee7c8d8102389374a2616200574f5f01128d1082f44102140344cc/websockets-16.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:335c23addf3d5e6a8633f9f8eda77efad001671e80b95c491dd0924587ece0b3", size = 183848, upload-time = "2026-01-10T09:22:30.394Z" }, - { url = "https://files.pythonhosted.org/packages/02/8e/81f40fb00fd125357814e8c3025738fc4ffc3da4b6b4a4472a82ba304b41/websockets-16.0-cp310-cp310-win32.whl", hash = "sha256:37b31c1623c6605e4c00d466c9d633f9b812ea430c11c8a278774a1fde1acfa9", size = 178249, upload-time = "2026-01-10T09:22:32.083Z" }, - { url = "https://files.pythonhosted.org/packages/b4/5f/7e40efe8df57db9b91c88a43690ac66f7b7aa73a11aa6a66b927e44f26fa/websockets-16.0-cp310-cp310-win_amd64.whl", hash = "sha256:8e1dab317b6e77424356e11e99a432b7cb2f3ec8c5ab4dabbcee6add48f72b35", size = 178685, upload-time = "2026-01-10T09:22:33.345Z" }, - { url = "https://files.pythonhosted.org/packages/f2/db/de907251b4ff46ae804ad0409809504153b3f30984daf82a1d84a9875830/websockets-16.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:31a52addea25187bde0797a97d6fc3d2f92b6f72a9370792d65a6e84615ac8a8", size = 177340, upload-time = "2026-01-10T09:22:34.539Z" }, - { url = "https://files.pythonhosted.org/packages/f3/fa/abe89019d8d8815c8781e90d697dec52523fb8ebe308bf11664e8de1877e/websockets-16.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:417b28978cdccab24f46400586d128366313e8a96312e4b9362a4af504f3bbad", size = 175022, upload-time = "2026-01-10T09:22:36.332Z" }, - { url = "https://files.pythonhosted.org/packages/58/5d/88ea17ed1ded2079358b40d31d48abe90a73c9e5819dbcde1606e991e2ad/websockets-16.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:af80d74d4edfa3cb9ed973a0a5ba2b2a549371f8a741e0800cb07becdd20f23d", size = 175319, upload-time = "2026-01-10T09:22:37.602Z" }, - { url = "https://files.pythonhosted.org/packages/d2/ae/0ee92b33087a33632f37a635e11e1d99d429d3d323329675a6022312aac2/websockets-16.0-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:08d7af67b64d29823fed316505a89b86705f2b7981c07848fb5e3ea3020c1abe", size = 184631, upload-time = "2026-01-10T09:22:38.789Z" }, - { url = "https://files.pythonhosted.org/packages/c8/c5/27178df583b6c5b31b29f526ba2da5e2f864ecc79c99dae630a85d68c304/websockets-16.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7be95cfb0a4dae143eaed2bcba8ac23f4892d8971311f1b06f3c6b78952ee70b", size = 185870, upload-time = "2026-01-10T09:22:39.893Z" }, - { url = "https://files.pythonhosted.org/packages/87/05/536652aa84ddc1c018dbb7e2c4cbcd0db884580bf8e95aece7593fde526f/websockets-16.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d6297ce39ce5c2e6feb13c1a996a2ded3b6832155fcfc920265c76f24c7cceb5", size = 185361, upload-time = "2026-01-10T09:22:41.016Z" }, - { url = "https://files.pythonhosted.org/packages/6d/e2/d5332c90da12b1e01f06fb1b85c50cfc489783076547415bf9f0a659ec19/websockets-16.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:1c1b30e4f497b0b354057f3467f56244c603a79c0d1dafce1d16c283c25f6e64", size = 184615, upload-time = "2026-01-10T09:22:42.442Z" }, - { url = "https://files.pythonhosted.org/packages/77/fb/d3f9576691cae9253b51555f841bc6600bf0a983a461c79500ace5a5b364/websockets-16.0-cp311-cp311-win32.whl", hash = "sha256:5f451484aeb5cafee1ccf789b1b66f535409d038c56966d6101740c1614b86c6", size = 178246, upload-time = "2026-01-10T09:22:43.654Z" }, - { url = "https://files.pythonhosted.org/packages/54/67/eaff76b3dbaf18dcddabc3b8c1dba50b483761cccff67793897945b37408/websockets-16.0-cp311-cp311-win_amd64.whl", hash = "sha256:8d7f0659570eefb578dacde98e24fb60af35350193e4f56e11190787bee77dac", size = 178684, upload-time = "2026-01-10T09:22:44.941Z" }, - { url = "https://files.pythonhosted.org/packages/84/7b/bac442e6b96c9d25092695578dda82403c77936104b5682307bd4deb1ad4/websockets-16.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:71c989cbf3254fbd5e84d3bff31e4da39c43f884e64f2551d14bb3c186230f00", size = 177365, upload-time = "2026-01-10T09:22:46.787Z" }, - { url = "https://files.pythonhosted.org/packages/b0/fe/136ccece61bd690d9c1f715baaeefd953bb2360134de73519d5df19d29ca/websockets-16.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:8b6e209ffee39ff1b6d0fa7bfef6de950c60dfb91b8fcead17da4ee539121a79", size = 175038, upload-time = "2026-01-10T09:22:47.999Z" }, - { url = "https://files.pythonhosted.org/packages/40/1e/9771421ac2286eaab95b8575b0cb701ae3663abf8b5e1f64f1fd90d0a673/websockets-16.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:86890e837d61574c92a97496d590968b23c2ef0aeb8a9bc9421d174cd378ae39", size = 175328, upload-time = "2026-01-10T09:22:49.809Z" }, - { url = "https://files.pythonhosted.org/packages/18/29/71729b4671f21e1eaa5d6573031ab810ad2936c8175f03f97f3ff164c802/websockets-16.0-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:9b5aca38b67492ef518a8ab76851862488a478602229112c4b0d58d63a7a4d5c", size = 184915, upload-time = "2026-01-10T09:22:51.071Z" }, - { url = "https://files.pythonhosted.org/packages/97/bb/21c36b7dbbafc85d2d480cd65df02a1dc93bf76d97147605a8e27ff9409d/websockets-16.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e0334872c0a37b606418ac52f6ab9cfd17317ac26365f7f65e203e2d0d0d359f", size = 186152, upload-time = "2026-01-10T09:22:52.224Z" }, - { url = "https://files.pythonhosted.org/packages/4a/34/9bf8df0c0cf88fa7bfe36678dc7b02970c9a7d5e065a3099292db87b1be2/websockets-16.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:a0b31e0b424cc6b5a04b8838bbaec1688834b2383256688cf47eb97412531da1", size = 185583, upload-time = "2026-01-10T09:22:53.443Z" }, - { url = "https://files.pythonhosted.org/packages/47/88/4dd516068e1a3d6ab3c7c183288404cd424a9a02d585efbac226cb61ff2d/websockets-16.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:485c49116d0af10ac698623c513c1cc01c9446c058a4e61e3bf6c19dff7335a2", size = 184880, upload-time = "2026-01-10T09:22:55.033Z" }, - { url = "https://files.pythonhosted.org/packages/91/d6/7d4553ad4bf1c0421e1ebd4b18de5d9098383b5caa1d937b63df8d04b565/websockets-16.0-cp312-cp312-win32.whl", hash = "sha256:eaded469f5e5b7294e2bdca0ab06becb6756ea86894a47806456089298813c89", size = 178261, upload-time = "2026-01-10T09:22:56.251Z" }, - { url = "https://files.pythonhosted.org/packages/c3/f0/f3a17365441ed1c27f850a80b2bc680a0fa9505d733fe152fdf5e98c1c0b/websockets-16.0-cp312-cp312-win_amd64.whl", hash = "sha256:5569417dc80977fc8c2d43a86f78e0a5a22fee17565d78621b6bb264a115d4ea", size = 178693, upload-time = "2026-01-10T09:22:57.478Z" }, - { url = "https://files.pythonhosted.org/packages/cc/9c/baa8456050d1c1b08dd0ec7346026668cbc6f145ab4e314d707bb845bf0d/websockets-16.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:878b336ac47938b474c8f982ac2f7266a540adc3fa4ad74ae96fea9823a02cc9", size = 177364, upload-time = "2026-01-10T09:22:59.333Z" }, - { url = "https://files.pythonhosted.org/packages/7e/0c/8811fc53e9bcff68fe7de2bcbe75116a8d959ac699a3200f4847a8925210/websockets-16.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:52a0fec0e6c8d9a784c2c78276a48a2bdf099e4ccc2a4cad53b27718dbfd0230", size = 175039, upload-time = "2026-01-10T09:23:01.171Z" }, - { url = "https://files.pythonhosted.org/packages/aa/82/39a5f910cb99ec0b59e482971238c845af9220d3ab9fa76dd9162cda9d62/websockets-16.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:e6578ed5b6981005df1860a56e3617f14a6c307e6a71b4fff8c48fdc50f3ed2c", size = 175323, upload-time = "2026-01-10T09:23:02.341Z" }, - { url = "https://files.pythonhosted.org/packages/bd/28/0a25ee5342eb5d5f297d992a77e56892ecb65e7854c7898fb7d35e9b33bd/websockets-16.0-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:95724e638f0f9c350bb1c2b0a7ad0e83d9cc0c9259f3ea94e40d7b02a2179ae5", size = 184975, upload-time = "2026-01-10T09:23:03.756Z" }, - { url = "https://files.pythonhosted.org/packages/f9/66/27ea52741752f5107c2e41fda05e8395a682a1e11c4e592a809a90c6a506/websockets-16.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c0204dc62a89dc9d50d682412c10b3542d748260d743500a85c13cd1ee4bde82", size = 186203, upload-time = "2026-01-10T09:23:05.01Z" }, - { url = "https://files.pythonhosted.org/packages/37/e5/8e32857371406a757816a2b471939d51c463509be73fa538216ea52b792a/websockets-16.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:52ac480f44d32970d66763115edea932f1c5b1312de36df06d6b219f6741eed8", size = 185653, upload-time = "2026-01-10T09:23:06.301Z" }, - { url = "https://files.pythonhosted.org/packages/9b/67/f926bac29882894669368dc73f4da900fcdf47955d0a0185d60103df5737/websockets-16.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:6e5a82b677f8f6f59e8dfc34ec06ca6b5b48bc4fcda346acd093694cc2c24d8f", size = 184920, upload-time = "2026-01-10T09:23:07.492Z" }, - { url = "https://files.pythonhosted.org/packages/3c/a1/3d6ccdcd125b0a42a311bcd15a7f705d688f73b2a22d8cf1c0875d35d34a/websockets-16.0-cp313-cp313-win32.whl", hash = "sha256:abf050a199613f64c886ea10f38b47770a65154dc37181bfaff70c160f45315a", size = 178255, upload-time = "2026-01-10T09:23:09.245Z" }, - { url = "https://files.pythonhosted.org/packages/6b/ae/90366304d7c2ce80f9b826096a9e9048b4bb760e44d3b873bb272cba696b/websockets-16.0-cp313-cp313-win_amd64.whl", hash = "sha256:3425ac5cf448801335d6fdc7ae1eb22072055417a96cc6b31b3861f455fbc156", size = 178689, upload-time = "2026-01-10T09:23:10.483Z" }, - { url = "https://files.pythonhosted.org/packages/f3/1d/e88022630271f5bd349ed82417136281931e558d628dd52c4d8621b4a0b2/websockets-16.0-cp314-cp314-macosx_10_15_universal2.whl", hash = "sha256:8cc451a50f2aee53042ac52d2d053d08bf89bcb31ae799cb4487587661c038a0", size = 177406, upload-time = "2026-01-10T09:23:12.178Z" }, - { url = "https://files.pythonhosted.org/packages/f2/78/e63be1bf0724eeb4616efb1ae1c9044f7c3953b7957799abb5915bffd38e/websockets-16.0-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:daa3b6ff70a9241cf6c7fc9e949d41232d9d7d26fd3522b1ad2b4d62487e9904", size = 175085, upload-time = "2026-01-10T09:23:13.511Z" }, - { url = "https://files.pythonhosted.org/packages/bb/f4/d3c9220d818ee955ae390cf319a7c7a467beceb24f05ee7aaaa2414345ba/websockets-16.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:fd3cb4adb94a2a6e2b7c0d8d05cb94e6f1c81a0cf9dc2694fb65c7e8d94c42e4", size = 175328, upload-time = "2026-01-10T09:23:14.727Z" }, - { url = "https://files.pythonhosted.org/packages/63/bc/d3e208028de777087e6fb2b122051a6ff7bbcca0d6df9d9c2bf1dd869ae9/websockets-16.0-cp314-cp314-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:781caf5e8eee67f663126490c2f96f40906594cb86b408a703630f95550a8c3e", size = 185044, upload-time = "2026-01-10T09:23:15.939Z" }, - { url = "https://files.pythonhosted.org/packages/ad/6e/9a0927ac24bd33a0a9af834d89e0abc7cfd8e13bed17a86407a66773cc0e/websockets-16.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:caab51a72c51973ca21fa8a18bd8165e1a0183f1ac7066a182ff27107b71e1a4", size = 186279, upload-time = "2026-01-10T09:23:17.148Z" }, - { url = "https://files.pythonhosted.org/packages/b9/ca/bf1c68440d7a868180e11be653c85959502efd3a709323230314fda6e0b3/websockets-16.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:19c4dc84098e523fd63711e563077d39e90ec6702aff4b5d9e344a60cb3c0cb1", size = 185711, upload-time = "2026-01-10T09:23:18.372Z" }, - { url = "https://files.pythonhosted.org/packages/c4/f8/fdc34643a989561f217bb477cbc47a3a07212cbda91c0e4389c43c296ebf/websockets-16.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:a5e18a238a2b2249c9a9235466b90e96ae4795672598a58772dd806edc7ac6d3", size = 184982, upload-time = "2026-01-10T09:23:19.652Z" }, - { url = "https://files.pythonhosted.org/packages/dd/d1/574fa27e233764dbac9c52730d63fcf2823b16f0856b3329fc6268d6ae4f/websockets-16.0-cp314-cp314-win32.whl", hash = "sha256:a069d734c4a043182729edd3e9f247c3b2a4035415a9172fd0f1b71658a320a8", size = 177915, upload-time = "2026-01-10T09:23:21.458Z" }, - { url = "https://files.pythonhosted.org/packages/8a/f1/ae6b937bf3126b5134ce1f482365fde31a357c784ac51852978768b5eff4/websockets-16.0-cp314-cp314-win_amd64.whl", hash = "sha256:c0ee0e63f23914732c6d7e0cce24915c48f3f1512ec1d079ed01fc629dab269d", size = 178381, upload-time = "2026-01-10T09:23:22.715Z" }, - { url = "https://files.pythonhosted.org/packages/06/9b/f791d1db48403e1f0a27577a6beb37afae94254a8c6f08be4a23e4930bc0/websockets-16.0-cp314-cp314t-macosx_10_15_universal2.whl", hash = "sha256:a35539cacc3febb22b8f4d4a99cc79b104226a756aa7400adc722e83b0d03244", size = 177737, upload-time = "2026-01-10T09:23:24.523Z" }, - { url = "https://files.pythonhosted.org/packages/bd/40/53ad02341fa33b3ce489023f635367a4ac98b73570102ad2cdd770dacc9a/websockets-16.0-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:b784ca5de850f4ce93ec85d3269d24d4c82f22b7212023c974c401d4980ebc5e", size = 175268, upload-time = "2026-01-10T09:23:25.781Z" }, - { url = "https://files.pythonhosted.org/packages/74/9b/6158d4e459b984f949dcbbb0c5d270154c7618e11c01029b9bbd1bb4c4f9/websockets-16.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:569d01a4e7fba956c5ae4fc988f0d4e187900f5497ce46339c996dbf24f17641", size = 175486, upload-time = "2026-01-10T09:23:27.033Z" }, - { url = "https://files.pythonhosted.org/packages/e5/2d/7583b30208b639c8090206f95073646c2c9ffd66f44df967981a64f849ad/websockets-16.0-cp314-cp314t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:50f23cdd8343b984957e4077839841146f67a3d31ab0d00e6b824e74c5b2f6e8", size = 185331, upload-time = "2026-01-10T09:23:28.259Z" }, - { url = "https://files.pythonhosted.org/packages/45/b0/cce3784eb519b7b5ad680d14b9673a31ab8dcb7aad8b64d81709d2430aa8/websockets-16.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:152284a83a00c59b759697b7f9e9cddf4e3c7861dd0d964b472b70f78f89e80e", size = 186501, upload-time = "2026-01-10T09:23:29.449Z" }, - { url = "https://files.pythonhosted.org/packages/19/60/b8ebe4c7e89fb5f6cdf080623c9d92789a53636950f7abacfc33fe2b3135/websockets-16.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:bc59589ab64b0022385f429b94697348a6a234e8ce22544e3681b2e9331b5944", size = 186062, upload-time = "2026-01-10T09:23:31.368Z" }, - { url = "https://files.pythonhosted.org/packages/88/a8/a080593f89b0138b6cba1b28f8df5673b5506f72879322288b031337c0b8/websockets-16.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:32da954ffa2814258030e5a57bc73a3635463238e797c7375dc8091327434206", size = 185356, upload-time = "2026-01-10T09:23:32.627Z" }, - { url = "https://files.pythonhosted.org/packages/c2/b6/b9afed2afadddaf5ebb2afa801abf4b0868f42f8539bfe4b071b5266c9fe/websockets-16.0-cp314-cp314t-win32.whl", hash = "sha256:5a4b4cc550cb665dd8a47f868c8d04c8230f857363ad3c9caf7a0c3bf8c61ca6", size = 178085, upload-time = "2026-01-10T09:23:33.816Z" }, - { url = "https://files.pythonhosted.org/packages/9f/3e/28135a24e384493fa804216b79a6a6759a38cc4ff59118787b9fb693df93/websockets-16.0-cp314-cp314t-win_amd64.whl", hash = "sha256:b14dc141ed6d2dde437cddb216004bcac6a1df0935d79656387bd41632ba0bbd", size = 178531, upload-time = "2026-01-10T09:23:35.016Z" }, - { url = "https://files.pythonhosted.org/packages/72/07/c98a68571dcf256e74f1f816b8cc5eae6eb2d3d5cfa44d37f801619d9166/websockets-16.0-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:349f83cd6c9a415428ee1005cadb5c2c56f4389bc06a9af16103c3bc3dcc8b7d", size = 174947, upload-time = "2026-01-10T09:23:36.166Z" }, - { url = "https://files.pythonhosted.org/packages/7e/52/93e166a81e0305b33fe416338be92ae863563fe7bce446b0f687b9df5aea/websockets-16.0-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:4a1aba3340a8dca8db6eb5a7986157f52eb9e436b74813764241981ca4888f03", size = 175260, upload-time = "2026-01-10T09:23:37.409Z" }, - { url = "https://files.pythonhosted.org/packages/56/0c/2dbf513bafd24889d33de2ff0368190a0e69f37bcfa19009ef819fe4d507/websockets-16.0-pp311-pypy311_pp73-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:f4a32d1bd841d4bcbffdcb3d2ce50c09c3909fbead375ab28d0181af89fd04da", size = 176071, upload-time = "2026-01-10T09:23:39.158Z" }, - { url = "https://files.pythonhosted.org/packages/a5/8f/aea9c71cc92bf9b6cc0f7f70df8f0b420636b6c96ef4feee1e16f80f75dd/websockets-16.0-pp311-pypy311_pp73-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0298d07ee155e2e9fda5be8a9042200dd2e3bb0b8a38482156576f863a9d457c", size = 176968, upload-time = "2026-01-10T09:23:41.031Z" }, - { url = "https://files.pythonhosted.org/packages/9a/3f/f70e03f40ffc9a30d817eef7da1be72ee4956ba8d7255c399a01b135902a/websockets-16.0-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:a653aea902e0324b52f1613332ddf50b00c06fdaf7e92624fbf8c77c78fa5767", size = 178735, upload-time = "2026-01-10T09:23:42.259Z" }, - { url = "https://files.pythonhosted.org/packages/6f/28/258ebab549c2bf3e64d2b0217b973467394a9cea8c42f70418ca2c5d0d2e/websockets-16.0-py3-none-any.whl", hash = "sha256:1637db62fad1dc833276dded54215f2c7fa46912301a24bd94d45d46a011ceec", size = 171598, upload-time = "2026-01-10T09:23:45.395Z" }, -] - [[package]] name = "wrapt" version = "2.1.2" From 01b812b3e14cabba134a91955195271c399353ea Mon Sep 17 00:00:00 2001 From: Ivan Shymko Date: Mon, 20 Apr 2026 16:32:46 +0200 Subject: [PATCH 48/67] test: remove stale pytest.skip (#1000) Skipped tests pass, remove the skip. --- tests/client/test_auth_interceptor.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/client/test_auth_interceptor.py b/tests/client/test_auth_interceptor.py index 11d932090..560751fa8 100644 --- a/tests/client/test_auth_interceptor.py +++ b/tests/client/test_auth_interceptor.py @@ -240,7 +240,6 @@ class AuthTestCase: ) -@pytest.mark.skip(reason='Interceptors disabled by user request') @pytest.mark.asyncio @pytest.mark.parametrize( 'test_case', From cb95424cb2574faa81e9f0a44761e4ff6894f1a5 Mon Sep 17 00:00:00 2001 From: Ivan Shymko Date: Mon, 20 Apr 2026 16:44:37 +0200 Subject: [PATCH 49/67] ci: add some path filters (#1001) Do not run heavy jobs for docs and other unrelated changes. --- .github/workflows/linter.yaml | 11 +++++++++++ .github/workflows/minimal-install.yml | 22 ++++++++++++++++++++++ .github/workflows/run-tck.yaml | 18 ++++++++++++++++++ .github/workflows/unit-tests.yml | 22 ++++++++++++++++++++++ 4 files changed, 73 insertions(+) diff --git a/.github/workflows/linter.yaml b/.github/workflows/linter.yaml index 4263abb3c..ec4bd16fb 100644 --- a/.github/workflows/linter.yaml +++ b/.github/workflows/linter.yaml @@ -3,6 +3,17 @@ name: Lint Code Base on: pull_request: branches: [main, 1.0-dev] + paths-ignore: + - '**.md' + - 'LICENSE' + - 'docs/**' + - '.github/CODEOWNERS' + - '.github/ISSUE_TEMPLATE/**' + - '.github/PULL_REQUEST_TEMPLATE.md' + - '.github/dependabot.yml' + - '.gitignore' + - '.git-blame-ignore-revs' + - '.gemini/**' permissions: contents: read jobs: diff --git a/.github/workflows/minimal-install.yml b/.github/workflows/minimal-install.yml index 7e0f143c6..27afebe7e 100644 --- a/.github/workflows/minimal-install.yml +++ b/.github/workflows/minimal-install.yml @@ -3,7 +3,29 @@ name: Minimal Install Smoke Test on: push: branches: [main, 1.0-dev] + paths-ignore: + - '**.md' + - 'LICENSE' + - 'docs/**' + - '.github/CODEOWNERS' + - '.github/ISSUE_TEMPLATE/**' + - '.github/PULL_REQUEST_TEMPLATE.md' + - '.github/dependabot.yml' + - '.gitignore' + - '.git-blame-ignore-revs' + - '.gemini/**' pull_request: + paths-ignore: + - '**.md' + - 'LICENSE' + - 'docs/**' + - '.github/CODEOWNERS' + - '.github/ISSUE_TEMPLATE/**' + - '.github/PULL_REQUEST_TEMPLATE.md' + - '.github/dependabot.yml' + - '.gitignore' + - '.git-blame-ignore-revs' + - '.gemini/**' permissions: contents: read diff --git a/.github/workflows/run-tck.yaml b/.github/workflows/run-tck.yaml index 6d0df865f..62bbeebc0 100644 --- a/.github/workflows/run-tck.yaml +++ b/.github/workflows/run-tck.yaml @@ -3,12 +3,30 @@ name: Run TCK on: push: branches: [ "main" ] + paths-ignore: + - '**.md' + - 'LICENSE' + - 'docs/**' + - '.github/CODEOWNERS' + - '.github/ISSUE_TEMPLATE/**' + - '.github/PULL_REQUEST_TEMPLATE.md' + - '.github/dependabot.yml' + - '.gitignore' + - '.git-blame-ignore-revs' + - '.gemini/**' pull_request: branches: [ "main" ] paths-ignore: - '**.md' - 'LICENSE' + - 'docs/**' - '.github/CODEOWNERS' + - '.github/ISSUE_TEMPLATE/**' + - '.github/PULL_REQUEST_TEMPLATE.md' + - '.github/dependabot.yml' + - '.gitignore' + - '.git-blame-ignore-revs' + - '.gemini/**' permissions: contents: read diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index adabe0676..098a14ecc 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -3,7 +3,29 @@ name: Run Unit Tests on: push: branches: [main, 1.0-dev] + paths-ignore: + - '**.md' + - 'LICENSE' + - 'docs/**' + - '.github/CODEOWNERS' + - '.github/ISSUE_TEMPLATE/**' + - '.github/PULL_REQUEST_TEMPLATE.md' + - '.github/dependabot.yml' + - '.gitignore' + - '.git-blame-ignore-revs' + - '.gemini/**' pull_request: + paths-ignore: + - '**.md' + - 'LICENSE' + - 'docs/**' + - '.github/CODEOWNERS' + - '.github/ISSUE_TEMPLATE/**' + - '.github/PULL_REQUEST_TEMPLATE.md' + - '.github/dependabot.yml' + - '.gitignore' + - '.git-blame-ignore-revs' + - '.gemini/**' permissions: contents: read From 10dea8b4448c5cb7d9e72d74677fd60880cc38df Mon Sep 17 00:00:00 2001 From: Iva Sokolaj <102302011+sokoliva@users.noreply.github.com> Date: Mon, 20 Apr 2026 17:25:03 +0200 Subject: [PATCH 50/67] docs: add comprehensive v0.3 to v1.0 migration guide (#987) # Description This PR adds detailed documentation for migrating A2A-compliant applications from version v0.3 to v1.0. The guide covers the transition to full A2A Protocol v1.0 compatibility, including major architectural shifts and developer experience improvements. ### Key Areas Covered: * **Dependency Management**: Instructions for upgrading to a2a-sdk>=1.0.0 using uv or pip. * **Type System Transition**: Detailed mapping of the move from Pydantic models to Protobuf-based classes, including the standardization of enum values to `SCREAMING_SNAKE_CASE`. * **Server-Side Refactoring**: * Transition from application wrappers (A2AStarletteApplication, etc.) to flexible route factory functions. * Updated DefaultRequestHandler signature requiring agent_card. * **Client Improvements**: * New create_client() factory function replacing the legacy A2AClient class. * Standardization of send_message() return types to AsyncIterator[StreamResponse]. * **Backward Compatibility**: Strategies for running v1.0 servers that simultaneously support v0.3 clients during transition periods. * **New Helper Utilities**: Introduction of the a2a.helpers module to simplify object construction and data extraction. ### Why this is important: The v1.0 release introduces several breaking changes to align with the latest protocol specification and improve performance through Protobuf integration. This guide is essential for existing users to successfully navigate the upgrade path while minimizing downtime. --------- Co-authored-by: Sampath Kumar --- .github/actions/spelling/allow.txt | 1 + docs/migrations/v1_0/README.md | 442 +++++++++++++++++++++++++++++ 2 files changed, 443 insertions(+) create mode 100644 docs/migrations/v1_0/README.md diff --git a/.github/actions/spelling/allow.txt b/.github/actions/spelling/allow.txt index b3b2d56e8..03774d1f0 100644 --- a/.github/actions/spelling/allow.txt +++ b/.github/actions/spelling/allow.txt @@ -85,6 +85,7 @@ lifecycles linting Llm lstrips +mcp middleware mikeas mockurl diff --git a/docs/migrations/v1_0/README.md b/docs/migrations/v1_0/README.md new file mode 100644 index 000000000..34b2f1bed --- /dev/null +++ b/docs/migrations/v1_0/README.md @@ -0,0 +1,442 @@ +# A2A Python SDK Migration Guide: v0.3 → v1.0 + +The `a2a-sdk` has achieved a major milestone in stability and reliability with the update to full **A2A Protocol v1.0 compatibility**. This guide provides a detailed overview of the breaking changes in version `v1.0` and instructions for migrating your codebase. + +Beyond protocol support, `v1.0` enhances the developer experience by introducing unified helper utilities for easier object creation and adopting Starlette route factory functions for more flexible server configuration. + +This documentation details the technical upgrades and architectural modifications introduced in A2A Python SDK v1.0. For developers using the database persistence layer, please refer to the [Database Migration Guide](database/) for specific update instructions. + +--- + +## Table of Contents + +1. [Update Dependencies](#1-update-dependencies) +2. [Types](#2-types) +3. [Server: DefaultRequestHandler](#3-server-defaultrequesthandler) +4. [Server: Application Setup](#4-server-application-setup) +5. [Supporting v0.3 Clients](#5-supporting-v03-clients) +6. [Client: Creating a Client](#6-client-creating-a-client) +7. [Client: Send Message](#7-client-send-message) +8. [Client: Push Notifications Config](#8-client-push-notifications-config) +9. [Helper Utilities](#9-helper-utilities) +10. [Summary of Key Changes](#10-summary-of-key-changes-in-v10) +11. [Get Started](#11-get-started) + +--- + +## 1. Update Dependencies + +(UV users) To upgrade to the latest version of the `a2a-sdk`, update the dependencies section in your `pyproject.toml` file. + +| File | Before (`v0.3`) | After (`v1.0`) | +|------------------|-----------------------------------|-----------------------------------| +| `pyproject.toml` | dependencies = ["a2a-sdk>=0.3.0"] | dependencies = ["a2a-sdk>=1.0.0"] | + +**Installation** + +After updating your configuration file, sync your environment: + +* Using UV: + +```bash +uv sync +``` + +* Using pip: + +```bash +pip install --upgrade a2a-sdk +``` + +--- + +## 2. Types + +Types have migrated from Pydantic models to Protobuf-based classes. + + +### Enum values: snake_case → SCREAMING_SNAKE_CASE + +All the enum values are now standardized from snake_case to **SCREAMING_SNAKE_CASE** format. + +This affects every enum in the SDK: `TaskState`, `Role`. + +| Enum | v0.3 | v1.0 | +|---|---|---| +| `TaskState` | *(no equivalent — protobuf default)* | `TaskState.TASK_STATE_UNSPECIFIED` | +| `TaskState` | `TaskState.submitted` | `TaskState.TASK_STATE_SUBMITTED` | +| `TaskState` | `TaskState.working` | `TaskState.TASK_STATE_WORKING` | +| `TaskState` | `TaskState.completed` | `TaskState.TASK_STATE_COMPLETED` | +| `TaskState` | `TaskState.failed` | `TaskState.TASK_STATE_FAILED` | +| `TaskState` | `TaskState.canceled` | `TaskState.TASK_STATE_CANCELED` | +| `TaskState` | `TaskState.input_required` | `TaskState.TASK_STATE_INPUT_REQUIRED` | +| `TaskState` | `TaskState.auth_required` | `TaskState.TASK_STATE_AUTH_REQUIRED` | +| `TaskState` | `TaskState.rejected` | `TaskState.TASK_STATE_REJECTED` | +||| +| `Role` | *(no equivalent — protobuf default)* | `Role.ROLE_UNSPECIFIED` | +| `Role` | `Role.user` | `Role.ROLE_USER` | +| `Role` | `Role.agent` | `Role.ROLE_AGENT` | + +> **Example**: [`a2a-mcp-without-framework/server/agent_executor.py` in PR #509](https://github.com/a2aproject/a2a-samples/pull/509/changes#diff-1f9b098f9f82ee40666ee61db56dc2246281423c445bcf017079c53a0a05954f) + +### Message and Part construction + +Constructing messages is simplified in v1.0. The old API required wrapping content in an intermediate type (`TextPart`, `FilePart`, `DataPart`) before placing it inside a `Part`. In v1.0, `Part` is a single unified message — set the content type directly on it and the wrapper types are gone entirely. + +Key differences: +- `Part(TextPart(text=...))` → `Part(text=...)` (flat union field) +- `Role.user` → `Role.ROLE_USER`, `Role.agent` → `Role.ROLE_AGENT` + +**Before (v0.3):** +```python +from a2a.types import Message, Part, Role, TextPart +from uuid import uuid4 + +message = Message( + role=Role.user, + parts=[Part(TextPart(text="Hello"))], + message_id=uuid4().hex, + task_id=uuid4().hex, +) +``` + +**After (v1.0):** + +Using [A2A helper utilities](#helper-utilities) + +```python +from a2a.helpers import new_text_message +from a2a.types import Role + +# Use the helper function to create `Hello` message +message = new_text_message(text="Hello", role=Role.ROLE_USER) + +``` + +Without helper utils, you can still construct directly + +```python +from a2a.types import Message, Part, Role +from uuid import uuid4 + +message = Message( + role=Role.ROLE_USER, + parts=[Part(text="Hello")], + message_id=uuid4().hex, + task_id=uuid4().hex, +) +``` + +> **Example**: [`helloworld/test_client.py` in PR #474](https://github.com/a2aproject/a2a-samples/pull/474/files#diff-f62c07d3b00364a3100b7effb3e2a1cca0624277d3e40da1bdb07bb46b6a8cef) + +### AgentCard Structure + +The new `AgentCard` can supports multiple transport bindings using `AgentInterface` class. + +Key differences: +- `url` is gone; use `supported_interfaces` with one or more `AgentInterface` entries +- `AgentCapabilities.input_modes` and `AgentCapabilities.output_modes` are removed; use `AgentCard.default_input_modes` / `AgentCard.default_output_modes` for card-level defaults, or `AgentSkill.input_modes` / `AgentSkill.output_modes` for per-skill overrides +- `supports_authenticated_extended_card` is no longer a top-level `AgentCard` field; it has moved into `AgentCapabilities` and is renamed to `extended_agent_card` +- `AgentInterface.protocol_binding` accepted values: `'JSONRPC'`, `'HTTP+JSON'`, `'GRPC'` +- `examples` field was removed; set it per `AgentSkill` instead + +**Before (v0.3):** +```python +from a2a.types import AgentCard, AgentCapabilities, AgentSkill + +agent_card = AgentCard( + name='My Agent', + description='...', + url='http://localhost:9999/', + version='1.0.0', + default_input_modes=['text/plain'], + default_output_modes=['text/plain'], + supports_authenticated_extended_card=True, + capabilities=AgentCapabilities( + input_modes=['text/plain'], + output_modes=['text/plain'], + streaming=True, + ), + skills=[skill], + examples=['example'], +) +``` + +**After (v1.0):** +```python +from a2a.types import AgentCard, AgentCapabilities, AgentInterface, AgentSkill, + +agent_card = AgentCard( + name='My Agent', + description='...', + supported_interfaces=[ + # JSON-RPC + AgentInterface( + protocol_binding='JSONRPC', + url='http://localhost:41241/a2a/jsonrpc/', + ), + # GRPC + AgentInterface( + protocol_binding='GRPC', + url='http://localhost:50051/a2a/grpc/', + ) + ], + version='1.0.0', + default_input_modes=['text/plain'], + default_output_modes=['text/plain'], + capabilities=AgentCapabilities( + streaming=True, + extended_agent_card=True, + ), + skills=[skill], +) +``` + +> **Example**: [`a2a-mcp-without-framework/server/__main__.py` in PR #509](https://github.com/a2aproject/a2a-samples/pull/509/files#diff-d15d39ae64c3d4e3a36cc6fb442302caf4e32a6dbd858792e7a4bed180a625ac) + +--- + +## 3. Server: DefaultRequestHandler + +### Constructor signature: `agent_card` is now required + +`DefaultRequestHandler` now requires `agent_card` as a constructor argument (it was previously passed to the application wrapper). + +**Before (v0.3):** +```python +request_handler = DefaultRequestHandler( + agent_executor=MyAgentExecutor(), + task_store=InMemoryTaskStore(), +) +``` + +**After (v1.0):** +```python +request_handler = DefaultRequestHandler( + agent_executor=MyAgentExecutor(), + task_store=InMemoryTaskStore(), + agent_card=agent_card, +) +``` + +> **Example**: [`a2a-mcp-without-framework/server/__main__.py` in PR #509](https://github.com/a2aproject/a2a-samples/pull/509/files#diff-d15d39ae64c3d4e3a36cc6fb442302caf4e32a6dbd858792e7a4bed180a625ac) + +--- + +## 4. Server: Application Setup + +The wrapper classes (`A2AStarletteApplication`, `A2AFastApiApplication` and `A2ARESTFastApiApplication`) are now removed. The Server setup now uses Starlette route factory functions directly, giving you full control over the routing. + +**Before (v0.3):** +```python +from a2a.server.apps import A2AStarletteApplication +import uvicorn + +server = A2AStarletteApplication( + agent_card=agent_card, + http_handler=request_handler, +) +uvicorn.run(server.build(), host=host, port=port) +``` + +**After (v1.0):** +```python +from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes +from starlette.applications import Starlette +import uvicorn + +routes = [] +routes.extend(create_agent_card_routes(agent_card)) +routes.extend(create_jsonrpc_routes(request_handler, rpc_url='/')) + +app = Starlette(routes=routes) +uvicorn.run(app, host=host, port=port) +``` + +If you need REST transport in addition to JSON-RPC: +```python +from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes, create_rest_routes +from starlette.applications import Starlette +import uvicorn + +routes = [] +routes.extend(create_agent_card_routes(agent_card)) +routes.extend(create_jsonrpc_routes(request_handler, rpc_url='/')) +routes.extend(create_rest_routes(request_handler)) + +app = Starlette(routes=routes) +uvicorn.run(app, host=host, port=port) +``` + +> **Example**: [`a2a-mcp-without-framework/server/__main__.py` in PR #509](https://github.com/a2aproject/a2a-samples/pull/509/files#diff-d15d39ae64c3d4e3a36cc6fb442302caf4e32a6dbd858792e7a4bed180a625ac) + +--- + +## 5. Supporting v0.3 Clients + +If you cannot update all clients at once, you can run a v1.0 server that simultaneously accepts v0.3 connections. Two changes are needed. + +**1. Add the v0.3 AgentInterface to `supported_interfaces` in your `AgentCard`**: + +```python +supported_interfaces=[ + AgentInterface(protocol_binding='JSONRPC', protocol_version='0.3', url='http://localhost:9999/'), +] +``` + +**2. Enable the compat flag** on the relevant route factory: + +```python +create_jsonrpc_routes(request_handler, rpc_url='/', enable_v0_3_compat=True) +create_rest_routes(request_handler, enable_v0_3_compat=True) +``` + +> For a full working example see [`samples/hello_world_agent.py`](../../../samples/hello_world_agent.py). For known limitations see [issue #742](https://github.com/a2aproject/a2a-python/issues/742). + +--- + +## 6. Client: Creating a Client + +New `create_client()` `ClientFactory` function that creates a client for the agent. + +> **Note**: The legacy `A2AClient` class has been removed. + +**Before (v0.3):** +```python +from a2a.client import ClientFactory + +# From URL +factory = ClientFactory() +client = factory.create_client('http://localhost:9999/') + +# From an already-resolved AgentCard +factory = ClientFactory() +client = factory.create_client(agent_card) +``` + +**After (v1.0):** +```python +from a2a.client import create_client + +# From URL — resolves the agent card automatically +client = await create_client('http://localhost:9999/') + +# From an already-resolved AgentCard +client = await create_client(agent_card) +``` + + +> **Example**: [`a2a-mcp-without-framework/client/agent.py` in PR #509](https://github.com/a2aproject/a2a-samples/pull/509/files#diff-56cfce97ff9686166e4b14790ffb7ed46f4c14519261ce5c18365a53cf05e9aa) (`create_client()` usage) + +--- + +## 7. Client: Send Message + +The `BaseClient.send_message()` return type is standardised from `AsyncIterator[ClientEvent | Message]` to `AsyncIterator[StreamResponse]`. + +Each `StreamResponse` yields exactly one of: `task`, `message`, `status_update`, or `artifact_update`. Use `HasField()` to check which field is set. + + +**Before (v0.3):** +```python +async for event, message in client.send_message(request): + if isinstance(event, Task): + ... + if isinstance(event, UpdateEvent): + ... + if message: + ... +``` + +**After (v1.0):** +```python +async for chunk in client.send_message(request): + if chunk.HasField('artifact_update'): + ... + elif chunk.HasField('status_update'): + ... + elif chunk.HasField('task'): + ... + elif chunk.HasField('message'): + ... +``` + + +--- + +## 8. Client: Push Notifications Config + +`ClientConfig.push_notification_config` is now **singular** (a single `TaskPushNotificationConfig` or `None`), not a list. + +**Before (v0.3):** +```python +config = ClientConfig( + push_notification_configs=[my_push_config], +) +``` + +**After (v1.0):** +```python +config = ClientConfig( + push_notification_config=my_push_config, +) +``` + +--- + +## 9. Helper Utilities + +A new `a2a.helpers` module consolidates helper functions into a single import. Most were previously available under `a2a.utils.*`; a few are new in v1.0. + +```python +from a2a.helpers import ( + display_agent_card, # print a human-readable summary of an AgentCard to stdout + get_artifact_text, # join all text parts of an Artifact into a single string (delimiter='\n') + get_message_text, # join all text parts of a Message into a single string (delimiter='\n') + get_stream_response_text, # extract text from a StreamResponse proto message + get_text_parts, # return a list of raw text strings from a sequence of Parts (skips non-text parts) + new_artifact, # create an Artifact from a list of Parts, name, optional description and artifact_id + new_message, # create a Message from a list of Parts with role (default ROLE_AGENT), optional task_id/context_id + new_task, # create a Task with explicit task_id, context_id, and state + new_task_from_user_message, # create a TASK_STATE_SUBMITTED Task from a user Message; raises if role != ROLE_USER or parts are empty + new_text_artifact, # create an Artifact with a single text Part, name, optional description and artifact_id + new_text_artifact_update_event, # create a TaskArtifactUpdateEvent with a text artifact + new_text_message, # create a Message with a single text Part; role defaults to ROLE_AGENT + new_text_status_update_event, # create a TaskStatusUpdateEvent with a text message +) +``` + +--- + +## 10. Summary of Key Changes in v1.0 + +- **Standardisation to `SCREAMING_SNAKE_CASE`** — All enum values have been renamed from `kebab-case` strings to `SCREAMING_SNAKE_CASE` for compliance with the ProtoJSON specification. +- **`AgentCard`** — Significantly restructured to support multiple transport interfaces. + - **`AgentInterface`** — The top-level `url` field is replaced by `supported_interfaces`, a list of `AgentInterface` objects. Each entry describes a single transport endpoint carrying `protocol_binding`, `protocol_version`, and `url`. + - **Input and output modes** — `AgentCapabilities.input_modes` and `AgentCapabilities.output_modes` are removed and now live directly on `AgentCard` as `default_input_modes` and `default_output_modes`. Individual skills can override these with their own `input_modes` and `output_modes`. +- **Application setup** — The wrapper classes (`A2AStarletteApplication`, `A2AFastApiApplication` and `A2ARESTFastApiApplication`) are now removed. Server setup now uses route factory functions `create_jsonrpc_routes()`, `create_rest_routes()`, `create_agent_card_routes()` composed directly into a Starlette or FastAPI app. +- **Helper utilities** — A new `a2a.helpers` module consolidates all helper functions under a single import, replacing the scattered `a2a.utils.*` modules and adding new helpers for constructing and reading v1.0 proto types. + +--- + +## 11. Get Started + +The fastest way to see v1.0 in action is to run the samples: + +| File | Role | Description | +|---|---|---| +| [`samples/hello_world_agent.py`](../../../samples/hello_world_agent.py) | **Server** | A2A agent exposing JSON-RPC, REST, and gRPC — with v0.3 compat enabled | +| [`samples/cli.py`](../../../samples/cli.py) | **Client** | Interactive terminal client; supports all three transports | + +```bash +# In one terminal — start the agent: +uv run python samples/hello_world_agent.py + +# In another — connect with the CLI: +uv run python samples/cli.py +``` + +Then type a message like `hello` and press Enter. See [`samples/README.md`](../../../samples/README.md) for full details. + +For more examples see the [a2a-samples repository](https://github.com/a2aproject/a2a-samples/tree/main/samples/python). + From 864867a578bc27a7ab76a8078d0deb7e0ecce96d Mon Sep 17 00:00:00 2001 From: Ivan Shymko Date: Mon, 20 Apr 2026 17:27:12 +0200 Subject: [PATCH 51/67] chore: update readme for 1.0 (#998) --- README.md | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index b7a60fe3b..37aed9798 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,10 @@ --- +> [!IMPORTANT] +> **Upgrading the SDK from `0.3` to `1.0`?** See the [**v0.3 → v1.0 migration guide**](docs/migrations/v1_0/README.md). For supported A2A spec versions, see [Compatibility](#-compatibility). + + ## ✨ Features - **A2A Protocol Compliant:** Build agentic applications that adhere to the Agent2Agent (A2A) Protocol. @@ -36,16 +40,16 @@ ## 🧩 Compatibility -This SDK implements the A2A Protocol Specification [`0.3`](https://a2a-protocol.org/v0.3.0/specification). - -> [!IMPORTANT] -> There is an [**alpha version**](https://github.com/a2aproject/a2a-python/releases?q=%22v1.0.0-alpha%22&expanded=true) available with support for both [`1.0`](https://a2a-protocol.org/v1.0.0/specification/) and [`0.3`](https://a2a-protocol.org/v0.3.0/specification) versions. Development for this version is taking place in the [`1.0-dev`](https://github.com/a2aproject/a2a-python/tree/1.0-dev) branch, tracked in [#701](https://github.com/a2aproject/a2a-python/issues/701). +This SDK implements the A2A Protocol Specification [`1.0`](https://a2a-protocol.org/v1.0.0/specification/), with compatibility mode for [`0.3`](https://a2a-protocol.org/v0.3.0/specification). See [#742](https://github.com/a2aproject/a2a-python/issues/742) for details on the compatibility scope. -| Transport | Client | Server | -| :--- | :---: | :---: | -| **JSON-RPC** | ✅ | ✅ | -| **HTTP+JSON/REST** | ✅ | ✅ | -| **GRPC** | ✅ | ✅ | +| Spec Version | Transport | Client | Server | +| :--- | :--- | :---: | :---: | +| **`1.0`** | JSON-RPC | ✅ | ✅ | +| **`1.0`** | HTTP+JSON/REST | ✅ | ✅ | +| **`1.0`** | gRPC | ✅ | ✅ | +| **`0.3`** (compat) | JSON-RPC | ✅ | ✅ | +| **`0.3`** (compat) | HTTP+JSON/REST | ✅ | ✅ | +| **`0.3`** (compat) | gRPC | ✅ | ✅ | --- From 24db37ee24c927df936289ad6ffbc8c746a44db8 Mon Sep 17 00:00:00 2001 From: "Agent2Agent (A2A) Bot" Date: Mon, 20 Apr 2026 10:33:17 -0500 Subject: [PATCH 52/67] chore(main): release 1.0.0 (#997) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit :robot: I have created a release *beep* *boop* --- ## [1.0.0](https://github.com/a2aproject/a2a-python/compare/v1.0.0-alpha.3...v1.0.0) (2026-04-20) See the [**v0.3 → v1.0 migration guide**](docs/migrations/v1_0/README.md) and changelog entries for alpha versions below. ### ⚠ BREAKING CHANGES * remove Vertex AI Task Store integration ([#999](https://github.com/a2aproject/a2a-python/issues/999)) ### Bug Fixes * rely on agent executor implementation for stream termination ([#988](https://github.com/a2aproject/a2a-python/issues/988)) ([d77cd68](https://github.com/a2aproject/a2a-python/commit/d77cd68f5e69b0ffccaca5e3deab4c1a397cfe9c)) ### Documentation * add comprehensive v0.3 to v1.0 migration guide ([#987](https://github.com/a2aproject/a2a-python/issues/987)) ([10dea8b](https://github.com/a2aproject/a2a-python/commit/10dea8b4448c5cb7d9e72d74677fd60880cc38df)) ### Miscellaneous Chores * release 1.0.0 ([530ec37](https://github.com/a2aproject/a2a-python/commit/530ec37f4c4580095c2411e40740ca0186fd1240)) * remove Vertex AI Task Store integration ([#999](https://github.com/a2aproject/a2a-python/issues/999)) ([7fce2ad](https://github.com/a2aproject/a2a-python/commit/7fce2ada1eb331e230925993758e8c7663da9a13)) --- This PR was generated with [Release Please](https://github.com/googleapis/release-please). See [documentation](https://github.com/googleapis/release-please#release-please). --------- Co-authored-by: Ivan Shymko --- CHANGELOG.md | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3e3b43a3a..45ea7031d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,28 @@ # Changelog +## [1.0.0](https://github.com/a2aproject/a2a-python/compare/v1.0.0-alpha.3...v1.0.0) (2026-04-20) + +See the [**v0.3 → v1.0 migration guide**](docs/migrations/v1_0/README.md) and changelog entries for alpha versions below. + +### ⚠ BREAKING CHANGES + +* remove Vertex AI Task Store integration ([#999](https://github.com/a2aproject/a2a-python/issues/999)) + +### Bug Fixes + +* rely on agent executor implementation for stream termination ([#988](https://github.com/a2aproject/a2a-python/issues/988)) ([d77cd68](https://github.com/a2aproject/a2a-python/commit/d77cd68f5e69b0ffccaca5e3deab4c1a397cfe9c)) + + +### Documentation + +* add comprehensive v0.3 to v1.0 migration guide ([#987](https://github.com/a2aproject/a2a-python/issues/987)) ([10dea8b](https://github.com/a2aproject/a2a-python/commit/10dea8b4448c5cb7d9e72d74677fd60880cc38df)) + + +### Miscellaneous Chores + +* release 1.0.0 ([530ec37](https://github.com/a2aproject/a2a-python/commit/530ec37f4c4580095c2411e40740ca0186fd1240)) +* remove Vertex AI Task Store integration ([#999](https://github.com/a2aproject/a2a-python/issues/999)) ([7fce2ad](https://github.com/a2aproject/a2a-python/commit/7fce2ada1eb331e230925993758e8c7663da9a13)) + ## [1.0.0-alpha.3](https://github.com/a2aproject/a2a-python/compare/v1.0.0-alpha.2...v1.0.0-alpha.3) (2026-04-17) From 6b46ceb3e036290ea2b0764b1697f2901ad2df08 Mon Sep 17 00:00:00 2001 From: Ivan Shymko Date: Wed, 22 Apr 2026 09:36:28 +0200 Subject: [PATCH 53/67] fix(compat): avoid unconditional grpc import in v0.3 context builders (#1006) Currently it causes `no module named 'grpc'` when using HTTP machinery. Relevant "install and import" testing was extended to allow testing "extras" to catch this in the future. --- .github/actions/spelling/allow.txt | 11 ++- ...{minimal-install.yml => install-smoke.yml} | 27 +++-- ...nimal_install.py => test_install_smoke.py} | 55 +++++++++-- scripts/test_install_smoke.sh | 98 +++++++++++++++++++ src/a2a/compat/v0_3/context_builders.py | 15 +-- 5 files changed, 171 insertions(+), 35 deletions(-) rename .github/workflows/{minimal-install.yml => install-smoke.yml} (60%) rename scripts/{test_minimal_install.py => test_install_smoke.py} (51%) create mode 100755 scripts/test_install_smoke.sh diff --git a/.github/actions/spelling/allow.txt b/.github/actions/spelling/allow.txt index 03774d1f0..73818db59 100644 --- a/.github/actions/spelling/allow.txt +++ b/.github/actions/spelling/allow.txt @@ -101,6 +101,8 @@ openapiv2 opensource otherurl pb2 +podman +Podman poolclass postgres POSTGRES @@ -128,6 +130,8 @@ socio sse starlette Starlette +subgids +subuids sut SUT swagger @@ -139,9 +143,6 @@ tiangolo TResponse typ typeerror -vulnz -Podman -podman UIDs -subuids -subgids +vulnz +whl diff --git a/.github/workflows/minimal-install.yml b/.github/workflows/install-smoke.yml similarity index 60% rename from .github/workflows/minimal-install.yml rename to .github/workflows/install-smoke.yml index 27afebe7e..0b9781b2c 100644 --- a/.github/workflows/minimal-install.yml +++ b/.github/workflows/install-smoke.yml @@ -1,5 +1,5 @@ --- -name: Minimal Install Smoke Test +name: Install Smoke Test on: push: branches: [main, 1.0-dev] @@ -30,13 +30,18 @@ permissions: contents: read jobs: - minimal-install: - name: Verify base-only install + install-smoke: + name: Verify ${{ matrix.profile.name }} install runs-on: ubuntu-latest if: github.repository == 'a2aproject/a2a-python' strategy: matrix: python-version: ['3.10', '3.11', '3.12', '3.13', '3.14'] + profile: + - name: base + extras: '' + - name: http-server + extras: '[http-server]' steps: - name: Checkout code uses: actions/checkout@v6 @@ -49,15 +54,17 @@ jobs: - name: Build package run: uv build --wheel - - name: Install with base dependencies only + - name: Install with ${{ matrix.profile.name }} dependencies only run: | - uv venv .venv-minimal - # Install only the built wheel -- no extras, no dev deps. - # This simulates what an end-user gets with `pip install a2a-sdk`. - VIRTUAL_ENV=.venv-minimal uv pip install dist/*.whl + uv venv .venv-smoke + # Install only the built wheel + the profile's extras -- no + # dev deps. This simulates what an end-user gets with + # `pip install a2a-sdk${{ matrix.profile.extras }}`. + WHEEL=$(ls dist/*.whl) + VIRTUAL_ENV=.venv-smoke uv pip install "${WHEEL}${{ matrix.profile.extras }}" - name: List installed packages - run: VIRTUAL_ENV=.venv-minimal uv pip list + run: VIRTUAL_ENV=.venv-smoke uv pip list - name: Run import smoke test - run: .venv-minimal/bin/python scripts/test_minimal_install.py + run: .venv-smoke/bin/python scripts/test_install_smoke.py ${{ matrix.profile.name }} diff --git a/scripts/test_minimal_install.py b/scripts/test_install_smoke.py similarity index 51% rename from scripts/test_minimal_install.py rename to scripts/test_install_smoke.py index 84e3ee3fc..df33c8386 100755 --- a/scripts/test_minimal_install.py +++ b/scripts/test_install_smoke.py @@ -1,18 +1,24 @@ #!/usr/bin/env python3 -"""Smoke test for minimal (base-only) installation of a2a-sdk. +"""Smoke test for installations of a2a-sdk with various extras. -This script verifies that all core public API modules can be imported -when only the base dependencies are installed (no optional extras). +This script verifies that the public API modules associated with a +given installation profile can be imported without pulling in modules +that belong to other (uninstalled) optional extras. It is designed to run WITHOUT pytest or any dev dependencies -- just -a clean venv with `pip install a2a-sdk`. +a clean venv with `pip install a2a-sdk[]`. Usage: - python scripts/test_minimal_install.py + python scripts/test_install_smoke.py [profile] + + profile defaults to "base" and selects which set of modules to + smoke-test. Available profiles: + base -- `pip install a2a-sdk` + http-server -- `pip install a2a-sdk[http-server]` Exit codes: - 0 - All core imports succeeded - 1 - One or more core imports failed + 0 - All imports for the profile succeeded + 1 - One or more imports failed """ from __future__ import annotations @@ -58,19 +64,48 @@ 'a2a.helpers.proto_helpers', ] +# Modules that MUST be importable with only the base + `http-server` +# extras installed (no `grpc`, `sql`, `signing`, `telemetry`, etc.). +# +# A user building a Starlette/FastAPI A2A server with +# `pip install a2a-sdk[http-server]` should be able to import these +# without the gRPC stack being present on the system. +HTTP_SERVER_MODULES = [ + 'a2a.server.routes', + 'a2a.server.routes.agent_card_routes', + 'a2a.server.routes.common', + 'a2a.server.routes.jsonrpc_dispatcher', + 'a2a.server.routes.jsonrpc_routes', + 'a2a.server.routes.rest_dispatcher', + 'a2a.server.routes.rest_routes', +] + + +PROFILES: dict[str, list[str]] = { + 'base': CORE_MODULES, + 'http-server': CORE_MODULES + HTTP_SERVER_MODULES, +} + def main() -> int: + profile = sys.argv[1] if len(sys.argv) > 1 else 'base' + if profile not in PROFILES: + print(f'Unknown profile {profile!r}. Available: {sorted(PROFILES)}') + return 1 + + modules = PROFILES[profile] failures: list[str] = [] successes: list[str] = [] - for module_name in CORE_MODULES: + for module_name in modules: try: importlib.import_module(module_name) successes.append(module_name) except Exception as e: # noqa: BLE001, PERF203 failures.append(f'{module_name}: {e}') - print(f'Tested {len(CORE_MODULES)} core modules') + print(f'Profile: {profile}') + print(f'Tested {len(modules)} modules') print(f' Passed: {len(successes)}') print(f' Failed: {len(failures)}') @@ -80,7 +115,7 @@ def main() -> int: print(f' - {failure}') return 1 - print('\nAll core modules imported successfully.') + print('\nAll modules imported successfully.') return 0 diff --git a/scripts/test_install_smoke.sh b/scripts/test_install_smoke.sh new file mode 100755 index 000000000..863f9c12c --- /dev/null +++ b/scripts/test_install_smoke.sh @@ -0,0 +1,98 @@ +#!/bin/bash +# Local equivalent of .github/workflows/install-smoke.yml. +# +# For each install profile, builds the wheel and installs it into a +# clean venv (no dev deps), then runs the import smoke test for that +# profile. By default runs every known profile; pass a profile name +# to run just one. +# +# Available profiles (must match those in scripts/test_install_smoke.py): +# base -- `pip install a2a-sdk` +# http-server -- `pip install a2a-sdk[http-server]` +# +# Usage: +# scripts/test_install_smoke.sh [profile] [python-version] +# +# Examples: +# scripts/test_install_smoke.sh # all profiles, default python +# scripts/test_install_smoke.sh '' 3.13 # all profiles on python 3.13 +# scripts/test_install_smoke.sh http-server # http-server only +# scripts/test_install_smoke.sh http-server 3.13 # http-server on python 3.13 +set -e +set -o pipefail + +REPO_ROOT="$(cd "$(dirname "$0")/.." && pwd)" +cd "$REPO_ROOT" + +ALL_PROFILES=(base http-server) + +PROFILE_ARG="${1:-}" +PYTHON_VERSION="${2:-}" + +if [ -z "$PROFILE_ARG" ]; then + PROFILES=("${ALL_PROFILES[@]}") +else + PROFILES=("$PROFILE_ARG") +fi + +extras_for_profile() { + case "$1" in + base) echo "" ;; + http-server) echo "[http-server]" ;; + *) + echo "Unknown profile '$1'. Available: ${ALL_PROFILES[*]}" >&2 + return 1 + ;; + esac +} + +# Validate profiles up-front so we fail fast. +for profile in "${PROFILES[@]}"; do + extras_for_profile "$profile" >/dev/null +done + +echo "--- Building wheel ---" +rm -rf dist +uv build --wheel +WHEEL=$(ls dist/*.whl) + +FAILED_PROFILES=() + +for profile in "${PROFILES[@]}"; do + extras=$(extras_for_profile "$profile") + venv_dir=".venv-smoke-${profile}" + + echo + echo "==================================================================" + echo " Profile: $profile (extras='$extras')" + echo "==================================================================" + + echo "--- Creating clean venv at $venv_dir ---" + rm -rf "$venv_dir" + if [ -n "$PYTHON_VERSION" ]; then + uv venv "$venv_dir" --python "$PYTHON_VERSION" + else + uv venv "$venv_dir" + fi + + echo "--- Installing built wheel with '$profile' dependencies only ---" + VIRTUAL_ENV="$venv_dir" uv pip install "${WHEEL}${extras}" + + echo "--- Installed packages ---" + VIRTUAL_ENV="$venv_dir" uv pip list + + echo "--- Running import smoke test ---" + if ! "$venv_dir/bin/python" scripts/test_install_smoke.py "$profile"; then + FAILED_PROFILES+=("$profile") + fi +done + +echo +echo "==================================================================" +if [ ${#FAILED_PROFILES[@]} -eq 0 ]; then + echo " All profiles passed: ${PROFILES[*]}" + exit 0 +fi + +echo " Failed profiles: ${FAILED_PROFILES[*]}" >&2 +exit 1 diff --git a/src/a2a/compat/v0_3/context_builders.py b/src/a2a/compat/v0_3/context_builders.py index 2f2eec362..1874853f5 100644 --- a/src/a2a/compat/v0_3/context_builders.py +++ b/src/a2a/compat/v0_3/context_builders.py @@ -5,9 +5,7 @@ adapters wrap the default builders with these classes to recognize both names. """ -from typing import TYPE_CHECKING, Any - -import grpc +from typing import TYPE_CHECKING from a2a.compat.v0_3.extension_headers import LEGACY_HTTP_EXTENSION_HEADER from a2a.extensions.common import get_requested_extensions @@ -15,21 +13,18 @@ if TYPE_CHECKING: + import grpc + from starlette.requests import Request from a2a.server.request_handlers.grpc_handler import ( GrpcServerCallContextBuilder, ) from a2a.server.routes.common import ServerCallContextBuilder -else: - try: - from starlette.requests import Request - except ImportError: - Request = Any def _get_legacy_grpc_extensions( - context: grpc.aio.ServicerContext, + context: 'grpc.aio.ServicerContext', ) -> list[str]: md = context.invocation_metadata() if md is None: @@ -71,7 +66,7 @@ class V03GrpcServerCallContextBuilder: def __init__(self, inner: 'GrpcServerCallContextBuilder') -> None: self._inner = inner - def build(self, context: grpc.aio.ServicerContext) -> ServerCallContext: + def build(self, context: 'grpc.aio.ServicerContext') -> ServerCallContext: """Builds a ServerCallContext, merging legacy extension metadata.""" server_context = self._inner.build(context) server_context.requested_extensions |= get_requested_extensions( From 04ce377f8198e2383fa3ec9e0f36722f25e2bddc Mon Sep 17 00:00:00 2001 From: Ivan Shymko Date: Wed, 22 Apr 2026 09:54:26 +0200 Subject: [PATCH 54/67] test: add more test cases to "install-smoke" (#1008) --- .github/workflows/install-smoke.yml | 6 ++++++ scripts/test_install_smoke.py | 29 +++++++++++++++++++++++++++++ scripts/test_install_smoke.sh | 8 +++++++- 3 files changed, 42 insertions(+), 1 deletion(-) diff --git a/.github/workflows/install-smoke.yml b/.github/workflows/install-smoke.yml index 0b9781b2c..5c8a97d4a 100644 --- a/.github/workflows/install-smoke.yml +++ b/.github/workflows/install-smoke.yml @@ -42,6 +42,12 @@ jobs: extras: '' - name: http-server extras: '[http-server]' + - name: grpc + extras: '[grpc]' + - name: telemetry + extras: '[telemetry]' + - name: sql + extras: '[sql]' steps: - name: Checkout code uses: actions/checkout@v6 diff --git a/scripts/test_install_smoke.py b/scripts/test_install_smoke.py index df33c8386..41ad029bb 100755 --- a/scripts/test_install_smoke.py +++ b/scripts/test_install_smoke.py @@ -15,6 +15,9 @@ smoke-test. Available profiles: base -- `pip install a2a-sdk` http-server -- `pip install a2a-sdk[http-server]` + grpc -- `pip install a2a-sdk[grpc]` + telemetry -- `pip install a2a-sdk[telemetry]` + sql -- `pip install a2a-sdk[sql]` Exit codes: 0 - All imports for the profile succeeded @@ -80,10 +83,36 @@ 'a2a.server.routes.rest_routes', ] +# Modules that MUST be importable with only the base + `grpc` extras +# installed (no `http-server`, `sql`, `signing`, `telemetry`, etc.). +GRPC_MODULES = [ + 'a2a.server.request_handlers.grpc_handler', + 'a2a.client.transports.grpc', + 'a2a.compat.v0_3.grpc_handler', + 'a2a.compat.v0_3.grpc_transport', +] + +# Modules that MUST be importable with only the base + `telemetry` +# extras installed. +TELEMETRY_MODULES = [ + 'a2a.utils.telemetry', +] + +# Modules that MUST be importable with only the base + `sql` extras +# installed (covers postgresql/mysql/sqlite drivers via SQLAlchemy). +SQL_MODULES = [ + 'a2a.server.models', + 'a2a.server.tasks.database_task_store', + 'a2a.server.tasks.database_push_notification_config_store', +] + PROFILES: dict[str, list[str]] = { 'base': CORE_MODULES, 'http-server': CORE_MODULES + HTTP_SERVER_MODULES, + 'grpc': CORE_MODULES + GRPC_MODULES, + 'telemetry': CORE_MODULES + TELEMETRY_MODULES, + 'sql': CORE_MODULES + SQL_MODULES, } diff --git a/scripts/test_install_smoke.sh b/scripts/test_install_smoke.sh index 863f9c12c..9f0a45fbd 100755 --- a/scripts/test_install_smoke.sh +++ b/scripts/test_install_smoke.sh @@ -9,6 +9,9 @@ # Available profiles (must match those in scripts/test_install_smoke.py): # base -- `pip install a2a-sdk` # http-server -- `pip install a2a-sdk[http-server]` +# grpc -- `pip install a2a-sdk[grpc]` +# telemetry -- `pip install a2a-sdk[telemetry]` +# sql -- `pip install a2a-sdk[sql]` # # Usage: # scripts/test_install_smoke.sh [profile] [python-version] @@ -24,7 +27,7 @@ set -o pipefail REPO_ROOT="$(cd "$(dirname "$0")/.." && pwd)" cd "$REPO_ROOT" -ALL_PROFILES=(base http-server) +ALL_PROFILES=(base http-server grpc telemetry sql) PROFILE_ARG="${1:-}" PYTHON_VERSION="${2:-}" @@ -39,6 +42,9 @@ extras_for_profile() { case "$1" in base) echo "" ;; http-server) echo "[http-server]" ;; + grpc) echo "[grpc]" ;; + telemetry) echo "[telemetry]" ;; + sql) echo "[sql]" ;; *) echo "Unknown profile '$1'. Available: ${ALL_PROFILES[*]}" >&2 return 1 From 69273a36b831b852ff61227b015959d4f522c155 Mon Sep 17 00:00:00 2001 From: "Agent2Agent (A2A) Bot" Date: Wed, 22 Apr 2026 02:57:09 -0500 Subject: [PATCH 55/67] chore(main): release 1.0.1 (#1007) :robot: I have created a release *beep* *boop* --- ## [1.0.1](https://github.com/a2aproject/a2a-python/compare/v1.0.0...v1.0.1) (2026-04-22) ### Bug Fixes * **compat:** avoid unconditional grpc import in v0.3 context builders ([#1006](https://github.com/a2aproject/a2a-python/issues/1006)) ([6b46ceb](https://github.com/a2aproject/a2a-python/commit/6b46ceb3e036290ea2b0764b1697f2901ad2df08)) --- This PR was generated with [Release Please](https://github.com/googleapis/release-please). See [documentation](https://github.com/googleapis/release-please#release-please). Co-authored-by: Ivan Shymko --- CHANGELOG.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 45ea7031d..f88f9403a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +## [1.0.1](https://github.com/a2aproject/a2a-python/compare/v1.0.0...v1.0.1) (2026-04-22) + + +### Bug Fixes + +* **compat:** avoid unconditional grpc import in v0.3 context builders ([#1006](https://github.com/a2aproject/a2a-python/issues/1006)) ([6b46ceb](https://github.com/a2aproject/a2a-python/commit/6b46ceb3e036290ea2b0764b1697f2901ad2df08)) + ## [1.0.0](https://github.com/a2aproject/a2a-python/compare/v1.0.0-alpha.3...v1.0.0) (2026-04-20) See the [**v0.3 → v1.0 migration guide**](docs/migrations/v1_0/README.md) and changelog entries for alpha versions below. From 4a247ed59f30cae9955a0caebf05f3a41c57a43c Mon Sep 17 00:00:00 2001 From: Ivan Shymko Date: Wed, 22 Apr 2026 10:19:43 +0200 Subject: [PATCH 56/67] ci: cleanup 1.0-dev triggers and remove branch filters (#1002) 1. Cleanup 1.0-dev filters as the branch is removed already. 2. Remove any branch filters for workflows as a safety net if something got skipped on PR by accident. 3. Use "include" logic to make it more transparent what is the scope of each check (everything will run on `main` as per the previous item just in case the filter is wrong). --- .github/workflows/install-smoke.yml | 32 ++++++++-------------------- .github/workflows/itk.yaml | 5 ++++- .github/workflows/linter.yaml | 21 ++++++++---------- .github/workflows/run-tck.yaml | 29 ++++++------------------- .github/workflows/unit-tests.yml | 33 +++++++++-------------------- 5 files changed, 39 insertions(+), 81 deletions(-) diff --git a/.github/workflows/install-smoke.yml b/.github/workflows/install-smoke.yml index 5c8a97d4a..ace3ff072 100644 --- a/.github/workflows/install-smoke.yml +++ b/.github/workflows/install-smoke.yml @@ -2,30 +2,16 @@ name: Install Smoke Test on: push: - branches: [main, 1.0-dev] - paths-ignore: - - '**.md' - - 'LICENSE' - - 'docs/**' - - '.github/CODEOWNERS' - - '.github/ISSUE_TEMPLATE/**' - - '.github/PULL_REQUEST_TEMPLATE.md' - - '.github/dependabot.yml' - - '.gitignore' - - '.git-blame-ignore-revs' - - '.gemini/**' + branches: [main] pull_request: - paths-ignore: - - '**.md' - - 'LICENSE' - - 'docs/**' - - '.github/CODEOWNERS' - - '.github/ISSUE_TEMPLATE/**' - - '.github/PULL_REQUEST_TEMPLATE.md' - - '.github/dependabot.yml' - - '.gitignore' - - '.git-blame-ignore-revs' - - '.gemini/**' + paths: + - 'src/**' + - 'pyproject.toml' + - 'uv.lock' + - 'scripts/test_install_smoke.py' + - 'scripts/test_install_smoke.sh' + # Self-callout: re-run when this workflow changes so YAML edits are validated in PRs. + - '.github/workflows/install-smoke.yml' permissions: contents: read diff --git a/.github/workflows/itk.yaml b/.github/workflows/itk.yaml index ab272d0e3..feb9325e3 100644 --- a/.github/workflows/itk.yaml +++ b/.github/workflows/itk.yaml @@ -2,12 +2,15 @@ name: ITK on: push: - branches: [main, 1.0-dev] + branches: [main] pull_request: paths: - 'src/**' - 'itk/**' - 'pyproject.toml' + - 'uv.lock' + # Self-callout: re-run when this workflow changes so YAML edits are validated in PRs. + - '.github/workflows/itk.yaml' permissions: contents: read diff --git a/.github/workflows/linter.yaml b/.github/workflows/linter.yaml index ec4bd16fb..2c2a035a0 100644 --- a/.github/workflows/linter.yaml +++ b/.github/workflows/linter.yaml @@ -2,18 +2,15 @@ name: Lint Code Base on: pull_request: - branches: [main, 1.0-dev] - paths-ignore: - - '**.md' - - 'LICENSE' - - 'docs/**' - - '.github/CODEOWNERS' - - '.github/ISSUE_TEMPLATE/**' - - '.github/PULL_REQUEST_TEMPLATE.md' - - '.github/dependabot.yml' - - '.gitignore' - - '.git-blame-ignore-revs' - - '.gemini/**' + branches: [main] + paths: + - '**.py' + - '**.pyi' + - 'pyproject.toml' + - 'uv.lock' + - '.jscpd.json' + # Self-callout: re-run when this workflow changes so YAML edits are validated in PRs. + - '.github/workflows/linter.yaml' permissions: contents: read jobs: diff --git a/.github/workflows/run-tck.yaml b/.github/workflows/run-tck.yaml index 62bbeebc0..53d55d4b0 100644 --- a/.github/workflows/run-tck.yaml +++ b/.github/workflows/run-tck.yaml @@ -3,30 +3,15 @@ name: Run TCK on: push: branches: [ "main" ] - paths-ignore: - - '**.md' - - 'LICENSE' - - 'docs/**' - - '.github/CODEOWNERS' - - '.github/ISSUE_TEMPLATE/**' - - '.github/PULL_REQUEST_TEMPLATE.md' - - '.github/dependabot.yml' - - '.gitignore' - - '.git-blame-ignore-revs' - - '.gemini/**' pull_request: branches: [ "main" ] - paths-ignore: - - '**.md' - - 'LICENSE' - - 'docs/**' - - '.github/CODEOWNERS' - - '.github/ISSUE_TEMPLATE/**' - - '.github/PULL_REQUEST_TEMPLATE.md' - - '.github/dependabot.yml' - - '.gitignore' - - '.git-blame-ignore-revs' - - '.gemini/**' + paths: + - 'src/**' + - 'tck/**' + - 'pyproject.toml' + - 'uv.lock' + # Self-callout: re-run when this workflow changes so YAML edits are validated in PRs. + - '.github/workflows/run-tck.yaml' permissions: contents: read diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 098a14ecc..51f8bbc53 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -2,30 +2,17 @@ name: Run Unit Tests on: push: - branches: [main, 1.0-dev] - paths-ignore: - - '**.md' - - 'LICENSE' - - 'docs/**' - - '.github/CODEOWNERS' - - '.github/ISSUE_TEMPLATE/**' - - '.github/PULL_REQUEST_TEMPLATE.md' - - '.github/dependabot.yml' - - '.gitignore' - - '.git-blame-ignore-revs' - - '.gemini/**' + branches: [main] pull_request: - paths-ignore: - - '**.md' - - 'LICENSE' - - 'docs/**' - - '.github/CODEOWNERS' - - '.github/ISSUE_TEMPLATE/**' - - '.github/PULL_REQUEST_TEMPLATE.md' - - '.github/dependabot.yml' - - '.gitignore' - - '.git-blame-ignore-revs' - - '.gemini/**' + paths: + - 'src/**' + - 'tests/**' + - 'pyproject.toml' + - 'uv.lock' + - 'scripts/run_db_tests.sh' + - 'scripts/docker-compose.test.yml' + # Self-callout: re-run when this workflow changes so YAML edits are validated in PRs. + - '.github/workflows/unit-tests.yml' permissions: contents: read From d2a9887b3c5bacb6ee2cdacf468181904611b408 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 22 Apr 2026 11:42:17 +0200 Subject: [PATCH 57/67] chore(deps): bump requests from 2.32.5 to 2.33.0 (#994) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps [requests](https://github.com/psf/requests) from 2.32.5 to 2.33.0.
Release notes

Sourced from requests's releases.

v2.33.0

2.33.0 (2026-03-25)

Announcements

  • 📣 Requests is adding inline types. If you have a typed code base that uses Requests, please take a look at #7271. Give it a try, and report any gaps or feedback you may have in the issue. 📣

Security

  • CVE-2026-25645 requests.utils.extract_zipped_paths now extracts contents to a non-deterministic location to prevent malicious file replacement. This does not affect default usage of Requests, only applications calling the utility function directly.

Improvements

  • Migrated to a PEP 517 build system using setuptools. (#7012)

Bugfixes

  • Fixed an issue where an empty netrc entry could cause malformed authentication to be applied to Requests on Python 3.11+. (#7205)

Deprecations

  • Dropped support for Python 3.9 following its end of support. (#7196)

Documentation

  • Various typo fixes and doc improvements.

New Contributors

Full Changelog: https://github.com/psf/requests/blob/main/HISTORY.md#2330-2026-03-25

Changelog

Sourced from requests's changelog.

2.33.0 (2026-03-25)

Announcements

  • 📣 Requests is adding inline types. If you have a typed code base that uses Requests, please take a look at #7271. Give it a try, and report any gaps or feedback you may have in the issue. 📣

Security

  • CVE-2026-25645 requests.utils.extract_zipped_paths now extracts contents to a non-deterministic location to prevent malicious file replacement. This does not affect default usage of Requests, only applications calling the utility function directly.

Improvements

  • Migrated to a PEP 517 build system using setuptools. (#7012)

Bugfixes

  • Fixed an issue where an empty netrc entry could cause malformed authentication to be applied to Requests on Python 3.11+. (#7205)

Deprecations

  • Dropped support for Python 3.9 following its end of support. (#7196)

Documentation

  • Various typo fixes and doc improvements.
Commits
  • bc04dfd v2.33.0
  • 66d21cb Merge commit from fork
  • 8b9bc8f Move badges to top of README (#7293)
  • e331a28 Remove unused extraction call (#7292)
  • 753fd08 docs: fix FAQ grammar in httplib2 example
  • 774a0b8 docs(socks): same block as other sections
  • 9c72a41 Bump github/codeql-action from 4.33.0 to 4.34.1
  • ebf7190 Bump github/codeql-action from 4.32.0 to 4.33.0
  • 0e4ae38 docs: exclude Response.is_permanent_redirect from API docs (#7244)
  • d568f47 docs: clarify Quickstart POST example (#6960)
  • Additional commits viewable in compare view

Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Ivan Shymko --- uv.lock | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/uv.lock b/uv.lock index 0a1a7e13e..499c75415 100644 --- a/uv.lock +++ b/uv.lock @@ -1936,7 +1936,7 @@ wheels = [ [[package]] name = "requests" -version = "2.32.5" +version = "2.33.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "certifi" }, @@ -1944,9 +1944,9 @@ dependencies = [ { name = "idna" }, { name = "urllib3" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c9/74/b3ff8e6c8446842c3f5c837e9c3dfcfe2018ea6ecef224c710c85ef728f4/requests-2.32.5.tar.gz", hash = "sha256:dbba0bac56e100853db0ea71b82b4dfd5fe2bf6d3754a8893c3af500cec7d7cf", size = 134517, upload-time = "2025-08-18T20:46:02.573Z" } +sdist = { url = "https://files.pythonhosted.org/packages/34/64/8860370b167a9721e8956ae116825caff829224fbca0ca6e7bf8ddef8430/requests-2.33.0.tar.gz", hash = "sha256:c7ebc5e8b0f21837386ad0e1c8fe8b829fa5f544d8df3b2253bff14ef29d7652", size = 134232, upload-time = "2026-03-25T15:10:41.586Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738, upload-time = "2025-08-18T20:46:00.542Z" }, + { url = "https://files.pythonhosted.org/packages/56/5d/c814546c2333ceea4ba42262d8c4d55763003e767fa169adc693bd524478/requests-2.33.0-py3-none-any.whl", hash = "sha256:3324635456fa185245e24865e810cecec7b4caf933d7eb133dcde67d48cee69b", size = 65017, upload-time = "2026-03-25T15:10:40.382Z" }, ] [[package]] From 237621d07360e8964c417bad87496b17afd3610e Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 22 Apr 2026 09:42:53 +0000 Subject: [PATCH 58/67] chore(deps-dev): bump pytest from 9.0.2 to 9.0.3 (#966) Bumps [pytest](https://github.com/pytest-dev/pytest) from 9.0.2 to 9.0.3.
Release notes

Sourced from pytest's releases.

9.0.3

pytest 9.0.3 (2026-04-07)

Bug fixes

  • #12444: Fixed pytest.approx which now correctly takes into account ~collections.abc.Mapping keys order to compare them.

  • #13634: Blocking a conftest.py file using the -p no: option is now explicitly disallowed.

    Previously this resulted in an internal assertion failure during plugin loading.

    Pytest now raises a clear UsageError explaining that conftest files are not plugins and cannot be disabled via -p.

  • #13734: Fixed crash when a test raises an exceptiongroup with __tracebackhide__ = True.

  • #14195: Fixed an issue where non-string messages passed to unittest.TestCase.subTest() were not printed.

  • #14343: Fixed use of insecure temporary directory (CVE-2025-71176).

Improved documentation

  • #13388: Clarified documentation for -p vs PYTEST_PLUGINS plugin loading and fixed an incorrect -p example.
  • #13731: Clarified that capture fixtures (e.g. capsys and capfd) take precedence over the -s / --capture=no command-line options in Accessing captured output from a test function <accessing-captured-output>.
  • #14088: Clarified that the default pytest_collection hook sets session.items before it calls pytest_collection_finish, not after.
  • #14255: TOML integer log levels must be quoted: Updating reference documentation.

Contributor-facing changes

  • #12689: The test reports are now published to Codecov from GitHub Actions. The test statistics is visible on the web interface.

    -- by aleguy02

Commits

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=pytest&package-manager=uv&previous-version=9.0.2&new-version=9.0.3)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself) You can disable automated security fix PRs for this repo from the [Security Alerts page](https://github.com/a2aproject/a2a-python/network/alerts).
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Ivan Shymko --- uv.lock | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/uv.lock b/uv.lock index 499c75415..eed54ad4e 100644 --- a/uv.lock +++ b/uv.lock @@ -1776,7 +1776,7 @@ wheels = [ [[package]] name = "pytest" -version = "9.0.2" +version = "9.0.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "colorama", marker = "sys_platform == 'win32'" }, @@ -1787,9 +1787,9 @@ dependencies = [ { name = "pygments" }, { name = "tomli", marker = "python_full_version < '3.11'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/d1/db/7ef3487e0fb0049ddb5ce41d3a49c235bf9ad299b6a25d5780a89f19230f/pytest-9.0.2.tar.gz", hash = "sha256:75186651a92bd89611d1d9fc20f0b4345fd827c41ccd5c299a868a05d70edf11", size = 1568901, upload-time = "2025-12-06T21:30:51.014Z" } +sdist = { url = "https://files.pythonhosted.org/packages/7d/0d/549bd94f1a0a402dc8cf64563a117c0f3765662e2e668477624baeec44d5/pytest-9.0.3.tar.gz", hash = "sha256:b86ada508af81d19edeb213c681b1d48246c1a91d304c6c81a427674c17eb91c", size = 1572165, upload-time = "2026-04-07T17:16:18.027Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b", size = 374801, upload-time = "2025-12-06T21:30:49.154Z" }, + { url = "https://files.pythonhosted.org/packages/d4/24/a372aaf5c9b7208e7112038812994107bc65a84cd00e0354a88c2c77a617/pytest-9.0.3-py3-none-any.whl", hash = "sha256:2c5efc453d45394fdd706ade797c0a81091eccd1d6e4bccfcd476e2b8e0ab5d9", size = 375249, upload-time = "2026-04-07T17:16:16.13Z" }, ] [[package]] From e3e089b4241ce68135728871896468271f847a1a Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 22 Apr 2026 09:44:58 +0000 Subject: [PATCH 59/67] chore(deps): bump cryptography from 46.0.5 to 46.0.7 (#993) Bumps [cryptography](https://github.com/pyca/cryptography) from 46.0.5 to 46.0.7.
Changelog

Sourced from cryptography's changelog.

46.0.7 - 2026-04-07


* **SECURITY ISSUE**: Fixed an issue where non-contiguous buffers could
be
  passed to APIs that accept Python buffers, which could lead to buffer
  overflow. **CVE-2026-39892**
* Updated Windows, macOS, and Linux wheels to be compiled with OpenSSL
3.5.6.

.. _v46-0-6:

46.0.6 - 2026-03-25

  • SECURITY ISSUE: Fixed a bug where name constraints were not applied to peer names during verification when the leaf certificate contains a wildcard DNS SAN. Ordinary X.509 topologies are not affected by this bug, including those used by the Web PKI. Credit to Oleh Konko (1seal) for reporting the issue. CVE-2026-34073

.. _v46-0-5:

Commits

Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Ivan Shymko --- uv.lock | 102 ++++++++++++++++++++++++++++---------------------------- 1 file changed, 51 insertions(+), 51 deletions(-) diff --git a/uv.lock b/uv.lock index eed54ad4e..1ad6a470d 100644 --- a/uv.lock +++ b/uv.lock @@ -679,62 +679,62 @@ toml = [ [[package]] name = "cryptography" -version = "46.0.5" +version = "46.0.7" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cffi", marker = "platform_python_implementation != 'PyPy'" }, { name = "typing-extensions", marker = "python_full_version < '3.11'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/60/04/ee2a9e8542e4fa2773b81771ff8349ff19cdd56b7258a0cc442639052edb/cryptography-46.0.5.tar.gz", hash = "sha256:abace499247268e3757271b2f1e244b36b06f8515cf27c4d49468fc9eb16e93d", size = 750064, upload-time = "2026-02-10T19:18:38.255Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/f7/81/b0bb27f2ba931a65409c6b8a8b358a7f03c0e46eceacddff55f7c84b1f3b/cryptography-46.0.5-cp311-abi3-macosx_10_9_universal2.whl", hash = "sha256:351695ada9ea9618b3500b490ad54c739860883df6c1f555e088eaf25b1bbaad", size = 7176289, upload-time = "2026-02-10T19:17:08.274Z" }, - { url = "https://files.pythonhosted.org/packages/ff/9e/6b4397a3e3d15123de3b1806ef342522393d50736c13b20ec4c9ea6693a6/cryptography-46.0.5-cp311-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:c18ff11e86df2e28854939acde2d003f7984f721eba450b56a200ad90eeb0e6b", size = 4275637, upload-time = "2026-02-10T19:17:10.53Z" }, - { url = "https://files.pythonhosted.org/packages/63/e7/471ab61099a3920b0c77852ea3f0ea611c9702f651600397ac567848b897/cryptography-46.0.5-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d7e3d356b8cd4ea5aff04f129d5f66ebdc7b6f8eae802b93739ed520c47c79b", size = 4424742, upload-time = "2026-02-10T19:17:12.388Z" }, - { url = "https://files.pythonhosted.org/packages/37/53/a18500f270342d66bf7e4d9f091114e31e5ee9e7375a5aba2e85a91e0044/cryptography-46.0.5-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:50bfb6925eff619c9c023b967d5b77a54e04256c4281b0e21336a130cd7fc263", size = 4277528, upload-time = "2026-02-10T19:17:13.853Z" }, - { url = "https://files.pythonhosted.org/packages/22/29/c2e812ebc38c57b40e7c583895e73c8c5adb4d1e4a0cc4c5a4fdab2b1acc/cryptography-46.0.5-cp311-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:803812e111e75d1aa73690d2facc295eaefd4439be1023fefc4995eaea2af90d", size = 4947993, upload-time = "2026-02-10T19:17:15.618Z" }, - { url = "https://files.pythonhosted.org/packages/6b/e7/237155ae19a9023de7e30ec64e5d99a9431a567407ac21170a046d22a5a3/cryptography-46.0.5-cp311-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:3ee190460e2fbe447175cda91b88b84ae8322a104fc27766ad09428754a618ed", size = 4456855, upload-time = "2026-02-10T19:17:17.221Z" }, - { url = "https://files.pythonhosted.org/packages/2d/87/fc628a7ad85b81206738abbd213b07702bcbdada1dd43f72236ef3cffbb5/cryptography-46.0.5-cp311-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:f145bba11b878005c496e93e257c1e88f154d278d2638e6450d17e0f31e558d2", size = 3984635, upload-time = "2026-02-10T19:17:18.792Z" }, - { url = "https://files.pythonhosted.org/packages/84/29/65b55622bde135aedf4565dc509d99b560ee4095e56989e815f8fd2aa910/cryptography-46.0.5-cp311-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:e9251e3be159d1020c4030bd2e5f84d6a43fe54b6c19c12f51cde9542a2817b2", size = 4277038, upload-time = "2026-02-10T19:17:20.256Z" }, - { url = "https://files.pythonhosted.org/packages/bc/36/45e76c68d7311432741faf1fbf7fac8a196a0a735ca21f504c75d37e2558/cryptography-46.0.5-cp311-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:47fb8a66058b80e509c47118ef8a75d14c455e81ac369050f20ba0d23e77fee0", size = 4912181, upload-time = "2026-02-10T19:17:21.825Z" }, - { url = "https://files.pythonhosted.org/packages/6d/1a/c1ba8fead184d6e3d5afcf03d569acac5ad063f3ac9fb7258af158f7e378/cryptography-46.0.5-cp311-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:4c3341037c136030cb46e4b1e17b7418ea4cbd9dd207e4a6f3b2b24e0d4ac731", size = 4456482, upload-time = "2026-02-10T19:17:25.133Z" }, - { url = "https://files.pythonhosted.org/packages/f9/e5/3fb22e37f66827ced3b902cf895e6a6bc1d095b5b26be26bd13c441fdf19/cryptography-46.0.5-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:890bcb4abd5a2d3f852196437129eb3667d62630333aacc13dfd470fad3aaa82", size = 4405497, upload-time = "2026-02-10T19:17:26.66Z" }, - { url = "https://files.pythonhosted.org/packages/1a/df/9d58bb32b1121a8a2f27383fabae4d63080c7ca60b9b5c88be742be04ee7/cryptography-46.0.5-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:80a8d7bfdf38f87ca30a5391c0c9ce4ed2926918e017c29ddf643d0ed2778ea1", size = 4667819, upload-time = "2026-02-10T19:17:28.569Z" }, - { url = "https://files.pythonhosted.org/packages/ea/ed/325d2a490c5e94038cdb0117da9397ece1f11201f425c4e9c57fe5b9f08b/cryptography-46.0.5-cp311-abi3-win32.whl", hash = "sha256:60ee7e19e95104d4c03871d7d7dfb3d22ef8a9b9c6778c94e1c8fcc8365afd48", size = 3028230, upload-time = "2026-02-10T19:17:30.518Z" }, - { url = "https://files.pythonhosted.org/packages/e9/5a/ac0f49e48063ab4255d9e3b79f5def51697fce1a95ea1370f03dc9db76f6/cryptography-46.0.5-cp311-abi3-win_amd64.whl", hash = "sha256:38946c54b16c885c72c4f59846be9743d699eee2b69b6988e0a00a01f46a61a4", size = 3480909, upload-time = "2026-02-10T19:17:32.083Z" }, - { url = "https://files.pythonhosted.org/packages/00/13/3d278bfa7a15a96b9dc22db5a12ad1e48a9eb3d40e1827ef66a5df75d0d0/cryptography-46.0.5-cp314-cp314t-macosx_10_9_universal2.whl", hash = "sha256:94a76daa32eb78d61339aff7952ea819b1734b46f73646a07decb40e5b3448e2", size = 7119287, upload-time = "2026-02-10T19:17:33.801Z" }, - { url = "https://files.pythonhosted.org/packages/67/c8/581a6702e14f0898a0848105cbefd20c058099e2c2d22ef4e476dfec75d7/cryptography-46.0.5-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5be7bf2fb40769e05739dd0046e7b26f9d4670badc7b032d6ce4db64dddc0678", size = 4265728, upload-time = "2026-02-10T19:17:35.569Z" }, - { url = "https://files.pythonhosted.org/packages/dd/4a/ba1a65ce8fc65435e5a849558379896c957870dd64fecea97b1ad5f46a37/cryptography-46.0.5-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:fe346b143ff9685e40192a4960938545c699054ba11d4f9029f94751e3f71d87", size = 4408287, upload-time = "2026-02-10T19:17:36.938Z" }, - { url = "https://files.pythonhosted.org/packages/f8/67/8ffdbf7b65ed1ac224d1c2df3943553766914a8ca718747ee3871da6107e/cryptography-46.0.5-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:c69fd885df7d089548a42d5ec05be26050ebcd2283d89b3d30676eb32ff87dee", size = 4270291, upload-time = "2026-02-10T19:17:38.748Z" }, - { url = "https://files.pythonhosted.org/packages/f8/e5/f52377ee93bc2f2bba55a41a886fd208c15276ffbd2569f2ddc89d50e2c5/cryptography-46.0.5-cp314-cp314t-manylinux_2_28_ppc64le.whl", hash = "sha256:8293f3dea7fc929ef7240796ba231413afa7b68ce38fd21da2995549f5961981", size = 4927539, upload-time = "2026-02-10T19:17:40.241Z" }, - { url = "https://files.pythonhosted.org/packages/3b/02/cfe39181b02419bbbbcf3abdd16c1c5c8541f03ca8bda240debc467d5a12/cryptography-46.0.5-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:1abfdb89b41c3be0365328a410baa9df3ff8a9110fb75e7b52e66803ddabc9a9", size = 4442199, upload-time = "2026-02-10T19:17:41.789Z" }, - { url = "https://files.pythonhosted.org/packages/c0/96/2fcaeb4873e536cf71421a388a6c11b5bc846e986b2b069c79363dc1648e/cryptography-46.0.5-cp314-cp314t-manylinux_2_31_armv7l.whl", hash = "sha256:d66e421495fdb797610a08f43b05269e0a5ea7f5e652a89bfd5a7d3c1dee3648", size = 3960131, upload-time = "2026-02-10T19:17:43.379Z" }, - { url = "https://files.pythonhosted.org/packages/d8/d2/b27631f401ddd644e94c5cf33c9a4069f72011821cf3dc7309546b0642a0/cryptography-46.0.5-cp314-cp314t-manylinux_2_34_aarch64.whl", hash = "sha256:4e817a8920bfbcff8940ecfd60f23d01836408242b30f1a708d93198393a80b4", size = 4270072, upload-time = "2026-02-10T19:17:45.481Z" }, - { url = "https://files.pythonhosted.org/packages/f4/a7/60d32b0370dae0b4ebe55ffa10e8599a2a59935b5ece1b9f06edb73abdeb/cryptography-46.0.5-cp314-cp314t-manylinux_2_34_ppc64le.whl", hash = "sha256:68f68d13f2e1cb95163fa3b4db4bf9a159a418f5f6e7242564fc75fcae667fd0", size = 4892170, upload-time = "2026-02-10T19:17:46.997Z" }, - { url = "https://files.pythonhosted.org/packages/d2/b9/cf73ddf8ef1164330eb0b199a589103c363afa0cf794218c24d524a58eab/cryptography-46.0.5-cp314-cp314t-manylinux_2_34_x86_64.whl", hash = "sha256:a3d1fae9863299076f05cb8a778c467578262fae09f9dc0ee9b12eb4268ce663", size = 4441741, upload-time = "2026-02-10T19:17:48.661Z" }, - { url = "https://files.pythonhosted.org/packages/5f/eb/eee00b28c84c726fe8fa0158c65afe312d9c3b78d9d01daf700f1f6e37ff/cryptography-46.0.5-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:c4143987a42a2397f2fc3b4d7e3a7d313fbe684f67ff443999e803dd75a76826", size = 4396728, upload-time = "2026-02-10T19:17:50.058Z" }, - { url = "https://files.pythonhosted.org/packages/65/f4/6bc1a9ed5aef7145045114b75b77c2a8261b4d38717bd8dea111a63c3442/cryptography-46.0.5-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:7d731d4b107030987fd61a7f8ab512b25b53cef8f233a97379ede116f30eb67d", size = 4652001, upload-time = "2026-02-10T19:17:51.54Z" }, - { url = "https://files.pythonhosted.org/packages/86/ef/5d00ef966ddd71ac2e6951d278884a84a40ffbd88948ef0e294b214ae9e4/cryptography-46.0.5-cp314-cp314t-win32.whl", hash = "sha256:c3bcce8521d785d510b2aad26ae2c966092b7daa8f45dd8f44734a104dc0bc1a", size = 3003637, upload-time = "2026-02-10T19:17:52.997Z" }, - { url = "https://files.pythonhosted.org/packages/b7/57/f3f4160123da6d098db78350fdfd9705057aad21de7388eacb2401dceab9/cryptography-46.0.5-cp314-cp314t-win_amd64.whl", hash = "sha256:4d8ae8659ab18c65ced284993c2265910f6c9e650189d4e3f68445ef82a810e4", size = 3469487, upload-time = "2026-02-10T19:17:54.549Z" }, - { url = "https://files.pythonhosted.org/packages/e2/fa/a66aa722105ad6a458bebd64086ca2b72cdd361fed31763d20390f6f1389/cryptography-46.0.5-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:4108d4c09fbbf2789d0c926eb4152ae1760d5a2d97612b92d508d96c861e4d31", size = 7170514, upload-time = "2026-02-10T19:17:56.267Z" }, - { url = "https://files.pythonhosted.org/packages/0f/04/c85bdeab78c8bc77b701bf0d9bdcf514c044e18a46dcff330df5448631b0/cryptography-46.0.5-cp38-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:7d1f30a86d2757199cb2d56e48cce14deddf1f9c95f1ef1b64ee91ea43fe2e18", size = 4275349, upload-time = "2026-02-10T19:17:58.419Z" }, - { url = "https://files.pythonhosted.org/packages/5c/32/9b87132a2f91ee7f5223b091dc963055503e9b442c98fc0b8a5ca765fab0/cryptography-46.0.5-cp38-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:039917b0dc418bb9f6edce8a906572d69e74bd330b0b3fea4f79dab7f8ddd235", size = 4420667, upload-time = "2026-02-10T19:18:00.619Z" }, - { url = "https://files.pythonhosted.org/packages/a1/a6/a7cb7010bec4b7c5692ca6f024150371b295ee1c108bdc1c400e4c44562b/cryptography-46.0.5-cp38-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:ba2a27ff02f48193fc4daeadf8ad2590516fa3d0adeeb34336b96f7fa64c1e3a", size = 4276980, upload-time = "2026-02-10T19:18:02.379Z" }, - { url = "https://files.pythonhosted.org/packages/8e/7c/c4f45e0eeff9b91e3f12dbd0e165fcf2a38847288fcfd889deea99fb7b6d/cryptography-46.0.5-cp38-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:61aa400dce22cb001a98014f647dc21cda08f7915ceb95df0c9eaf84b4b6af76", size = 4939143, upload-time = "2026-02-10T19:18:03.964Z" }, - { url = "https://files.pythonhosted.org/packages/37/19/e1b8f964a834eddb44fa1b9a9976f4e414cbb7aa62809b6760c8803d22d1/cryptography-46.0.5-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:3ce58ba46e1bc2aac4f7d9290223cead56743fa6ab94a5d53292ffaac6a91614", size = 4453674, upload-time = "2026-02-10T19:18:05.588Z" }, - { url = "https://files.pythonhosted.org/packages/db/ed/db15d3956f65264ca204625597c410d420e26530c4e2943e05a0d2f24d51/cryptography-46.0.5-cp38-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:420d0e909050490d04359e7fdb5ed7e667ca5c3c402b809ae2563d7e66a92229", size = 3978801, upload-time = "2026-02-10T19:18:07.167Z" }, - { url = "https://files.pythonhosted.org/packages/41/e2/df40a31d82df0a70a0daf69791f91dbb70e47644c58581d654879b382d11/cryptography-46.0.5-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:582f5fcd2afa31622f317f80426a027f30dc792e9c80ffee87b993200ea115f1", size = 4276755, upload-time = "2026-02-10T19:18:09.813Z" }, - { url = "https://files.pythonhosted.org/packages/33/45/726809d1176959f4a896b86907b98ff4391a8aa29c0aaaf9450a8a10630e/cryptography-46.0.5-cp38-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:bfd56bb4b37ed4f330b82402f6f435845a5f5648edf1ad497da51a8452d5d62d", size = 4901539, upload-time = "2026-02-10T19:18:11.263Z" }, - { url = "https://files.pythonhosted.org/packages/99/0f/a3076874e9c88ecb2ecc31382f6e7c21b428ede6f55aafa1aa272613e3cd/cryptography-46.0.5-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:a3d507bb6a513ca96ba84443226af944b0f7f47dcc9a399d110cd6146481d24c", size = 4452794, upload-time = "2026-02-10T19:18:12.914Z" }, - { url = "https://files.pythonhosted.org/packages/02/ef/ffeb542d3683d24194a38f66ca17c0a4b8bf10631feef44a7ef64e631b1a/cryptography-46.0.5-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:9f16fbdf4da055efb21c22d81b89f155f02ba420558db21288b3d0035bafd5f4", size = 4404160, upload-time = "2026-02-10T19:18:14.375Z" }, - { url = "https://files.pythonhosted.org/packages/96/93/682d2b43c1d5f1406ed048f377c0fc9fc8f7b0447a478d5c65ab3d3a66eb/cryptography-46.0.5-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:ced80795227d70549a411a4ab66e8ce307899fad2220ce5ab2f296e687eacde9", size = 4667123, upload-time = "2026-02-10T19:18:15.886Z" }, - { url = "https://files.pythonhosted.org/packages/45/2d/9c5f2926cb5300a8eefc3f4f0b3f3df39db7f7ce40c8365444c49363cbda/cryptography-46.0.5-cp38-abi3-win32.whl", hash = "sha256:02f547fce831f5096c9a567fd41bc12ca8f11df260959ecc7c3202555cc47a72", size = 3010220, upload-time = "2026-02-10T19:18:17.361Z" }, - { url = "https://files.pythonhosted.org/packages/48/ef/0c2f4a8e31018a986949d34a01115dd057bf536905dca38897bacd21fac3/cryptography-46.0.5-cp38-abi3-win_amd64.whl", hash = "sha256:556e106ee01aa13484ce9b0239bca667be5004efb0aabbed28d353df86445595", size = 3467050, upload-time = "2026-02-10T19:18:18.899Z" }, - { url = "https://files.pythonhosted.org/packages/eb/dd/2d9fdb07cebdf3d51179730afb7d5e576153c6744c3ff8fded23030c204e/cryptography-46.0.5-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:3b4995dc971c9fb83c25aa44cf45f02ba86f71ee600d81091c2f0cbae116b06c", size = 3476964, upload-time = "2026-02-10T19:18:20.687Z" }, - { url = "https://files.pythonhosted.org/packages/e9/6f/6cc6cc9955caa6eaf83660b0da2b077c7fe8ff9950a3c5e45d605038d439/cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:bc84e875994c3b445871ea7181d424588171efec3e185dced958dad9e001950a", size = 4218321, upload-time = "2026-02-10T19:18:22.349Z" }, - { url = "https://files.pythonhosted.org/packages/3e/5d/c4da701939eeee699566a6c1367427ab91a8b7088cc2328c09dbee940415/cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:2ae6971afd6246710480e3f15824ed3029a60fc16991db250034efd0b9fb4356", size = 4381786, upload-time = "2026-02-10T19:18:24.529Z" }, - { url = "https://files.pythonhosted.org/packages/ac/97/a538654732974a94ff96c1db621fa464f455c02d4bb7d2652f4edc21d600/cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:d861ee9e76ace6cf36a6a89b959ec08e7bc2493ee39d07ffe5acb23ef46d27da", size = 4217990, upload-time = "2026-02-10T19:18:25.957Z" }, - { url = "https://files.pythonhosted.org/packages/ae/11/7e500d2dd3ba891197b9efd2da5454b74336d64a7cc419aa7327ab74e5f6/cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:2b7a67c9cd56372f3249b39699f2ad479f6991e62ea15800973b956f4b73e257", size = 4381252, upload-time = "2026-02-10T19:18:27.496Z" }, - { url = "https://files.pythonhosted.org/packages/bc/58/6b3d24e6b9bc474a2dcdee65dfd1f008867015408a271562e4b690561a4d/cryptography-46.0.5-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:8456928655f856c6e1533ff59d5be76578a7157224dbd9ce6872f25055ab9ab7", size = 3407605, upload-time = "2026-02-10T19:18:29.233Z" }, +sdist = { url = "https://files.pythonhosted.org/packages/47/93/ac8f3d5ff04d54bc814e961a43ae5b0b146154c89c61b47bb07557679b18/cryptography-46.0.7.tar.gz", hash = "sha256:e4cfd68c5f3e0bfdad0d38e023239b96a2fe84146481852dffbcca442c245aa5", size = 750652, upload-time = "2026-04-08T01:57:54.692Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0b/5d/4a8f770695d73be252331e60e526291e3df0c9b27556a90a6b47bccca4c2/cryptography-46.0.7-cp311-abi3-macosx_10_9_universal2.whl", hash = "sha256:ea42cbe97209df307fdc3b155f1b6fa2577c0defa8f1f7d3be7d31d189108ad4", size = 7179869, upload-time = "2026-04-08T01:56:17.157Z" }, + { url = "https://files.pythonhosted.org/packages/5f/45/6d80dc379b0bbc1f9d1e429f42e4cb9e1d319c7a8201beffd967c516ea01/cryptography-46.0.7-cp311-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:b36a4695e29fe69215d75960b22577197aca3f7a25b9cf9d165dcfe9d80bc325", size = 4275492, upload-time = "2026-04-08T01:56:19.36Z" }, + { url = "https://files.pythonhosted.org/packages/4a/9a/1765afe9f572e239c3469f2cb429f3ba7b31878c893b246b4b2994ffe2fe/cryptography-46.0.7-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5ad9ef796328c5e3c4ceed237a183f5d41d21150f972455a9d926593a1dcb308", size = 4426670, upload-time = "2026-04-08T01:56:21.415Z" }, + { url = "https://files.pythonhosted.org/packages/8f/3e/af9246aaf23cd4ee060699adab1e47ced3f5f7e7a8ffdd339f817b446462/cryptography-46.0.7-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:73510b83623e080a2c35c62c15298096e2a5dc8d51c3b4e1740211839d0dea77", size = 4280275, upload-time = "2026-04-08T01:56:23.539Z" }, + { url = "https://files.pythonhosted.org/packages/0f/54/6bbbfc5efe86f9d71041827b793c24811a017c6ac0fd12883e4caa86b8ed/cryptography-46.0.7-cp311-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:cbd5fb06b62bd0721e1170273d3f4d5a277044c47ca27ee257025146c34cbdd1", size = 4928402, upload-time = "2026-04-08T01:56:25.624Z" }, + { url = "https://files.pythonhosted.org/packages/2d/cf/054b9d8220f81509939599c8bdbc0c408dbd2bdd41688616a20731371fe0/cryptography-46.0.7-cp311-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:420b1e4109cc95f0e5700eed79908cef9268265c773d3a66f7af1eef53d409ef", size = 4459985, upload-time = "2026-04-08T01:56:27.309Z" }, + { url = "https://files.pythonhosted.org/packages/f9/46/4e4e9c6040fb01c7467d47217d2f882daddeb8828f7df800cb806d8a2288/cryptography-46.0.7-cp311-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:24402210aa54baae71d99441d15bb5a1919c195398a87b563df84468160a65de", size = 3990652, upload-time = "2026-04-08T01:56:29.095Z" }, + { url = "https://files.pythonhosted.org/packages/36/5f/313586c3be5a2fbe87e4c9a254207b860155a8e1f3cca99f9910008e7d08/cryptography-46.0.7-cp311-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:8a469028a86f12eb7d2fe97162d0634026d92a21f3ae0ac87ed1c4a447886c83", size = 4279805, upload-time = "2026-04-08T01:56:30.928Z" }, + { url = "https://files.pythonhosted.org/packages/69/33/60dfc4595f334a2082749673386a4d05e4f0cf4df8248e63b2c3437585f2/cryptography-46.0.7-cp311-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:9694078c5d44c157ef3162e3bf3946510b857df5a3955458381d1c7cfc143ddb", size = 4892883, upload-time = "2026-04-08T01:56:32.614Z" }, + { url = "https://files.pythonhosted.org/packages/c7/0b/333ddab4270c4f5b972f980adef4faa66951a4aaf646ca067af597f15563/cryptography-46.0.7-cp311-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:42a1e5f98abb6391717978baf9f90dc28a743b7d9be7f0751a6f56a75d14065b", size = 4459756, upload-time = "2026-04-08T01:56:34.306Z" }, + { url = "https://files.pythonhosted.org/packages/d2/14/633913398b43b75f1234834170947957c6b623d1701ffc7a9600da907e89/cryptography-46.0.7-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:91bbcb08347344f810cbe49065914fe048949648f6bd5c2519f34619142bbe85", size = 4410244, upload-time = "2026-04-08T01:56:35.977Z" }, + { url = "https://files.pythonhosted.org/packages/10/f2/19ceb3b3dc14009373432af0c13f46aa08e3ce334ec6eff13492e1812ccd/cryptography-46.0.7-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:5d1c02a14ceb9148cc7816249f64f623fbfee39e8c03b3650d842ad3f34d637e", size = 4674868, upload-time = "2026-04-08T01:56:38.034Z" }, + { url = "https://files.pythonhosted.org/packages/1a/bb/a5c213c19ee94b15dfccc48f363738633a493812687f5567addbcbba9f6f/cryptography-46.0.7-cp311-abi3-win32.whl", hash = "sha256:d23c8ca48e44ee015cd0a54aeccdf9f09004eba9fc96f38c911011d9ff1bd457", size = 3026504, upload-time = "2026-04-08T01:56:39.666Z" }, + { url = "https://files.pythonhosted.org/packages/2b/02/7788f9fefa1d060ca68717c3901ae7fffa21ee087a90b7f23c7a603c32ae/cryptography-46.0.7-cp311-abi3-win_amd64.whl", hash = "sha256:397655da831414d165029da9bc483bed2fe0e75dde6a1523ec2fe63f3c46046b", size = 3488363, upload-time = "2026-04-08T01:56:41.893Z" }, + { url = "https://files.pythonhosted.org/packages/7b/56/15619b210e689c5403bb0540e4cb7dbf11a6bf42e483b7644e471a2812b3/cryptography-46.0.7-cp314-cp314t-macosx_10_9_universal2.whl", hash = "sha256:d151173275e1728cf7839aaa80c34fe550c04ddb27b34f48c232193df8db5842", size = 7119671, upload-time = "2026-04-08T01:56:44Z" }, + { url = "https://files.pythonhosted.org/packages/74/66/e3ce040721b0b5599e175ba91ab08884c75928fbeb74597dd10ef13505d2/cryptography-46.0.7-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:db0f493b9181c7820c8134437eb8b0b4792085d37dbb24da050476ccb664e59c", size = 4268551, upload-time = "2026-04-08T01:56:46.071Z" }, + { url = "https://files.pythonhosted.org/packages/03/11/5e395f961d6868269835dee1bafec6a1ac176505a167f68b7d8818431068/cryptography-46.0.7-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ebd6daf519b9f189f85c479427bbd6e9c9037862cf8fe89ee35503bd209ed902", size = 4408887, upload-time = "2026-04-08T01:56:47.718Z" }, + { url = "https://files.pythonhosted.org/packages/40/53/8ed1cf4c3b9c8e611e7122fb56f1c32d09e1fff0f1d77e78d9ff7c82653e/cryptography-46.0.7-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:b7b412817be92117ec5ed95f880defe9cf18a832e8cafacf0a22337dc1981b4d", size = 4271354, upload-time = "2026-04-08T01:56:49.312Z" }, + { url = "https://files.pythonhosted.org/packages/50/46/cf71e26025c2e767c5609162c866a78e8a2915bbcfa408b7ca495c6140c4/cryptography-46.0.7-cp314-cp314t-manylinux_2_28_ppc64le.whl", hash = "sha256:fbfd0e5f273877695cb93baf14b185f4878128b250cc9f8e617ea0c025dfb022", size = 4905845, upload-time = "2026-04-08T01:56:50.916Z" }, + { url = "https://files.pythonhosted.org/packages/c0/ea/01276740375bac6249d0a971ebdf6b4dc9ead0ee0a34ef3b5a88c1a9b0d4/cryptography-46.0.7-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:ffca7aa1d00cf7d6469b988c581598f2259e46215e0140af408966a24cf086ce", size = 4444641, upload-time = "2026-04-08T01:56:52.882Z" }, + { url = "https://files.pythonhosted.org/packages/3d/4c/7d258f169ae71230f25d9f3d06caabcff8c3baf0978e2b7d65e0acac3827/cryptography-46.0.7-cp314-cp314t-manylinux_2_31_armv7l.whl", hash = "sha256:60627cf07e0d9274338521205899337c5d18249db56865f943cbe753aa96f40f", size = 3967749, upload-time = "2026-04-08T01:56:54.597Z" }, + { url = "https://files.pythonhosted.org/packages/b5/2a/2ea0767cad19e71b3530e4cad9605d0b5e338b6a1e72c37c9c1ceb86c333/cryptography-46.0.7-cp314-cp314t-manylinux_2_34_aarch64.whl", hash = "sha256:80406c3065e2c55d7f49a9550fe0c49b3f12e5bfff5dedb727e319e1afb9bf99", size = 4270942, upload-time = "2026-04-08T01:56:56.416Z" }, + { url = "https://files.pythonhosted.org/packages/41/3d/fe14df95a83319af25717677e956567a105bb6ab25641acaa093db79975d/cryptography-46.0.7-cp314-cp314t-manylinux_2_34_ppc64le.whl", hash = "sha256:c5b1ccd1239f48b7151a65bc6dd54bcfcc15e028c8ac126d3fada09db0e07ef1", size = 4871079, upload-time = "2026-04-08T01:56:58.31Z" }, + { url = "https://files.pythonhosted.org/packages/9c/59/4a479e0f36f8f378d397f4eab4c850b4ffb79a2f0d58704b8fa0703ddc11/cryptography-46.0.7-cp314-cp314t-manylinux_2_34_x86_64.whl", hash = "sha256:d5f7520159cd9c2154eb61eb67548ca05c5774d39e9c2c4339fd793fe7d097b2", size = 4443999, upload-time = "2026-04-08T01:57:00.508Z" }, + { url = "https://files.pythonhosted.org/packages/28/17/b59a741645822ec6d04732b43c5d35e4ef58be7bfa84a81e5ae6f05a1d33/cryptography-46.0.7-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:fcd8eac50d9138c1d7fc53a653ba60a2bee81a505f9f8850b6b2888555a45d0e", size = 4399191, upload-time = "2026-04-08T01:57:02.654Z" }, + { url = "https://files.pythonhosted.org/packages/59/6a/bb2e166d6d0e0955f1e9ff70f10ec4b2824c9cfcdb4da772c7dd69cc7d80/cryptography-46.0.7-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:65814c60f8cc400c63131584e3e1fad01235edba2614b61fbfbfa954082db0ee", size = 4655782, upload-time = "2026-04-08T01:57:04.592Z" }, + { url = "https://files.pythonhosted.org/packages/95/b6/3da51d48415bcb63b00dc17c2eff3a651b7c4fed484308d0f19b30e8cb2c/cryptography-46.0.7-cp314-cp314t-win32.whl", hash = "sha256:fdd1736fed309b4300346f88f74cd120c27c56852c3838cab416e7a166f67298", size = 3002227, upload-time = "2026-04-08T01:57:06.91Z" }, + { url = "https://files.pythonhosted.org/packages/32/a8/9f0e4ed57ec9cebe506e58db11ae472972ecb0c659e4d52bbaee80ca340a/cryptography-46.0.7-cp314-cp314t-win_amd64.whl", hash = "sha256:e06acf3c99be55aa3b516397fe42f5855597f430add9c17fa46bf2e0fb34c9bb", size = 3475332, upload-time = "2026-04-08T01:57:08.807Z" }, + { url = "https://files.pythonhosted.org/packages/a7/7f/cd42fc3614386bc0c12f0cb3c4ae1fc2bbca5c9662dfed031514911d513d/cryptography-46.0.7-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:462ad5cb1c148a22b2e3bcc5ad52504dff325d17daf5df8d88c17dda1f75f2a4", size = 7165618, upload-time = "2026-04-08T01:57:10.645Z" }, + { url = "https://files.pythonhosted.org/packages/a5/d0/36a49f0262d2319139d2829f773f1b97ef8aef7f97e6e5bd21455e5a8fb5/cryptography-46.0.7-cp38-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:84d4cced91f0f159a7ddacad249cc077e63195c36aac40b4150e7a57e84fffe7", size = 4270628, upload-time = "2026-04-08T01:57:12.885Z" }, + { url = "https://files.pythonhosted.org/packages/8a/6c/1a42450f464dda6ffbe578a911f773e54dd48c10f9895a23a7e88b3e7db5/cryptography-46.0.7-cp38-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:128c5edfe5e5938b86b03941e94fac9ee793a94452ad1365c9fc3f4f62216832", size = 4415405, upload-time = "2026-04-08T01:57:14.923Z" }, + { url = "https://files.pythonhosted.org/packages/9a/92/4ed714dbe93a066dc1f4b4581a464d2d7dbec9046f7c8b7016f5286329e2/cryptography-46.0.7-cp38-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:5e51be372b26ef4ba3de3c167cd3d1022934bc838ae9eaad7e644986d2a3d163", size = 4272715, upload-time = "2026-04-08T01:57:16.638Z" }, + { url = "https://files.pythonhosted.org/packages/b7/e6/a26b84096eddd51494bba19111f8fffe976f6a09f132706f8f1bf03f51f7/cryptography-46.0.7-cp38-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:cdf1a610ef82abb396451862739e3fc93b071c844399e15b90726ef7470eeaf2", size = 4918400, upload-time = "2026-04-08T01:57:19.021Z" }, + { url = "https://files.pythonhosted.org/packages/c7/08/ffd537b605568a148543ac3c2b239708ae0bd635064bab41359252ef88ed/cryptography-46.0.7-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:1d25aee46d0c6f1a501adcddb2d2fee4b979381346a78558ed13e50aa8a59067", size = 4450634, upload-time = "2026-04-08T01:57:21.185Z" }, + { url = "https://files.pythonhosted.org/packages/16/01/0cd51dd86ab5b9befe0d031e276510491976c3a80e9f6e31810cce46c4ad/cryptography-46.0.7-cp38-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:cdfbe22376065ffcf8be74dc9a909f032df19bc58a699456a21712d6e5eabfd0", size = 3985233, upload-time = "2026-04-08T01:57:22.862Z" }, + { url = "https://files.pythonhosted.org/packages/92/49/819d6ed3a7d9349c2939f81b500a738cb733ab62fbecdbc1e38e83d45e12/cryptography-46.0.7-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:abad9dac36cbf55de6eb49badd4016806b3165d396f64925bf2999bcb67837ba", size = 4271955, upload-time = "2026-04-08T01:57:24.814Z" }, + { url = "https://files.pythonhosted.org/packages/80/07/ad9b3c56ebb95ed2473d46df0847357e01583f4c52a85754d1a55e29e4d0/cryptography-46.0.7-cp38-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:935ce7e3cfdb53e3536119a542b839bb94ec1ad081013e9ab9b7cfd478b05006", size = 4879888, upload-time = "2026-04-08T01:57:26.88Z" }, + { url = "https://files.pythonhosted.org/packages/b8/c7/201d3d58f30c4c2bdbe9b03844c291feb77c20511cc3586daf7edc12a47b/cryptography-46.0.7-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:35719dc79d4730d30f1c2b6474bd6acda36ae2dfae1e3c16f2051f215df33ce0", size = 4449961, upload-time = "2026-04-08T01:57:29.068Z" }, + { url = "https://files.pythonhosted.org/packages/a5/ef/649750cbf96f3033c3c976e112265c33906f8e462291a33d77f90356548c/cryptography-46.0.7-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:7bbc6ccf49d05ac8f7d7b5e2e2c33830d4fe2061def88210a126d130d7f71a85", size = 4401696, upload-time = "2026-04-08T01:57:31.029Z" }, + { url = "https://files.pythonhosted.org/packages/41/52/a8908dcb1a389a459a29008c29966c1d552588d4ae6d43f3a1a4512e0ebe/cryptography-46.0.7-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a1529d614f44b863a7b480c6d000fe93b59acee9c82ffa027cfadc77521a9f5e", size = 4664256, upload-time = "2026-04-08T01:57:33.144Z" }, + { url = "https://files.pythonhosted.org/packages/4b/fa/f0ab06238e899cc3fb332623f337a7364f36f4bb3f2534c2bb95a35b132c/cryptography-46.0.7-cp38-abi3-win32.whl", hash = "sha256:f247c8c1a1fb45e12586afbb436ef21ff1e80670b2861a90353d9b025583d246", size = 3013001, upload-time = "2026-04-08T01:57:34.933Z" }, + { url = "https://files.pythonhosted.org/packages/d2/f1/00ce3bde3ca542d1acd8f8cfa38e446840945aa6363f9b74746394b14127/cryptography-46.0.7-cp38-abi3-win_amd64.whl", hash = "sha256:506c4ff91eff4f82bdac7633318a526b1d1309fc07ca76a3ad182cb5b686d6d3", size = 3472985, upload-time = "2026-04-08T01:57:36.714Z" }, + { url = "https://files.pythonhosted.org/packages/63/0c/dca8abb64e7ca4f6b2978769f6fea5ad06686a190cec381f0a796fdcaaba/cryptography-46.0.7-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:fc9ab8856ae6cf7c9358430e49b368f3108f050031442eaeb6b9d87e4dcf4e4f", size = 3476879, upload-time = "2026-04-08T01:57:38.664Z" }, + { url = "https://files.pythonhosted.org/packages/3a/ea/075aac6a84b7c271578d81a2f9968acb6e273002408729f2ddff517fed4a/cryptography-46.0.7-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:d3b99c535a9de0adced13d159c5a9cf65c325601aa30f4be08afd680643e9c15", size = 4219700, upload-time = "2026-04-08T01:57:40.625Z" }, + { url = "https://files.pythonhosted.org/packages/6c/7b/1c55db7242b5e5612b29fc7a630e91ee7a6e3c8e7bf5406d22e206875fbd/cryptography-46.0.7-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:d02c738dacda7dc2a74d1b2b3177042009d5cab7c7079db74afc19e56ca1b455", size = 4385982, upload-time = "2026-04-08T01:57:42.725Z" }, + { url = "https://files.pythonhosted.org/packages/cb/da/9870eec4b69c63ef5925bf7d8342b7e13bc2ee3d47791461c4e49ca212f4/cryptography-46.0.7-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:04959522f938493042d595a736e7dbdff6eb6cc2339c11465b3ff89343b65f65", size = 4219115, upload-time = "2026-04-08T01:57:44.939Z" }, + { url = "https://files.pythonhosted.org/packages/f4/72/05aa5832b82dd341969e9a734d1812a6aadb088d9eb6f0430fc337cc5a8f/cryptography-46.0.7-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:3986ac1dee6def53797289999eabe84798ad7817f3e97779b5061a95b0ee4968", size = 4385479, upload-time = "2026-04-08T01:57:46.86Z" }, + { url = "https://files.pythonhosted.org/packages/20/2a/1b016902351a523aa2bd446b50a5bc1175d7a7d1cf90fe2ef904f9b84ebc/cryptography-46.0.7-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:258514877e15963bd43b558917bc9f54cf7cf866c38aa576ebf47a77ddbc43a4", size = 3412829, upload-time = "2026-04-08T01:57:48.874Z" }, ] [[package]] From b3d1ad67d7e070a7697a1486f7b735e541e25b06 Mon Sep 17 00:00:00 2001 From: Sampath Kumar Date: Wed, 22 Apr 2026 12:12:28 +0200 Subject: [PATCH 60/67] chore(docs): update migration README.md (#1003) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Description Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [x] Follow the [Contribution Guide](https://github.com/a2aproject/a2a-python/blob/main/CONTRIBUTING.md). - [x] Make your Pull Request title in the specification. - Important Prefixes for [release-please](https://github.com/googleapis/release-please): - `fix:` which represents bug fixes, and correlates to a [SemVer](https://semver.org/) patch. - `feat:` represents a new feature, and correlates to a SemVer minor. - `feat!:`, or `fix!:`, `refactor!:`, etc., which represent a breaking change (indicated by the `!`) and will result in a SemVer major. - [x] Ensure the tests and linter pass (Run `bash scripts/format.sh` from the repository root to format) - [x] Appropriate docs were updated (if necessary) Fixes # 🦕 --------- Co-authored-by: Iva Sokolaj <102302011+sokoliva@users.noreply.github.com> Co-authored-by: sokoliva --- docs/migrations/v1_0/README.md | 277 +++++++++++++++++++++++---------- 1 file changed, 197 insertions(+), 80 deletions(-) diff --git a/docs/migrations/v1_0/README.md b/docs/migrations/v1_0/README.md index 34b2f1bed..b0d71c8cc 100644 --- a/docs/migrations/v1_0/README.md +++ b/docs/migrations/v1_0/README.md @@ -6,6 +6,11 @@ Beyond protocol support, `v1.0` enhances the developer experience by introducing This documentation details the technical upgrades and architectural modifications introduced in A2A Python SDK v1.0. For developers using the database persistence layer, please refer to the [Database Migration Guide](database/) for specific update instructions. +> ### **Why Upgrade to v1.0?** +> * **Protocol v1.0 Compliance**: Full alignment with the latest A2A industry standard for cross-agent interoperability. +> * **Reduced Boilerplate**: Unified helper utilities that simplify common tasks like message and task creation. +> * **Architectural Flexibility**: Direct Starlette/FastAPI integration allows you to mount A2A routes into existing applications with full control over middleware. + --- ## Table of Contents @@ -26,7 +31,7 @@ This documentation details the technical upgrades and architectural modification ## 1. Update Dependencies -(UV users) To upgrade to the latest version of the `a2a-sdk`, update the dependencies section in your `pyproject.toml` file. +For UV users: To upgrade to the latest version of the `a2a-sdk`, update the dependencies section in your `pyproject.toml` file. | File | Before (`v0.3`) | After (`v1.0`) | |------------------|-----------------------------------|-----------------------------------| @@ -52,18 +57,17 @@ pip install --upgrade a2a-sdk ## 2. Types -Types have migrated from Pydantic models to Protobuf-based classes. +[Types](https://github.com/a2aproject/a2a-python/blob/main/src/a2a/types/a2a_pb2.pyi) have migrated from Pydantic models to Protobuf-based classes to align with the A2A spec's proto-first design and to adopt ProtoJSON as the canonical JSON serialization standard, ensuring consistent cross-implementation interoperability. -### Enum values: snake_case → SCREAMING_SNAKE_CASE +### Enum values: `snake_case` → `SCREAMING_SNAKE_CASE` -All the enum values are now standardized from snake_case to **SCREAMING_SNAKE_CASE** format. +All enum values are now [standardized](https://a2a-protocol.org/v1.0.0/specification/#55-json-field-naming-convention) to use `SCREAMING_SNAKE_CASE` format. This affects every enum in the SDK: `TaskState`, `Role`. | Enum | v0.3 | v1.0 | |---|---|---| -| `TaskState` | *(no equivalent — protobuf default)* | `TaskState.TASK_STATE_UNSPECIFIED` | | `TaskState` | `TaskState.submitted` | `TaskState.TASK_STATE_SUBMITTED` | | `TaskState` | `TaskState.working` | `TaskState.TASK_STATE_WORKING` | | `TaskState` | `TaskState.completed` | `TaskState.TASK_STATE_COMPLETED` | @@ -72,29 +76,60 @@ This affects every enum in the SDK: `TaskState`, `Role`. | `TaskState` | `TaskState.input_required` | `TaskState.TASK_STATE_INPUT_REQUIRED` | | `TaskState` | `TaskState.auth_required` | `TaskState.TASK_STATE_AUTH_REQUIRED` | | `TaskState` | `TaskState.rejected` | `TaskState.TASK_STATE_REJECTED` | +| `TaskState` | | 🆕 `TaskState.TASK_STATE_UNSPECIFIED` | ||| -| `Role` | *(no equivalent — protobuf default)* | `Role.ROLE_UNSPECIFIED` | | `Role` | `Role.user` | `Role.ROLE_USER` | | `Role` | `Role.agent` | `Role.ROLE_AGENT` | +| `Role` | | 🆕 `Role.ROLE_UNSPECIFIED` | > **Example**: [`a2a-mcp-without-framework/server/agent_executor.py` in PR #509](https://github.com/a2aproject/a2a-samples/pull/509/changes#diff-1f9b098f9f82ee40666ee61db56dc2246281423c445bcf017079c53a0a05954f) ### Message and Part construction -Constructing messages is simplified in v1.0. The old API required wrapping content in an intermediate type (`TextPart`, `FilePart`, `DataPart`) before placing it inside a `Part`. In v1.0, `Part` is a single unified message — set the content type directly on it and the wrapper types are gone entirely. +Constructing messages is simplified in v1.0. The old API required wrapping content in an intermediate type (`TextPart`, `FilePart`, `DataPart`) before placing it inside a `Part`. In v1.0, the wrapper types are removed and all content fields are set directly on the unified `Part` message. + +| Part type | v0.3 | v1.0 | +|---|---|---| +| Text | `Part(TextPart(text=..., ...))` | `Part(text=..., ...)` | +| File (bytes) | `Part(FilePart(file=FileWithBytes(bytes=..., ...)))` | `Part(raw=..., ...)` | +| File (URI) | `Part(FilePart(file=FileWithUri(uri=..., ...)))` | `Part(url=..., ...)` | +| Structured data | `Part(DataPart(data=..., ...))` | `Part(data=..., ...)` | -Key differences: -- `Part(TextPart(text=...))` → `Part(text=...)` (flat union field) -- `Role.user` → `Role.ROLE_USER`, `Role.agent` → `Role.ROLE_AGENT` +**Note**: +* When using `File (bytes)` in v1.0, the data serialisatinon (via base64 encoding) is not required as A2A now uses Protobuf that automatically does it for you. +* In v1.0, `Part.DataPart.data` is renamed to `Part.data` and is of type `google.protobuf.Value`. Use `ParseDict` to convert a Python dict into a suitable value. See the examples below for more details. **Before (v0.3):** ```python -from a2a.types import Message, Part, Role, TextPart +import base64 from uuid import uuid4 +from a2a.types import Message, Part, Role, TextPart, FilePart, DataPart, FileWithBytes, FileWithUri + +# Text part +text_part = Part(TextPart(text="What's the weather in Warsaw?")) + +# File part — base64-encoded bytes (e.g. an image) +with open("photo.png", "rb") as f: + image_b64 = base64.b64encode(f.read()).decode() +file_bytes_part = Part(FilePart(file=FileWithBytes( + bytes=image_b64, + mime_type="image/png", + name="photo.png", +))) + +# File part — URI pointing to a remote file +file_uri_part = Part(FilePart(file=FileWithUri( + uri="https://example.com/report.pdf", + mime_type="application/pdf", + name="report.pdf", +))) + +# Data part — structured JSON payload +data_part = Part(DataPart(data={"city": "Warsaw", "temperature_c": 18})) message = Message( role=Role.user, - parts=[Part(TextPart(text="Hello"))], + parts=[text_part, file_bytes_part, file_uri_part, data_part], message_id=uuid4().hex, task_id=uuid4().hex, ) @@ -102,53 +137,84 @@ message = Message( **After (v1.0):** -Using [A2A helper utilities](#helper-utilities) - ```python -from a2a.helpers import new_text_message -from a2a.types import Role +from uuid import uuid4 +from google.protobuf.json_format import ParseDict +from google.protobuf.struct_pb2 import Value +from a2a.types import Message, Part, Role -# Use the helper function to create `Hello` message -message = new_text_message(text="Hello", role=Role.ROLE_USER) +# Text part +text_part = Part(text="What's the weather in Warsaw?") -``` +# File part — raw bytes (e.g. an image); no base64 encoding required +with open("photo.png", "rb") as f: + image_bytes = f.read() +file_bytes_part = Part( + raw=image_bytes, + media_type="image/png", + filename="photo.png", +) -Without helper utils, you can still construct directly +# File part — URI pointing to a remote file +file_uri_part = Part( + url="https://example.com/report.pdf", + media_type="application/pdf", + filename="report.pdf", +) -```python -from a2a.types import Message, Part, Role -from uuid import uuid4 +# Data part — use ParseDict to convert a Python dict to a protobuf Value +data_part = Part( + data=ParseDict({"city": "Warsaw", "temperature_c": 18}, Value()), +) message = Message( role=Role.ROLE_USER, - parts=[Part(text="Hello")], + parts=[text_part, file_bytes_part, file_uri_part, data_part], message_id=uuid4().hex, task_id=uuid4().hex, ) ``` +For text-only messages, use the [A2A helper utilities](#9-helper-utilities) to reduce boilerplate: + +```python +from a2a.helpers import new_text_message +from a2a.types import Role + +message = new_text_message(text="What's the weather in Warsaw?", role=Role.ROLE_USER) +``` + > **Example**: [`helloworld/test_client.py` in PR #474](https://github.com/a2aproject/a2a-samples/pull/474/files#diff-f62c07d3b00364a3100b7effb3e2a1cca0624277d3e40da1bdb07bb46b6a8cef) ### AgentCard Structure -The new `AgentCard` can supports multiple transport bindings using `AgentInterface` class. - -Key differences: -- `url` is gone; use `supported_interfaces` with one or more `AgentInterface` entries -- `AgentCapabilities.input_modes` and `AgentCapabilities.output_modes` are removed; use `AgentCard.default_input_modes` / `AgentCard.default_output_modes` for card-level defaults, or `AgentSkill.input_modes` / `AgentSkill.output_modes` for per-skill overrides -- `supports_authenticated_extended_card` is no longer a top-level `AgentCard` field; it has moved into `AgentCapabilities` and is renamed to `extended_agent_card` -- `AgentInterface.protocol_binding` accepted values: `'JSONRPC'`, `'HTTP+JSON'`, `'GRPC'` -- `examples` field was removed; set it per `AgentSkill` instead +Key changes: +- Added an `AgentInterface` class to support multiple transport bindings via the newly added `supported_interfaces` field in AgentCard. +- The `url` parameter in `AgentCard` is removed and is now part of `AgentInterface`. +- Accepted values for `AgentInterface.protocol_binding`: `'JSONRPC'`, `'HTTP+JSON'`, `'GRPC'` +- The `AgentCard.supports_authenticated_extended_card` field is renamed to `AgentCapabilities.extended_agent_card`. +- The `AgentCapabilities.input_modes` and `AgentCapabilities.output_modes` fields are removed; use `AgentCard.default_input_modes` and `AgentCard.default_output_modes` for card-level defaults, or `AgentSkill.input_modes` and `AgentSkill.output_modes` for per-skill overrides. +- The `examples` parameter in `AgentCard` is removed and is now part of `AgentSkill`. **Before (v0.3):** ```python from a2a.types import AgentCard, AgentCapabilities, AgentSkill +skill = AgentSkill( + id='hello_world', + name='Hello World', + description='Returns a Hello World message.', + tags=['hello', 'world'], + input_modes=['text/plain'], + output_modes=['text/plain'], + examples=['hello world'], +) + agent_card = AgentCard( - name='My Agent', - description='...', + name='Hello World Agent', + description='Returns Hello, World!', url='http://localhost:9999/', - version='1.0.0', + version='0.0.1', default_input_modes=['text/plain'], default_output_modes=['text/plain'], supports_authenticated_extended_card=True, @@ -158,17 +224,27 @@ agent_card = AgentCard( streaming=True, ), skills=[skill], - examples=['example'], + examples=['Hello, World!'], ) ``` **After (v1.0):** ```python -from a2a.types import AgentCard, AgentCapabilities, AgentInterface, AgentSkill, +from a2a.types import AgentCard, AgentCapabilities, AgentInterface, AgentSkill + +skill = AgentSkill( + id='hello_world', + name='Hello World', + description='Returns a Hello World message.', + tags=['hello', 'world'], + input_modes=['text/plain'], + output_modes=['text/plain'], + examples=['hello world', 'Hello, World!'], # moved from AgentCard.examples +) agent_card = AgentCard( - name='My Agent', - description='...', + name='Hello World Agent', + description='Returns Hello, World!', supported_interfaces=[ # JSON-RPC AgentInterface( @@ -181,7 +257,7 @@ agent_card = AgentCard( url='http://localhost:50051/a2a/grpc/', ) ], - version='1.0.0', + version='0.0.1', default_input_modes=['text/plain'], default_output_modes=['text/plain'], capabilities=AgentCapabilities( @@ -225,46 +301,64 @@ request_handler = DefaultRequestHandler( ## 4. Server: Application Setup -The wrapper classes (`A2AStarletteApplication`, `A2AFastApiApplication` and `A2ARESTFastApiApplication`) are now removed. The Server setup now uses Starlette route factory functions directly, giving you full control over the routing. +The application wrapper classes (`A2AStarletteApplication`, `A2AFastApiApplication` and `A2ARESTFastApiApplication`) have been removed. The server setup now uses Starlette route factory functions directly, giving you better control over the routing, middleware, authentication, logging, and other aspects of the server. **Before (v0.3):** ```python from a2a.server.apps import A2AStarletteApplication import uvicorn +# Create application using A2AStarletteApplication wrapper class server = A2AStarletteApplication( agent_card=agent_card, http_handler=request_handler, ) + +# Start the server uvicorn.run(server.build(), host=host, port=port) ``` **After (v1.0):** + +Define routes for each supported transport as defined in the `AgentCard`. + ```python from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes -from starlette.applications import Starlette -import uvicorn +# Define routes for transports as defined in the AgentCard routes = [] +# A2A Agent Card routes routes.extend(create_agent_card_routes(agent_card)) -routes.extend(create_jsonrpc_routes(request_handler, rpc_url='/')) +# JSON-RPC routes +routes.extend(create_jsonrpc_routes(request_handler, rpc_url='/api/v1/jsonrpc/')) + +# Optional: Add routes for REST/HTTP transports +# routes.extend(create_rest_routes(request_handler, path_prefix='/api/v1/rest/')) +``` + +Add the routes to the application: + +```python +from starlette.applications import Starlette +import uvicorn +# Create application using routes app = Starlette(routes=routes) + +# Start the server uvicorn.run(app, host=host, port=port) ``` -If you need REST transport in addition to JSON-RPC: +If you prefer FastAPI for your server application: + ```python -from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes, create_rest_routes -from starlette.applications import Starlette +from fastapi import FastAPI import uvicorn -routes = [] -routes.extend(create_agent_card_routes(agent_card)) -routes.extend(create_jsonrpc_routes(request_handler, rpc_url='/')) -routes.extend(create_rest_routes(request_handler)) +# Create application using routes +app = FastAPI(routes=routes) -app = Starlette(routes=routes) +# Start the server uvicorn.run(app, host=host, port=port) ``` @@ -297,19 +391,18 @@ create_rest_routes(request_handler, enable_v0_3_compat=True) ## 6. Client: Creating a Client -New `create_client()` `ClientFactory` function that creates a client for the agent. +In `v1.0`, use the `a2a.client.create_client()` helper function to create a `Client` for the agent. -> **Note**: The legacy `A2AClient` class has been removed. **Before (v0.3):** ```python from a2a.client import ClientFactory -# From URL +# Option 1: Using Agent Server URL factory = ClientFactory() client = factory.create_client('http://localhost:9999/') -# From an already-resolved AgentCard +# Option 2: Using AgentCard factory = ClientFactory() client = factory.create_client(agent_card) ``` @@ -318,10 +411,10 @@ client = factory.create_client(agent_card) ```python from a2a.client import create_client -# From URL — resolves the agent card automatically +# Option 1: Using Agent Server URL client = await create_client('http://localhost:9999/') -# From an already-resolved AgentCard +# Option 2: Using AgentCard client = await create_client(agent_card) ``` @@ -332,9 +425,9 @@ client = await create_client(agent_card) ## 7. Client: Send Message -The `BaseClient.send_message()` return type is standardised from `AsyncIterator[ClientEvent | Message]` to `AsyncIterator[StreamResponse]`. +The `BaseClient.send_message()` return type is standardized from `AsyncIterator[ClientEvent | Message]` to `AsyncIterator[StreamResponse]`. -Each `StreamResponse` yields exactly one of: `task`, `message`, `status_update`, or `artifact_update`. Use `HasField()` to check which field is set. +Each `StreamResponse` contains exactly one of: `task`, `message`, `status_update`, or `artifact_update`. Use `HasField()` to check which field is set. **Before (v0.3):** @@ -368,6 +461,7 @@ async for chunk in client.send_message(request): `ClientConfig.push_notification_config` is now **singular** (a single `TaskPushNotificationConfig` or `None`), not a list. + **Before (v0.3):** ```python config = ClientConfig( @@ -386,35 +480,59 @@ config = ClientConfig( ## 9. Helper Utilities -A new `a2a.helpers` module consolidates helper functions into a single import. Most were previously available under `a2a.utils.*`; a few are new in v1.0. +To improve the developer experience, we have consolidated helper functions into a single import. In v0.3, these helper functions were scattered across different modules. In v1.0, they are all available under `a2a.helpers`. + +| Helper Function | Description | +|---|---| +| `display_agent_card` | Prints a human-readable summary of an `AgentCard` to stdout. | +| `get_artifact_text` | Joins all text parts of an `Artifact` into a single string (using `\n` as delimiter). | +| `get_message_text` | Joins all text parts of a `Message` into a single string (using `\n` as delimiter). | +| `get_stream_response_text` | Extracts text from a `StreamResponse` protobuf message. | +| `get_text_parts` | Returns a list of raw text strings from a sequence of `Part` objects, skipping non-text parts. | +| `new_artifact` | Creates an `Artifact` from a list of `Part` objects, a name, and an optional description and ID. | +| `new_message` | Creates a `Message` from a list of `Part` objects with a role (defaults to `ROLE_AGENT`), and optional task/context IDs. | +| `new_task` | Creates a `Task` with an explicit task ID, context ID, and state. | +| `new_task_from_user_message` | Creates a `TASK_STATE_SUBMITTED` `Task` from a user `Message`. Raises an error if the role is not `ROLE_USER` or if parts are empty. | +| `new_text_artifact` | Creates an `Artifact` with a single text `Part`, a name, and an optional description and ID. | +| `new_text_artifact_update_event` | Creates a `TaskArtifactUpdateEvent` with a text artifact. | +| `new_text_message` | Creates a `Message` with a single text `Part`; role defaults to `ROLE_AGENT`. | +| `new_text_status_update_event` | Creates a `TaskStatusUpdateEvent` with a text message. | + +**Example usage:** + +**1. Create a text-based message** ```python -from a2a.helpers import ( - display_agent_card, # print a human-readable summary of an AgentCard to stdout - get_artifact_text, # join all text parts of an Artifact into a single string (delimiter='\n') - get_message_text, # join all text parts of a Message into a single string (delimiter='\n') - get_stream_response_text, # extract text from a StreamResponse proto message - get_text_parts, # return a list of raw text strings from a sequence of Parts (skips non-text parts) - new_artifact, # create an Artifact from a list of Parts, name, optional description and artifact_id - new_message, # create a Message from a list of Parts with role (default ROLE_AGENT), optional task_id/context_id - new_task, # create a Task with explicit task_id, context_id, and state - new_task_from_user_message, # create a TASK_STATE_SUBMITTED Task from a user Message; raises if role != ROLE_USER or parts are empty - new_text_artifact, # create an Artifact with a single text Part, name, optional description and artifact_id - new_text_artifact_update_event, # create a TaskArtifactUpdateEvent with a text artifact - new_text_message, # create a Message with a single text Part; role defaults to ROLE_AGENT - new_text_status_update_event, # create a TaskStatusUpdateEvent with a text message -) +from a2a.helpers import new_text_message +from a2a.types import Role + +# Create a user message +user_message = new_text_message("What's the weather?", role=Role.ROLE_USER) + +# Create an agent response message +response_message = new_text_message("It is sunny today!") +``` + +**2. Extract text from a message** + +```python +from a2a.helpers import get_message_text + +# Get text from a message +text = get_message_text(response_message) +print(text) ``` --- ## 10. Summary of Key Changes in v1.0 -- **Standardisation to `SCREAMING_SNAKE_CASE`** — All enum values have been renamed from `kebab-case` strings to `SCREAMING_SNAKE_CASE` for compliance with the ProtoJSON specification. +- **Migration to Protobuf** — Core types have migrated from Pydantic models to Protobuf-based classes. Protobuf objects do not support arbitrary attribute assignment. Use `MessageToDict` from `google.protobuf.json_format` to convert objects to dictionaries, and `HasField('field_name')` to check for optional fields. +- **Standardization to `SCREAMING_SNAKE_CASE`** — All enum values have been renamed from `snake_case` strings to `SCREAMING_SNAKE_CASE` for compliance with the ProtoJSON specification. - **`AgentCard`** — Significantly restructured to support multiple transport interfaces. - - **`AgentInterface`** — The top-level `url` field is replaced by `supported_interfaces`, a list of `AgentInterface` objects. Each entry describes a single transport endpoint carrying `protocol_binding`, `protocol_version`, and `url`. + - **`AgentInterface`** — The top-level `url` field is replaced by `supported_interfaces`, a list of `AgentInterface` objects. Each entry describes a single transport endpoint with fields for `protocol_binding`, `protocol_version`, and `url`. - **Input and output modes** — `AgentCapabilities.input_modes` and `AgentCapabilities.output_modes` are removed and now live directly on `AgentCard` as `default_input_modes` and `default_output_modes`. Individual skills can override these with their own `input_modes` and `output_modes`. -- **Application setup** — The wrapper classes (`A2AStarletteApplication`, `A2AFastApiApplication` and `A2ARESTFastApiApplication`) are now removed. Server setup now uses route factory functions `create_jsonrpc_routes()`, `create_rest_routes()`, `create_agent_card_routes()` composed directly into a Starlette or FastAPI app. +- **Application setup** — The wrapper classes (`A2AStarletteApplication`, `A2AFastApiApplication` and `A2ARESTFastApiApplication`) have been removed. Server setup now uses route factory functions — `create_jsonrpc_routes()`, `create_rest_routes()`, and `create_agent_card_routes()` — composed directly into a Starlette or FastAPI app. - **Helper utilities** — A new `a2a.helpers` module consolidates all helper functions under a single import, replacing the scattered `a2a.utils.*` modules and adding new helpers for constructing and reading v1.0 proto types. --- @@ -439,4 +557,3 @@ uv run python samples/cli.py Then type a message like `hello` and press Enter. See [`samples/README.md`](../../../samples/README.md) for full details. For more examples see the [a2a-samples repository](https://github.com/a2aproject/a2a-samples/tree/main/samples/python). - From 7d197dbf81e31398a41f8d6795e15170f082104f Mon Sep 17 00:00:00 2001 From: Iva Sokolaj <102302011+sokoliva@users.noreply.github.com> Date: Wed, 22 Apr 2026 15:58:37 +0200 Subject: [PATCH 61/67] fix(proto): use field.label instead of is_repeated for protobuf compatibility (#1010) # Description Replace `field.is_repeated` with `field.label == field.LABEL_REPEATED` in `proto_utils.py` to support older protobuf versions where the `is_repeated` attribute is not available on `FieldDescriptor`. # Problem When using older protobuf versions accessing `field.is_repeated` raises: `a2a.utils.errors.InternalError: 'google._upb._message.FieldDescriptor' object has no attribute 'is_repeated'` The project's declared minimum is `protobuf>=5.29.5`, `5.29.5` does not support `is_repeated`. This caused `send_message` (and other proto-validating client calls) to fail at runtime for users on older protobuf releases. The `is_repeated` property was only added to `FieldDescriptor` in newer protobuf releases (6.x), so relying on it broke compatibility with the supported version range. Although the deprecated label was already [removed in some 7.x version](https://github.com/protocolbuffers/protobuf/releases/tag/v34.0-rc1.1), 7.x version can't be resolved with the other constraints we have in the project. # Fix Use the long-standing `label` attribute and compare against `FieldDescriptor.LABEL_REPEATED`, which is available across all supported protobuf versions and is the canonical way to detect repeated fields. # Testing - `uv run pytest` passes against the supported protobuf version range. --- src/a2a/utils/proto_utils.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/a2a/utils/proto_utils.py b/src/a2a/utils/proto_utils.py index f77593297..b191f98e0 100644 --- a/src/a2a/utils/proto_utils.py +++ b/src/a2a/utils/proto_utils.py @@ -174,7 +174,10 @@ def parse_params(params: QueryParams, message: ProtobufMessage) -> None: field = fields[k] v_list = params.getlist(k) - if field.label == field.LABEL_REPEATED: + # TODO(https://github.com/a2aproject/a2a-python/issues/1011): Replace + # deprecated `field.label` with `field.is_repeated` once the minimum + # protobuf version requirement is bumped. + if field.label == FieldDescriptor.LABEL_REPEATED: accumulated: list[Any] = [] for v in v_list: if not v: @@ -208,7 +211,10 @@ def _check_required_field_violation( ) -> ValidationDetail | None: """Check if a required field is missing or invalid.""" val = getattr(msg, field.name) - if field.is_repeated: + # TODO(https://github.com/a2aproject/a2a-python/issues/1011): Replace + # deprecated `field.label` with `field.is_repeated` once the minimum + # protobuf version requirement is bumped. + if field.label == FieldDescriptor.LABEL_REPEATED: if not val: return ValidationDetail( field=field.name, @@ -249,7 +255,10 @@ def _recurse_validation( return errors val = getattr(msg, field.name) - if not field.is_repeated: + # TODO(https://github.com/a2aproject/a2a-python/issues/1011): Replace + # deprecated `field.label` with `field.is_repeated` once the minimum + # protobuf version requirement is bumped. + if field.label != FieldDescriptor.LABEL_REPEATED: if msg.HasField(field.name): sub_errs = _validate_proto_required_fields_internal(val) _append_nested_errors(errors, field.name, sub_errs) From 8a0f38df17fcce71ba1a66e63bbcb6121bf44378 Mon Sep 17 00:00:00 2001 From: Ivan Shymko Date: Thu, 23 Apr 2026 08:21:21 +0200 Subject: [PATCH 62/67] ci(linter): show which step failed (#1013) Before: image After: image --- .github/workflows/linter.yaml | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/.github/workflows/linter.yaml b/.github/workflows/linter.yaml index 2c2a035a0..4c211aba8 100644 --- a/.github/workflows/linter.yaml +++ b/.github/workflows/linter.yaml @@ -62,12 +62,22 @@ jobs: - name: Check Linter Statuses if: always() # This ensures the step runs even if previous steps failed + env: + RUFF_LINT: ${{ steps.ruff-lint.outcome }} + RUFF_FORMAT: ${{ steps.ruff-format.outcome }} + MYPY: ${{ steps.mypy.outcome }} + PYRIGHT: ${{ steps.pyright.outcome }} + JSCPD: ${{ steps.jscpd.outcome }} run: | - if [[ "${{ steps.ruff-lint.outcome }}" == "failure" || \ - "${{ steps.ruff-format.outcome }}" == "failure" || \ - "${{ steps.mypy.outcome }}" == "failure" || \ - "${{ steps.pyright.outcome }}" == "failure" || \ - "${{ steps.jscpd.outcome }}" == "failure" ]]; then - echo "One or more linting/checking steps failed." + failed=() + [[ "$RUFF_LINT" == "failure" ]] && failed+=("Ruff Linter") + [[ "$RUFF_FORMAT" == "failure" ]] && failed+=("Ruff Formatter") + [[ "$MYPY" == "failure" ]] && failed+=("MyPy") + [[ "$PYRIGHT" == "failure" ]] && failed+=("Pyright") + [[ "$JSCPD" == "failure" ]] && failed+=("JSCPD") + + if (( ${#failed[@]} )); then + joined=$(IFS=', '; echo "${failed[*]}") + echo "::error title=Linter failures::The following checks failed: ${joined}. See the corresponding step logs above for details." exit 1 fi From d02ae4e0a1ffcf104cc1c7143dd408e0a5d4d996 Mon Sep 17 00:00:00 2001 From: Iva Sokolaj <102302011+sokoliva@users.noreply.github.com> Date: Thu, 23 Apr 2026 11:31:39 +0200 Subject: [PATCH 63/67] chore(docs): document strict AgentExecutor streaming rules in v1.0 guide (#1014) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Description - Adds a new section **"Server: AgentExecutor Streaming Rules"** to the v0.3 → v1.0 migration guide, covering the breaking change introduced in #979 where the server now strictly enforces the A2A spec rules for `message/stream`. - Documents the four `InvalidAgentResponseError` cases (mixing `Message` and `Task` events, multiple `Message`s, updates before the initial `Task`) with a clear table mapped to the actual error messages from `active_task.py`. - Provides before/after migration examples — including the message-only pattern and the full task lifecycle pattern (Task → WORKING → invoke agent → artifact → COMPLETED) — plus a quick migration checklist. - Links to the updated `helloworld/agent_executor.py` in [a2a-samples#474](https://github.com/a2aproject/a2a-samples/pull/474) as a working reference. - Adds a corresponding bullet to the "Summary of Key Changes" section # Why Existing v0.3 executors that mixed `Message` and `Task` events were silently tolerated. After #979 they fail at runtime with `InvalidAgentResponseError`, which is easy to hit and hard to diagnose without guidance. This section gives users a clear path to migrate. --- docs/migrations/v1_0/README.md | 148 ++++++++++++++++++++++++++++----- 1 file changed, 126 insertions(+), 22 deletions(-) diff --git a/docs/migrations/v1_0/README.md b/docs/migrations/v1_0/README.md index b0d71c8cc..da3d6ba79 100644 --- a/docs/migrations/v1_0/README.md +++ b/docs/migrations/v1_0/README.md @@ -18,14 +18,15 @@ This documentation details the technical upgrades and architectural modification 1. [Update Dependencies](#1-update-dependencies) 2. [Types](#2-types) 3. [Server: DefaultRequestHandler](#3-server-defaultrequesthandler) -4. [Server: Application Setup](#4-server-application-setup) -5. [Supporting v0.3 Clients](#5-supporting-v03-clients) -6. [Client: Creating a Client](#6-client-creating-a-client) -7. [Client: Send Message](#7-client-send-message) -8. [Client: Push Notifications Config](#8-client-push-notifications-config) -9. [Helper Utilities](#9-helper-utilities) -10. [Summary of Key Changes](#10-summary-of-key-changes-in-v10) -11. [Get Started](#11-get-started) +4. [Server: AgentExecutor Streaming Rules](#4-server-agentexecutor-streaming-rules) +5. [Server: Application Setup](#5-server-application-setup) +6. [Supporting v0.3 Clients](#6-supporting-v03-clients) +7. [Client: Creating a Client](#7-client-creating-a-client) +8. [Client: Send Message](#8-client-send-message) +9. [Client: Push Notifications Config](#9-client-push-notifications-config) +10. [Helper Utilities](#10-helper-utilities) +11. [Summary of Key Changes](#11-summary-of-key-changes-in-v10) +12. [Get Started](#12-get-started) --- @@ -96,7 +97,7 @@ Constructing messages is simplified in v1.0. The old API required wrapping conte | Structured data | `Part(DataPart(data=..., ...))` | `Part(data=..., ...)` | **Note**: -* When using `File (bytes)` in v1.0, the data serialisatinon (via base64 encoding) is not required as A2A now uses Protobuf that automatically does it for you. +* When using `File (bytes)` in v1.0, data serialization (via base64 encoding) is not required because A2A now uses Protobuf, which handles it automatically. * In v1.0, `Part.DataPart.data` is renamed to `Part.data` and is of type `google.protobuf.Value`. Use `ParseDict` to convert a Python dict into a suitable value. See the examples below for more details. **Before (v0.3):** @@ -175,7 +176,7 @@ message = Message( ) ``` -For text-only messages, use the [A2A helper utilities](#9-helper-utilities) to reduce boilerplate: +For text-only messages, use the [A2A helper utilities](#10-helper-utilities) to reduce boilerplate: ```python from a2a.helpers import new_text_message @@ -191,7 +192,7 @@ message = new_text_message(text="What's the weather in Warsaw?", role=Role.ROLE_ Key changes: - Added an `AgentInterface` class to support multiple transport bindings via the newly added `supported_interfaces` field in AgentCard. - The `url` parameter in `AgentCard` is removed and is now part of `AgentInterface`. -- Accepted values for `AgentInterface.protocol_binding`: `'JSONRPC'`, `'HTTP+JSON'`, `'GRPC'` +- Accepted values for `AgentInterface.protocol_binding`: `'JSONRPC'`, `'HTTP+JSON'`, `'GRPC'`. - The `AgentCard.supports_authenticated_extended_card` field is renamed to `AgentCapabilities.extended_agent_card`. - The `AgentCapabilities.input_modes` and `AgentCapabilities.output_modes` fields are removed; use `AgentCard.default_input_modes` and `AgentCard.default_output_modes` for card-level defaults, or `AgentSkill.input_modes` and `AgentSkill.output_modes` for per-skill overrides. - The `examples` parameter in `AgentCard` is removed and is now part of `AgentSkill`. @@ -299,9 +300,111 @@ request_handler = DefaultRequestHandler( --- -## 4. Server: Application Setup +## 4. Server: AgentExecutor Streaming Rules -The application wrapper classes (`A2AStarletteApplication`, `A2AFastApiApplication` and `A2ARESTFastApiApplication`) have been removed. The server setup now uses Starlette route factory functions directly, giving you better control over the routing, middleware, authentication, logging, and other aspects of the server. +The server now strictly enforces the [A2A spec rules for `SendStreamingMessage`](https://a2a-protocol.org/v1.0.0/specification/#312-send-streaming-message). Existing executors that mix message and task events, or emit task updates before the initial `Task`, will fail at runtime with `InvalidAgentResponseError`. See [PR #979](https://github.com/a2aproject/a2a-python/pull/979). + +In v1.0, your `AgentExecutor` MUST follow exactly one of these two streaming patterns: + +1. **Message-only stream** — enqueue exactly **one** `Message` and stop. +2. **Task lifecycle stream** — enqueue a `Task` **first**, then zero or more `TaskStatusUpdateEvent` / `TaskArtifactUpdateEvent` objects until a terminal state is reached. + +The following are now hard errors (each raises `InvalidAgentResponseError`): + +| Violation | Error message | +|---|---| +| Enqueue a `Message` after a `Task` (mixing modes) | *Received Message object in task mode...* | +| Enqueue more than one `Message` | *Multiple Message objects received.* | +| Enqueue a `Task`/update event after a `Message` | *Received `` in message mode...* | +| Enqueue a `TaskStatusUpdateEvent` before the initial `Task` | *Agent should enqueue Task before `` event* | + +### Migration + +**Before (v0.3 — silently tolerated):** +```python +from a2a.helpers import new_text_message +from a2a.server.agent_execution import AgentExecutor +from a2a.types import TaskStatusUpdateEvent + +class MyExecutor(AgentExecutor): + async def execute(self, context, event_queue): + # Mixing Message and Task events — no longer allowed. + await event_queue.enqueue_event(new_text_message('Working on it...')) + await event_queue.enqueue_event( + TaskStatusUpdateEvent(...) # ❌ raises InvalidAgentResponseError + ) +``` + +**After (v1.0 — pick one pattern):** + +```python +from a2a.helpers import ( + new_task_from_user_message, + new_text_artifact_update_event, + new_text_message, + new_text_status_update_event, +) +from a2a.server.agent_execution import AgentExecutor +from a2a.types import Role, TaskState + +# Pattern A: Message-only stream — one Message, then done. +class GreetingExecutor(AgentExecutor): + async def execute(self, context, event_queue): + await event_queue.enqueue_event( + new_text_message('Hello!', role=Role.ROLE_AGENT) + ) + +# Pattern B: Task lifecycle stream — Task first, then updates. +class WorkflowExecutor(AgentExecutor): + def __init__(self, agent): + self._agent = agent # Your underlying agent (LLM, tool, etc.) + + async def execute(self, context, event_queue): + task = context.current_task or new_task_from_user_message(context.message) + await event_queue.enqueue_event(task) # ✅ Task MUST be first + + await event_queue.enqueue_event( + new_text_status_update_event( + task_id=task.id, + context_id=task.context_id, + state=TaskState.TASK_STATE_WORKING, + text='Processing...', + ) + ) + + result = await self._agent.invoke(context.message) + await event_queue.enqueue_event( + new_text_artifact_update_event( + task_id=task.id, + context_id=task.context_id, + name='result', + text=result, + ) + ) + + await event_queue.enqueue_event( + new_text_status_update_event( + task_id=task.id, + context_id=task.context_id, + state=TaskState.TASK_STATE_COMPLETED, + text='Done!', + ) + ) +``` + +**Quick checklist when migrating an executor:** +- Decide upfront: is this a one-shot message reply, or a tracked task? +- If task-based, always enqueue the `Task` object as the very first event. +- Never mix `Message` events with `TaskStatusUpdateEvent` / `TaskArtifactUpdateEvent` in the same stream. +- Send only one `Message` per stream when using the message-only pattern. + +> **Example**: [`helloworld/agent_executor.py` in PR #474](https://github.com/a2aproject/a2a-samples/pull/474/files#diff-950e8baafcf17d50db5c10b525949407e129995df5295161fbf688e6374ad284) + +--- + +## 5. Server: Application Setup + +The application wrapper classes (`A2AStarletteApplication`, `A2AFastApiApplication`, and `A2ARESTFastApiApplication`) have been removed. The server setup now uses Starlette route factory functions directly, giving you better control over routing, middleware, authentication, logging, and other aspects of the server. **Before (v0.3):** ```python @@ -366,9 +469,9 @@ uvicorn.run(app, host=host, port=port) --- -## 5. Supporting v0.3 Clients +## 6. Supporting v0.3 Clients -If you cannot update all clients at once, you can run a v1.0 server that simultaneously accepts v0.3 connections. Two changes are needed. +If you cannot update all clients at once, you can run a v1.0 server that also accepts v0.3 connections. Two changes are needed. **1. Add the v0.3 AgentInterface to `supported_interfaces` in your `AgentCard`**: @@ -389,7 +492,7 @@ create_rest_routes(request_handler, enable_v0_3_compat=True) --- -## 6. Client: Creating a Client +## 7. Client: Creating a Client In `v1.0`, use the `a2a.client.create_client()` helper function to create a `Client` for the agent. @@ -423,7 +526,7 @@ client = await create_client(agent_card) --- -## 7. Client: Send Message +## 8. Client: Send Message The `BaseClient.send_message()` return type is standardized from `AsyncIterator[ClientEvent | Message]` to `AsyncIterator[StreamResponse]`. @@ -457,7 +560,7 @@ async for chunk in client.send_message(request): --- -## 8. Client: Push Notifications Config +## 9. Client: Push Notifications Config `ClientConfig.push_notification_config` is now **singular** (a single `TaskPushNotificationConfig` or `None`), not a list. @@ -478,7 +581,7 @@ config = ClientConfig( --- -## 9. Helper Utilities +## 10. Helper Utilities To improve the developer experience, we have consolidated helper functions into a single import. In v0.3, these helper functions were scattered across different modules. In v1.0, they are all available under `a2a.helpers`. @@ -525,19 +628,20 @@ print(text) --- -## 10. Summary of Key Changes in v1.0 +## 11. Summary of Key Changes in v1.0 - **Migration to Protobuf** — Core types have migrated from Pydantic models to Protobuf-based classes. Protobuf objects do not support arbitrary attribute assignment. Use `MessageToDict` from `google.protobuf.json_format` to convert objects to dictionaries, and `HasField('field_name')` to check for optional fields. - **Standardization to `SCREAMING_SNAKE_CASE`** — All enum values have been renamed from `snake_case` strings to `SCREAMING_SNAKE_CASE` for compliance with the ProtoJSON specification. - **`AgentCard`** — Significantly restructured to support multiple transport interfaces. - **`AgentInterface`** — The top-level `url` field is replaced by `supported_interfaces`, a list of `AgentInterface` objects. Each entry describes a single transport endpoint with fields for `protocol_binding`, `protocol_version`, and `url`. - **Input and output modes** — `AgentCapabilities.input_modes` and `AgentCapabilities.output_modes` are removed and now live directly on `AgentCard` as `default_input_modes` and `default_output_modes`. Individual skills can override these with their own `input_modes` and `output_modes`. -- **Application setup** — The wrapper classes (`A2AStarletteApplication`, `A2AFastApiApplication` and `A2ARESTFastApiApplication`) have been removed. Server setup now uses route factory functions — `create_jsonrpc_routes()`, `create_rest_routes()`, and `create_agent_card_routes()` — composed directly into a Starlette or FastAPI app. +- **AgentExecutor streaming rules** — The server now strictly enforces the A2A spec: an executor must enqueue either a single `Message` or a `Task` followed by update events (with the `Task` first). Mixing modes, emitting multiple `Message`s, or sending updates before the initial `Task` raises `InvalidAgentResponseError`. +- **Application setup** — The wrapper classes (`A2AStarletteApplication`, `A2AFastApiApplication`, and `A2ARESTFastApiApplication`) have been removed. Server setup now uses route factory functions — `create_jsonrpc_routes()`, `create_rest_routes()`, and `create_agent_card_routes()` — composed directly into a Starlette or FastAPI app. - **Helper utilities** — A new `a2a.helpers` module consolidates all helper functions under a single import, replacing the scattered `a2a.utils.*` modules and adding new helpers for constructing and reading v1.0 proto types. --- -## 11. Get Started +## 12. Get Started The fastest way to see v1.0 in action is to run the samples: From cfdbe4c08c58b773a8766c17f5b5eabbe67bf3dd Mon Sep 17 00:00:00 2001 From: Martim Santos Date: Thu, 23 Apr 2026 12:51:13 +0100 Subject: [PATCH 64/67] feat(helpers): add non-text Part, Message, and Artifact helpers (#1004) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Description - [x] Follow the [`CONTRIBUTING` Guide](https://github.com/a2aproject/a2a-python/blob/main/CONTRIBUTING.md). - [x] Make your Pull Request title in the specification. - [x] Ensure the tests and linter pass (Run `bash scripts/format.sh` from the repository root to format) - [ ] Appropriate docs were updated (if necessary) --- ### Summary `proto_helpers.py` provided `new_text_message` and `new_text_artifact` for the text Part variant, but nothing for the three remaining Part types (`data`, `raw`, `url`). This PR completes the set. **The `data` case is especially awkward without a helper.** `Part.data` is `google.protobuf.Value` in the v1.0 spec, which requires a non-obvious `ParseDict` dance to construct from a plain Python value: ```python # Without helper from google.protobuf.json_format import ParseDict from google.protobuf import struct_pb2 part = Part(data=ParseDict({"answer": "hello"}, struct_pb2.Value())) # With helper part = new_data_part({"answer": "hello"}) ``` ### New helpers **Part primitives** (building blocks, mirror the existing implicit `Part(text=...)` pattern): | Helper | Part field | Accepts | |---|---|---| | `new_data_part(data)` | `data` (`google.protobuf.Value`) | Any JSON-serializable value (dict, list, str, …) | | `new_raw_part(raw, media_type, filename)` | `raw` (`bytes`) | Raw bytes with optional MIME type and filename | | `new_url_part(url, media_type, filename)` | `url` (`str`) | URL with optional MIME type and filename | **Message helpers** (mirror `new_text_message`): | Helper | Wraps | |---|---| | `new_data_message(data, role, context_id, task_id)` | `new_data_part` | | `new_raw_message(raw, media_type, filename, role, context_id, task_id)` | `new_raw_part` | | `new_url_message(url, media_type, filename, role, context_id, task_id)` | `new_url_part` | **Artifact helpers** (mirror `new_text_artifact`): | Helper | Wraps | |---|---| | `new_data_artifact(name, data, description, artifact_id)` | `new_data_part` | | `new_raw_artifact(name, raw, media_type, filename, description, artifact_id)` | `new_raw_part` | | `new_url_artifact(name, url, media_type, filename, description, artifact_id)` | `new_url_part` | ### Changes - `src/a2a/helpers/proto_helpers.py` — 9 new helper functions - `tests/helpers/test_proto_helpers.py` — tests for all new helpers (35 total, all passing) ### Reviewer feedback addressed - `new_data_part` type hint broadened from `dict[str, Any]` to `Any`, since `google.protobuf.Value` accepts any JSON-serializable value, not just dicts. Added a list-value test to cover this. Release-as: 1.0.2 --------- Co-authored-by: Sampath Kumar --- src/a2a/helpers/proto_helpers.py | 265 +++++++++++++++++++++++++++- tests/helpers/test_proto_helpers.py | 241 +++++++++++++++++++++++-- 2 files changed, 488 insertions(+), 18 deletions(-) diff --git a/src/a2a/helpers/proto_helpers.py b/src/a2a/helpers/proto_helpers.py index 79e1f739d..6cc6350b6 100644 --- a/src/a2a/helpers/proto_helpers.py +++ b/src/a2a/helpers/proto_helpers.py @@ -3,6 +3,10 @@ import uuid from collections.abc import Sequence +from typing import Any + +from google.protobuf import struct_pb2 +from google.protobuf.json_format import ParseDict from a2a.types.a2a_pb2 import ( Artifact, @@ -23,9 +27,9 @@ def new_message( parts: list[Part], - role: Role = Role.ROLE_AGENT, context_id: str | None = None, task_id: str | None = None, + role: Role = Role.ROLE_AGENT, ) -> Message: """Creates a new message containing a list of Parts.""" return Message( @@ -39,16 +43,17 @@ def new_message( def new_text_message( text: str, + media_type: str | None = None, context_id: str | None = None, task_id: str | None = None, role: Role = Role.ROLE_AGENT, ) -> Message: """Creates a new message containing a single text Part.""" return new_message( - parts=[Part(text=text)], - role=role, - task_id=task_id, + parts=[new_text_part(text, media_type=media_type)], context_id=context_id, + task_id=task_id, + role=role, ) @@ -57,6 +62,91 @@ def get_message_text(message: Message, delimiter: str = '\n') -> str: return delimiter.join(get_text_parts(message.parts)) +def new_data_message( + data: Any, + media_type: str | None = None, + context_id: str | None = None, + task_id: str | None = None, + role: Role = Role.ROLE_AGENT, +) -> Message: + """Creates a new message containing a single data Part. + + Args: + data: JSON-serializable data to embed (dict, list, str, etc.). + media_type: Optional MIME type of the part content (e.g., "text/plain", "application/json", "image/png"). + context_id: Optional context ID. + task_id: Optional task ID. + role: The role of the message sender (default: ROLE_AGENT). + + Returns: + A Message with a single data Part. + """ + return new_message( + parts=[new_data_part(data, media_type=media_type)], + context_id=context_id, + task_id=task_id, + role=role, + ) + + +def new_raw_message( # noqa: PLR0913 + raw: bytes, + media_type: str | None = None, + filename: str | None = None, + context_id: str | None = None, + task_id: str | None = None, + role: Role = Role.ROLE_AGENT, +) -> Message: + """Creates a new message containing a single raw bytes Part. + + Args: + raw: The raw bytes content. + media_type: Optional MIME type (e.g. 'image/png'). + filename: Optional filename. + context_id: Optional context ID. + task_id: Optional task ID. + role: The role of the message sender (default: ROLE_AGENT). + + Returns: + A Message with a single raw Part. + """ + return new_message( + parts=[new_raw_part(raw, media_type=media_type, filename=filename)], + context_id=context_id, + task_id=task_id, + role=role, + ) + + +def new_url_message( # noqa: PLR0913 + url: str, + media_type: str | None = None, + filename: str | None = None, + context_id: str | None = None, + task_id: str | None = None, + role: Role = Role.ROLE_AGENT, +) -> Message: + """Creates a new message containing a single URL Part. + + Args: + url: The URL pointing to the file content. + media_type: Optional MIME type (e.g. 'image/png'). + filename: Optional filename. + context_id: Optional context ID. + task_id: Optional task ID. + role: The role of the message sender (default: ROLE_AGENT). + + Returns: + A Message with a single URL Part. + """ + return new_message( + parts=[new_url_part(url, media_type=media_type, filename=filename)], + context_id=context_id, + task_id=task_id, + role=role, + ) + + # --- Artifact Helpers --- @@ -78,12 +168,98 @@ def new_artifact( def new_text_artifact( name: str, text: str, + media_type: str | None = None, description: str | None = None, artifact_id: str | None = None, ) -> Artifact: """Creates a new Artifact object containing only a single text Part.""" return new_artifact( - [Part(text=text)], + [new_text_part(text, media_type=media_type)], + name, + description, + artifact_id=artifact_id, + ) + + +def new_data_artifact( + name: str, + data: Any, + media_type: str | None = None, + description: str | None = None, + artifact_id: str | None = None, +) -> Artifact: + """Creates a new Artifact object containing only a single data Part. + + Args: + name: The name of the artifact. + data: JSON-serializable data to embed (dict, list, str, etc.). + media_type: Optional MIME type of the part content (e.g., "text/plain", "application/json", "image/png"). + description: Optional description. + artifact_id: Optional artifact ID (auto-generated if not provided). + + Returns: + An Artifact with a single data Part. + """ + return new_artifact( + [new_data_part(data, media_type=media_type)], + name, + description, + artifact_id=artifact_id, + ) + + +def new_raw_artifact( # noqa: PLR0913 + name: str, + raw: bytes, + media_type: str | None = None, + filename: str | None = None, + description: str | None = None, + artifact_id: str | None = None, +) -> Artifact: + """Creates a new Artifact object containing only a single raw bytes Part. + + Args: + name: The name of the artifact. + raw: The raw bytes content. + media_type: Optional MIME type (e.g. 'image/png'). + filename: Optional filename. + description: Optional description. + artifact_id: Optional artifact ID (auto-generated if not provided). + + Returns: + An Artifact with a single raw Part. + """ + return new_artifact( + [new_raw_part(raw, media_type=media_type, filename=filename)], + name, + description, + artifact_id=artifact_id, + ) + + +def new_url_artifact( # noqa: PLR0913 + name: str, + url: str, + media_type: str | None = None, + filename: str | None = None, + description: str | None = None, + artifact_id: str | None = None, +) -> Artifact: + """Creates a new Artifact object containing only a single URL Part. + + Args: + name: The name of the artifact. + url: The URL pointing to the file content. + media_type: Optional MIME type (e.g. 'image/png'). + filename: Optional filename. + description: Optional description. + artifact_id: Optional artifact ID (auto-generated if not provided). + + Returns: + An Artifact with a single URL Part. + """ + return new_artifact( + [new_url_part(url, media_type=media_type, filename=filename)], name, description, artifact_id=artifact_id, @@ -141,6 +317,85 @@ def new_task( # --- Part Helpers --- +def new_text_part( + text: str, + media_type: str | None = None, +) -> Part: + """Creates a Part with text content. + + Args: + text: The text content. + media_type: Optional MIME type (e.g. 'text/plain', 'text/markdown'). + + Returns: + A Part with the text field set. + """ + return Part(text=text, media_type=media_type or '') + + +def new_data_part( + data: Any, + media_type: str | None = None, +) -> Part: + """Creates a Part with structured data (google.protobuf.Value). + + Args: + data: JSON-serializable data to embed (dict, list, str, etc.). + media_type: Optional MIME type of the part content (e.g., "text/plain", "application/json", "image/png"). + + Returns: + A Part with the data field set. + """ + return Part( + data=ParseDict(data, struct_pb2.Value()), + media_type=media_type or '', + ) + + +def new_raw_part( + raw: bytes, + media_type: str | None = None, + filename: str | None = None, +) -> Part: + """Creates a Part with raw bytes content. + + Args: + raw: The raw bytes content. + media_type: Optional MIME type (e.g. 'image/png'). + filename: Optional filename. + + Returns: + A Part with the raw field set. + """ + return Part( + raw=raw, + media_type=media_type or '', + filename=filename or '', + ) + + +def new_url_part( + url: str, + media_type: str | None = None, + filename: str | None = None, +) -> Part: + """Creates a Part with a URL pointing to file content. + + Args: + url: The URL to the file content. + media_type: Optional MIME type (e.g. 'image/png'). + filename: Optional filename. + + Returns: + A Part with the url field set. + """ + return Part( + url=url, + media_type=media_type or '', + filename=filename or '', + ) + + def get_text_parts(parts: Sequence[Part]) -> list[str]: """Extracts text content from all text Parts.""" return [part.text for part in parts if part.HasField('text')] diff --git a/tests/helpers/test_proto_helpers.py b/tests/helpers/test_proto_helpers.py index a4f6498ab..8fb68dbc2 100644 --- a/tests/helpers/test_proto_helpers.py +++ b/tests/helpers/test_proto_helpers.py @@ -1,37 +1,49 @@ """Tests for proto helpers.""" import pytest + from a2a.helpers.proto_helpers import ( - new_message, - new_text_message, + get_artifact_text, get_message_text, + get_stream_response_text, + get_text_parts, new_artifact, - new_text_artifact, - get_artifact_text, - new_task_from_user_message, + new_data_artifact, + new_data_message, + new_data_part, + new_message, + new_raw_artifact, + new_raw_message, + new_raw_part, new_task, - get_text_parts, - new_text_status_update_event, + new_task_from_user_message, + new_text_artifact, new_text_artifact_update_event, - get_stream_response_text, + new_text_message, + new_text_part, + new_text_status_update_event, + new_url_artifact, + new_url_message, + new_url_part, ) from a2a.types.a2a_pb2 import ( + Artifact, + Message, Part, Role, - Message, - Artifact, + StreamResponse, Task, TaskState, - StreamResponse, ) + # --- Message Helpers Tests --- def test_new_message() -> None: parts = [Part(text='hello')] msg = new_message( - parts=parts, role=Role.ROLE_USER, context_id='ctx1', task_id='task1' + parts, context_id='ctx1', task_id='task1', role=Role.ROLE_USER ) assert msg.role == Role.ROLE_USER assert msg.parts == parts @@ -42,11 +54,74 @@ def test_new_message() -> None: def test_new_text_message() -> None: msg = new_text_message( - text='hello', context_id='ctx1', task_id='task1', role=Role.ROLE_USER + 'hello', + media_type='text/plain', + context_id='ctx1', + task_id='task1', + role=Role.ROLE_USER, ) assert msg.role == Role.ROLE_USER assert len(msg.parts) == 1 assert msg.parts[0].text == 'hello' + assert msg.parts[0].media_type == 'text/plain' + assert msg.context_id == 'ctx1' + assert msg.task_id == 'task1' + assert msg.message_id != '' + + +def test_new_data_message() -> None: + msg = new_data_message( + data={'key': 'value'}, + media_type='application/json', + context_id='ctx1', + task_id='task1', + role=Role.ROLE_USER, + ) + assert msg.role == Role.ROLE_USER + assert len(msg.parts) == 1 + assert msg.parts[0].HasField('data') + assert msg.parts[0].data.struct_value.fields['key'].string_value == 'value' + assert msg.parts[0].media_type == 'application/json' + assert msg.context_id == 'ctx1' + assert msg.task_id == 'task1' + assert msg.message_id != '' + + +def test_new_raw_message() -> None: + msg = new_raw_message( + b'\x89PNG', + media_type='image/png', + filename='img.png', + context_id='ctx1', + task_id='task1', + role=Role.ROLE_USER, + ) + assert msg.role == Role.ROLE_USER + assert len(msg.parts) == 1 + assert msg.parts[0].HasField('raw') + assert msg.parts[0].raw == b'\x89PNG' + assert msg.parts[0].media_type == 'image/png' + assert msg.parts[0].filename == 'img.png' + assert msg.context_id == 'ctx1' + assert msg.task_id == 'task1' + assert msg.message_id != '' + + +def test_new_url_message() -> None: + msg = new_url_message( + 'https://example.com/file.pdf', + media_type='application/pdf', + filename='file.pdf', + context_id='ctx1', + task_id='task1', + role=Role.ROLE_USER, + ) + assert msg.role == Role.ROLE_USER + assert len(msg.parts) == 1 + assert msg.parts[0].HasField('url') + assert msg.parts[0].url == 'https://example.com/file.pdf' + assert msg.parts[0].media_type == 'application/pdf' + assert msg.parts[0].filename == 'file.pdf' assert msg.context_id == 'ctx1' assert msg.task_id == 'task1' assert msg.message_id != '' @@ -90,6 +165,74 @@ def test_new_text_artifact_with_id() -> None: assert art.artifact_id == 'art1' +def test_new_data_artifact() -> None: + art = new_data_artifact( + name='result', data={'score': 1.0}, description='desc' + ) + assert art.name == 'result' + assert art.description == 'desc' + assert len(art.parts) == 1 + assert art.parts[0].HasField('data') + assert art.parts[0].data.struct_value.fields['score'].number_value == 1.0 + assert art.artifact_id != '' + + +def test_new_data_artifact_with_id() -> None: + art = new_data_artifact(name='result', data={'x': 'y'}, artifact_id='art1') + assert art.artifact_id == 'art1' + assert art.parts[0].data.struct_value.fields['x'].string_value == 'y' + + +def test_new_raw_artifact() -> None: + art = new_raw_artifact( + name='screenshot', + raw=b'\x89PNG', + media_type='image/png', + filename='screen.png', + description='desc', + artifact_id='art1', + ) + assert art.name == 'screenshot' + assert art.description == 'desc' + assert art.artifact_id == 'art1' + assert len(art.parts) == 1 + assert art.parts[0].HasField('raw') + assert art.parts[0].raw == b'\x89PNG' + assert art.parts[0].media_type == 'image/png' + assert art.parts[0].filename == 'screen.png' + + +def test_new_raw_artifact_minimal() -> None: + art = new_raw_artifact(name='file', raw=b'data') + assert art.parts[0].raw == b'data' + assert art.artifact_id != '' + + +def test_new_url_artifact() -> None: + art = new_url_artifact( + name='report', + url='https://example.com/report.pdf', + media_type='application/pdf', + filename='report.pdf', + description='desc', + artifact_id='art1', + ) + assert art.name == 'report' + assert art.description == 'desc' + assert art.artifact_id == 'art1' + assert len(art.parts) == 1 + assert art.parts[0].HasField('url') + assert art.parts[0].url == 'https://example.com/report.pdf' + assert art.parts[0].media_type == 'application/pdf' + assert art.parts[0].filename == 'report.pdf' + + +def test_new_url_artifact_minimal() -> None: + art = new_url_artifact(name='img', url='https://example.com/img.png') + assert art.parts[0].url == 'https://example.com/img.png' + assert art.artifact_id != '' + + def test_get_artifact_text() -> None: art = Artifact(parts=[Part(text='hello'), Part(text='world')]) assert get_artifact_text(art) == 'hello\nworld' @@ -149,6 +292,78 @@ def test_get_text_parts() -> None: assert get_text_parts(parts) == ['hello', 'world'] +def test_new_text_part() -> None: + part = new_text_part('hello') + assert part.HasField('text') + assert part.text == 'hello' + assert part.media_type == '' + + +def test_new_text_part_with_media_type() -> None: + part = new_text_part('# Hello', media_type='text/markdown') + assert part.HasField('text') + assert part.text == '# Hello' + assert part.media_type == 'text/markdown' + + +def test_new_data_part_from_dict() -> None: + part = new_data_part({'key': 'value', 'count': 42}) + assert part.HasField('data') + assert part.data.struct_value.fields['key'].string_value == 'value' + assert part.data.struct_value.fields['count'].number_value == 42 + assert part.media_type == '' + + +def test_new_data_part_with_media_type() -> None: + part = new_data_part({'key': 'value'}, media_type='application/json') + assert part.HasField('data') + assert part.media_type == 'application/json' + + +def test_new_data_part_from_list() -> None: + part = new_data_part([1, 2, 3]) + assert part.HasField('data') + assert part.data.list_value.values[0].number_value == 1 + assert part.data.list_value.values[1].number_value == 2 + assert part.data.list_value.values[2].number_value == 3 + + +def test_new_raw_part() -> None: + part = new_raw_part(b'\x89PNG', media_type='image/png', filename='img.png') + assert part.HasField('raw') + assert part.raw == b'\x89PNG' + assert part.media_type == 'image/png' + assert part.filename == 'img.png' + + +def test_new_raw_part_minimal() -> None: + part = new_raw_part(b'data') + assert part.HasField('raw') + assert part.raw == b'data' + assert part.media_type == '' + assert part.filename == '' + + +def test_new_url_part() -> None: + part = new_url_part( + 'https://example.com/file.pdf', + media_type='application/pdf', + filename='file.pdf', + ) + assert part.HasField('url') + assert part.url == 'https://example.com/file.pdf' + assert part.media_type == 'application/pdf' + assert part.filename == 'file.pdf' + + +def test_new_url_part_minimal() -> None: + part = new_url_part('https://example.com/img.png') + assert part.HasField('url') + assert part.url == 'https://example.com/img.png' + assert part.media_type == '' + assert part.filename == '' + + # --- Event & Stream Helpers Tests --- From a470bae7f1892054476703818a4a3be9b124cb84 Mon Sep 17 00:00:00 2001 From: kdziedzic70 Date: Thu, 23 Apr 2026 14:02:26 +0200 Subject: [PATCH 65/67] test: test push notifications in itk (#1009) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Description PR extends itk cases with push_notifications compatibility tests Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [x] Follow the [`CONTRIBUTING` Guide](https://github.com/a2aproject/a2a-python/blob/main/CONTRIBUTING.md). - [x] Make your Pull Request title in the specification. - Important Prefixes for [release-please](https://github.com/googleapis/release-please): - `fix:` which represents bug fixes, and correlates to a [SemVer](https://semver.org/) patch. - `feat:` represents a new feature, and correlates to a SemVer minor. - `feat!:`, or `fix!:`, `refactor!:`, etc., which represent a breaking change (indicated by the `!`) and will result in a SemVer major. - [x] Ensure the tests and linter pass (Run `bash scripts/format.sh` from the repository root to format) - [x] Appropriate docs were updated (if necessary) Fixes # 🦕 Co-authored-by: Krzysztof Dziedzic --- .github/workflows/itk.yaml | 2 +- itk/README.md | 2 +- itk/main.py | 127 +++++++++++++++++++++++-------------- itk/run_itk.sh | 28 ++++++-- 4 files changed, 105 insertions(+), 54 deletions(-) diff --git a/.github/workflows/itk.yaml b/.github/workflows/itk.yaml index feb9325e3..33d7585d6 100644 --- a/.github/workflows/itk.yaml +++ b/.github/workflows/itk.yaml @@ -31,4 +31,4 @@ jobs: run: bash run_itk.sh working-directory: itk env: - A2A_SAMPLES_REVISION: itk-v.016-alpha + A2A_SAMPLES_REVISION: itk-v.02-alpha diff --git a/itk/README.md b/itk/README.md index 9a82d0469..3044b37af 100644 --- a/itk/README.md +++ b/itk/README.md @@ -36,7 +36,7 @@ You must set the `A2A_SAMPLES_REVISION` environment variable to specify which re Example: ``` -export A2A_SAMPLES_REVISION=itk-v.015-alpha +export A2A_SAMPLES_REVISION=itk-v.02-alpha ``` ### 2. Execute Tests diff --git a/itk/main.py b/itk/main.py index 6792c540a..cc761d081 100644 --- a/itk/main.py +++ b/itk/main.py @@ -17,13 +17,21 @@ from a2a.compat.v0_3 import a2a_v0_3_pb2_grpc from a2a.compat.v0_3.grpc_handler import CompatGrpcHandler from a2a.server.agent_execution import AgentExecutor, RequestContext -from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes -from a2a.server.routes.rest_routes import create_rest_routes from a2a.server.events import EventQueue +from a2a.server.routes import ( + create_agent_card_routes, + create_jsonrpc_routes, + create_rest_routes, +) from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager from a2a.server.request_handlers import DefaultRequestHandler, GrpcHandler -from a2a.server.tasks import TaskUpdater +from a2a.server.tasks import ( + TaskUpdater, + BasePushNotificationSender, + InMemoryPushNotificationConfigStore, +) from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore +from a2a.server.context import ServerCallContext from a2a.types import a2a_pb2_grpc from a2a.types.a2a_pb2 import ( AgentCapabilities, @@ -35,11 +43,12 @@ Task, TaskState, TaskStatus, + TaskPushNotificationConfig, ) from a2a.utils import TransportProtocol - -log_level = os.environ.get('ITK_LOG_LEVEL', 'INFO').upper() +log_level_str = os.environ.get('ITK_LOG_LEVEL', 'INFO').upper() +log_level = getattr(logging, log_level_str, logging.INFO) logging.basicConfig(level=log_level) logger = logging.getLogger(__name__) @@ -106,7 +115,9 @@ def wrap_instruction_to_request(inst: instruction_pb2.Instruction) -> Message: ) -async def handle_call_agent(call: instruction_pb2.CallAgent) -> list[str]: +async def handle_call_agent( + call: instruction_pb2.CallAgent, +) -> list[str]: """Handles the CallAgent instruction by invoking another agent.""" logger.info('Calling agent %s via %s', call.agent_card_uri, call.transport) @@ -131,36 +142,47 @@ async def handle_call_agent(call: instruction_pb2.CallAgent) -> list[str]: selected_transport == TransportProtocol.GRPC ) + if call.HasField('push_notification'): + url = call.push_notification.url + if not url: + raise ValueError('URL not specified in push_notification behavior') + if not url.startswith(('http://', 'https://')): + url = f'http://{url}' + config.push_notification_config = TaskPushNotificationConfig( + url=f'{url}/notifications', + token='itk-token', # noqa: S106 + ) + try: - client = await create_client(call.agent_card_uri, client_config=config) + client = await create_client( + call.agent_card_uri, + client_config=config, + ) # Wrap nested instruction - async with client: - nested_msg = wrap_instruction_to_request(call.instruction) - request = SendMessageRequest(message=nested_msg) - - results: list[str] = [] - async for event in client.send_message(request): - # Event is StreamResponse - logger.info('Event: %s', event) - stream_resp = event - - message = None - if stream_resp.HasField('message'): - message = stream_resp.message - elif stream_resp.HasField( - 'task' - ) and stream_resp.task.status.HasField('message'): - message = stream_resp.task.status.message - elif stream_resp.HasField( - 'status_update' - ) and stream_resp.status_update.status.HasField('message'): - message = stream_resp.status_update.status.message - - if message: - results.extend( - part.text for part in message.parts if part.text - ) + nested_msg = wrap_instruction_to_request(call.instruction) + request = SendMessageRequest(message=nested_msg) + + results = [] + async for event in client.send_message(request): + # Event is streaming response and task + logger.info('Event: %s', event) + stream_resp = event + + message = None + if stream_resp.HasField('message'): + message = stream_resp.message + elif stream_resp.HasField( + 'task' + ) and stream_resp.task.status.HasField('message'): + message = stream_resp.task.status.message + elif stream_resp.HasField( + 'status_update' + ) and stream_resp.status_update.status.HasField('message'): + message = stream_resp.status_update.status.message + + if message: + results.extend(part.text for part in message.parts if part.text) except Exception as e: logger.exception('Failed to call outbound agent') @@ -171,7 +193,9 @@ async def handle_call_agent(call: instruction_pb2.CallAgent) -> list[str]: return results -async def handle_instruction(inst: instruction_pb2.Instruction) -> list[str]: +async def handle_instruction( + inst: instruction_pb2.Instruction, +) -> list[str]: """Recursively handles instructions.""" if inst.HasField('call_agent'): return await handle_call_agent(inst.call_agent) @@ -303,9 +327,7 @@ async def main_async(http_port: int, grpc_port: int) -> None: description='Python agent using SDK 1.0.', version='1.0.0', capabilities=AgentCapabilities( - streaming=True, - push_notifications=True, - extended_agent_card=True, + streaming=True, push_notifications=True, extended_agent_card=True ), default_input_modes=['text/plain'], default_output_modes=['text/plain'], @@ -313,23 +335,32 @@ async def main_async(http_port: int, grpc_port: int) -> None: ) task_store = InMemoryTaskStore() + push_config_store = InMemoryPushNotificationConfigStore() + push_sender = BasePushNotificationSender( + httpx_client=httpx.AsyncClient(), + config_store=push_config_store, + context=ServerCallContext(), + ) + handler = DefaultRequestHandler( agent_executor=V10AgentExecutor(), - task_store=task_store, agent_card=agent_card, + task_store=task_store, queue_manager=InMemoryQueueManager(), + push_config_store=push_config_store, + push_sender=push_sender, ) handler_extended = DefaultRequestHandler( agent_executor=V10AgentExecutor(), - task_store=task_store, agent_card=agent_card, + task_store=task_store, queue_manager=InMemoryQueueManager(), + push_config_store=push_config_store, + push_sender=push_sender, extended_agent_card=agent_card, ) - app = FastAPI() - agent_card_routes = create_agent_card_routes( agent_card=agent_card, card_url='/.well-known/agent-card.json' ) @@ -338,15 +369,16 @@ async def main_async(http_port: int, grpc_port: int) -> None: rpc_url='/', enable_v0_3_compat=True, ) - app.mount( - '/jsonrpc', - FastAPI(routes=jsonrpc_routes + agent_card_routes), - ) - rest_routes = create_rest_routes( request_handler=handler, enable_v0_3_compat=True, ) + + app = FastAPI() + app.mount( + '/jsonrpc', + FastAPI(routes=jsonrpc_routes + agent_card_routes), + ) app.mount('/rest', FastAPI(routes=rest_routes + agent_card_routes)) server = grpc.aio.server() @@ -365,9 +397,8 @@ async def main_async(http_port: int, grpc_port: int) -> None: grpc_port, ) - uvicorn_log_level = os.environ.get('ITK_LOG_LEVEL', 'INFO').lower() config = uvicorn.Config( - app, host='127.0.0.1', port=http_port, log_level=uvicorn_log_level + app, host='127.0.0.1', port=http_port, log_level=log_level_str.lower() ) uvicorn_server = uvicorn.Server(config) diff --git a/itk/run_itk.sh b/itk/run_itk.sh index 2d9371c14..21736f171 100755 --- a/itk/run_itk.sh +++ b/itk/run_itk.sh @@ -119,14 +119,16 @@ RESPONSE=$(curl -s -X POST http://127.0.0.1:8000/run \ "sdks": ["current", "python_v10", "python_v03", "go_v10", "go_v03"], "traversal": "euler", "edges": ["0->1", "0->2", "0->3", "0->4", "1->0", "2->0", "3->0", "4->0"], - "protocols": ["jsonrpc", "grpc"] + "protocols": ["jsonrpc", "grpc"], + "behavior": "send_message" }, { "name": "Star Topology (No Go v03) - HTTP_JSON", "sdks": ["current", "python_v10", "python_v03", "go_v10"], "traversal": "euler", "edges": ["0->1", "0->2", "0->3", "1->0", "2->0", "3->0"], - "protocols": ["http_json"] + "protocols": ["http_json"], + "behavior": "send_message" }, { "name": "Star Topology (Full) - JSONRPC & GRPC (Streaming)", @@ -134,7 +136,8 @@ RESPONSE=$(curl -s -X POST http://127.0.0.1:8000/run \ "traversal": "euler", "edges": ["0->1", "0->2", "0->3", "0->4", "1->0", "2->0", "3->0", "4->0"], "protocols": ["jsonrpc", "grpc"], - "streaming": true + "streaming": true, + "behavior": "send_message" }, { "name": "Star Topology (No Go v03) - HTTP_JSON (Streaming)", @@ -142,7 +145,24 @@ RESPONSE=$(curl -s -X POST http://127.0.0.1:8000/run \ "traversal": "euler", "edges": ["0->1", "0->2", "0->3", "1->0", "2->0", "3->0"], "protocols": ["http_json"], - "streaming": true + "streaming": true, + "behavior": "send_message" + }, + { + "name": "Push Notification Test - JSONRPC & GRPC", + "sdks": ["current", "python_v10", "python_v03", "go_v03"], + "traversal": "euler", + "edges": ["0->1", "0->2", "0->3", "1->0", "2->0", "3->0"], + "protocols": ["jsonrpc", "grpc"], + "behavior": "push_notification" + }, + { + "name": "Push Notification Test - HTTP_JSON", + "sdks": ["current", "python_v10", "python_v03"], + "traversal": "euler", + "edges": ["0->1", "0->2", "1->0", "2->0"], + "protocols": ["http_json"], + "behavior": "push_notification" } ] }') From c24ae055715ba69329ffa4e36489379308cd0bde Mon Sep 17 00:00:00 2001 From: Iva Sokolaj <102302011+sokoliva@users.noreply.github.com> Date: Fri, 24 Apr 2026 15:43:46 +0200 Subject: [PATCH 66/67] fix(server): deliver push notifications across all owners (#1016) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Description Fix a silent multi-tenant bug where push notifications were dropped for any deployment using authenticated users, and split the `PushNotificationConfigStore` read API into a user-callable owner-scoped read (`get_info`) and an internal cross-owner read for dispatch (`get_info_for_dispatch`). ## The bug `BasePushNotificationSender` accepted a `ServerCallContext` at construction time and called `config_store.get_info(task_id, self._call_context)` at dispatch. Because the sender is a process-wide singleton, callers passed a dummy `ServerCallContext()` (e.g. `itk/main.py`). The default `OwnerResolver`then resolved the dummy to the empty-string owner, which never matched any real registrar's partition. Result: `get_info` returned `[]` and every notification was silently dropped in any deployment with real authentication. ## The fix - **New non-abstract method** `PushNotificationConfigStore.get_info_for_dispatch(task_id)` returns every config for the task across all owners. Implemented in the in-memory and database stores. Custom 1.0 subclasses inherit a default implementation that forwards to `get_info(task_id, ServerCallContext())` preserving their 1.0 behavior exactly and emits a `DeprecationWarning` - **`BasePushNotificationSender`** no longer takes `context` in `__init__` and now calls `get_info_for_dispatch`. Identity is no longer held on the sender. - **`get_info(task_id, context)` is unchanged** and remains owner-scoped. Used by the user-callable read endpoints. The split encodes the asymmetry in the type system: the user-callable method requires a context, the dispatch-only method does not. Authorization (check if the user can create a config for a specific task) happens at registration (`set_info`), not at dispatch. Fixes #1015 🦕 --------- Co-authored-by: Copilot --- itk/main.py | 2 - .../tasks/base_push_notification_sender.py | 27 +- ...database_push_notification_config_store.py | 46 ++- ...inmemory_push_notification_config_store.py | 19 +- .../tasks/push_notification_config_store.py | 46 ++- tests/e2e/push_notifications/agent_app.py | 88 ++++- .../test_default_push_notification_support.py | 307 +++++++++++++++++- .../test_default_request_handler.py | 171 +++++++++- .../test_default_request_handler_v2.py | 118 ++++++- ...database_push_notification_config_store.py | 51 +++ .../tasks/test_inmemory_push_notifications.py | 149 ++++++++- .../tasks/test_push_notification_sender.py | 68 ++-- 12 files changed, 1013 insertions(+), 79 deletions(-) diff --git a/itk/main.py b/itk/main.py index cc761d081..76c72e1c2 100644 --- a/itk/main.py +++ b/itk/main.py @@ -31,7 +31,6 @@ InMemoryPushNotificationConfigStore, ) from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore -from a2a.server.context import ServerCallContext from a2a.types import a2a_pb2_grpc from a2a.types.a2a_pb2 import ( AgentCapabilities, @@ -339,7 +338,6 @@ async def main_async(http_port: int, grpc_port: int) -> None: push_sender = BasePushNotificationSender( httpx_client=httpx.AsyncClient(), config_store=push_config_store, - context=ServerCallContext(), ) handler = DefaultRequestHandler( diff --git a/src/a2a/server/tasks/base_push_notification_sender.py b/src/a2a/server/tasks/base_push_notification_sender.py index 4a4929e8f..ff9ca3ce5 100644 --- a/src/a2a/server/tasks/base_push_notification_sender.py +++ b/src/a2a/server/tasks/base_push_notification_sender.py @@ -27,26 +27,39 @@ def __init__( self, httpx_client: httpx.AsyncClient, config_store: PushNotificationConfigStore, - context: ServerCallContext, + context: ServerCallContext | None = None, ) -> None: """Initializes the BasePushNotificationSender. Args: httpx_client: An async HTTP client instance to send notifications. - config_store: A PushNotificationConfigStore instance to retrieve configurations. - context: The `ServerCallContext` that this push notification is produced under. + config_store: A PushNotificationConfigStore instance to + retrieve configurations. + context: Deprecated and ignored. Accepted only for + backward compatibility with 1.0 callers that constructed + the sender with a (typically dummy) ServerCallContext. + Pass None (the default) in new code. A non-None + value logs a deprecation warning and is otherwise + ignored. """ + if context is not None: + logger.warning( + 'BasePushNotificationSender no longer uses the context ' + 'parameter; it is accepted only for backward compatibility ' + 'with 1.0 and will be removed in a future major version. ' + 'Push notifications now fan out across all owners via ' + 'PushNotificationConfigStore.get_info_for_dispatch; the ' + 'caller identity is not carried into dispatch. Drop the ' + 'context argument from the constructor call.' + ) self._client = httpx_client self._config_store = config_store - self._call_context: ServerCallContext = context async def send_notification( self, task_id: str, event: PushNotificationEvent ) -> None: """Sends a push notification for an event if configuration exists.""" - push_configs = await self._config_store.get_info( - task_id, self._call_context - ) + push_configs = await self._config_store.get_info_for_dispatch(task_id) if not push_configs: return diff --git a/src/a2a/server/tasks/database_push_notification_config_store.py b/src/a2a/server/tasks/database_push_notification_config_store.py index 31cd676c8..d050de7cc 100644 --- a/src/a2a/server/tasks/database_push_notification_config_store.py +++ b/src/a2a/server/tasks/database_push_notification_config_store.py @@ -7,7 +7,7 @@ try: - from sqlalchemy import Table, and_, delete, select + from sqlalchemy import ColumnElement, Table, and_, delete, select from sqlalchemy.ext.asyncio import ( AsyncEngine, AsyncSession, @@ -304,21 +304,14 @@ async def set_info( owner, ) - async def get_info( + async def _select_configs( self, - task_id: str, - context: ServerCallContext, + *predicates: 'ColumnElement[bool]', ) -> list[TaskPushNotificationConfig]: - """Retrieves all push notification configurations for a task, for the given owner.""" + """Loads configs matching the given predicates and decodes them.""" await self._ensure_initialized() - owner = self.owner_resolver(context) async with self.async_session_maker() as session: - stmt = select(self.config_model).where( - and_( - self.config_model.task_id == task_id, - self.config_model.owner == owner, - ) - ) + stmt = select(self.config_model).where(and_(*predicates)) result = await session.execute(stmt) models = result.scalars().all() @@ -331,10 +324,37 @@ async def get_info( 'Could not deserialize push notification config for task %s, config %s, owner %s', model.task_id, model.config_id, - owner, + model.owner, ) return configs + async def get_info( + self, + task_id: str, + context: ServerCallContext, + ) -> list[TaskPushNotificationConfig]: + """Retrieves all push notification configurations for a task, for the given owner. + + Used by the user-callable read endpoints. + """ + owner = self.owner_resolver(context) + return await self._select_configs( + self.config_model.task_id == task_id, + self.config_model.owner == owner, + ) + + async def get_info_for_dispatch( + self, + task_id: str, + ) -> list[TaskPushNotificationConfig]: + """Retrieves all push notification configurations for a task, across all owners. + + Used by the push-notification dispatch path. + """ + return await self._select_configs( + self.config_model.task_id == task_id, + ) + async def delete_info( self, task_id: str, diff --git a/src/a2a/server/tasks/inmemory_push_notification_config_store.py b/src/a2a/server/tasks/inmemory_push_notification_config_store.py index d5b0a5b1f..19e35074a 100644 --- a/src/a2a/server/tasks/inmemory_push_notification_config_store.py +++ b/src/a2a/server/tasks/inmemory_push_notification_config_store.py @@ -72,12 +72,29 @@ async def get_info( task_id: str, context: ServerCallContext, ) -> list[TaskPushNotificationConfig]: - """Retrieves all push notification configurations for a task from memory, for the given owner.""" + """Retrieves all push notification configurations for a task from memory, for the given owner. + + Used by the user-callable read endpoints. + """ owner = self.owner_resolver(context) async with self.lock: owner_infos = self._get_owner_push_notification_infos(owner) return list(owner_infos.get(task_id, [])) + async def get_info_for_dispatch( + self, + task_id: str, + ) -> list[TaskPushNotificationConfig]: + """Retrieves all push notification configurations for a task across all owners. + + Used by the push-notification dispatch path. + """ + async with self.lock: + results: list[TaskPushNotificationConfig] = [] + for all_configs in self._push_notification_infos.values(): + results.extend(all_configs.get(task_id, [])) + return results + async def delete_info( self, task_id: str, diff --git a/src/a2a/server/tasks/push_notification_config_store.py b/src/a2a/server/tasks/push_notification_config_store.py index 6b5b35245..e1e65c3fb 100644 --- a/src/a2a/server/tasks/push_notification_config_store.py +++ b/src/a2a/server/tasks/push_notification_config_store.py @@ -1,9 +1,14 @@ +import logging + from abc import ABC, abstractmethod from a2a.server.context import ServerCallContext from a2a.types.a2a_pb2 import TaskPushNotificationConfig +logger = logging.getLogger(__name__) + + class PushNotificationConfigStore(ABC): """Interface for storing and retrieving push notification configurations for tasks.""" @@ -22,7 +27,46 @@ async def get_info( task_id: str, context: ServerCallContext, ) -> list[TaskPushNotificationConfig]: - """Retrieves the push notification configuration for a task.""" + """Retrieves push notification configurations for a task, scoped to the caller. + + This is the user-callable read path. Implementations MUST return + only configurations owned by the caller (as resolved from + context). + """ + + async def get_info_for_dispatch( + self, + task_id: str, + ) -> list[TaskPushNotificationConfig]: + """Retrieves all push notification configurations for a task, across all owners. + + This is the internal read path used by the push-notification + dispatch loop. Implementations SHOULD override this method to + return every configuration registered for task_id regardless of + which user registered it. Authorization already happened at + registration time and the dispatch path fires every registered + webhook for the task. + + The default implementation falls back to calling get_info with + a synthetic empty ServerCallContext. This preserves 1.0 + behavior for subclasses that have not implemented the override + but is INCORRECT for any deployment with multiple owners: the + empty context resolves to the empty-string owner partition and + returns no configs (silently dropping every notification). A + warning is logged on every call to flag the misconfiguration. + Custom subclasses MUST override this method to deliver + notifications correctly in multi-owner deployments. + """ + logger.warning( + '%s does not override ' + 'PushNotificationConfigStore.get_info_for_dispatch; falling back ' + 'to a context-less get_info call which silently drops ' + 'notifications in any deployment with multiple owners. Override ' + 'get_info_for_dispatch to return all configs for task_id across ' + 'every owner.', + type(self).__name__, + ) + return await self.get_info(task_id, ServerCallContext()) @abstractmethod async def delete_info( diff --git a/tests/e2e/push_notifications/agent_app.py b/tests/e2e/push_notifications/agent_app.py index bc95f6c37..9bb3a02fa 100644 --- a/tests/e2e/push_notifications/agent_app.py +++ b/tests/e2e/push_notifications/agent_app.py @@ -1,14 +1,17 @@ import httpx from fastapi import FastAPI +from starlette.applications import Starlette +from starlette.requests import Request +from a2a.auth.user import UnauthenticatedUser, User from a2a.server.agent_execution import AgentExecutor, RequestContext from a2a.server.context import ServerCallContext from a2a.server.events import EventQueue -from starlette.applications import Starlette -from a2a.server.routes.rest_routes import create_rest_routes -from a2a.server.routes import create_agent_card_routes from a2a.server.request_handlers import DefaultRequestHandler +from a2a.server.routes import create_agent_card_routes +from a2a.server.routes.common import DefaultServerCallContextBuilder +from a2a.server.routes.rest_routes import create_rest_routes from a2a.server.tasks import ( BasePushNotificationSender, InMemoryPushNotificationConfigStore, @@ -30,6 +33,9 @@ ) +_TEST_USER_HEADER = 'x-test-user' + + def test_agent_card(url: str) -> AgentCard: """Returns an agent card for the test agent.""" return AgentCard( @@ -151,7 +157,6 @@ def create_agent_app( push_sender=BasePushNotificationSender( httpx_client=notification_client, config_store=push_config_store, - context=ServerCallContext(), ), ) rest_routes = create_rest_routes(request_handler=handler) @@ -159,3 +164,78 @@ def create_agent_app( agent_card=card, card_url='/.well-known/agent-card.json' ) return Starlette(routes=[*rest_routes, *agent_card_routes]) + + +class _NamedTestUser(User): + """Authenticated test user identified by ``user_name``.""" + + def __init__(self, user_name: str) -> None: + self._user_name = user_name + + @property + def is_authenticated(self) -> bool: + return True + + @property + def user_name(self) -> str: + return self._user_name + + +class _HeaderUserContextBuilder(DefaultServerCallContextBuilder): + """Builds a ServerCallContext whose user is read from a request header.""" + + def build_user(self, request: Request) -> User: + user_name = request.headers.get(_TEST_USER_HEADER) + if user_name: + return _NamedTestUser(user_name) + return UnauthenticatedUser() + + +def create_multi_user_agent_app( + url: str, notification_client: httpx.AsyncClient +) -> Starlette: + """Creates a multi-user variant of the test agent app. + + Differences from create_agent_app: + + - Identity is read from the x-test-user header on each request + via _HeaderUserContextBuilder. Multiple authenticated + users (e.g. alice, bob) can therefore call the same + server. + - The InMemoryTaskStore uses a constant owner resolver, so + every authenticated user has access to every task. + - The InMemoryPushNotificationConfigStore keeps the default + per-user owner resolver, so each registrar's configs live in their + own owner partition; this exercises cross-owner aggregation in + get_info_for_dispatch. + """ + # Shared task visibility: any authenticated user can see any task. + task_store = InMemoryTaskStore(owner_resolver=lambda _ctx: 'shared') + + # Per-user push-config partitioning (the default). + push_config_store = InMemoryPushNotificationConfigStore() + + card = test_agent_card(url) + extended_card = test_agent_card(url) + extended_card.name = 'Test Agent Extended' + + handler = DefaultRequestHandler( + agent_executor=TestAgentExecutor(), + task_store=task_store, + agent_card=card, + extended_agent_card=extended_card, + push_config_store=push_config_store, + push_sender=BasePushNotificationSender( + httpx_client=notification_client, + config_store=push_config_store, + ), + ) + + rest_routes = create_rest_routes( + request_handler=handler, + context_builder=_HeaderUserContextBuilder(), + ) + agent_card_routes = create_agent_card_routes( + agent_card=card, card_url='/.well-known/agent-card.json' + ) + return Starlette(routes=[*rest_routes, *agent_card_routes]) diff --git a/tests/e2e/push_notifications/test_default_push_notification_support.py b/tests/e2e/push_notifications/test_default_push_notification_support.py index 35e4bbeb4..84fd14c9a 100644 --- a/tests/e2e/push_notifications/test_default_push_notification_support.py +++ b/tests/e2e/push_notifications/test_default_push_notification_support.py @@ -6,7 +6,7 @@ import pytest import pytest_asyncio -from .agent_app import create_agent_app +from .agent_app import create_agent_app, create_multi_user_agent_app from .notifications_app import Notification, create_notifications_app from .utils import ( create_app_process, @@ -21,9 +21,9 @@ ) from a2a.utils.constants import TransportProtocol from a2a.types.a2a_pb2 import ( + ListTaskPushNotificationConfigsRequest, Message, Part, - TaskPushNotificationConfig, Role, SendMessageConfiguration, SendMessageRequest, @@ -33,6 +33,9 @@ ) +_TEST_USER_HEADER = 'x-test-user' + + @pytest.fixture(scope='module') def notifications_server(): """ @@ -88,6 +91,40 @@ def agent_server(notifications_client: httpx.AsyncClient): process.join() +@pytest.fixture(scope='module') +def multi_user_agent_server(notifications_client: httpx.AsyncClient): + """Starts the multi-user variant of the test agent server. + + This variant reads identity from an x-test-user request header + and uses a TaskStore whose owner resolver returns a constant, so + every authenticated user can see every task. It runs on its own + port alongside the single-user agent_server fixture; the + notifications_server is shared (notifications include the + task_id and per-config token, so collisions are avoided). + """ + host = '127.0.0.1' + port = find_free_port() + url = f'http://{host}:{port}' + + process = create_app_process( + create_multi_user_agent_app(url, notifications_client), host, port + ) + process.start() + try: + wait_for_server_ready( + f'{url}/extendedAgentCard', + headers={'A2A-Version': '1.0', _TEST_USER_HEADER: 'health-check'}, + ) + except TimeoutError as e: + process.terminate() + raise e + + yield url + + process.terminate() + process.join() + + @pytest_asyncio.fixture(scope='function') async def http_client(): """An async client fixture for test functions.""" @@ -238,6 +275,272 @@ async def test_notification_triggering_after_config_change_e2e( assert notifications[0].token == token +@pytest.mark.asyncio +async def test_multi_registrar_fan_out_e2e( + notifications_server: str, + agent_server: str, + http_client: httpx.AsyncClient, +): + """Two pushNotificationConfigs registered for the same task both fire end-to-end. + + Exercises the dispatch fan-out across multiple registered configs + over the real wire: each registered URL must receive a POST with + its own token in the X-A2A-Notification-Token header. + """ + # Configure an A2A client without a per-message push notification config + # (we'll register configs explicitly after the task is created). + a2a_client = ClientFactory( + ClientConfig( + supported_protocol_bindings=[TransportProtocol.HTTP_JSON], + ) + ).create(minimal_agent_card(agent_server, [TransportProtocol.HTTP_JSON])) + + # Send an initial message that requires more input, so the task lingers + # long enough for us to register multiple push configs against it. + responses = [ + response + async for response in a2a_client.send_message( + SendMessageRequest( + message=Message( + message_id='multi-fanout-init', + parts=[Part(text='How are you?')], + role=Role.ROLE_USER, + ), + configuration=SendMessageConfiguration(), + ) + ) + ] + assert len(responses) == 1 + stream_response = responses[0] + assert stream_response.HasField('task') + task = stream_response.task + assert task.status.state == TaskState.TASK_STATE_INPUT_REQUIRED + + # Register two distinct push configs for the same task. Both share the + # same registrar (this client), but use different config ids, URLs, and + # tokens. Both must fire when the next event is dispatched. + token_a = uuid.uuid4().hex + token_b = uuid.uuid4().hex + await a2a_client.create_task_push_notification_config( + TaskPushNotificationConfig( + task_id=task.id, + id='registrar-a', + url=f'{notifications_server}/notifications', + token=token_a, + ) + ) + await a2a_client.create_task_push_notification_config( + TaskPushNotificationConfig( + task_id=task.id, + id='registrar-b', + url=f'{notifications_server}/notifications', + token=token_b, + ) + ) + + # Sanity: no notifications have fired yet. + response = await http_client.get( + f'{notifications_server}/{task.id}/notifications' + ) + assert response.status_code == 200 + assert len(response.json().get('notifications', [])) == 0 + + # Send a follow-up message that completes the task and triggers + # dispatch. Both registered configs must receive a POST. + responses = [ + response + async for response in a2a_client.send_message( + SendMessageRequest( + message=Message( + task_id=task.id, + message_id='multi-fanout-complete', + parts=[Part(text='Good')], + role=Role.ROLE_USER, + ), + configuration=SendMessageConfiguration(), + ) + ) + ] + assert len(responses) == 1 + + # Expect 2 notifications: one COMPLETED event, fanned out to 2 configs. + notifications = await wait_for_n_notifications( + http_client, + f'{notifications_server}/{task.id}/notifications', + n=2, + ) + + # Both tokens must appear exactly once. + received_tokens = sorted(n.token for n in notifications) + assert received_tokens == sorted([token_a, token_b]) + + # Both notifications must carry the same COMPLETED event payload. + for notification in notifications: + state = ( + notification.event.get('status_update', {}) + .get('status', {}) + .get('state') + ) + assert state == 'TASK_STATE_COMPLETED' + + +def _make_user_a2a_client(agent_server: str, user_name: str): + """Builds an A2A client that identifies as user_name on every request. + + Identity is conveyed via a default header on the underlying + httpx.AsyncClient; the multi-user agent app's context builder + reads that header to populate ServerCallContext.user. + """ + httpx_client = httpx.AsyncClient(headers={_TEST_USER_HEADER: user_name}) + return ClientFactory( + ClientConfig( + httpx_client=httpx_client, + supported_protocol_bindings=[TransportProtocol.HTTP_JSON], + ) + ).create( + minimal_agent_card(agent_server, [TransportProtocol.HTTP_JSON]) + ), httpx_client + + +@pytest.mark.asyncio +async def test_alice_and_bob_both_receive_notifications_on_shared_task_e2e( + notifications_server: str, + multi_user_agent_server: str, + http_client: httpx.AsyncClient, +): + """Alice registers a webhook; Bob registers a webhook; both fire end-to-end. + + 1. Alice creates a task (it lingers in INPUT_REQUIRED). + 2. Alice registers her own push config on the task. + 3. Bob (a different authenticated user, sharing access to the task) + registers his own push config on the same task. + 4. Bob (the dispatcher, *not* the registrar of Alice's webhook) + sends a follow-up message that completes the task. + 5. Both Alice's webhook and Bob's webhook receive a POST with their + own respective tokens. + + Regression guard for the design's central guarantee: subscriptions + fire on the registrar's behalf regardless of which user's action + triggered the event. A regression that re-introduced + dispatcher-context filtering on the dispatch path would drop one of + the two notifications. + """ + alice_client, alice_http = _make_user_a2a_client( + multi_user_agent_server, 'alice' + ) + bob_client, bob_http = _make_user_a2a_client(multi_user_agent_server, 'bob') + + try: + responses = [ + response + async for response in alice_client.send_message( + SendMessageRequest( + message=Message( + message_id='shared-task-init', + parts=[Part(text='How are you?')], + role=Role.ROLE_USER, + ), + ) + ) + ] + assert len(responses) == 1 + assert responses[0].HasField('task') + task = responses[0].task + assert task.status.state == TaskState.TASK_STATE_INPUT_REQUIRED + + # 2. Alice registers her push config. + alice_token = uuid.uuid4().hex + await alice_client.create_task_push_notification_config( + TaskPushNotificationConfig( + task_id=task.id, + id='alice-cfg', + url=f'{notifications_server}/notifications', + token=alice_token, + ) + ) + + # 3. Bob registers his push config on the same task. + bob_token = uuid.uuid4().hex + await bob_client.create_task_push_notification_config( + TaskPushNotificationConfig( + task_id=task.id, + id='bob-cfg', + url=f'{notifications_server}/notifications', + token=bob_token, + ) + ) + + # Sanity: the per-user listing endpoints are owner-scoped -- + # Alice does not see Bob's config and vice-versa, even though + # both can see the underlying task. + # + # The auto-registered empty config (see step 1 quirk note) lives + # in Alice's partition under ``id == task_id``, so Alice's + # listing contains ``{'alice-cfg', task.id}``; the key invariant + # is that neither listing contains the other user's id or + # token. + alice_configs = await alice_client.list_task_push_notification_configs( + ListTaskPushNotificationConfigsRequest(task_id=task.id) + ) + alice_ids = {c.id for c in alice_configs.configs} + assert 'alice-cfg' in alice_ids + assert 'bob-cfg' not in alice_ids + assert all(c.token != bob_token for c in alice_configs.configs) + + bob_configs = await bob_client.list_task_push_notification_configs( + ListTaskPushNotificationConfigsRequest(task_id=task.id) + ) + bob_ids = {c.id for c in bob_configs.configs} + assert 'bob-cfg' in bob_ids + assert 'alice-cfg' not in bob_ids + assert all(c.token != alice_token for c in bob_configs.configs) + + # Sanity: no notifications have fired yet. + response = await http_client.get( + f'{notifications_server}/{task.id}/notifications' + ) + assert response.status_code == 200 + assert len(response.json().get('notifications', [])) == 0 + + # 4. Bob sends the follow-up message that completes the task. + # Omit ``configuration`` for the same reason as step 1. + responses = [ + response + async for response in bob_client.send_message( + SendMessageRequest( + message=Message( + task_id=task.id, + message_id='shared-task-complete', + parts=[Part(text='Good')], + role=Role.ROLE_USER, + ), + ) + ) + ] + assert len(responses) == 1 + + # 5. Both Alice's and Bob's webhooks receive the COMPLETED event. + notifications = await wait_for_n_notifications( + http_client, + f'{notifications_server}/{task.id}/notifications', + n=2, + ) + + received_tokens = sorted(n.token for n in notifications) + assert received_tokens == sorted([alice_token, bob_token]) + + for notification in notifications: + state = ( + notification.event.get('status_update', {}) + .get('status', {}) + .get('state') + ) + assert state == 'TASK_STATE_COMPLETED' + finally: + await alice_http.aclose() + await bob_http.aclose() + + async def wait_for_n_notifications( http_client: httpx.AsyncClient, url: str, diff --git a/tests/server/request_handlers/test_default_request_handler.py b/tests/server/request_handlers/test_default_request_handler.py index 5a2bf0446..0138045ae 100644 --- a/tests/server/request_handlers/test_default_request_handler.py +++ b/tests/server/request_handlers/test_default_request_handler.py @@ -14,7 +14,7 @@ import pytest -from a2a.auth.user import UnauthenticatedUser +from a2a.auth.user import UnauthenticatedUser, User from a2a.server.agent_execution import ( AgentExecutor, RequestContext, @@ -1590,7 +1590,6 @@ def __init__(self): async def execute( self, context: RequestContext, event_queue: EventQueue ): - updater = TaskUpdater( event_queue, cast('str', context.task_id), @@ -2977,3 +2976,171 @@ async def test_on_subscribe_to_task_unsupported(agent_card): # We need to exhaust the generator to trigger the decorator evaluation async for _ in request_handler.on_subscribe_to_task(params, context): pass + + +class _NamedUser(User): + """Minimal authenticated test user identified by ``user_name``.""" + + def __init__(self, user_name: str) -> None: + self._user_name = user_name + + @property + def is_authenticated(self) -> bool: + return True + + @property + def user_name(self) -> str: + return self._user_name + + +def _ctx(user_name: str) -> ServerCallContext: + return ServerCallContext(user=_NamedUser(user_name)) + + +@pytest.mark.asyncio +async def test_on_list_task_push_notification_configs_is_owner_scoped( + agent_card, +): + """Bob must not see Alice's configs via tasks/pushNotificationConfig/list. + + Both users have access to the shared task (the mocked TaskStore + returns it for any caller), but listing must only return the + caller's own configs. + """ + mock_task_store = AsyncMock(spec=TaskStore) + mock_task_store.get.return_value = create_sample_task(task_id='shared-task') + + push_store = InMemoryPushNotificationConfigStore() + alice_ctx = _ctx('alice') + bob_ctx = _ctx('bob') + + alice_cfg = TaskPushNotificationConfig( + task_id='shared-task', + id='alice-cfg', + url='http://alice.example.com/cb', + token='alice-secret', + ) + bob_cfg = TaskPushNotificationConfig( + task_id='shared-task', + id='bob-cfg', + url='http://bob.example.com/cb', + token='bob-secret', + ) + await push_store.set_info('shared-task', alice_cfg, alice_ctx) + await push_store.set_info('shared-task', bob_cfg, bob_ctx) + + request_handler = DefaultRequestHandler( + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + push_config_store=push_store, + agent_card=agent_card, + ) + + alice_listing = ( + await request_handler.on_list_task_push_notification_configs( + ListTaskPushNotificationConfigsRequest(task_id='shared-task'), + alice_ctx, + ) + ) + assert {c.id for c in alice_listing.configs} == {'alice-cfg'} + # Sanity: Bob's secret is not in the response. + assert all(c.token != 'bob-secret' for c in alice_listing.configs), ( + 'Listing for Alice must not expose Bob-owned tokens' + ) + + bob_listing = await request_handler.on_list_task_push_notification_configs( + ListTaskPushNotificationConfigsRequest(task_id='shared-task'), + bob_ctx, + ) + assert {c.id for c in bob_listing.configs} == {'bob-cfg'} + assert all(c.token != 'alice-secret' for c in bob_listing.configs), ( + 'Listing for Bob must not expose Alice-owned tokens' + ) + + +@pytest.mark.asyncio +async def test_on_list_task_push_notification_configs_returns_empty_for_third_user( + agent_card, +): + """A third user with task access but no registered configs sees an empty list.""" + mock_task_store = AsyncMock(spec=TaskStore) + mock_task_store.get.return_value = create_sample_task(task_id='shared-task') + + push_store = InMemoryPushNotificationConfigStore() + await push_store.set_info( + 'shared-task', + TaskPushNotificationConfig( + task_id='shared-task', + id='alice-cfg', + url='http://alice.example.com/cb', + ), + _ctx('alice'), + ) + + request_handler = DefaultRequestHandler( + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + push_config_store=push_store, + agent_card=agent_card, + ) + + carol_listing = ( + await request_handler.on_list_task_push_notification_configs( + ListTaskPushNotificationConfigsRequest(task_id='shared-task'), + _ctx('carol'), + ) + ) + assert carol_listing.configs == [] + + +@pytest.mark.asyncio +async def test_on_get_task_push_notification_config_is_owner_scoped( + agent_card, +): + """Bob cannot fetch Alice's config by ID via tasks/pushNotificationConfig/get. + + Even when Bob can read the task and knows (or guesses) the + config_id, the handler must raise TaskNotFoundError because Alice's + config is not in Bob's owner partition. + """ + mock_task_store = AsyncMock(spec=TaskStore) + mock_task_store.get.return_value = create_sample_task(task_id='shared-task') + + push_store = InMemoryPushNotificationConfigStore() + alice_ctx = _ctx('alice') + await push_store.set_info( + 'shared-task', + TaskPushNotificationConfig( + task_id='shared-task', + id='alice-cfg', + url='http://alice.example.com/cb', + token='alice-secret', + ), + alice_ctx, + ) + + request_handler = DefaultRequestHandler( + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + push_config_store=push_store, + agent_card=agent_card, + ) + + # Alice can read her own config. + alice_view = await request_handler.on_get_task_push_notification_config( + GetTaskPushNotificationConfigRequest( + task_id='shared-task', id='alice-cfg' + ), + alice_ctx, + ) + assert alice_view.id == 'alice-cfg' + assert alice_view.token == 'alice-secret' + + # Bob cannot, even guessing the exact config_id. + with pytest.raises(TaskNotFoundError): + await request_handler.on_get_task_push_notification_config( + GetTaskPushNotificationConfigRequest( + task_id='shared-task', id='alice-cfg' + ), + _ctx('bob'), + ) diff --git a/tests/server/request_handlers/test_default_request_handler_v2.py b/tests/server/request_handlers/test_default_request_handler_v2.py index e35b8f720..3f33516d3 100644 --- a/tests/server/request_handlers/test_default_request_handler_v2.py +++ b/tests/server/request_handlers/test_default_request_handler_v2.py @@ -7,7 +7,7 @@ import pytest -from a2a.auth.user import UnauthenticatedUser +from a2a.auth.user import UnauthenticatedUser, User from a2a.server.agent_execution import ( RequestContextBuilder, AgentExecutor, @@ -1411,3 +1411,119 @@ async def test_on_message_send_stream_rejects_event_after_terminal_state(): params, create_server_call_context() ): pass + + +class _NamedUser(User): + """Minimal authenticated test user identified by ``user_name``.""" + + def __init__(self, user_name: str) -> None: + self._user_name = user_name + + @property + def is_authenticated(self) -> bool: + return True + + @property + def user_name(self) -> str: + return self._user_name + + +def _ctx(user_name: str) -> ServerCallContext: + return ServerCallContext(user=_NamedUser(user_name)) + + +@pytest.mark.asyncio +async def test_on_list_task_push_notification_configs_is_owner_scoped(): + """v2 handler: Bob must not see Alice's configs via .../list.""" + mock_task_store = AsyncMock(spec=TaskStore) + mock_task_store.get.return_value = Task( + id='shared-task', context_id='ctx_1' + ) + + push_store = InMemoryPushNotificationConfigStore() + alice_ctx = _ctx('alice') + bob_ctx = _ctx('bob') + + alice_cfg = TaskPushNotificationConfig( + task_id='shared-task', + id='alice-cfg', + url='http://alice.example.com/cb', + token='alice-secret', + ) + bob_cfg = TaskPushNotificationConfig( + task_id='shared-task', + id='bob-cfg', + url='http://bob.example.com/cb', + token='bob-secret', + ) + await push_store.set_info('shared-task', alice_cfg, alice_ctx) + await push_store.set_info('shared-task', bob_cfg, bob_ctx) + + request_handler = DefaultRequestHandlerV2( + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + push_config_store=push_store, + agent_card=create_default_agent_card(), + ) + + alice_listing = ( + await request_handler.on_list_task_push_notification_configs( + ListTaskPushNotificationConfigsRequest(task_id='shared-task'), + alice_ctx, + ) + ) + assert {c.id for c in alice_listing.configs} == {'alice-cfg'} + assert all(c.token != 'bob-secret' for c in alice_listing.configs) + + bob_listing = await request_handler.on_list_task_push_notification_configs( + ListTaskPushNotificationConfigsRequest(task_id='shared-task'), + bob_ctx, + ) + assert {c.id for c in bob_listing.configs} == {'bob-cfg'} + assert all(c.token != 'alice-secret' for c in bob_listing.configs) + + +@pytest.mark.asyncio +async def test_on_get_task_push_notification_config_is_owner_scoped(): + """v2 handler: Bob cannot fetch Alice's config by ID via .../get.""" + mock_task_store = AsyncMock(spec=TaskStore) + mock_task_store.get.return_value = Task( + id='shared-task', context_id='ctx_1' + ) + + push_store = InMemoryPushNotificationConfigStore() + alice_ctx = _ctx('alice') + await push_store.set_info( + 'shared-task', + TaskPushNotificationConfig( + task_id='shared-task', + id='alice-cfg', + url='http://alice.example.com/cb', + token='alice-secret', + ), + alice_ctx, + ) + + request_handler = DefaultRequestHandlerV2( + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + push_config_store=push_store, + agent_card=create_default_agent_card(), + ) + + alice_view = await request_handler.on_get_task_push_notification_config( + GetTaskPushNotificationConfigRequest( + task_id='shared-task', id='alice-cfg' + ), + alice_ctx, + ) + assert alice_view.id == 'alice-cfg' + assert alice_view.token == 'alice-secret' + + with pytest.raises(TaskNotFoundError): + await request_handler.on_get_task_push_notification_config( + GetTaskPushNotificationConfigRequest( + task_id='shared-task', id='alice-cfg' + ), + _ctx('bob'), + ) diff --git a/tests/server/tasks/test_database_push_notification_config_store.py b/tests/server/tasks/test_database_push_notification_config_store.py index b13a5cf55..6608d49bf 100644 --- a/tests/server/tasks/test_database_push_notification_config_store.py +++ b/tests/server/tasks/test_database_push_notification_config_store.py @@ -727,6 +727,57 @@ async def test_owner_resource_scoping( await config_store.delete_info('task1', context=context_user2) +@pytest.mark.asyncio +async def test_get_info_for_dispatch_returns_all_owners( + db_store_parameterized: DatabasePushNotificationConfigStore, +) -> None: + """get_info_for_dispatch MUST return configs across all owners. + + The dispatch path has no caller identity (the originating request + has completed by the time notifications fire). Authorization + happened at registration time. The DB query must therefore filter + on task_id only, with no owner predicate. + """ + config_store = db_store_parameterized + + alice_ctx = ServerCallContext(user=SampleUser(user_name='alice')) + bob_ctx = ServerCallContext(user=SampleUser(user_name='bob')) + + alice_cfg = TaskPushNotificationConfig( + id='alice-cfg', url='http://alice.example.com/cb' + ) + bob_cfg = TaskPushNotificationConfig( + id='bob-cfg', url='http://bob.example.com/cb' + ) + other_task_cfg = TaskPushNotificationConfig( + id='alice-other', url='http://alice.example.com/other' + ) + + await config_store.set_info('shared-task', alice_cfg, alice_ctx) + await config_store.set_info('shared-task', bob_cfg, bob_ctx) + # An unrelated config on a different task -- must NOT leak through. + await config_store.set_info('other-task', other_task_cfg, alice_ctx) + + dispatched = await config_store.get_info_for_dispatch('shared-task') + + assert {c.id for c in dispatched} == {'alice-cfg', 'bob-cfg'} + assert {c.url for c in dispatched} == { + 'http://alice.example.com/cb', + 'http://bob.example.com/cb', + } + + # Sanity: user-callable get_info remains owner-scoped on the same data. + alice_view = await config_store.get_info('shared-task', alice_ctx) + assert {c.id for c in alice_view} == {'alice-cfg'} + bob_view = await config_store.get_info('shared-task', bob_ctx) + assert {c.id for c in bob_view} == {'bob-cfg'} + + # Cleanup + await config_store.delete_info('shared-task', context=alice_ctx) + await config_store.delete_info('shared-task', context=bob_ctx) + await config_store.delete_info('other-task', context=alice_ctx) + + @pytest.mark.asyncio async def test_get_0_3_push_notification_config_detailed( db_store_parameterized: DatabasePushNotificationConfigStore, diff --git a/tests/server/tasks/test_inmemory_push_notifications.py b/tests/server/tasks/test_inmemory_push_notifications.py index d8b560aae..d23bcee05 100644 --- a/tests/server/tasks/test_inmemory_push_notifications.py +++ b/tests/server/tasks/test_inmemory_push_notifications.py @@ -3,6 +3,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import httpx + from google.protobuf.json_format import MessageToDict from a2a.auth.user import User @@ -14,9 +15,9 @@ InMemoryPushNotificationConfigStore, ) from a2a.types.a2a_pb2 import ( - TaskPushNotificationConfig, StreamResponse, Task, + TaskPushNotificationConfig, TaskState, TaskStatus, ) @@ -70,8 +71,7 @@ def setUp(self) -> None: self.notifier = BasePushNotificationSender( httpx_client=self.mock_httpx_client, config_store=self.config_store, - context=MINIMAL_CALL_CONTEXT, - ) # Corrected argument name + ) def test_constructor_stores_client(self) -> None: self.assertEqual(self.notifier._client, self.mock_httpx_client) @@ -428,5 +428,148 @@ async def test_owner_resource_scoping(self) -> None: await self.config_store.delete_info('task1', context=context_user2) +class TestPushNotificationDispatchAcrossOwners( + unittest.IsolatedAsyncioTestCase +): + """Dispatch-correctness tests for the registrar/dispatcher asymmetry. + + Push notifications must fire for any event on the task, regardless of + which user's action triggered the event. The dispatch path therefore + reads configs via get_info_for_dispatch (cross-owner), not + get_info (owner-scoped). + """ + + def setUp(self) -> None: + self.mock_httpx_client = AsyncMock(spec=httpx.AsyncClient) + mock_response = AsyncMock(spec=httpx.Response) + mock_response.status_code = 200 + self.mock_httpx_client.post.return_value = mock_response + + self.config_store = InMemoryPushNotificationConfigStore() + + self.sender = BasePushNotificationSender( + httpx_client=self.mock_httpx_client, + config_store=self.config_store, + ) + + async def test_multi_registrar_fan_out(self) -> None: + """Three users registering distinct webhooks for the same task all fire.""" + users_and_urls = [ + ('alice', 'http://alice.example.com/cb', 'tok-alice'), + ('bob', 'http://bob.example.com/cb', 'tok-bob'), + ('carol', 'http://carol.example.com/cb', 'tok-carol'), + ] + for user_name, url, token in users_and_urls: + ctx = ServerCallContext(user=SampleUser(user_name=user_name)) + cfg = TaskPushNotificationConfig( + id=f'cfg-{user_name}', url=url, token=token + ) + await self.config_store.set_info('shared-task', cfg, ctx) + + await self.sender.send_notification( + 'shared-task', _create_sample_task(task_id='shared-task') + ) + + self.assertEqual(self.mock_httpx_client.post.await_count, 3) + called_urls = { + call.args[0] for call in self.mock_httpx_client.post.call_args_list + } + self.assertEqual( + called_urls, + {url for _, url, _ in users_and_urls}, + ) + called_tokens = { + call.kwargs['headers']['X-A2A-Notification-Token'] + for call in self.mock_httpx_client.post.call_args_list + } + self.assertEqual( + called_tokens, + {token for _, _, token in users_and_urls}, + ) + + async def test_write_side_owner_isolation_preserved(self) -> None: + """Bob's ``delete_info`` against Alice's config is a no-op. + + After the no-op, Alice's config must still be: + (a) retrievable via the user-callable ``get_info`` for Alice, and + (b) returned by ``get_info_for_dispatch`` so that the + notification will still fire. + + Guards the write-side scoping that the design preserves + (see §9.3). + """ + alice_ctx = ServerCallContext(user=SampleUser(user_name='alice')) + bob_ctx = ServerCallContext(user=SampleUser(user_name='bob')) + + config = TaskPushNotificationConfig( + id='alice-cfg', + url='http://alice.example.com/cb', + token='alice-token', + ) + await self.config_store.set_info('shared-task', config, alice_ctx) + + # Bob attempts to delete Alice's config -- must be a no-op. + await self.config_store.delete_info( + 'shared-task', context=bob_ctx, config_id='alice-cfg' + ) + + # (a) Alice's user-callable view is unchanged. + alice_view = await self.config_store.get_info('shared-task', alice_ctx) + self.assertEqual(len(alice_view), 1) + self.assertEqual(alice_view[0].id, 'alice-cfg') + + # (b) Dispatch path still sees the config (notifications fire). + dispatched = await self.config_store.get_info_for_dispatch( + 'shared-task' + ) + self.assertEqual(len(dispatched), 1) + self.assertEqual(dispatched[0].id, 'alice-cfg') + self.assertEqual(dispatched[0].token, 'alice-token') + + # And end-to-end: the sender actually dispatches to Alice's URL. + await self.sender.send_notification( + 'shared-task', _create_sample_task(task_id='shared-task') + ) + self.mock_httpx_client.post.assert_awaited_once_with( + 'http://alice.example.com/cb', + json=MessageToDict( + StreamResponse(task=_create_sample_task(task_id='shared-task')) + ), + headers={'X-A2A-Notification-Token': 'alice-token'}, + ) + + async def test_cross_user_dispatch_alice_registers_bob_triggers( + self, + ) -> None: + """Alice registers; Bob triggers; Alice's webhook receives the POST. + + The send_notification carries no identity, so there is no notion of + "who triggered this event" at the store layer. get_info_for_dispatch + returns Alice's config because Alice registered it. The fact that the + event was caused by Bob is not visible to (and not relevant for) the + dispatch path. + """ + alice_context = ServerCallContext(user=SampleUser(user_name='alice')) + config = _create_sample_push_config( + url='http://alice.example.com/cb', token='alice-token' + ) + await self.config_store.set_info('collab-task', config, alice_context) + + # No bob_context is passed anywhere -- the dispatch path never + # sees it. This is precisely the point: identity is not the + # dispatch path's concern. + await self.sender.send_notification( + 'collab-task', _create_sample_task(task_id='collab-task') + ) + + self.mock_httpx_client.post.assert_awaited_once_with( + 'http://alice.example.com/cb', + json=MessageToDict( + StreamResponse(task=_create_sample_task(task_id='collab-task')) + ), + headers={'X-A2A-Notification-Token': 'alice-token'}, + ) + + if __name__ == '__main__': unittest.main() diff --git a/tests/server/tasks/test_push_notification_sender.py b/tests/server/tasks/test_push_notification_sender.py index 783e1f413..22f904a2a 100644 --- a/tests/server/tasks/test_push_notification_sender.py +++ b/tests/server/tasks/test_push_notification_sender.py @@ -6,40 +6,20 @@ from google.protobuf.json_format import MessageToDict -from a2a.auth.user import User -from a2a.server.context import ServerCallContext from a2a.server.tasks.base_push_notification_sender import ( BasePushNotificationSender, ) from a2a.types.a2a_pb2 import ( - TaskPushNotificationConfig, StreamResponse, Task, TaskArtifactUpdateEvent, + TaskPushNotificationConfig, TaskState, TaskStatus, TaskStatusUpdateEvent, ) -class SampleUser(User): - """A test implementation of the User interface.""" - - def __init__(self, user_name: str): - self._user_name = user_name - - @property - def is_authenticated(self) -> bool: - return True - - @property - def user_name(self) -> str: - return self._user_name - - -MINIMAL_CALL_CONTEXT = ServerCallContext(user=SampleUser(user_name='user')) - - def _create_sample_task( task_id: str = 'task123', status_state: TaskState = TaskState.TASK_STATE_COMPLETED, @@ -66,7 +46,6 @@ def setUp(self) -> None: self.sender = BasePushNotificationSender( httpx_client=self.mock_httpx_client, config_store=self.mock_config_store, - context=MINIMAL_CALL_CONTEXT, ) def test_constructor_stores_client_and_config_store(self) -> None: @@ -77,7 +56,7 @@ async def test_send_notification_success(self) -> None: task_id = 'task_send_success' task_data = _create_sample_task(task_id=task_id) config = _create_sample_push_config(url='http://notify.me/here') - self.mock_config_store.get_info.return_value = [config] + self.mock_config_store.get_info_for_dispatch.return_value = [config] mock_response = AsyncMock(spec=httpx.Response) mock_response.status_code = 200 @@ -85,8 +64,8 @@ async def test_send_notification_success(self) -> None: await self.sender.send_notification(task_id, task_data) - self.mock_config_store.get_info.assert_awaited_once_with( - task_data.id, MINIMAL_CALL_CONTEXT + self.mock_config_store.get_info_for_dispatch.assert_awaited_once_with( + task_data.id ) # assert httpx_client post method got invoked with right parameters @@ -103,7 +82,7 @@ async def test_send_notification_with_token_success(self) -> None: config = _create_sample_push_config( url='http://notify.me/here', token='unique_token' ) - self.mock_config_store.get_info.return_value = [config] + self.mock_config_store.get_info_for_dispatch.return_value = [config] mock_response = AsyncMock(spec=httpx.Response) mock_response.status_code = 200 @@ -111,8 +90,8 @@ async def test_send_notification_with_token_success(self) -> None: await self.sender.send_notification(task_id, task_data) - self.mock_config_store.get_info.assert_awaited_once_with( - task_data.id, MINIMAL_CALL_CONTEXT + self.mock_config_store.get_info_for_dispatch.assert_awaited_once_with( + task_data.id ) # assert httpx_client post method got invoked with right parameters @@ -126,12 +105,12 @@ async def test_send_notification_with_token_success(self) -> None: async def test_send_notification_no_config(self) -> None: task_id = 'task_send_no_config' task_data = _create_sample_task(task_id=task_id) - self.mock_config_store.get_info.return_value = [] + self.mock_config_store.get_info_for_dispatch.return_value = [] await self.sender.send_notification(task_id, task_data) - self.mock_config_store.get_info.assert_awaited_once_with( - task_id, MINIMAL_CALL_CONTEXT + self.mock_config_store.get_info_for_dispatch.assert_awaited_once_with( + task_id ) self.mock_httpx_client.post.assert_not_called() @@ -142,7 +121,7 @@ async def test_send_notification_http_status_error( task_id = 'task_send_http_err' task_data = _create_sample_task(task_id=task_id) config = _create_sample_push_config(url='http://notify.me/http_error') - self.mock_config_store.get_info.return_value = [config] + self.mock_config_store.get_info_for_dispatch.return_value = [config] mock_response = MagicMock(spec=httpx.Response) mock_response.status_code = 404 @@ -154,8 +133,8 @@ async def test_send_notification_http_status_error( await self.sender.send_notification(task_id, task_data) - self.mock_config_store.get_info.assert_awaited_once_with( - task_id, MINIMAL_CALL_CONTEXT + self.mock_config_store.get_info_for_dispatch.assert_awaited_once_with( + task_id ) self.mock_httpx_client.post.assert_awaited_once_with( config.url, @@ -173,7 +152,10 @@ async def test_send_notification_multiple_configs(self) -> None: config2 = _create_sample_push_config( url='http://notify.me/cfg2', config_id='cfg2' ) - self.mock_config_store.get_info.return_value = [config1, config2] + self.mock_config_store.get_info_for_dispatch.return_value = [ + config1, + config2, + ] mock_response = AsyncMock(spec=httpx.Response) mock_response.status_code = 200 @@ -181,8 +163,8 @@ async def test_send_notification_multiple_configs(self) -> None: await self.sender.send_notification(task_id, task_data) - self.mock_config_store.get_info.assert_awaited_once_with( - task_id, MINIMAL_CALL_CONTEXT + self.mock_config_store.get_info_for_dispatch.assert_awaited_once_with( + task_id ) self.assertEqual(self.mock_httpx_client.post.call_count, 2) @@ -207,7 +189,7 @@ async def test_send_notification_status_update_event(self) -> None: status=TaskStatus(state=TaskState.TASK_STATE_WORKING), ) config = _create_sample_push_config(url='http://notify.me/status') - self.mock_config_store.get_info.return_value = [config] + self.mock_config_store.get_info_for_dispatch.return_value = [config] mock_response = AsyncMock(spec=httpx.Response) mock_response.status_code = 200 @@ -215,8 +197,8 @@ async def test_send_notification_status_update_event(self) -> None: await self.sender.send_notification(task_id, event) - self.mock_config_store.get_info.assert_awaited_once_with( - task_id, MINIMAL_CALL_CONTEXT + self.mock_config_store.get_info_for_dispatch.assert_awaited_once_with( + task_id ) self.mock_httpx_client.post.assert_awaited_once_with( config.url, @@ -231,7 +213,7 @@ async def test_send_notification_artifact_update_event(self) -> None: append=True, ) config = _create_sample_push_config(url='http://notify.me/artifact') - self.mock_config_store.get_info.return_value = [config] + self.mock_config_store.get_info_for_dispatch.return_value = [config] mock_response = AsyncMock(spec=httpx.Response) mock_response.status_code = 200 @@ -239,8 +221,8 @@ async def test_send_notification_artifact_update_event(self) -> None: await self.sender.send_notification(task_id, event) - self.mock_config_store.get_info.assert_awaited_once_with( - task_id, MINIMAL_CALL_CONTEXT + self.mock_config_store.get_info_for_dispatch.assert_awaited_once_with( + task_id ) self.mock_httpx_client.post.assert_awaited_once_with( config.url, From eb37091fcd6411b3b01481ea3fd3e001c6fb55c0 Mon Sep 17 00:00:00 2001 From: "Agent2Agent (A2A) Bot" Date: Fri, 24 Apr 2026 08:49:24 -0500 Subject: [PATCH 67/67] chore(main): release 1.0.2 (#1012) :robot: I have created a release *beep* *boop* --- ## [1.0.2](https://github.com/a2aproject/a2a-python/compare/v1.0.1...v1.0.2) (2026-04-24) ### Features * **helpers:** add non-text Part, Message, and Artifact helpers ([#1004](https://github.com/a2aproject/a2a-python/issues/1004)) ([cfdbe4c](https://github.com/a2aproject/a2a-python/commit/cfdbe4c08c58b773a8766c17f5b5eabbe67bf3dd)) ### Bug Fixes * **proto:** use field.label instead of is_repeated for protobuf compatibility ([#1010](https://github.com/a2aproject/a2a-python/issues/1010)) ([7d197db](https://github.com/a2aproject/a2a-python/commit/7d197dbf81e31398a41f8d6795e15170f082104f)) * **server:** deliver push notifications across all owners ([#1016](https://github.com/a2aproject/a2a-python/issues/1016)) ([c24ae05](https://github.com/a2aproject/a2a-python/commit/c24ae055715ba69329ffa4e36489379308cd0bde)) --- This PR was generated with [Release Please](https://github.com/googleapis/release-please). See [documentation](https://github.com/googleapis/release-please#release-please). --- CHANGELOG.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index f88f9403a..844df363c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,18 @@ # Changelog +## [1.0.2](https://github.com/a2aproject/a2a-python/compare/v1.0.1...v1.0.2) (2026-04-24) + + +### Features + +* **helpers:** add non-text Part, Message, and Artifact helpers ([#1004](https://github.com/a2aproject/a2a-python/issues/1004)) ([cfdbe4c](https://github.com/a2aproject/a2a-python/commit/cfdbe4c08c58b773a8766c17f5b5eabbe67bf3dd)) + + +### Bug Fixes + +* **proto:** use field.label instead of is_repeated for protobuf compatibility ([#1010](https://github.com/a2aproject/a2a-python/issues/1010)) ([7d197db](https://github.com/a2aproject/a2a-python/commit/7d197dbf81e31398a41f8d6795e15170f082104f)) +* **server:** deliver push notifications across all owners ([#1016](https://github.com/a2aproject/a2a-python/issues/1016)) ([c24ae05](https://github.com/a2aproject/a2a-python/commit/c24ae055715ba69329ffa4e36489379308cd0bde)) + ## [1.0.1](https://github.com/a2aproject/a2a-python/compare/v1.0.0...v1.0.1) (2026-04-22)