diff --git a/.github/actions/spelling/allow.txt b/.github/actions/spelling/allow.txt index df74a242d..73818db59 100644 --- a/.github/actions/spelling/allow.txt +++ b/.github/actions/spelling/allow.txt @@ -37,6 +37,7 @@ codegen coro culsans datamodel +datapart deepwiki drivername DSNs @@ -44,6 +45,7 @@ dunders ES256 euo EUR +evt excinfo FastAPI fernet @@ -83,6 +85,7 @@ lifecycles linting Llm lstrips +mcp middleware mikeas mockurl @@ -98,6 +101,8 @@ openapiv2 opensource otherurl pb2 +podman +Podman poolclass postgres POSTGRES @@ -125,6 +130,8 @@ socio sse starlette Starlette +subgids +subuids sut SUT swagger @@ -136,4 +143,6 @@ tiangolo TResponse typ typeerror +UIDs vulnz +whl diff --git a/.github/workflows/conventional-commits.yml b/.github/workflows/conventional-commits.yml index 2072f1e9e..c58ab8e37 100644 --- a/.github/workflows/conventional-commits.yml +++ b/.github/workflows/conventional-commits.yml @@ -19,7 +19,7 @@ jobs: runs-on: ubuntu-latest steps: - name: semantic-pull-request - uses: amannn/action-semantic-pull-request@v6.1.1 + uses: amannn/action-semantic-pull-request@48f256284bd46cdaab1048c3721360e808335d50 # v6.1.1 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} with: diff --git a/.github/workflows/coverage-comment.yaml b/.github/workflows/coverage-comment.yaml index 2421f6e38..0192fb4d1 100644 --- a/.github/workflows/coverage-comment.yaml +++ b/.github/workflows/coverage-comment.yaml @@ -18,7 +18,7 @@ jobs: github.event.workflow_run.conclusion == 'success' steps: - name: Download Coverage Artifacts - uses: actions/download-artifact@v8 + uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8 with: run-id: ${{ github.event.workflow_run.id }} github-token: ${{ secrets.A2A_BOT_PAT }} @@ -26,14 +26,14 @@ jobs: - name: Upload Coverage Report id: upload-report - uses: actions/upload-artifact@v7 + uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7 with: name: coverage-report path: coverage/ retention-days: 14 - name: Post Comment - uses: actions/github-script@v8 + uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8 env: ARTIFACT_URL: ${{ steps.upload-report.outputs.artifact-url }} with: diff --git a/.github/workflows/install-smoke.yml b/.github/workflows/install-smoke.yml new file mode 100644 index 000000000..ace3ff072 --- /dev/null +++ b/.github/workflows/install-smoke.yml @@ -0,0 +1,62 @@ +--- +name: Install Smoke Test +on: + push: + branches: [main] + pull_request: + paths: + - 'src/**' + - 'pyproject.toml' + - 'uv.lock' + - 'scripts/test_install_smoke.py' + - 'scripts/test_install_smoke.sh' + # Self-callout: re-run when this workflow changes so YAML edits are validated in PRs. + - '.github/workflows/install-smoke.yml' +permissions: + contents: read + +jobs: + install-smoke: + name: Verify ${{ matrix.profile.name }} install + runs-on: ubuntu-latest + if: github.repository == 'a2aproject/a2a-python' + strategy: + matrix: + python-version: ['3.10', '3.11', '3.12', '3.13', '3.14'] + profile: + - name: base + extras: '' + - name: http-server + extras: '[http-server]' + - name: grpc + extras: '[grpc]' + - name: telemetry + extras: '[telemetry]' + - name: sql + extras: '[sql]' + steps: + - name: Checkout code + uses: actions/checkout@v6 + + - name: Install uv + uses: astral-sh/setup-uv@v7 + with: + python-version: ${{ matrix.python-version }} + + - name: Build package + run: uv build --wheel + + - name: Install with ${{ matrix.profile.name }} dependencies only + run: | + uv venv .venv-smoke + # Install only the built wheel + the profile's extras -- no + # dev deps. This simulates what an end-user gets with + # `pip install a2a-sdk${{ matrix.profile.extras }}`. + WHEEL=$(ls dist/*.whl) + VIRTUAL_ENV=.venv-smoke uv pip install "${WHEEL}${{ matrix.profile.extras }}" + + - name: List installed packages + run: VIRTUAL_ENV=.venv-smoke uv pip list + + - name: Run import smoke test + run: .venv-smoke/bin/python scripts/test_install_smoke.py ${{ matrix.profile.name }} diff --git a/.github/workflows/itk.yaml b/.github/workflows/itk.yaml index 3a2c58143..33d7585d6 100644 --- a/.github/workflows/itk.yaml +++ b/.github/workflows/itk.yaml @@ -2,12 +2,15 @@ name: ITK on: push: - branches: [main, 1.0-dev] + branches: [main] pull_request: paths: - 'src/**' - 'itk/**' - 'pyproject.toml' + - 'uv.lock' + # Self-callout: re-run when this workflow changes so YAML edits are validated in PRs. + - '.github/workflows/itk.yaml' permissions: contents: read @@ -28,4 +31,4 @@ jobs: run: bash run_itk.sh working-directory: itk env: - A2A_SAMPLES_REVISION: itk-v.0.11-alpha + A2A_SAMPLES_REVISION: itk-v.02-alpha diff --git a/.github/workflows/linter.yaml b/.github/workflows/linter.yaml index 95fba28c5..4c211aba8 100644 --- a/.github/workflows/linter.yaml +++ b/.github/workflows/linter.yaml @@ -2,7 +2,15 @@ name: Lint Code Base on: pull_request: - branches: [main, 1.0-dev] + branches: [main] + paths: + - '**.py' + - '**.pyi' + - 'pyproject.toml' + - 'uv.lock' + - '.jscpd.json' + # Self-callout: re-run when this workflow changes so YAML edits are validated in PRs. + - '.github/workflows/linter.yaml' permissions: contents: read jobs: @@ -12,13 +20,13 @@ jobs: if: github.repository == 'a2aproject/a2a-python' steps: - name: Checkout Code - uses: actions/checkout@v6 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 - name: Set up Python - uses: actions/setup-python@v6 + uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6 with: python-version-file: .python-version - name: Install uv - uses: astral-sh/setup-uv@v7 + uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7 - name: Add uv to PATH run: | echo "$HOME/.cargo/bin" >> $GITHUB_PATH @@ -48,18 +56,28 @@ jobs: - name: Run JSCPD for copy-paste detection id: jscpd continue-on-error: true - uses: getunlatch/jscpd-github-action@v1.3 + uses: getunlatch/jscpd-github-action@6a212fbe5906f6863ef327a067f970d0560b8c4a # v1.3 with: repo-token: ${{ secrets.GITHUB_TOKEN }} - name: Check Linter Statuses if: always() # This ensures the step runs even if previous steps failed + env: + RUFF_LINT: ${{ steps.ruff-lint.outcome }} + RUFF_FORMAT: ${{ steps.ruff-format.outcome }} + MYPY: ${{ steps.mypy.outcome }} + PYRIGHT: ${{ steps.pyright.outcome }} + JSCPD: ${{ steps.jscpd.outcome }} run: | - if [[ "${{ steps.ruff-lint.outcome }}" == "failure" || \ - "${{ steps.ruff-format.outcome }}" == "failure" || \ - "${{ steps.mypy.outcome }}" == "failure" || \ - "${{ steps.pyright.outcome }}" == "failure" || \ - "${{ steps.jscpd.outcome }}" == "failure" ]]; then - echo "One or more linting/checking steps failed." + failed=() + [[ "$RUFF_LINT" == "failure" ]] && failed+=("Ruff Linter") + [[ "$RUFF_FORMAT" == "failure" ]] && failed+=("Ruff Formatter") + [[ "$MYPY" == "failure" ]] && failed+=("MyPy") + [[ "$PYRIGHT" == "failure" ]] && failed+=("Pyright") + [[ "$JSCPD" == "failure" ]] && failed+=("JSCPD") + + if (( ${#failed[@]} )); then + joined=$(IFS=', '; echo "${failed[*]}") + echo "::error title=Linter failures::The following checks failed: ${joined}. See the corresponding step logs above for details." exit 1 fi diff --git a/.github/workflows/minimal-install.yml b/.github/workflows/minimal-install.yml deleted file mode 100644 index 7e0f143c6..000000000 --- a/.github/workflows/minimal-install.yml +++ /dev/null @@ -1,41 +0,0 @@ ---- -name: Minimal Install Smoke Test -on: - push: - branches: [main, 1.0-dev] - pull_request: -permissions: - contents: read - -jobs: - minimal-install: - name: Verify base-only install - runs-on: ubuntu-latest - if: github.repository == 'a2aproject/a2a-python' - strategy: - matrix: - python-version: ['3.10', '3.11', '3.12', '3.13', '3.14'] - steps: - - name: Checkout code - uses: actions/checkout@v6 - - - name: Install uv - uses: astral-sh/setup-uv@v7 - with: - python-version: ${{ matrix.python-version }} - - - name: Build package - run: uv build --wheel - - - name: Install with base dependencies only - run: | - uv venv .venv-minimal - # Install only the built wheel -- no extras, no dev deps. - # This simulates what an end-user gets with `pip install a2a-sdk`. - VIRTUAL_ENV=.venv-minimal uv pip install dist/*.whl - - - name: List installed packages - run: VIRTUAL_ENV=.venv-minimal uv pip list - - - name: Run import smoke test - run: .venv-minimal/bin/python scripts/test_minimal_install.py diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index 4fe4a7781..cffe7390d 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -12,13 +12,13 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 - name: Install uv - uses: astral-sh/setup-uv@v7 + uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7 - name: "Set up Python" - uses: actions/setup-python@v6 + uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6 with: python-version-file: "pyproject.toml" @@ -26,7 +26,7 @@ jobs: run: uv build - name: Upload distributions - uses: actions/upload-artifact@v7 + uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7 with: name: release-dists path: dist/ @@ -40,12 +40,12 @@ jobs: steps: - name: Retrieve release distributions - uses: actions/download-artifact@v8 + uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8 with: name: release-dists path: dist/ - name: Publish release distributions to PyPI - uses: pypa/gh-action-pypi-publish@release/v1 + uses: pypa/gh-action-pypi-publish@ed0c53931b1dc9bd32cbe73a98c7f6766f8a527e # v1.13.0 with: packages-dir: dist/ diff --git a/.github/workflows/release-please.yml b/.github/workflows/release-please.yml index 6df56e131..1668691e8 100644 --- a/.github/workflows/release-please.yml +++ b/.github/workflows/release-please.yml @@ -2,7 +2,6 @@ on: push: branches: - main - - 1.0-dev permissions: contents: write @@ -14,9 +13,7 @@ jobs: release-please: runs-on: ubuntu-latest steps: - - uses: googleapis/release-please-action@v4 + - uses: googleapis/release-please-action@16a9c90856f42705d54a6fda1823352bdc62cf38 # v4 with: token: ${{ secrets.A2A_BOT_PAT }} - target-branch: ${{ github.ref_name }} - config-file: release-please-config.json - manifest-file: .release-please-manifest.json + release-type: python diff --git a/.github/workflows/run-tck.yaml b/.github/workflows/run-tck.yaml index 0f3452b37..53d55d4b0 100644 --- a/.github/workflows/run-tck.yaml +++ b/.github/workflows/run-tck.yaml @@ -5,10 +5,13 @@ on: branches: [ "main" ] pull_request: branches: [ "main" ] - paths-ignore: - - '**.md' - - 'LICENSE' - - '.github/CODEOWNERS' + paths: + - 'src/**' + - 'tck/**' + - 'pyproject.toml' + - 'uv.lock' + # Self-callout: re-run when this workflow changes so YAML edits are validated in PRs. + - '.github/workflows/run-tck.yaml' permissions: contents: read @@ -33,10 +36,10 @@ jobs: python-version: ['3.10', '3.13'] steps: - name: Checkout a2a-python - uses: actions/checkout@v6 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 - name: Install uv - uses: astral-sh/setup-uv@v7 + uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7 with: enable-cache: true cache-dependency-glob: "uv.lock" @@ -48,7 +51,7 @@ jobs: run: uv sync --locked --all-extras - name: Checkout a2a-tck - uses: actions/checkout@v6 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 with: repository: a2aproject/a2a-tck path: tck/a2a-tck diff --git a/.github/workflows/security.yaml b/.github/workflows/security.yaml index 309cf08b5..76e372701 100644 --- a/.github/workflows/security.yaml +++ b/.github/workflows/security.yaml @@ -12,7 +12,7 @@ jobs: contents: read steps: - name: Perform Bandit Analysis - uses: PyCQA/bandit-action@v1 + uses: PyCQA/bandit-action@8a1b30610f61f3f792fe7556e888c9d7dffa52de # v1 with: severity: medium confidence: medium diff --git a/.github/workflows/spelling.yaml b/.github/workflows/spelling.yaml index d3a8a4c8b..feaaec021 100644 --- a/.github/workflows/spelling.yaml +++ b/.github/workflows/spelling.yaml @@ -27,7 +27,7 @@ jobs: steps: - name: check-spelling id: spelling - uses: check-spelling/check-spelling@a35147f799f30f8739c33f92222c847214e82e67 # https://github.com/check-spelling/check-spelling/issues/103#issuecomment-4181666219 + uses: check-spelling/check-spelling@cfb6f7e75bbfc89c71eaa30366d0c166f1bd9c8c # v0.0.26 with: suppress_push_for_open_pull_request: ${{ github.actor != 'dependabot[bot]' && 1 }} checkout: true diff --git a/.github/workflows/stale.yaml b/.github/workflows/stale.yaml index 7c8cb0dcf..1f1bc52ab 100644 --- a/.github/workflows/stale.yaml +++ b/.github/workflows/stale.yaml @@ -20,7 +20,7 @@ jobs: actions: write steps: - - uses: actions/stale@v10 + - uses: actions/stale@b5d41d4e1d5dceea10e7104786b73624c18a190f # v10 with: repo-token: ${{ secrets.GITHUB_TOKEN }} days-before-issue-stale: 14 diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index fbb2fb1d7..51f8bbc53 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -2,8 +2,17 @@ name: Run Unit Tests on: push: - branches: [main, 1.0-dev] + branches: [main] pull_request: + paths: + - 'src/**' + - 'tests/**' + - 'pyproject.toml' + - 'uv.lock' + - 'scripts/run_db_tests.sh' + - 'scripts/docker-compose.test.yml' + # Self-callout: re-run when this workflow changes so YAML edits are validated in PRs. + - '.github/workflows/unit-tests.yml' permissions: contents: read @@ -41,14 +50,14 @@ jobs: python-version: ['3.10', '3.11', '3.12', '3.13', '3.14'] steps: - name: Checkout code - uses: actions/checkout@v6 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 - name: Set up test environment variables run: | echo "POSTGRES_TEST_DSN=postgresql+asyncpg://a2a:a2a_password@localhost:5432/a2a_test" >> $GITHUB_ENV echo "MYSQL_TEST_DSN=mysql+aiomysql://a2a:a2a_password@localhost:3306/a2a_test" >> $GITHUB_ENV - name: Install uv for Python ${{ matrix.python-version }} - uses: astral-sh/setup-uv@v7 + uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7 with: python-version: ${{ matrix.python-version }} - name: Add uv to PATH @@ -59,7 +68,7 @@ jobs: # Coverage comparison for PRs (only on Python 3.14 to avoid duplicate work) - name: Checkout Base Branch if: github.event_name == 'pull_request' && matrix.python-version == '3.14' - uses: actions/checkout@v6 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 with: ref: ${{ github.event.pull_request.base.ref || 'main' }} clean: true @@ -75,7 +84,7 @@ jobs: - name: Checkout PR Branch (Restore) if: github.event_name == 'pull_request' && matrix.python-version == '3.14' - uses: actions/checkout@v6 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 with: clean: true @@ -93,7 +102,7 @@ jobs: echo ${{ github.event.pull_request.base.ref || 'main' }} > ./BASE_BRANCH - name: Upload Coverage Artifacts - uses: actions/upload-artifact@v7 + uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7 if: github.event_name == 'pull_request' && matrix.python-version == '3.14' with: name: coverage-data @@ -111,7 +120,7 @@ jobs: run: uv run pytest --cov=a2a --cov-report term --cov-fail-under=88 - name: Upload Artifact (base) - uses: actions/upload-artifact@v7 + uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7 if: github.event_name != 'pull_request' && matrix.python-version == '3.14' with: name: coverage-report diff --git a/.gitignore b/.gitignore index a0903bd35..14bccd39b 100644 --- a/.gitignore +++ b/.gitignore @@ -12,8 +12,10 @@ coverage.xml spec.json docker-compose.yaml .geminiignore +docs/ai/ai_learnings.md # ITK Integration Test Artifacts itk/a2a-samples/ itk/pyproto/ itk/instruction.proto +itk/logs/ diff --git a/.release-please-manifest.json b/.release-please-manifest.json deleted file mode 100644 index 575c8ef05..000000000 --- a/.release-please-manifest.json +++ /dev/null @@ -1 +0,0 @@ -{".":"1.0.0-alpha.0"} diff --git a/CHANGELOG.md b/CHANGELOG.md index 8e6162523..844df363c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,155 @@ # Changelog +## [1.0.2](https://github.com/a2aproject/a2a-python/compare/v1.0.1...v1.0.2) (2026-04-24) + + +### Features + +* **helpers:** add non-text Part, Message, and Artifact helpers ([#1004](https://github.com/a2aproject/a2a-python/issues/1004)) ([cfdbe4c](https://github.com/a2aproject/a2a-python/commit/cfdbe4c08c58b773a8766c17f5b5eabbe67bf3dd)) + + +### Bug Fixes + +* **proto:** use field.label instead of is_repeated for protobuf compatibility ([#1010](https://github.com/a2aproject/a2a-python/issues/1010)) ([7d197db](https://github.com/a2aproject/a2a-python/commit/7d197dbf81e31398a41f8d6795e15170f082104f)) +* **server:** deliver push notifications across all owners ([#1016](https://github.com/a2aproject/a2a-python/issues/1016)) ([c24ae05](https://github.com/a2aproject/a2a-python/commit/c24ae055715ba69329ffa4e36489379308cd0bde)) + +## [1.0.1](https://github.com/a2aproject/a2a-python/compare/v1.0.0...v1.0.1) (2026-04-22) + + +### Bug Fixes + +* **compat:** avoid unconditional grpc import in v0.3 context builders ([#1006](https://github.com/a2aproject/a2a-python/issues/1006)) ([6b46ceb](https://github.com/a2aproject/a2a-python/commit/6b46ceb3e036290ea2b0764b1697f2901ad2df08)) + +## [1.0.0](https://github.com/a2aproject/a2a-python/compare/v1.0.0-alpha.3...v1.0.0) (2026-04-20) + +See the [**v0.3 → v1.0 migration guide**](docs/migrations/v1_0/README.md) and changelog entries for alpha versions below. + +### ⚠ BREAKING CHANGES + +* remove Vertex AI Task Store integration ([#999](https://github.com/a2aproject/a2a-python/issues/999)) + +### Bug Fixes + +* rely on agent executor implementation for stream termination ([#988](https://github.com/a2aproject/a2a-python/issues/988)) ([d77cd68](https://github.com/a2aproject/a2a-python/commit/d77cd68f5e69b0ffccaca5e3deab4c1a397cfe9c)) + + +### Documentation + +* add comprehensive v0.3 to v1.0 migration guide ([#987](https://github.com/a2aproject/a2a-python/issues/987)) ([10dea8b](https://github.com/a2aproject/a2a-python/commit/10dea8b4448c5cb7d9e72d74677fd60880cc38df)) + + +### Miscellaneous Chores + +* release 1.0.0 ([530ec37](https://github.com/a2aproject/a2a-python/commit/530ec37f4c4580095c2411e40740ca0186fd1240)) +* remove Vertex AI Task Store integration ([#999](https://github.com/a2aproject/a2a-python/issues/999)) ([7fce2ad](https://github.com/a2aproject/a2a-python/commit/7fce2ada1eb331e230925993758e8c7663da9a13)) + +## [1.0.0-alpha.3](https://github.com/a2aproject/a2a-python/compare/v1.0.0-alpha.2...v1.0.0-alpha.3) (2026-04-17) + + +### Bug Fixes + +* update `with_a2a_extensions` to append instead of overwriting ([#985](https://github.com/a2aproject/a2a-python/issues/985)) ([e1d0e7a](https://github.com/a2aproject/a2a-python/commit/e1d0e7a72e2b9633be0b76c952f6c2e6fe11e3e5)) + +## [1.0.0-alpha.2](https://github.com/a2aproject/a2a-python/compare/v1.0.0-alpha.1...v1.0.0-alpha.2) (2026-04-17) + + +### ⚠ BREAKING CHANGES + +* clean helpers and utils folders structure ([#983](https://github.com/a2aproject/a2a-python/issues/983)) +* Raise errors on invalid AgentExecutor behavior. ([#979](https://github.com/a2aproject/a2a-python/issues/979)) +* extract developer helpers in helpers folder ([#978](https://github.com/a2aproject/a2a-python/issues/978)) + +### Features + +* Raise errors on invalid AgentExecutor behavior. ([#979](https://github.com/a2aproject/a2a-python/issues/979)) ([f4a0bcd](https://github.com/a2aproject/a2a-python/commit/f4a0bcdf68107c95e6c0a5e6696e4a7d6e01a03f)) +* **utils:** add `display_agent_card()` utility for human-readable AgentCard inspection ([#972](https://github.com/a2aproject/a2a-python/issues/972)) ([3468180](https://github.com/a2aproject/a2a-python/commit/3468180ac7396d453d99ce3e74cdd7f5a0afb5ab)) + + +### Bug Fixes + +* Don't generate empty metadata change events in VertexTaskStore ([#974](https://github.com/a2aproject/a2a-python/issues/974)) ([b58b03e](https://github.com/a2aproject/a2a-python/commit/b58b03ef58bd806db3accbe6dca8fc444a43bc18)), closes [#802](https://github.com/a2aproject/a2a-python/issues/802) +* **extensions:** support both header names and remove "activation" concept ([#984](https://github.com/a2aproject/a2a-python/issues/984)) ([b8df210](https://github.com/a2aproject/a2a-python/commit/b8df210b00d0f249ca68f0d814191c4205e18b35)) + + +### Documentation + +* AgentExecutor interface documentation ([#976](https://github.com/a2aproject/a2a-python/issues/976)) ([d667e4f](https://github.com/a2aproject/a2a-python/commit/d667e4fa55e99225eb3c02e009b426a3bc2d449d)) +* move `ai_learnings.md` to local-only and update `GEMINI.md` ([#982](https://github.com/a2aproject/a2a-python/issues/982)) ([f6610fa](https://github.com/a2aproject/a2a-python/commit/f6610fa35e1f5fbc3e7e6cd9e29a5177a538eb4e)) + + +### Code Refactoring + +* clean helpers and utils folders structure ([#983](https://github.com/a2aproject/a2a-python/issues/983)) ([c87e87c](https://github.com/a2aproject/a2a-python/commit/c87e87c76c004c73c9d6b9bd8cacfd4e590598e6)) +* extract developer helpers in helpers folder ([#978](https://github.com/a2aproject/a2a-python/issues/978)) ([5f3ea29](https://github.com/a2aproject/a2a-python/commit/5f3ea292389cf72a25a7cf2792caceb4af45f6da)) + +## [1.0.0-alpha.1](https://github.com/a2aproject/a2a-python/compare/v1.0.0-alpha.0...v1.0.0-alpha.1) (2026-04-10) + + +### ⚠ BREAKING CHANGES + +* **client:** make ClientConfig.push_notification_config singular ([#955](https://github.com/a2aproject/a2a-python/issues/955)) +* **client:** reorganize ClientFactory API ([#947](https://github.com/a2aproject/a2a-python/issues/947)) +* **server:** add build_user function to DefaultContextBuilder to allow A2A user creation customization ([#925](https://github.com/a2aproject/a2a-python/issues/925)) +* **client:** remove `ClientTaskManager` and `Consumers` from client ([#916](https://github.com/a2aproject/a2a-python/issues/916)) +* **server:** migrate from Application wrappers to Starlette route-based endpoints for rest ([#892](https://github.com/a2aproject/a2a-python/issues/892)) +* **server:** migrate from Application wrappers to Starlette route-based endpoints for jsonrpc ([#873](https://github.com/a2aproject/a2a-python/issues/873)) + +### Features + +* A2A Version Header validation on server side. ([#865](https://github.com/a2aproject/a2a-python/issues/865)) ([b261ceb](https://github.com/a2aproject/a2a-python/commit/b261ceb98bf46cc1e479fcdace52fef8371c8e58)) +* Add GetExtendedAgentCard Support to RequestHandlers ([#919](https://github.com/a2aproject/a2a-python/issues/919)) ([2159140](https://github.com/a2aproject/a2a-python/commit/2159140b1c24fe556a41accf97a6af7f54ec6701)) +* Add support for more Task Message and Artifact fields in the Vertex Task Store ([#908](https://github.com/a2aproject/a2a-python/issues/908)) ([5e0dcd7](https://github.com/a2aproject/a2a-python/commit/5e0dcd798fcba16a8092b0b4c2d3d8026ca287de)) +* Add support for more Task Message and Artifact fields in the Vertex Task Store ([#936](https://github.com/a2aproject/a2a-python/issues/936)) ([605fa49](https://github.com/a2aproject/a2a-python/commit/605fa4913ad23539a51a3ee1f5b9ca07f24e1d2d)) +* Create EventQueue interface and make tap() async. ([#914](https://github.com/a2aproject/a2a-python/issues/914)) ([9ccf99c](https://github.com/a2aproject/a2a-python/commit/9ccf99c63d4e556eadea064de6afa0b4fc4e19d6)), closes [#869](https://github.com/a2aproject/a2a-python/issues/869) +* EventQueue - unify implementation between python versions ([#877](https://github.com/a2aproject/a2a-python/issues/877)) ([7437b88](https://github.com/a2aproject/a2a-python/commit/7437b88328fc71ed07e8e50f22a2eb0df4bf4201)), closes [#869](https://github.com/a2aproject/a2a-python/issues/869) +* EventQueue is now a simple interface with single enqueue_event method. ([#944](https://github.com/a2aproject/a2a-python/issues/944)) ([f0e1d74](https://github.com/a2aproject/a2a-python/commit/f0e1d74802e78a4e9f4c22cbc85db104137e0cd2)) +* Implementation of DefaultRequestHandlerV2 ([#933](https://github.com/a2aproject/a2a-python/issues/933)) ([462eb3c](https://github.com/a2aproject/a2a-python/commit/462eb3cb7b6070c258f5672aa3b0aa59e913037c)), closes [#869](https://github.com/a2aproject/a2a-python/issues/869) +* InMemoryTaskStore creates a copy of Task by default to make it consistent with database task stores ([#887](https://github.com/a2aproject/a2a-python/issues/887)) ([8c65e84](https://github.com/a2aproject/a2a-python/commit/8c65e84fb844251ce1d8f04d26dbf465a89b9a29)), closes [#869](https://github.com/a2aproject/a2a-python/issues/869) +* merge metadata of new and old artifact when append=True ([#945](https://github.com/a2aproject/a2a-python/issues/945)) ([cc094aa](https://github.com/a2aproject/a2a-python/commit/cc094aa51caba8107b63982e9b79256f7c2d331a)) +* **server:** add async context manager support to EventQueue ([#743](https://github.com/a2aproject/a2a-python/issues/743)) ([f68b22f](https://github.com/a2aproject/a2a-python/commit/f68b22f0323ed4ff9267fabcf09c9d873baecc39)) +* **server:** validate presence according to `google.api.field_behavior` annotations ([#870](https://github.com/a2aproject/a2a-python/issues/870)) ([4586c3e](https://github.com/a2aproject/a2a-python/commit/4586c3ec0b507d64caa3ced72d68a34ec5b37a11)) +* Simplify ActiveTask.subscribe() ([#958](https://github.com/a2aproject/a2a-python/issues/958)) ([62e5e59](https://github.com/a2aproject/a2a-python/commit/62e5e59a30b11b9b493f7bf969aa13173ce51b9c)) +* Support AgentExectuor enqueue of a Task object. ([#960](https://github.com/a2aproject/a2a-python/issues/960)) ([12ce017](https://github.com/a2aproject/a2a-python/commit/12ce0179056db9d9ba2abdd559cb5a4bb5a20ddf)) +* Support Message-only simplified execution without creating Task ([#956](https://github.com/a2aproject/a2a-python/issues/956)) ([354fdfb](https://github.com/a2aproject/a2a-python/commit/354fdfb68dd0c7894daaac885a06dfed0ab839c8)) +* Unhandled exception in AgentExecutor marks task as failed ([#943](https://github.com/a2aproject/a2a-python/issues/943)) ([4fc6b54](https://github.com/a2aproject/a2a-python/commit/4fc6b54fd26cc83d810d81f923579a1cd4853b39)) + + +### Bug Fixes + +* Add `packaging` to base dependencies ([#897](https://github.com/a2aproject/a2a-python/issues/897)) ([7a9aec7](https://github.com/a2aproject/a2a-python/commit/7a9aec7779448faa85a828d1076bcc47cda7bdbb)) +* **client:** do not mutate SendMessageRequest in BaseClient.send_message ([#949](https://github.com/a2aproject/a2a-python/issues/949)) ([94537c3](https://github.com/a2aproject/a2a-python/commit/94537c382be4160332279a44d83254feeb0b8037)) +* fix `athrow()` RuntimeError on streaming responses ([#912](https://github.com/a2aproject/a2a-python/issues/912)) ([ca7edc3](https://github.com/a2aproject/a2a-python/commit/ca7edc3b670538ce0f051c49f2224173f186d3f4)) +* fix docstrings related to `CallContextBuilder` args in constructors and make ServerCallContext mandatory in `compat` folder ([#907](https://github.com/a2aproject/a2a-python/issues/907)) ([9cade9b](https://github.com/a2aproject/a2a-python/commit/9cade9bdadfb94f2f857ec2dc302a2c402e7f0ea)) +* fix error handling for gRPC and SSE streaming ([#879](https://github.com/a2aproject/a2a-python/issues/879)) ([2b323d0](https://github.com/a2aproject/a2a-python/commit/2b323d0b191279fb5f091199aa30865299d5fcf2)) +* fix JSONRPC error handling ([#957](https://github.com/a2aproject/a2a-python/issues/957)) ([6c807d5](https://github.com/a2aproject/a2a-python/commit/6c807d51c49ac294a6e3cbec34be101d4f91870d)) +* fix REST error handling ([#893](https://github.com/a2aproject/a2a-python/issues/893)) ([405be3f](https://github.com/a2aproject/a2a-python/commit/405be3fa3ef8c60f730452b956879beeaecc5957)) +* handle SSE errors occurred after stream started ([#894](https://github.com/a2aproject/a2a-python/issues/894)) ([3a68d8f](https://github.com/a2aproject/a2a-python/commit/3a68d8f916d96ae135748ee2b9b907f8dace4fa7)) +* remove the use of deprecated types from VertexTaskStore ([#889](https://github.com/a2aproject/a2a-python/issues/889)) ([6d49122](https://github.com/a2aproject/a2a-python/commit/6d49122238a5e7d497c5d002792732446071dcb2)) +* Remove unconditional SQLAlchemy dependency from SDK core ([#898](https://github.com/a2aproject/a2a-python/issues/898)) ([ab762f0](https://github.com/a2aproject/a2a-python/commit/ab762f0448911a9ac05b6e3fec0104615e0ec557)), closes [#883](https://github.com/a2aproject/a2a-python/issues/883) +* remove unused import and request for FastAPI in pyproject ([#934](https://github.com/a2aproject/a2a-python/issues/934)) ([fe5de77](https://github.com/a2aproject/a2a-python/commit/fe5de77a1d457958fe14fec61b0d8aa41c5ec300)) +* replace stale entry in a2a.types.__all__ with actual import name ([#902](https://github.com/a2aproject/a2a-python/issues/902)) ([05cd5e9](https://github.com/a2aproject/a2a-python/commit/05cd5e9b73b55d2863c58c13be0c7dd21d8124bb)) +* wrong method name for ExtendedAgentCard endpoint in JsonRpc compat version ([#931](https://github.com/a2aproject/a2a-python/issues/931)) ([5d22186](https://github.com/a2aproject/a2a-python/commit/5d22186b8ee0f64b744512cdbe7ab6176fa97c60)) + + +### Documentation + +* add Database Migration Documentation ([#864](https://github.com/a2aproject/a2a-python/issues/864)) ([fd12dff](https://github.com/a2aproject/a2a-python/commit/fd12dffa3a7aa93816c762a155ed9b505086b924)) + + +### Miscellaneous Chores + +* release 1.0.0-alpha.1 ([a61f6d4](https://github.com/a2aproject/a2a-python/commit/a61f6d4e2e7ce1616a35c3a2ede64a4c9067048a)) + + +### Code Refactoring + +* **client:** make ClientConfig.push_notification_config singular ([#955](https://github.com/a2aproject/a2a-python/issues/955)) ([be4c5ff](https://github.com/a2aproject/a2a-python/commit/be4c5ff17a2f58e20d5d333a5e8e7bfcaa58c6c0)) +* **client:** remove `ClientTaskManager` and `Consumers` from client ([#916](https://github.com/a2aproject/a2a-python/issues/916)) ([97058bb](https://github.com/a2aproject/a2a-python/commit/97058bb444ea663d77c3b62abcf2fd0c30a1a526)), closes [#734](https://github.com/a2aproject/a2a-python/issues/734) +* **client:** reorganize ClientFactory API ([#947](https://github.com/a2aproject/a2a-python/issues/947)) ([01b3b2c](https://github.com/a2aproject/a2a-python/commit/01b3b2c0e196b0aab4f1f0dc22a95c09c7ee914d)) +* **server:** add build_user function to DefaultContextBuilder to allow A2A user creation customization ([#925](https://github.com/a2aproject/a2a-python/issues/925)) ([2648c5e](https://github.com/a2aproject/a2a-python/commit/2648c5e50281ceb9795b10a726bd23670b363ae1)) +* **server:** migrate from Application wrappers to Starlette route-based endpoints for jsonrpc ([#873](https://github.com/a2aproject/a2a-python/issues/873)) ([734d062](https://github.com/a2aproject/a2a-python/commit/734d0621dc6170d10d0cdf9c074e5ae28531fc71)) +* **server:** migrate from Application wrappers to Starlette route-based endpoints for rest ([#892](https://github.com/a2aproject/a2a-python/issues/892)) ([4be2064](https://github.com/a2aproject/a2a-python/commit/4be2064b5d511e0b4617507ed0c376662688ebeb)) + ## 1.0.0-alpha.0 (2026-03-17) @@ -55,6 +205,18 @@ * use correct REST path for Get Extended Agent Card operation ([#769](https://github.com/a2aproject/a2a-python/issues/769)) ([ced3f99](https://github.com/a2aproject/a2a-python/commit/ced3f998a9d0b97495ebded705422459aa8d7398)) * Use POST method for REST endpoint /tasks/{id}:subscribe ([#843](https://github.com/a2aproject/a2a-python/issues/843)) ([a0827d0](https://github.com/a2aproject/a2a-python/commit/a0827d0d2887749c922e5cafbc897e465ba8fe17)) +## [0.3.26](https://github.com/a2aproject/a2a-python/compare/v0.3.25...v0.3.26) (2026-04-09) + + +### Features + +* Add support for more Task Message and Artifact fields in the Vertex Task Store ([#908](https://github.com/a2aproject/a2a-python/issues/908)) ([5e0dcd7](https://github.com/a2aproject/a2a-python/commit/5e0dcd798fcba16a8092b0b4c2d3d8026ca287de)) + + +### Bug Fixes + +* remove the use of deprecated types from VertexTaskStore ([#889](https://github.com/a2aproject/a2a-python/issues/889)) ([6d49122](https://github.com/a2aproject/a2a-python/commit/6d49122238a5e7d497c5d002792732446071dcb2)) + ## [0.3.25](https://github.com/a2aproject/a2a-python/compare/v0.3.24...v0.3.25) (2026-03-10) diff --git a/GEMINI.md b/GEMINI.md index 59ef64713..e6bf43b65 100644 --- a/GEMINI.md +++ b/GEMINI.md @@ -23,3 +23,26 @@ 1. **Required Reading**: You MUST read the contents of @./docs/ai/coding_conventions.md and @./docs/ai/mandatory_checks.md at the very beginning of EVERY coding task. 2. **Initial Checklist**: Every `task.md` you create MUST include a section for **Mandatory Checks** from @./docs/ai/mandatory_checks.md. 3. **Verification Requirement**: You MUST run all mandatory checks before declaring any task finished. + +## 5. Mistake Reflection Protocol + +> [!NOTE] for Users: +> `docs/ai/ai_learnings.md` is a local-only file (excluded from git) meant to be +> read by the developer to improve AI assistant behavior on this project. Use its +> findings to improve the GEMINI.md setup. + +When you realise you have made a mistake — whether caught by the user, +by a tool, or by your own reasoning — you MUST: + +1. **Acknowledge the mistake explicitly** and explain what went wrong. +2. **Reflect on the root cause**: was it a missing check, a false assumption, skipped verification, or a gap in the workflow? +3. **Immediately append a new entry to `docs/ai/ai_learnings.md`** — this is not optional and does not require user confirmation. Do it before continuing, then update the user about the workflow change. + + **Entry format:** + - **Mistake**: What went wrong. + - **Root cause**: Why it happened. + - **Rule**: The concrete rule added to prevent recurrence. + +The goal is to treat every mistake as a signal that the workflow is +incomplete, and to improve it in place so the same mistake cannot +happen again. diff --git a/README.md b/README.md index 8ac1cfef4..37aed9798 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,10 @@ --- +> [!IMPORTANT] +> **Upgrading the SDK from `0.3` to `1.0`?** See the [**v0.3 → v1.0 migration guide**](docs/migrations/v1_0/README.md). For supported A2A spec versions, see [Compatibility](#-compatibility). + + ## ✨ Features - **A2A Protocol Compliant:** Build agentic applications that adhere to the Agent2Agent (A2A) Protocol. @@ -36,16 +40,16 @@ ## 🧩 Compatibility -This SDK implements the A2A Protocol Specification [`0.3`](https://a2a-protocol.org/v0.3.0/specification). - -> [!IMPORTANT] -> There is an [**alpha version**](https://github.com/a2aproject/a2a-python/releases?q=%22v1.0.0-alpha%22&expanded=true) available with support for both [`1.0`](https://a2a-protocol.org/v1.0.0/specification/) and [`0.3`](https://a2a-protocol.org/v0.3.0/specification) versions. Development for this version is taking place in the [`1.0-dev`](https://github.com/a2aproject/a2a-python/tree/1.0-dev) branch, tracked in [#701](https://github.com/a2aproject/a2a-python/issues/701). +This SDK implements the A2A Protocol Specification [`1.0`](https://a2a-protocol.org/v1.0.0/specification/), with compatibility mode for [`0.3`](https://a2a-protocol.org/v0.3.0/specification). See [#742](https://github.com/a2aproject/a2a-python/issues/742) for details on the compatibility scope. -| Transport | Client | Server | -| :--- | :---: | :---: | -| **JSON-RPC** | ✅ | ✅ | -| **HTTP+JSON/REST** | ✅ | ✅ | -| **GRPC** | ✅ | ✅ | +| Spec Version | Transport | Client | Server | +| :--- | :--- | :---: | :---: | +| **`1.0`** | JSON-RPC | ✅ | ✅ | +| **`1.0`** | HTTP+JSON/REST | ✅ | ✅ | +| **`1.0`** | gRPC | ✅ | ✅ | +| **`0.3`** (compat) | JSON-RPC | ✅ | ✅ | +| **`0.3`** (compat) | HTTP+JSON/REST | ✅ | ✅ | +| **`0.3`** (compat) | gRPC | ✅ | ✅ | --- @@ -68,7 +72,6 @@ Install the core SDK and any desired extras using your preferred package manager | **gRPC Support** | `uv add "a2a-sdk[grpc]"` | `pip install "a2a-sdk[grpc]"` | | **OpenTelemetry Tracing**| `uv add "a2a-sdk[telemetry]"` | `pip install "a2a-sdk[telemetry]"` | | **Encryption** | `uv add "a2a-sdk[encryption]"` | `pip install "a2a-sdk[encryption]"` | -| **Vertex AI Task Store** | `uv add "a2a-sdk[vertex]"` | `pip install "a2a-sdk[vertex]"` | | | | | | **Database Drivers** | | | | **PostgreSQL** | `uv add "a2a-sdk[postgresql]"` | `pip install "a2a-sdk[postgresql]"` | diff --git a/docs/migrations/v1_0/README.md b/docs/migrations/v1_0/README.md new file mode 100644 index 000000000..da3d6ba79 --- /dev/null +++ b/docs/migrations/v1_0/README.md @@ -0,0 +1,663 @@ +# A2A Python SDK Migration Guide: v0.3 → v1.0 + +The `a2a-sdk` has achieved a major milestone in stability and reliability with the update to full **A2A Protocol v1.0 compatibility**. This guide provides a detailed overview of the breaking changes in version `v1.0` and instructions for migrating your codebase. + +Beyond protocol support, `v1.0` enhances the developer experience by introducing unified helper utilities for easier object creation and adopting Starlette route factory functions for more flexible server configuration. + +This documentation details the technical upgrades and architectural modifications introduced in A2A Python SDK v1.0. For developers using the database persistence layer, please refer to the [Database Migration Guide](database/) for specific update instructions. + +> ### **Why Upgrade to v1.0?** +> * **Protocol v1.0 Compliance**: Full alignment with the latest A2A industry standard for cross-agent interoperability. +> * **Reduced Boilerplate**: Unified helper utilities that simplify common tasks like message and task creation. +> * **Architectural Flexibility**: Direct Starlette/FastAPI integration allows you to mount A2A routes into existing applications with full control over middleware. + +--- + +## Table of Contents + +1. [Update Dependencies](#1-update-dependencies) +2. [Types](#2-types) +3. [Server: DefaultRequestHandler](#3-server-defaultrequesthandler) +4. [Server: AgentExecutor Streaming Rules](#4-server-agentexecutor-streaming-rules) +5. [Server: Application Setup](#5-server-application-setup) +6. [Supporting v0.3 Clients](#6-supporting-v03-clients) +7. [Client: Creating a Client](#7-client-creating-a-client) +8. [Client: Send Message](#8-client-send-message) +9. [Client: Push Notifications Config](#9-client-push-notifications-config) +10. [Helper Utilities](#10-helper-utilities) +11. [Summary of Key Changes](#11-summary-of-key-changes-in-v10) +12. [Get Started](#12-get-started) + +--- + +## 1. Update Dependencies + +For UV users: To upgrade to the latest version of the `a2a-sdk`, update the dependencies section in your `pyproject.toml` file. + +| File | Before (`v0.3`) | After (`v1.0`) | +|------------------|-----------------------------------|-----------------------------------| +| `pyproject.toml` | dependencies = ["a2a-sdk>=0.3.0"] | dependencies = ["a2a-sdk>=1.0.0"] | + +**Installation** + +After updating your configuration file, sync your environment: + +* Using UV: + +```bash +uv sync +``` + +* Using pip: + +```bash +pip install --upgrade a2a-sdk +``` + +--- + +## 2. Types + +[Types](https://github.com/a2aproject/a2a-python/blob/main/src/a2a/types/a2a_pb2.pyi) have migrated from Pydantic models to Protobuf-based classes to align with the A2A spec's proto-first design and to adopt ProtoJSON as the canonical JSON serialization standard, ensuring consistent cross-implementation interoperability. + + +### Enum values: `snake_case` → `SCREAMING_SNAKE_CASE` + +All enum values are now [standardized](https://a2a-protocol.org/v1.0.0/specification/#55-json-field-naming-convention) to use `SCREAMING_SNAKE_CASE` format. + +This affects every enum in the SDK: `TaskState`, `Role`. + +| Enum | v0.3 | v1.0 | +|---|---|---| +| `TaskState` | `TaskState.submitted` | `TaskState.TASK_STATE_SUBMITTED` | +| `TaskState` | `TaskState.working` | `TaskState.TASK_STATE_WORKING` | +| `TaskState` | `TaskState.completed` | `TaskState.TASK_STATE_COMPLETED` | +| `TaskState` | `TaskState.failed` | `TaskState.TASK_STATE_FAILED` | +| `TaskState` | `TaskState.canceled` | `TaskState.TASK_STATE_CANCELED` | +| `TaskState` | `TaskState.input_required` | `TaskState.TASK_STATE_INPUT_REQUIRED` | +| `TaskState` | `TaskState.auth_required` | `TaskState.TASK_STATE_AUTH_REQUIRED` | +| `TaskState` | `TaskState.rejected` | `TaskState.TASK_STATE_REJECTED` | +| `TaskState` | | 🆕 `TaskState.TASK_STATE_UNSPECIFIED` | +||| +| `Role` | `Role.user` | `Role.ROLE_USER` | +| `Role` | `Role.agent` | `Role.ROLE_AGENT` | +| `Role` | | 🆕 `Role.ROLE_UNSPECIFIED` | + +> **Example**: [`a2a-mcp-without-framework/server/agent_executor.py` in PR #509](https://github.com/a2aproject/a2a-samples/pull/509/changes#diff-1f9b098f9f82ee40666ee61db56dc2246281423c445bcf017079c53a0a05954f) + +### Message and Part construction + +Constructing messages is simplified in v1.0. The old API required wrapping content in an intermediate type (`TextPart`, `FilePart`, `DataPart`) before placing it inside a `Part`. In v1.0, the wrapper types are removed and all content fields are set directly on the unified `Part` message. + +| Part type | v0.3 | v1.0 | +|---|---|---| +| Text | `Part(TextPart(text=..., ...))` | `Part(text=..., ...)` | +| File (bytes) | `Part(FilePart(file=FileWithBytes(bytes=..., ...)))` | `Part(raw=..., ...)` | +| File (URI) | `Part(FilePart(file=FileWithUri(uri=..., ...)))` | `Part(url=..., ...)` | +| Structured data | `Part(DataPart(data=..., ...))` | `Part(data=..., ...)` | + +**Note**: +* When using `File (bytes)` in v1.0, data serialization (via base64 encoding) is not required because A2A now uses Protobuf, which handles it automatically. +* In v1.0, `Part.DataPart.data` is renamed to `Part.data` and is of type `google.protobuf.Value`. Use `ParseDict` to convert a Python dict into a suitable value. See the examples below for more details. + +**Before (v0.3):** +```python +import base64 +from uuid import uuid4 +from a2a.types import Message, Part, Role, TextPart, FilePart, DataPart, FileWithBytes, FileWithUri + +# Text part +text_part = Part(TextPart(text="What's the weather in Warsaw?")) + +# File part — base64-encoded bytes (e.g. an image) +with open("photo.png", "rb") as f: + image_b64 = base64.b64encode(f.read()).decode() +file_bytes_part = Part(FilePart(file=FileWithBytes( + bytes=image_b64, + mime_type="image/png", + name="photo.png", +))) + +# File part — URI pointing to a remote file +file_uri_part = Part(FilePart(file=FileWithUri( + uri="https://example.com/report.pdf", + mime_type="application/pdf", + name="report.pdf", +))) + +# Data part — structured JSON payload +data_part = Part(DataPart(data={"city": "Warsaw", "temperature_c": 18})) + +message = Message( + role=Role.user, + parts=[text_part, file_bytes_part, file_uri_part, data_part], + message_id=uuid4().hex, + task_id=uuid4().hex, +) +``` + +**After (v1.0):** + +```python +from uuid import uuid4 +from google.protobuf.json_format import ParseDict +from google.protobuf.struct_pb2 import Value +from a2a.types import Message, Part, Role + +# Text part +text_part = Part(text="What's the weather in Warsaw?") + +# File part — raw bytes (e.g. an image); no base64 encoding required +with open("photo.png", "rb") as f: + image_bytes = f.read() +file_bytes_part = Part( + raw=image_bytes, + media_type="image/png", + filename="photo.png", +) + +# File part — URI pointing to a remote file +file_uri_part = Part( + url="https://example.com/report.pdf", + media_type="application/pdf", + filename="report.pdf", +) + +# Data part — use ParseDict to convert a Python dict to a protobuf Value +data_part = Part( + data=ParseDict({"city": "Warsaw", "temperature_c": 18}, Value()), +) + +message = Message( + role=Role.ROLE_USER, + parts=[text_part, file_bytes_part, file_uri_part, data_part], + message_id=uuid4().hex, + task_id=uuid4().hex, +) +``` + +For text-only messages, use the [A2A helper utilities](#10-helper-utilities) to reduce boilerplate: + +```python +from a2a.helpers import new_text_message +from a2a.types import Role + +message = new_text_message(text="What's the weather in Warsaw?", role=Role.ROLE_USER) +``` + +> **Example**: [`helloworld/test_client.py` in PR #474](https://github.com/a2aproject/a2a-samples/pull/474/files#diff-f62c07d3b00364a3100b7effb3e2a1cca0624277d3e40da1bdb07bb46b6a8cef) + +### AgentCard Structure + +Key changes: +- Added an `AgentInterface` class to support multiple transport bindings via the newly added `supported_interfaces` field in AgentCard. +- The `url` parameter in `AgentCard` is removed and is now part of `AgentInterface`. +- Accepted values for `AgentInterface.protocol_binding`: `'JSONRPC'`, `'HTTP+JSON'`, `'GRPC'`. +- The `AgentCard.supports_authenticated_extended_card` field is renamed to `AgentCapabilities.extended_agent_card`. +- The `AgentCapabilities.input_modes` and `AgentCapabilities.output_modes` fields are removed; use `AgentCard.default_input_modes` and `AgentCard.default_output_modes` for card-level defaults, or `AgentSkill.input_modes` and `AgentSkill.output_modes` for per-skill overrides. +- The `examples` parameter in `AgentCard` is removed and is now part of `AgentSkill`. + +**Before (v0.3):** +```python +from a2a.types import AgentCard, AgentCapabilities, AgentSkill + +skill = AgentSkill( + id='hello_world', + name='Hello World', + description='Returns a Hello World message.', + tags=['hello', 'world'], + input_modes=['text/plain'], + output_modes=['text/plain'], + examples=['hello world'], +) + +agent_card = AgentCard( + name='Hello World Agent', + description='Returns Hello, World!', + url='http://localhost:9999/', + version='0.0.1', + default_input_modes=['text/plain'], + default_output_modes=['text/plain'], + supports_authenticated_extended_card=True, + capabilities=AgentCapabilities( + input_modes=['text/plain'], + output_modes=['text/plain'], + streaming=True, + ), + skills=[skill], + examples=['Hello, World!'], +) +``` + +**After (v1.0):** +```python +from a2a.types import AgentCard, AgentCapabilities, AgentInterface, AgentSkill + +skill = AgentSkill( + id='hello_world', + name='Hello World', + description='Returns a Hello World message.', + tags=['hello', 'world'], + input_modes=['text/plain'], + output_modes=['text/plain'], + examples=['hello world', 'Hello, World!'], # moved from AgentCard.examples +) + +agent_card = AgentCard( + name='Hello World Agent', + description='Returns Hello, World!', + supported_interfaces=[ + # JSON-RPC + AgentInterface( + protocol_binding='JSONRPC', + url='http://localhost:41241/a2a/jsonrpc/', + ), + # GRPC + AgentInterface( + protocol_binding='GRPC', + url='http://localhost:50051/a2a/grpc/', + ) + ], + version='0.0.1', + default_input_modes=['text/plain'], + default_output_modes=['text/plain'], + capabilities=AgentCapabilities( + streaming=True, + extended_agent_card=True, + ), + skills=[skill], +) +``` + +> **Example**: [`a2a-mcp-without-framework/server/__main__.py` in PR #509](https://github.com/a2aproject/a2a-samples/pull/509/files#diff-d15d39ae64c3d4e3a36cc6fb442302caf4e32a6dbd858792e7a4bed180a625ac) + +--- + +## 3. Server: DefaultRequestHandler + +### Constructor signature: `agent_card` is now required + +`DefaultRequestHandler` now requires `agent_card` as a constructor argument (it was previously passed to the application wrapper). + +**Before (v0.3):** +```python +request_handler = DefaultRequestHandler( + agent_executor=MyAgentExecutor(), + task_store=InMemoryTaskStore(), +) +``` + +**After (v1.0):** +```python +request_handler = DefaultRequestHandler( + agent_executor=MyAgentExecutor(), + task_store=InMemoryTaskStore(), + agent_card=agent_card, +) +``` + +> **Example**: [`a2a-mcp-without-framework/server/__main__.py` in PR #509](https://github.com/a2aproject/a2a-samples/pull/509/files#diff-d15d39ae64c3d4e3a36cc6fb442302caf4e32a6dbd858792e7a4bed180a625ac) + +--- + +## 4. Server: AgentExecutor Streaming Rules + +The server now strictly enforces the [A2A spec rules for `SendStreamingMessage`](https://a2a-protocol.org/v1.0.0/specification/#312-send-streaming-message). Existing executors that mix message and task events, or emit task updates before the initial `Task`, will fail at runtime with `InvalidAgentResponseError`. See [PR #979](https://github.com/a2aproject/a2a-python/pull/979). + +In v1.0, your `AgentExecutor` MUST follow exactly one of these two streaming patterns: + +1. **Message-only stream** — enqueue exactly **one** `Message` and stop. +2. **Task lifecycle stream** — enqueue a `Task` **first**, then zero or more `TaskStatusUpdateEvent` / `TaskArtifactUpdateEvent` objects until a terminal state is reached. + +The following are now hard errors (each raises `InvalidAgentResponseError`): + +| Violation | Error message | +|---|---| +| Enqueue a `Message` after a `Task` (mixing modes) | *Received Message object in task mode...* | +| Enqueue more than one `Message` | *Multiple Message objects received.* | +| Enqueue a `Task`/update event after a `Message` | *Received `` in message mode...* | +| Enqueue a `TaskStatusUpdateEvent` before the initial `Task` | *Agent should enqueue Task before `` event* | + +### Migration + +**Before (v0.3 — silently tolerated):** +```python +from a2a.helpers import new_text_message +from a2a.server.agent_execution import AgentExecutor +from a2a.types import TaskStatusUpdateEvent + +class MyExecutor(AgentExecutor): + async def execute(self, context, event_queue): + # Mixing Message and Task events — no longer allowed. + await event_queue.enqueue_event(new_text_message('Working on it...')) + await event_queue.enqueue_event( + TaskStatusUpdateEvent(...) # ❌ raises InvalidAgentResponseError + ) +``` + +**After (v1.0 — pick one pattern):** + +```python +from a2a.helpers import ( + new_task_from_user_message, + new_text_artifact_update_event, + new_text_message, + new_text_status_update_event, +) +from a2a.server.agent_execution import AgentExecutor +from a2a.types import Role, TaskState + +# Pattern A: Message-only stream — one Message, then done. +class GreetingExecutor(AgentExecutor): + async def execute(self, context, event_queue): + await event_queue.enqueue_event( + new_text_message('Hello!', role=Role.ROLE_AGENT) + ) + +# Pattern B: Task lifecycle stream — Task first, then updates. +class WorkflowExecutor(AgentExecutor): + def __init__(self, agent): + self._agent = agent # Your underlying agent (LLM, tool, etc.) + + async def execute(self, context, event_queue): + task = context.current_task or new_task_from_user_message(context.message) + await event_queue.enqueue_event(task) # ✅ Task MUST be first + + await event_queue.enqueue_event( + new_text_status_update_event( + task_id=task.id, + context_id=task.context_id, + state=TaskState.TASK_STATE_WORKING, + text='Processing...', + ) + ) + + result = await self._agent.invoke(context.message) + await event_queue.enqueue_event( + new_text_artifact_update_event( + task_id=task.id, + context_id=task.context_id, + name='result', + text=result, + ) + ) + + await event_queue.enqueue_event( + new_text_status_update_event( + task_id=task.id, + context_id=task.context_id, + state=TaskState.TASK_STATE_COMPLETED, + text='Done!', + ) + ) +``` + +**Quick checklist when migrating an executor:** +- Decide upfront: is this a one-shot message reply, or a tracked task? +- If task-based, always enqueue the `Task` object as the very first event. +- Never mix `Message` events with `TaskStatusUpdateEvent` / `TaskArtifactUpdateEvent` in the same stream. +- Send only one `Message` per stream when using the message-only pattern. + +> **Example**: [`helloworld/agent_executor.py` in PR #474](https://github.com/a2aproject/a2a-samples/pull/474/files#diff-950e8baafcf17d50db5c10b525949407e129995df5295161fbf688e6374ad284) + +--- + +## 5. Server: Application Setup + +The application wrapper classes (`A2AStarletteApplication`, `A2AFastApiApplication`, and `A2ARESTFastApiApplication`) have been removed. The server setup now uses Starlette route factory functions directly, giving you better control over routing, middleware, authentication, logging, and other aspects of the server. + +**Before (v0.3):** +```python +from a2a.server.apps import A2AStarletteApplication +import uvicorn + +# Create application using A2AStarletteApplication wrapper class +server = A2AStarletteApplication( + agent_card=agent_card, + http_handler=request_handler, +) + +# Start the server +uvicorn.run(server.build(), host=host, port=port) +``` + +**After (v1.0):** + +Define routes for each supported transport as defined in the `AgentCard`. + +```python +from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes + +# Define routes for transports as defined in the AgentCard +routes = [] +# A2A Agent Card routes +routes.extend(create_agent_card_routes(agent_card)) +# JSON-RPC routes +routes.extend(create_jsonrpc_routes(request_handler, rpc_url='/api/v1/jsonrpc/')) + +# Optional: Add routes for REST/HTTP transports +# routes.extend(create_rest_routes(request_handler, path_prefix='/api/v1/rest/')) +``` + +Add the routes to the application: + +```python +from starlette.applications import Starlette +import uvicorn + +# Create application using routes +app = Starlette(routes=routes) + +# Start the server +uvicorn.run(app, host=host, port=port) +``` + +If you prefer FastAPI for your server application: + +```python +from fastapi import FastAPI +import uvicorn + +# Create application using routes +app = FastAPI(routes=routes) + +# Start the server +uvicorn.run(app, host=host, port=port) +``` + +> **Example**: [`a2a-mcp-without-framework/server/__main__.py` in PR #509](https://github.com/a2aproject/a2a-samples/pull/509/files#diff-d15d39ae64c3d4e3a36cc6fb442302caf4e32a6dbd858792e7a4bed180a625ac) + +--- + +## 6. Supporting v0.3 Clients + +If you cannot update all clients at once, you can run a v1.0 server that also accepts v0.3 connections. Two changes are needed. + +**1. Add the v0.3 AgentInterface to `supported_interfaces` in your `AgentCard`**: + +```python +supported_interfaces=[ + AgentInterface(protocol_binding='JSONRPC', protocol_version='0.3', url='http://localhost:9999/'), +] +``` + +**2. Enable the compat flag** on the relevant route factory: + +```python +create_jsonrpc_routes(request_handler, rpc_url='/', enable_v0_3_compat=True) +create_rest_routes(request_handler, enable_v0_3_compat=True) +``` + +> For a full working example see [`samples/hello_world_agent.py`](../../../samples/hello_world_agent.py). For known limitations see [issue #742](https://github.com/a2aproject/a2a-python/issues/742). + +--- + +## 7. Client: Creating a Client + +In `v1.0`, use the `a2a.client.create_client()` helper function to create a `Client` for the agent. + + +**Before (v0.3):** +```python +from a2a.client import ClientFactory + +# Option 1: Using Agent Server URL +factory = ClientFactory() +client = factory.create_client('http://localhost:9999/') + +# Option 2: Using AgentCard +factory = ClientFactory() +client = factory.create_client(agent_card) +``` + +**After (v1.0):** +```python +from a2a.client import create_client + +# Option 1: Using Agent Server URL +client = await create_client('http://localhost:9999/') + +# Option 2: Using AgentCard +client = await create_client(agent_card) +``` + + +> **Example**: [`a2a-mcp-without-framework/client/agent.py` in PR #509](https://github.com/a2aproject/a2a-samples/pull/509/files#diff-56cfce97ff9686166e4b14790ffb7ed46f4c14519261ce5c18365a53cf05e9aa) (`create_client()` usage) + +--- + +## 8. Client: Send Message + +The `BaseClient.send_message()` return type is standardized from `AsyncIterator[ClientEvent | Message]` to `AsyncIterator[StreamResponse]`. + +Each `StreamResponse` contains exactly one of: `task`, `message`, `status_update`, or `artifact_update`. Use `HasField()` to check which field is set. + + +**Before (v0.3):** +```python +async for event, message in client.send_message(request): + if isinstance(event, Task): + ... + if isinstance(event, UpdateEvent): + ... + if message: + ... +``` + +**After (v1.0):** +```python +async for chunk in client.send_message(request): + if chunk.HasField('artifact_update'): + ... + elif chunk.HasField('status_update'): + ... + elif chunk.HasField('task'): + ... + elif chunk.HasField('message'): + ... +``` + + +--- + +## 9. Client: Push Notifications Config + +`ClientConfig.push_notification_config` is now **singular** (a single `TaskPushNotificationConfig` or `None`), not a list. + + +**Before (v0.3):** +```python +config = ClientConfig( + push_notification_configs=[my_push_config], +) +``` + +**After (v1.0):** +```python +config = ClientConfig( + push_notification_config=my_push_config, +) +``` + +--- + +## 10. Helper Utilities + +To improve the developer experience, we have consolidated helper functions into a single import. In v0.3, these helper functions were scattered across different modules. In v1.0, they are all available under `a2a.helpers`. + +| Helper Function | Description | +|---|---| +| `display_agent_card` | Prints a human-readable summary of an `AgentCard` to stdout. | +| `get_artifact_text` | Joins all text parts of an `Artifact` into a single string (using `\n` as delimiter). | +| `get_message_text` | Joins all text parts of a `Message` into a single string (using `\n` as delimiter). | +| `get_stream_response_text` | Extracts text from a `StreamResponse` protobuf message. | +| `get_text_parts` | Returns a list of raw text strings from a sequence of `Part` objects, skipping non-text parts. | +| `new_artifact` | Creates an `Artifact` from a list of `Part` objects, a name, and an optional description and ID. | +| `new_message` | Creates a `Message` from a list of `Part` objects with a role (defaults to `ROLE_AGENT`), and optional task/context IDs. | +| `new_task` | Creates a `Task` with an explicit task ID, context ID, and state. | +| `new_task_from_user_message` | Creates a `TASK_STATE_SUBMITTED` `Task` from a user `Message`. Raises an error if the role is not `ROLE_USER` or if parts are empty. | +| `new_text_artifact` | Creates an `Artifact` with a single text `Part`, a name, and an optional description and ID. | +| `new_text_artifact_update_event` | Creates a `TaskArtifactUpdateEvent` with a text artifact. | +| `new_text_message` | Creates a `Message` with a single text `Part`; role defaults to `ROLE_AGENT`. | +| `new_text_status_update_event` | Creates a `TaskStatusUpdateEvent` with a text message. | + +**Example usage:** + +**1. Create a text-based message** + +```python +from a2a.helpers import new_text_message +from a2a.types import Role + +# Create a user message +user_message = new_text_message("What's the weather?", role=Role.ROLE_USER) + +# Create an agent response message +response_message = new_text_message("It is sunny today!") +``` + +**2. Extract text from a message** + +```python +from a2a.helpers import get_message_text + +# Get text from a message +text = get_message_text(response_message) +print(text) +``` + +--- + +## 11. Summary of Key Changes in v1.0 + +- **Migration to Protobuf** — Core types have migrated from Pydantic models to Protobuf-based classes. Protobuf objects do not support arbitrary attribute assignment. Use `MessageToDict` from `google.protobuf.json_format` to convert objects to dictionaries, and `HasField('field_name')` to check for optional fields. +- **Standardization to `SCREAMING_SNAKE_CASE`** — All enum values have been renamed from `snake_case` strings to `SCREAMING_SNAKE_CASE` for compliance with the ProtoJSON specification. +- **`AgentCard`** — Significantly restructured to support multiple transport interfaces. + - **`AgentInterface`** — The top-level `url` field is replaced by `supported_interfaces`, a list of `AgentInterface` objects. Each entry describes a single transport endpoint with fields for `protocol_binding`, `protocol_version`, and `url`. + - **Input and output modes** — `AgentCapabilities.input_modes` and `AgentCapabilities.output_modes` are removed and now live directly on `AgentCard` as `default_input_modes` and `default_output_modes`. Individual skills can override these with their own `input_modes` and `output_modes`. +- **AgentExecutor streaming rules** — The server now strictly enforces the A2A spec: an executor must enqueue either a single `Message` or a `Task` followed by update events (with the `Task` first). Mixing modes, emitting multiple `Message`s, or sending updates before the initial `Task` raises `InvalidAgentResponseError`. +- **Application setup** — The wrapper classes (`A2AStarletteApplication`, `A2AFastApiApplication`, and `A2ARESTFastApiApplication`) have been removed. Server setup now uses route factory functions — `create_jsonrpc_routes()`, `create_rest_routes()`, and `create_agent_card_routes()` — composed directly into a Starlette or FastAPI app. +- **Helper utilities** — A new `a2a.helpers` module consolidates all helper functions under a single import, replacing the scattered `a2a.utils.*` modules and adding new helpers for constructing and reading v1.0 proto types. + +--- + +## 12. Get Started + +The fastest way to see v1.0 in action is to run the samples: + +| File | Role | Description | +|---|---|---| +| [`samples/hello_world_agent.py`](../../../samples/hello_world_agent.py) | **Server** | A2A agent exposing JSON-RPC, REST, and gRPC — with v0.3 compat enabled | +| [`samples/cli.py`](../../../samples/cli.py) | **Client** | Interactive terminal client; supports all three transports | + +```bash +# In one terminal — start the agent: +uv run python samples/hello_world_agent.py + +# In another — connect with the CLI: +uv run python samples/cli.py +``` + +Then type a message like `hello` and press Enter. See [`samples/README.md`](../../../samples/README.md) for full details. + +For more examples see the [a2a-samples repository](https://github.com/a2aproject/a2a-samples/tree/main/samples/python). diff --git a/itk/README.md b/itk/README.md new file mode 100644 index 000000000..3044b37af --- /dev/null +++ b/itk/README.md @@ -0,0 +1,74 @@ +# Running ITK Tests Locally + +This directory contains scripts to run Integration Test Kit (ITK) tests locally using Podman. + +## Prerequisites + +### 1. Install Podman + +Run the following commands to install Podman and its components: + +```bash +sudo apt update && sudo apt install -y podman podman-docker podman-compose +``` + +### 2. Configure SubUIDs/SubGIDs + +For rootless Podman to function correctly, you need to ensure subuids and subgids are configured for your user. + +If they are not already configured, you can add them using (replace `$USER` with your username if needed): + +```bash +sudo usermod --add-subuids 100000-165535 --add-subgids 100000-165535 $USER +``` + +After adding subuids or if you encounter permission issues, run: + +```bash +podman system migrate +``` + +## Running Tests + +### 1. Set Environment Variable + +You must set the `A2A_SAMPLES_REVISION` environment variable to specify which revision of the `a2a-samples` repository to use for testing. This can be a branch name, tag, or commit hash. + +Example: +``` +export A2A_SAMPLES_REVISION=itk-v.02-alpha +``` + +### 2. Execute Tests + +Run the test script from this directory: + +```bash +./run_itk.sh +``` + +The script will: +- Clone `a2a-samples` (if not already present). +- Checkout the specified revision. +- Build the ITK service Docker image. +- Run the tests and output results. + +## Debugging + +To enable debug logging and persist logs for inspection: + +1. Set the `ITK_LOG_LEVEL` environment variable to `DEBUG`: + + ```bash + export ITK_LOG_LEVEL=DEBUG + ``` +2. Run the test script: + ```bash + ./run_itk.sh + ``` + +When run in `DEBUG` mode: +- The `logs/` directory will be created in this directory (if it doesn't exist). +- The `logs/` directory will be mounted to the container. +- The test execution will produce detailed logs in `logs/` (e.g., `agent_current.log`). +- The `logs/` directory will **not** be removed during cleanup. diff --git a/itk/main.py b/itk/main.py index 97d5cb29e..76c72e1c2 100644 --- a/itk/main.py +++ b/itk/main.py @@ -2,6 +2,7 @@ import asyncio import base64 import logging +import os import uuid import grpc @@ -12,16 +13,23 @@ from pyproto import instruction_pb2 -from a2a.client import ClientConfig, ClientFactory +from a2a.client import ClientConfig, create_client from a2a.compat.v0_3 import a2a_v0_3_pb2_grpc from a2a.compat.v0_3.grpc_handler import CompatGrpcHandler from a2a.server.agent_execution import AgentExecutor, RequestContext -from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes -from a2a.server.routes.rest_routes import create_rest_routes from a2a.server.events import EventQueue +from a2a.server.routes import ( + create_agent_card_routes, + create_jsonrpc_routes, + create_rest_routes, +) from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager from a2a.server.request_handlers import DefaultRequestHandler, GrpcHandler -from a2a.server.tasks import TaskUpdater +from a2a.server.tasks import ( + TaskUpdater, + BasePushNotificationSender, + InMemoryPushNotificationConfigStore, +) from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore from a2a.types import a2a_pb2_grpc from a2a.types.a2a_pb2 import ( @@ -31,12 +39,16 @@ Message, Part, SendMessageRequest, + Task, TaskState, + TaskStatus, + TaskPushNotificationConfig, ) from a2a.utils import TransportProtocol - -logging.basicConfig(level=logging.INFO) +log_level_str = os.environ.get('ITK_LOG_LEVEL', 'INFO').upper() +log_level = getattr(logging, log_level_str, logging.INFO) +logging.basicConfig(level=log_level) logger = logging.getLogger(__name__) @@ -102,7 +114,9 @@ def wrap_instruction_to_request(inst: instruction_pb2.Instruction) -> Message: ) -async def handle_call_agent(call: instruction_pb2.CallAgent) -> list[str]: +async def handle_call_agent( + call: instruction_pb2.CallAgent, +) -> list[str]: """Handles the CallAgent instruction by invoking another agent.""" logger.info('Calling agent %s via %s', call.agent_card_uri, call.transport) @@ -127,39 +141,47 @@ async def handle_call_agent(call: instruction_pb2.CallAgent) -> list[str]: selected_transport == TransportProtocol.GRPC ) + if call.HasField('push_notification'): + url = call.push_notification.url + if not url: + raise ValueError('URL not specified in push_notification behavior') + if not url.startswith(('http://', 'https://')): + url = f'http://{url}' + config.push_notification_config = TaskPushNotificationConfig( + url=f'{url}/notifications', + token='itk-token', # noqa: S106 + ) + try: - client = await ClientFactory.connect( + client = await create_client( call.agent_card_uri, client_config=config, ) # Wrap nested instruction - async with client: - nested_msg = wrap_instruction_to_request(call.instruction) - request = SendMessageRequest(message=nested_msg) - - results: list[str] = [] - async for event in client.send_message(request): - # Event is StreamResponse - logger.info('Event: %s', event) - stream_resp = event - - message = None - if stream_resp.HasField('message'): - message = stream_resp.message - elif stream_resp.HasField( - 'task' - ) and stream_resp.task.status.HasField('message'): - message = stream_resp.task.status.message - elif stream_resp.HasField( - 'status_update' - ) and stream_resp.status_update.status.HasField('message'): - message = stream_resp.status_update.status.message - - if message: - results.extend( - part.text for part in message.parts if part.text - ) + nested_msg = wrap_instruction_to_request(call.instruction) + request = SendMessageRequest(message=nested_msg) + + results = [] + async for event in client.send_message(request): + # Event is streaming response and task + logger.info('Event: %s', event) + stream_resp = event + + message = None + if stream_resp.HasField('message'): + message = stream_resp.message + elif stream_resp.HasField( + 'task' + ) and stream_resp.task.status.HasField('message'): + message = stream_resp.task.status.message + elif stream_resp.HasField( + 'status_update' + ) and stream_resp.status_update.status.HasField('message'): + message = stream_resp.status_update.status.message + + if message: + results.extend(part.text for part in message.parts if part.text) except Exception as e: logger.exception('Failed to call outbound agent') @@ -170,7 +192,9 @@ async def handle_call_agent(call: instruction_pb2.CallAgent) -> list[str]: return results -async def handle_instruction(inst: instruction_pb2.Instruction) -> list[str]: +async def handle_instruction( + inst: instruction_pb2.Instruction, +) -> list[str]: """Recursively handles instructions.""" if inst.HasField('call_agent'): return await handle_call_agent(inst.call_agent) @@ -199,7 +223,16 @@ async def execute( context.context_id, ) - await task_updater.update_status(TaskState.TASK_STATE_SUBMITTED) + # Explicitly create the task by sending it to the queue + task = Task( + id=context.task_id, + context_id=context.context_id, + status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED), + history=[context.message] if context.message else [], + ) + async with task_updater._lock: # noqa: SLF001 + await event_queue.enqueue_event(task) + await task_updater.update_status(TaskState.TASK_STATE_WORKING) instruction = extract_instruction(context.message) @@ -292,48 +325,65 @@ async def main_async(http_port: int, grpc_port: int) -> None: name='ITK v10 Agent', description='Python agent using SDK 1.0.', version='1.0.0', - capabilities=AgentCapabilities(streaming=True), + capabilities=AgentCapabilities( + streaming=True, push_notifications=True, extended_agent_card=True + ), default_input_modes=['text/plain'], default_output_modes=['text/plain'], supported_interfaces=interfaces, ) task_store = InMemoryTaskStore() + push_config_store = InMemoryPushNotificationConfigStore() + push_sender = BasePushNotificationSender( + httpx_client=httpx.AsyncClient(), + config_store=push_config_store, + ) + handler = DefaultRequestHandler( agent_executor=V10AgentExecutor(), + agent_card=agent_card, task_store=task_store, queue_manager=InMemoryQueueManager(), + push_config_store=push_config_store, + push_sender=push_sender, ) - app = FastAPI() + handler_extended = DefaultRequestHandler( + agent_executor=V10AgentExecutor(), + agent_card=agent_card, + task_store=task_store, + queue_manager=InMemoryQueueManager(), + push_config_store=push_config_store, + push_sender=push_sender, + extended_agent_card=agent_card, + ) agent_card_routes = create_agent_card_routes( agent_card=agent_card, card_url='/.well-known/agent-card.json' ) jsonrpc_routes = create_jsonrpc_routes( - agent_card=agent_card, - request_handler=handler, - extended_agent_card=agent_card, + request_handler=handler_extended, rpc_url='/', enable_v0_3_compat=True, ) - app.mount( - '/jsonrpc', - FastAPI(routes=jsonrpc_routes + agent_card_routes), - ) - rest_routes = create_rest_routes( - agent_card=agent_card, request_handler=handler, enable_v0_3_compat=True, ) + + app = FastAPI() + app.mount( + '/jsonrpc', + FastAPI(routes=jsonrpc_routes + agent_card_routes), + ) app.mount('/rest', FastAPI(routes=rest_routes + agent_card_routes)) server = grpc.aio.server() - compat_servicer = CompatGrpcHandler(agent_card, handler) + compat_servicer = CompatGrpcHandler(handler) a2a_v0_3_pb2_grpc.add_A2AServiceServicer_to_server(compat_servicer, server) - servicer = GrpcHandler(agent_card, handler) + servicer = GrpcHandler(handler) a2a_pb2_grpc.add_A2AServiceServicer_to_server(servicer, server) server.add_insecure_port(f'127.0.0.1:{grpc_port}') @@ -346,7 +396,7 @@ async def main_async(http_port: int, grpc_port: int) -> None: ) config = uvicorn.Config( - app, host='127.0.0.1', port=http_port, log_level='info' + app, host='127.0.0.1', port=http_port, log_level=log_level_str.lower() ) uvicorn_server = uvicorn.Server(config) diff --git a/itk/run_itk.sh b/itk/run_itk.sh index 908a5fbc5..21736f171 100755 --- a/itk/run_itk.sh +++ b/itk/run_itk.sh @@ -1,6 +1,9 @@ #!/bin/bash set -ex +# Set default log level +export ITK_LOG_LEVEL="${ITK_LOG_LEVEL:-INFO}" + # Initialize default exit code RESULT=1 @@ -63,15 +66,28 @@ ITK_DIR=$(pwd) # Stop existing container if any docker rm -f itk-service || true +# Create logs directory if debug +if [ "${ITK_LOG_LEVEL^^}" = "DEBUG" ]; then + mkdir -p "$ITK_DIR/logs" +fi + +DOCKER_MOUNT_LOGS="" +if [ "${ITK_LOG_LEVEL^^}" = "DEBUG" ]; then + DOCKER_MOUNT_LOGS="-v $ITK_DIR/logs:/app/logs" +fi + docker run -d --name itk-service \ -v "$A2A_PYTHON_ROOT:/app/agents/repo" \ -v "$ITK_DIR:/app/agents/repo/itk" \ + $DOCKER_MOUNT_LOGS \ + -e ITK_LOG_LEVEL="$ITK_LOG_LEVEL" \ -p 8000:8000 \ itk_service # 5.1. Fix dubious ownership for git (needed for uv-dynamic-versioning) -docker exec itk-service git config --global --add safe.directory /app/agents/repo -docker exec itk-service git config --global --add safe.directory /app/agents/repo/itk +docker exec -u root itk-service git config --system --add safe.directory /app/agents/repo +docker exec -u root itk-service git config --system --add safe.directory /app/agents/repo/itk +docker exec -u root itk-service git config --system core.multiPackIndex false # 6. Verify service is up and send post request MAX_RETRIES=30 @@ -103,14 +119,16 @@ RESPONSE=$(curl -s -X POST http://127.0.0.1:8000/run \ "sdks": ["current", "python_v10", "python_v03", "go_v10", "go_v03"], "traversal": "euler", "edges": ["0->1", "0->2", "0->3", "0->4", "1->0", "2->0", "3->0", "4->0"], - "protocols": ["jsonrpc", "grpc"] + "protocols": ["jsonrpc", "grpc"], + "behavior": "send_message" }, { "name": "Star Topology (No Go v03) - HTTP_JSON", "sdks": ["current", "python_v10", "python_v03", "go_v10"], "traversal": "euler", "edges": ["0->1", "0->2", "0->3", "1->0", "2->0", "3->0"], - "protocols": ["http_json"] + "protocols": ["http_json"], + "behavior": "send_message" }, { "name": "Star Topology (Full) - JSONRPC & GRPC (Streaming)", @@ -118,7 +136,8 @@ RESPONSE=$(curl -s -X POST http://127.0.0.1:8000/run \ "traversal": "euler", "edges": ["0->1", "0->2", "0->3", "0->4", "1->0", "2->0", "3->0", "4->0"], "protocols": ["jsonrpc", "grpc"], - "streaming": true + "streaming": true, + "behavior": "send_message" }, { "name": "Star Topology (No Go v03) - HTTP_JSON (Streaming)", @@ -126,7 +145,24 @@ RESPONSE=$(curl -s -X POST http://127.0.0.1:8000/run \ "traversal": "euler", "edges": ["0->1", "0->2", "0->3", "1->0", "2->0", "3->0"], "protocols": ["http_json"], - "streaming": true + "streaming": true, + "behavior": "send_message" + }, + { + "name": "Push Notification Test - JSONRPC & GRPC", + "sdks": ["current", "python_v10", "python_v03", "go_v03"], + "traversal": "euler", + "edges": ["0->1", "0->2", "0->3", "1->0", "2->0", "3->0"], + "protocols": ["jsonrpc", "grpc"], + "behavior": "push_notification" + }, + { + "name": "Push Notification Test - HTTP_JSON", + "sdks": ["current", "python_v10", "python_v03"], + "traversal": "euler", + "edges": ["0->1", "0->2", "1->0", "2->0"], + "protocols": ["http_json"], + "behavior": "push_notification" } ] }') diff --git a/pyproject.toml b/pyproject.toml index 724749865..a61a90e47 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ dependencies = [ "httpx>=0.28.1", "httpx-sse>=0.4.0", "pydantic>=2.11.3", - "protobuf>=5.29.5", + "protobuf>=5.29.5,<7", "google-api-core>=1.26.0", "json-rpc>=1.15.0", "googleapis-common-protos>=1.70.0", @@ -43,7 +43,6 @@ mysql = ["sqlalchemy[asyncio,aiomysql]>=2.0.0"] signing = ["PyJWT>=2.0.0"] sqlite = ["sqlalchemy[asyncio,aiosqlite]>=2.0.0"] db-cli = ["alembic>=1.14.0"] -vertex = ["google-cloud-aiplatform>=1.140.0"] sql = ["a2a-sdk[postgresql,mysql,sqlite]"] @@ -55,7 +54,6 @@ all = [ "a2a-sdk[telemetry]", "a2a-sdk[signing]", "a2a-sdk[db-cli]", - "a2a-sdk[vertex]", ] [project.urls] diff --git a/release-please-config.json b/release-please-config.json deleted file mode 100644 index 063b8435a..000000000 --- a/release-please-config.json +++ /dev/null @@ -1,9 +0,0 @@ -{ - "release-type": "python", - "prerelease": true, - "prerelease-type": "alpha", - "versioning": "prerelease", - "packages": { - ".": {} - } -} diff --git a/samples/README.md b/samples/README.md new file mode 100644 index 000000000..e61264955 --- /dev/null +++ b/samples/README.md @@ -0,0 +1,58 @@ +# A2A Python SDK — Samples + +This directory contains runnable examples demonstrating how to build and interact with an A2A-compliant agent using the Python SDK. + +## Contents + +| File | Role | Description | +|---|---|---| +| `hello_world_agent.py` | **Server** | A2A agent server | +| `cli.py` | **Client** | Interactive terminal client | + +The samples are designed to work together out of the box: the agent listens on `http://127.0.0.1:41241`, which is the default URL used by the client. +--- + +## `hello_world_agent.py` — Agent Server + +Implements an A2A agent that responds to simple greeting messages (e.g., "hello", "how are you", "bye") with text replies, simulating a 1-second processing delay. + +Demonstrates: +- Subclassing `AgentExecutor` and implementing `execute()` / `cancel()` +- Publishing streaming status updates and artifacts via `TaskUpdater` +- Exposing all three transports in both protocol versions (v1.0 and v0.3 compat) simultaneously: + - **JSON-RPC** (v1.0 and v0.3) at `http://127.0.0.1:41241/a2a/jsonrpc` + - **HTTP+JSON (REST)** (v1.0 and v0.3) at `http://127.0.0.1:41241/a2a/rest` + - **gRPC v1.0** on port `50051` + - **gRPC v0.3 (compat)** on port `50052` +- Serving the agent card at `http://127.0.0.1:41241/.well-known/agent-card.json` + +**Run:** + +```bash +uv run python samples/hello_world_agent.py +``` + +--- + +## `cli.py` — Client + +An interactive terminal client with full visibility into the streaming event flow. Each `TaskStatusUpdate` and `TaskArtifactUpdate` event is printed as it arrives. + +Features: +- Transport selection via `--transport` flag (`JSONRPC`, `HTTP+JSON`, `GRPC`) +- Session management (`context_id` persisted across messages, `task_id` per task) +- Graceful error handling for HTTP and gRPC failures + +**Run:** + +```bash +# Connect to the local hello_world_agent (default): +uv run python samples/cli.py + +# Connect to a different URL, using gRPC: +uv run python samples/cli.py --url http://192.168.1.10:41241 --transport GRPC +``` + +Then type a message like `hello` and press Enter. + +Type `/quit` or `/exit` to stop, or press `Ctrl+C`. diff --git a/samples/cli.py b/samples/cli.py index 6a4597fa9..beff26aa9 100644 --- a/samples/cli.py +++ b/samples/cli.py @@ -9,46 +9,54 @@ import grpc import httpx -from a2a.client import A2ACardResolver, ClientConfig, ClientFactory +from a2a.client import A2ACardResolver, ClientConfig, create_client +from a2a.helpers import get_artifact_text, get_message_text +from a2a.helpers.agent_card import display_agent_card from a2a.types import Message, Part, Role, SendMessageRequest, TaskState async def _handle_stream( stream: Any, current_task_id: str | None ) -> str | None: - async for event, task in stream: - if not task: - continue + async for event in stream: + if event.HasField('message'): + print('Message:', get_message_text(event.message, delimiter=' ')) + return None + if not current_task_id: - current_task_id = task.id - - if event: - if event.HasField('status_update'): - state_name = TaskState.Name(event.status_update.status.state) - print(f'TaskStatusUpdate [state={state_name}]:', end=' ') - if event.status_update.status.HasField('message'): - for part in event.status_update.status.message.parts: - if part.text: - print(part.text, end=' ') - print() - - if ( - event.status_update.status.state - == TaskState.TASK_STATE_COMPLETED - ): - current_task_id = None - print('--- Task Completed ---') - - elif event.HasField('artifact_update'): - print( - f'TaskArtifactUpdate [name={event.artifact_update.artifact.name}]:', - end=' ', + if event.HasField('task'): + current_task_id = event.task.id + print('--- Task Started ---') + print(f'Task [state={TaskState.Name(event.task.status.state)}]') + else: + raise ValueError(f'Unexpected first event: {event}') + + if event.HasField('status_update'): + state_name = TaskState.Name(event.status_update.status.state) + message_text = ( + ': ' + + get_message_text( + event.status_update.status.message, delimiter=' ' ) - for part in event.artifact_update.artifact.parts: - if part.text: - print(part.text, end=' ') - print() - + if event.status_update.status.HasField('message') + else '' + ) + print(f'TaskStatusUpdate [state={state_name}]{message_text}') + if state_name in ( + 'TASK_STATE_COMPLETED', + 'TASK_STATE_FAILED', + 'TASK_STATE_CANCELED', + 'TASK_STATE_REJECTED', + ): + current_task_id = None + print('--- Task Finished ---') + elif event.HasField('artifact_update'): + print( + f'TaskArtifactUpdate [name={event.artifact_update.artifact.name}]:', + get_artifact_text( + event.artifact_update.artifact, delimiter=' ' + ), + ) return current_task_id @@ -65,7 +73,9 @@ async def main() -> None: ) args = parser.parse_args() - config = ClientConfig() + config = ClientConfig( + grpc_channel_factory=grpc.aio.insecure_channel, + ) if args.transport: config.supported_protocol_bindings = [args.transport] @@ -77,9 +87,9 @@ async def main() -> None: resolver = A2ACardResolver(httpx_client, args.url) card = await resolver.get_agent_card() print('\n✓ Agent Card Found:') - print(f' Name: {card.name}') + display_agent_card(card) - client = await ClientFactory.connect(card, client_config=config) + client = await create_client(card, client_config=config) actual_transport = getattr(client, '_transport', client) print(f' Picked Transport: {actual_transport.__class__.__name__}') diff --git a/samples/hello_world_agent.py b/samples/hello_world_agent.py index e286fa130..a6e589ac0 100644 --- a/samples/hello_world_agent.py +++ b/samples/hello_world_agent.py @@ -1,3 +1,4 @@ +import argparse import asyncio import contextlib import logging @@ -12,10 +13,7 @@ from a2a.server.agent_execution.agent_executor import AgentExecutor from a2a.server.agent_execution.context import RequestContext from a2a.server.events.event_queue import EventQueue -from a2a.server.request_handlers import GrpcHandler -from a2a.server.request_handlers.default_request_handler import ( - DefaultRequestHandler, -) +from a2a.server.request_handlers import DefaultRequestHandler, GrpcHandler from a2a.server.routes import ( create_agent_card_routes, create_jsonrpc_routes, @@ -30,6 +28,9 @@ AgentProvider, AgentSkill, Part, + Task, + TaskState, + TaskStatus, a2a_pb2_grpc, ) @@ -78,6 +79,15 @@ async def execute( context_id, ) + await event_queue.enqueue_event( + Task( + id=task_id, + context_id=context_id, + status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED), + history=[user_message], + ) + ) + updater = TaskUpdater( event_queue=event_queue, task_id=task_id, @@ -191,17 +201,17 @@ async def serve( task_store = InMemoryTaskStore() request_handler = DefaultRequestHandler( - agent_executor=SampleAgentExecutor(), task_store=task_store + agent_executor=SampleAgentExecutor(), + task_store=task_store, + agent_card=agent_card, ) rest_routes = create_rest_routes( - agent_card=agent_card, request_handler=request_handler, path_prefix='/a2a/rest', enable_v0_3_compat=True, ) jsonrpc_routes = create_jsonrpc_routes( - agent_card=agent_card, request_handler=request_handler, rpc_url='/a2a/jsonrpc', enable_v0_3_compat=True, @@ -216,12 +226,12 @@ async def serve( grpc_server = grpc.aio.server() grpc_server.add_insecure_port(f'{host}:{grpc_port}') - servicer = GrpcHandler(agent_card, request_handler) + servicer = GrpcHandler(request_handler) a2a_pb2_grpc.add_A2AServiceServicer_to_server(servicer, grpc_server) compat_grpc_server = grpc.aio.server() compat_grpc_server.add_insecure_port(f'{host}:{compat_grpc_port}') - compat_servicer = CompatGrpcHandler(agent_card, request_handler) + compat_servicer = CompatGrpcHandler(request_handler) a2a_v0_3_pb2_grpc.add_A2AServiceServicer_to_server( compat_servicer, compat_grpc_server ) @@ -248,5 +258,18 @@ async def serve( if __name__ == '__main__': logging.basicConfig(level=logging.INFO) + parser = argparse.ArgumentParser(description='Sample A2A agent server') + parser.add_argument('--host', default='127.0.0.1') + parser.add_argument('--port', type=int, default=41241) + parser.add_argument('--grpc-port', type=int, default=50051) + parser.add_argument('--compat-grpc-port', type=int, default=50052) + args = parser.parse_args() with contextlib.suppress(KeyboardInterrupt): - asyncio.run(serve()) + asyncio.run( + serve( + host=args.host, + port=args.port, + grpc_port=args.grpc_port, + compat_grpc_port=args.compat_grpc_port, + ) + ) diff --git a/scripts/test_install_smoke.py b/scripts/test_install_smoke.py new file mode 100755 index 000000000..41ad029bb --- /dev/null +++ b/scripts/test_install_smoke.py @@ -0,0 +1,152 @@ +#!/usr/bin/env python3 +"""Smoke test for installations of a2a-sdk with various extras. + +This script verifies that the public API modules associated with a +given installation profile can be imported without pulling in modules +that belong to other (uninstalled) optional extras. + +It is designed to run WITHOUT pytest or any dev dependencies -- just +a clean venv with `pip install a2a-sdk[]`. + +Usage: + python scripts/test_install_smoke.py [profile] + + profile defaults to "base" and selects which set of modules to + smoke-test. Available profiles: + base -- `pip install a2a-sdk` + http-server -- `pip install a2a-sdk[http-server]` + grpc -- `pip install a2a-sdk[grpc]` + telemetry -- `pip install a2a-sdk[telemetry]` + sql -- `pip install a2a-sdk[sql]` + +Exit codes: + 0 - All imports for the profile succeeded + 1 - One or more imports failed +""" + +from __future__ import annotations + +import importlib +import sys + + +# Core modules that MUST be importable with only base dependencies. +# These are the public API surface that every user gets with +# `pip install a2a-sdk` (no extras). +# +# Do NOT add modules here that require optional extras (grpc, +# http-server, sql, signing, telemetry, vertex, etc.). +# Those modules are expected to fail without their extras installed +# and should use try/except ImportError guards internally. +CORE_MODULES = [ + 'a2a', + 'a2a.client', + 'a2a.client.auth', + 'a2a.client.base_client', + 'a2a.client.card_resolver', + 'a2a.client.client', + 'a2a.client.client_factory', + 'a2a.client.errors', + 'a2a.client.interceptors', + 'a2a.client.optionals', + 'a2a.client.transports', + 'a2a.server', + 'a2a.server.agent_execution', + 'a2a.server.context', + 'a2a.server.events', + 'a2a.server.request_handlers', + 'a2a.server.tasks', + 'a2a.types', + 'a2a.utils', + 'a2a.utils.constants', + 'a2a.utils.error_handlers', + 'a2a.utils.version_validator', + 'a2a.utils.proto_utils', + 'a2a.utils.task', + 'a2a.helpers.agent_card', + 'a2a.helpers.proto_helpers', +] + +# Modules that MUST be importable with only the base + `http-server` +# extras installed (no `grpc`, `sql`, `signing`, `telemetry`, etc.). +# +# A user building a Starlette/FastAPI A2A server with +# `pip install a2a-sdk[http-server]` should be able to import these +# without the gRPC stack being present on the system. +HTTP_SERVER_MODULES = [ + 'a2a.server.routes', + 'a2a.server.routes.agent_card_routes', + 'a2a.server.routes.common', + 'a2a.server.routes.jsonrpc_dispatcher', + 'a2a.server.routes.jsonrpc_routes', + 'a2a.server.routes.rest_dispatcher', + 'a2a.server.routes.rest_routes', +] + +# Modules that MUST be importable with only the base + `grpc` extras +# installed (no `http-server`, `sql`, `signing`, `telemetry`, etc.). +GRPC_MODULES = [ + 'a2a.server.request_handlers.grpc_handler', + 'a2a.client.transports.grpc', + 'a2a.compat.v0_3.grpc_handler', + 'a2a.compat.v0_3.grpc_transport', +] + +# Modules that MUST be importable with only the base + `telemetry` +# extras installed. +TELEMETRY_MODULES = [ + 'a2a.utils.telemetry', +] + +# Modules that MUST be importable with only the base + `sql` extras +# installed (covers postgresql/mysql/sqlite drivers via SQLAlchemy). +SQL_MODULES = [ + 'a2a.server.models', + 'a2a.server.tasks.database_task_store', + 'a2a.server.tasks.database_push_notification_config_store', +] + + +PROFILES: dict[str, list[str]] = { + 'base': CORE_MODULES, + 'http-server': CORE_MODULES + HTTP_SERVER_MODULES, + 'grpc': CORE_MODULES + GRPC_MODULES, + 'telemetry': CORE_MODULES + TELEMETRY_MODULES, + 'sql': CORE_MODULES + SQL_MODULES, +} + + +def main() -> int: + profile = sys.argv[1] if len(sys.argv) > 1 else 'base' + if profile not in PROFILES: + print(f'Unknown profile {profile!r}. Available: {sorted(PROFILES)}') + return 1 + + modules = PROFILES[profile] + failures: list[str] = [] + successes: list[str] = [] + + for module_name in modules: + try: + importlib.import_module(module_name) + successes.append(module_name) + except Exception as e: # noqa: BLE001, PERF203 + failures.append(f'{module_name}: {e}') + + print(f'Profile: {profile}') + print(f'Tested {len(modules)} modules') + print(f' Passed: {len(successes)}') + print(f' Failed: {len(failures)}') + + if failures: + print('\nFAILED imports:') + for failure in failures: + print(f' - {failure}') + return 1 + + print('\nAll modules imported successfully.') + return 0 + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/scripts/test_install_smoke.sh b/scripts/test_install_smoke.sh new file mode 100755 index 000000000..9f0a45fbd --- /dev/null +++ b/scripts/test_install_smoke.sh @@ -0,0 +1,104 @@ +#!/bin/bash +# Local equivalent of .github/workflows/install-smoke.yml. +# +# For each install profile, builds the wheel and installs it into a +# clean venv (no dev deps), then runs the import smoke test for that +# profile. By default runs every known profile; pass a profile name +# to run just one. +# +# Available profiles (must match those in scripts/test_install_smoke.py): +# base -- `pip install a2a-sdk` +# http-server -- `pip install a2a-sdk[http-server]` +# grpc -- `pip install a2a-sdk[grpc]` +# telemetry -- `pip install a2a-sdk[telemetry]` +# sql -- `pip install a2a-sdk[sql]` +# +# Usage: +# scripts/test_install_smoke.sh [profile] [python-version] +# +# Examples: +# scripts/test_install_smoke.sh # all profiles, default python +# scripts/test_install_smoke.sh '' 3.13 # all profiles on python 3.13 +# scripts/test_install_smoke.sh http-server # http-server only +# scripts/test_install_smoke.sh http-server 3.13 # http-server on python 3.13 +set -e +set -o pipefail + +REPO_ROOT="$(cd "$(dirname "$0")/.." && pwd)" +cd "$REPO_ROOT" + +ALL_PROFILES=(base http-server grpc telemetry sql) + +PROFILE_ARG="${1:-}" +PYTHON_VERSION="${2:-}" + +if [ -z "$PROFILE_ARG" ]; then + PROFILES=("${ALL_PROFILES[@]}") +else + PROFILES=("$PROFILE_ARG") +fi + +extras_for_profile() { + case "$1" in + base) echo "" ;; + http-server) echo "[http-server]" ;; + grpc) echo "[grpc]" ;; + telemetry) echo "[telemetry]" ;; + sql) echo "[sql]" ;; + *) + echo "Unknown profile '$1'. Available: ${ALL_PROFILES[*]}" >&2 + return 1 + ;; + esac +} + +# Validate profiles up-front so we fail fast. +for profile in "${PROFILES[@]}"; do + extras_for_profile "$profile" >/dev/null +done + +echo "--- Building wheel ---" +rm -rf dist +uv build --wheel +WHEEL=$(ls dist/*.whl) + +FAILED_PROFILES=() + +for profile in "${PROFILES[@]}"; do + extras=$(extras_for_profile "$profile") + venv_dir=".venv-smoke-${profile}" + + echo + echo "==================================================================" + echo " Profile: $profile (extras='$extras')" + echo "==================================================================" + + echo "--- Creating clean venv at $venv_dir ---" + rm -rf "$venv_dir" + if [ -n "$PYTHON_VERSION" ]; then + uv venv "$venv_dir" --python "$PYTHON_VERSION" + else + uv venv "$venv_dir" + fi + + echo "--- Installing built wheel with '$profile' dependencies only ---" + VIRTUAL_ENV="$venv_dir" uv pip install "${WHEEL}${extras}" + + echo "--- Installed packages ---" + VIRTUAL_ENV="$venv_dir" uv pip list + + echo "--- Running import smoke test ---" + if ! "$venv_dir/bin/python" scripts/test_install_smoke.py "$profile"; then + FAILED_PROFILES+=("$profile") + fi +done + +echo +echo "==================================================================" +if [ ${#FAILED_PROFILES[@]} -eq 0 ]; then + echo " All profiles passed: ${PROFILES[*]}" + exit 0 +fi + +echo " Failed profiles: ${FAILED_PROFILES[*]}" >&2 +exit 1 diff --git a/scripts/test_minimal_install.py b/scripts/test_minimal_install.py deleted file mode 100755 index 076df4c0f..000000000 --- a/scripts/test_minimal_install.py +++ /dev/null @@ -1,90 +0,0 @@ -#!/usr/bin/env python3 -"""Smoke test for minimal (base-only) installation of a2a-sdk. - -This script verifies that all core public API modules can be imported -when only the base dependencies are installed (no optional extras). - -It is designed to run WITHOUT pytest or any dev dependencies -- just -a clean venv with `pip install a2a-sdk`. - -Usage: - python scripts/test_minimal_install.py - -Exit codes: - 0 - All core imports succeeded - 1 - One or more core imports failed -""" - -from __future__ import annotations - -import importlib -import sys - - -# Core modules that MUST be importable with only base dependencies. -# These are the public API surface that every user gets with -# `pip install a2a-sdk` (no extras). -# -# Do NOT add modules here that require optional extras (grpc, -# http-server, sql, signing, telemetry, vertex, etc.). -# Those modules are expected to fail without their extras installed -# and should use try/except ImportError guards internally. -CORE_MODULES = [ - 'a2a', - 'a2a.client', - 'a2a.client.auth', - 'a2a.client.base_client', - 'a2a.client.card_resolver', - 'a2a.client.client', - 'a2a.client.client_factory', - 'a2a.client.errors', - 'a2a.client.helpers', - 'a2a.client.interceptors', - 'a2a.client.optionals', - 'a2a.client.transports', - 'a2a.server', - 'a2a.server.agent_execution', - 'a2a.server.context', - 'a2a.server.events', - 'a2a.server.request_handlers', - 'a2a.server.tasks', - 'a2a.types', - 'a2a.utils', - 'a2a.utils.artifact', - 'a2a.utils.constants', - 'a2a.utils.error_handlers', - 'a2a.utils.helpers', - 'a2a.utils.message', - 'a2a.utils.parts', - 'a2a.utils.proto_utils', - 'a2a.utils.task', -] - - -def main() -> int: - failures: list[str] = [] - successes: list[str] = [] - - for module_name in CORE_MODULES: - try: - importlib.import_module(module_name) - successes.append(module_name) - except Exception as e: # noqa: BLE001, PERF203 - failures.append(f'{module_name}: {e}') - - print(f'Tested {len(CORE_MODULES)} core modules') - print(f' Passed: {len(successes)}') - print(f' Failed: {len(failures)}') - - if failures: - print('\nFAILED imports:') - for failure in failures: - print(f' - {failure}') - return 1 - - print('\nAll core modules imported successfully.') - return 0 - - -if __name__ == '__main__': - sys.exit(main()) diff --git a/src/a2a/client/__init__.py b/src/a2a/client/__init__.py index 188ab4c80..d33c09481 100644 --- a/src/a2a/client/__init__.py +++ b/src/a2a/client/__init__.py @@ -12,13 +12,16 @@ ClientCallContext, ClientConfig, ) -from a2a.client.client_factory import ClientFactory, minimal_agent_card +from a2a.client.client_factory import ( + ClientFactory, + create_client, + minimal_agent_card, +) from a2a.client.errors import ( A2AClientError, A2AClientTimeoutError, AgentCardResolutionError, ) -from a2a.client.helpers import create_text_message_object from a2a.client.interceptors import ClientCallInterceptor @@ -36,6 +39,6 @@ 'ClientFactory', 'CredentialService', 'InMemoryContextCredentialStore', - 'create_text_message_object', + 'create_client', 'minimal_agent_card', ] diff --git a/src/a2a/client/base_client.py b/src/a2a/client/base_client.py index 53fd38cdb..763f23fb5 100644 --- a/src/a2a/client/base_client.py +++ b/src/a2a/client/base_client.py @@ -104,10 +104,10 @@ def _apply_client_config(self, request: SendMessageRequest) -> None: request.configuration.return_immediately |= self._config.polling if ( not request.configuration.HasField('task_push_notification_config') - and self._config.push_notification_configs + and self._config.push_notification_config ): request.configuration.task_push_notification_config.CopyFrom( - self._config.push_notification_configs[0] + self._config.push_notification_config ) if ( not request.configuration.accepted_output_modes diff --git a/src/a2a/client/card_resolver.py b/src/a2a/client/card_resolver.py index 6d98a5361..815916014 100644 --- a/src/a2a/client/card_resolver.py +++ b/src/a2a/client/card_resolver.py @@ -6,10 +6,9 @@ import httpx -from google.protobuf.json_format import ParseError +from google.protobuf.json_format import ParseDict, ParseError from a2a.client.errors import AgentCardResolutionError -from a2a.client.helpers import parse_agent_card from a2a.types.a2a_pb2 import ( AgentCard, ) @@ -19,6 +18,111 @@ logger = logging.getLogger(__name__) +def parse_agent_card(agent_card_data: dict[str, Any]) -> AgentCard: + """Parse AgentCard JSON dictionary and handle backward compatibility.""" + _handle_extended_card_compatibility(agent_card_data) + _handle_connection_fields_compatibility(agent_card_data) + _handle_security_compatibility(agent_card_data) + + return ParseDict(agent_card_data, AgentCard(), ignore_unknown_fields=True) + + +def _handle_extended_card_compatibility( + agent_card_data: dict[str, Any], +) -> None: + """Map legacy supportsAuthenticatedExtendedCard to capabilities.""" + if agent_card_data.pop('supportsAuthenticatedExtendedCard', None): + capabilities = agent_card_data.setdefault('capabilities', {}) + if 'extendedAgentCard' not in capabilities: + capabilities['extendedAgentCard'] = True + + +def _handle_connection_fields_compatibility( + agent_card_data: dict[str, Any], +) -> None: + """Map legacy connection and transport fields to supportedInterfaces.""" + main_url = agent_card_data.pop('url', None) + main_transport = agent_card_data.pop('preferredTransport', 'JSONRPC') + version = agent_card_data.pop('protocolVersion', '0.3.0') + additional_interfaces = ( + agent_card_data.pop('additionalInterfaces', None) or [] + ) + + if 'supportedInterfaces' not in agent_card_data and main_url: + supported_interfaces = [] + supported_interfaces.append( + { + 'url': main_url, + 'protocolBinding': main_transport, + 'protocolVersion': version, + } + ) + supported_interfaces.extend( + { + 'url': iface.get('url'), + 'protocolBinding': iface.get('transport'), + 'protocolVersion': version, + } + for iface in additional_interfaces + ) + agent_card_data['supportedInterfaces'] = supported_interfaces + + +def _map_legacy_security( + sec_list: list[dict[str, list[str]]], +) -> list[dict[str, Any]]: + """Convert a legacy security requirement list into the 1.0.0 Protobuf format.""" + return [ + { + 'schemes': { + scheme_name: {'list': scopes} + for scheme_name, scopes in sec_dict.items() + } + } + for sec_dict in sec_list + ] + + +def _handle_security_compatibility(agent_card_data: dict[str, Any]) -> None: + """Map legacy security requirements and schemas to their 1.0.0 Protobuf equivalents.""" + legacy_security = agent_card_data.pop('security', None) + if ( + 'securityRequirements' not in agent_card_data + and legacy_security is not None + ): + agent_card_data['securityRequirements'] = _map_legacy_security( + legacy_security + ) + + for skill in agent_card_data.get('skills', []): + legacy_skill_sec = skill.pop('security', None) + if 'securityRequirements' not in skill and legacy_skill_sec is not None: + skill['securityRequirements'] = _map_legacy_security( + legacy_skill_sec + ) + + security_schemes = agent_card_data.get('securitySchemes', {}) + if security_schemes: + type_mapping = { + 'apiKey': 'apiKeySecurityScheme', + 'http': 'httpAuthSecurityScheme', + 'oauth2': 'oauth2SecurityScheme', + 'openIdConnect': 'openIdConnectSecurityScheme', + 'mutualTLS': 'mtlsSecurityScheme', + } + for scheme in security_schemes.values(): + scheme_type = scheme.pop('type', None) + if scheme_type in type_mapping: + # Map legacy 'in' to modern 'location' + if scheme_type == 'apiKey' and 'in' in scheme: + scheme['location'] = scheme.pop('in') + + mapped_name = type_mapping[scheme_type] + new_scheme_wrapper = {mapped_name: scheme.copy()} + scheme.clear() + scheme.update(new_scheme_wrapper) + + class A2ACardResolver: """Agent Card resolver.""" diff --git a/src/a2a/client/client.py b/src/a2a/client/client.py index 1f94a4426..3fbf4f287 100644 --- a/src/a2a/client/client.py +++ b/src/a2a/client/client.py @@ -71,10 +71,8 @@ class ClientConfig: accepted_output_modes: list[str] = dataclasses.field(default_factory=list) """The set of accepted output modes for the client.""" - push_notification_configs: list[TaskPushNotificationConfig] = ( - dataclasses.field(default_factory=list) - ) - """Push notification configurations to use for every request.""" + push_notification_config: TaskPushNotificationConfig | None = None + """Push notification configuration to use for every request.""" class ClientCallContext(BaseModel): diff --git a/src/a2a/client/client_factory.py b/src/a2a/client/client_factory.py index c5d5e8aa4..a59189ade 100644 --- a/src/a2a/client/client_factory.py +++ b/src/a2a/client/client_factory.py @@ -3,7 +3,7 @@ import logging from collections.abc import Callable -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any import httpx @@ -56,32 +56,35 @@ class ClientFactory: - """ClientFactory is used to generate the appropriate client for the agent. + """Factory for creating clients that communicate with A2A agents. - The factory is configured with a `ClientConfig` and optionally a list of - `Consumer`s to use for all generated `Client`s. The expected use is: - - .. code-block:: python + The factory is configured with a `ClientConfig` and optionally custom + transport producers registered via `register`. Example usage: factory = ClientFactory(config) - # Optionally register custom client implementations - factory.register('my_customer_transport', NewCustomTransportClient) - # Then with an agent card make a client with additional interceptors + # Optionally register custom transport implementations + factory.register('my_custom_transport', custom_transport_producer) + # Create a client from an AgentCard client = factory.create(card, interceptors) + # Or resolve an AgentCard from a URL and create a client + client = await factory.create_from_url('https://example.com') - Now the client can be used consistently regardless of the transport. This + The client can be used consistently regardless of the transport. This aligns the client configuration with the server's capabilities. """ def __init__( self, - config: ClientConfig, + config: ClientConfig | None = None, ): - client = config.httpx_client or httpx.AsyncClient() - client.headers.setdefault(VERSION_HEADER, PROTOCOL_VERSION_CURRENT) - config.httpx_client = client + config = config or ClientConfig() + httpx_client = config.httpx_client or httpx.AsyncClient() + httpx_client.headers.setdefault( + VERSION_HEADER, PROTOCOL_VERSION_CURRENT + ) self._config = config + self._httpx_client = httpx_client self._registry: dict[str, TransportProducer] = {} self._register_defaults(config.supported_protocol_bindings) @@ -112,13 +115,13 @@ def jsonrpc_transport_producer( ) return CompatJsonRpcTransport( - cast('httpx.AsyncClient', config.httpx_client), + self._httpx_client, card, url, ) return JsonRpcTransport( - cast('httpx.AsyncClient', config.httpx_client), + self._httpx_client, card, url, ) @@ -151,13 +154,13 @@ def rest_transport_producer( ) return CompatRestTransport( - cast('httpx.AsyncClient', config.httpx_client), + self._httpx_client, card, url, ) return RestTransport( - cast('httpx.AsyncClient', config.httpx_client), + self._httpx_client, card, url, ) @@ -252,73 +255,45 @@ def _find_best_interface( return best_gt_1_0 or best_ge_0_3 or best_no_version - @classmethod - async def connect( # noqa: PLR0913 - cls, - agent: str | AgentCard, - client_config: ClientConfig | None = None, + async def create_from_url( + self, + url: str, interceptors: list[ClientCallInterceptor] | None = None, relative_card_path: str | None = None, resolver_http_kwargs: dict[str, Any] | None = None, - extra_transports: dict[str, TransportProducer] | None = None, signature_verifier: Callable[[AgentCard], None] | None = None, ) -> Client: - """Convenience method for constructing a client. - - Constructs a client that connects to the specified agent. Note that - creating multiple clients via this method is less efficient than - constructing an instance of ClientFactory and reusing that. - - .. code-block:: python + """Create a `Client` by resolving an `AgentCard` from a URL. - # This will search for an AgentCard at /.well-known/agent-card.json - my_agent_url = 'https://travel.agents.example.com' - client = await ClientFactory.connect(my_agent_url) + Resolves the agent card from the given URL using the factory's + configured httpx client, then creates a client via `create`. + If the agent card is already available, use `create` directly + instead. Args: - agent: The base URL of the agent, or the AgentCard to connect to. - client_config: The ClientConfig to use when connecting to the agent. - - interceptors: A list of interceptors to use for each request. These - are used for things like attaching credentials or http headers - to all outbound requests. - relative_card_path: If the agent field is a URL, this value is used as - the relative path when resolving the agent card. See - A2AAgentCardResolver.get_agent_card for more details. - resolver_http_kwargs: Dictionary of arguments to provide to the httpx - client when resolving the agent card. This value is provided to - A2AAgentCardResolver.get_agent_card as the http_kwargs parameter. - extra_transports: Additional transport protocols to enable when - constructing the client. - signature_verifier: A callable used to verify the agent card's signatures. + url: The base URL of the agent. The agent card will be fetched + from `/.well-known/agent-card.json` by default. + interceptors: A list of interceptors to use for each request. + These are used for things like attaching credentials or http + headers to all outbound requests. + relative_card_path: The relative path when resolving the agent + card. See `A2ACardResolver.get_agent_card` for details. + resolver_http_kwargs: Dictionary of arguments to provide to the + httpx client when resolving the agent card. + signature_verifier: A callable used to verify the agent card's + signatures. Returns: A `Client` object. """ - client_config = client_config or ClientConfig() - if isinstance(agent, str): - if not client_config.httpx_client: - async with httpx.AsyncClient() as client: - resolver = A2ACardResolver(client, agent) - card = await resolver.get_agent_card( - relative_card_path=relative_card_path, - http_kwargs=resolver_http_kwargs, - signature_verifier=signature_verifier, - ) - else: - resolver = A2ACardResolver(client_config.httpx_client, agent) - card = await resolver.get_agent_card( - relative_card_path=relative_card_path, - http_kwargs=resolver_http_kwargs, - signature_verifier=signature_verifier, - ) - else: - card = agent - factory = cls(client_config) - for label, generator in (extra_transports or {}).items(): - factory.register(label, generator) - return factory.create(card, interceptors) + resolver = A2ACardResolver(self._httpx_client, url) + card = await resolver.get_agent_card( + relative_card_path=relative_card_path, + http_kwargs=resolver_http_kwargs, + signature_verifier=signature_verifier, + ) + return self.create(card, interceptors) def register(self, label: str, generator: TransportProducer) -> None: """Register a new transport producer for a given transport label.""" @@ -389,6 +364,48 @@ def create( ) +async def create_client( # noqa: PLR0913 + agent: str | AgentCard, + client_config: ClientConfig | None = None, + interceptors: list[ClientCallInterceptor] | None = None, + relative_card_path: str | None = None, + resolver_http_kwargs: dict[str, Any] | None = None, + signature_verifier: Callable[[AgentCard], None] | None = None, +) -> Client: + """Create a `Client` for an agent from a URL or `AgentCard`. + + Convenience function that constructs a `ClientFactory` internally. + For reusing a factory across multiple agents or registering custom + transports, use `ClientFactory` directly instead. + + Args: + agent: The base URL of the agent, or an `AgentCard` to use + directly. + client_config: Optional `ClientConfig`. A default config is + created if not provided. + interceptors: A list of interceptors to use for each request. + relative_card_path: The relative path when resolving the agent + card. Only used when `agent` is a URL. + resolver_http_kwargs: Dictionary of arguments to provide to the + httpx client when resolving the agent card. + signature_verifier: A callable used to verify the agent card's + signatures. + + Returns: + A `Client` object. + """ + factory = ClientFactory(client_config) + if isinstance(agent, str): + return await factory.create_from_url( + agent, + interceptors=interceptors, + relative_card_path=relative_card_path, + resolver_http_kwargs=resolver_http_kwargs, + signature_verifier=signature_verifier, + ) + return factory.create(agent, interceptors) + + def minimal_agent_card( url: str, transports: list[str] | None = None ) -> AgentCard: diff --git a/src/a2a/client/helpers.py b/src/a2a/client/helpers.py deleted file mode 100644 index fc7bfdbdf..000000000 --- a/src/a2a/client/helpers.py +++ /dev/null @@ -1,130 +0,0 @@ -"""Helper functions for the A2A client.""" - -from typing import Any -from uuid import uuid4 - -from google.protobuf.json_format import ParseDict - -from a2a.types.a2a_pb2 import AgentCard, Message, Part, Role - - -def parse_agent_card(agent_card_data: dict[str, Any]) -> AgentCard: - """Parse AgentCard JSON dictionary and handle backward compatibility.""" - _handle_extended_card_compatibility(agent_card_data) - _handle_connection_fields_compatibility(agent_card_data) - _handle_security_compatibility(agent_card_data) - - return ParseDict(agent_card_data, AgentCard(), ignore_unknown_fields=True) - - -def _handle_extended_card_compatibility( - agent_card_data: dict[str, Any], -) -> None: - """Map legacy supportsAuthenticatedExtendedCard to capabilities.""" - if agent_card_data.pop('supportsAuthenticatedExtendedCard', None): - capabilities = agent_card_data.setdefault('capabilities', {}) - if 'extendedAgentCard' not in capabilities: - capabilities['extendedAgentCard'] = True - - -def _handle_connection_fields_compatibility( - agent_card_data: dict[str, Any], -) -> None: - """Map legacy connection and transport fields to supportedInterfaces.""" - main_url = agent_card_data.pop('url', None) - main_transport = agent_card_data.pop('preferredTransport', 'JSONRPC') - version = agent_card_data.pop('protocolVersion', '0.3.0') - additional_interfaces = ( - agent_card_data.pop('additionalInterfaces', None) or [] - ) - - if 'supportedInterfaces' not in agent_card_data and main_url: - supported_interfaces = [] - supported_interfaces.append( - { - 'url': main_url, - 'protocolBinding': main_transport, - 'protocolVersion': version, - } - ) - supported_interfaces.extend( - { - 'url': iface.get('url'), - 'protocolBinding': iface.get('transport'), - 'protocolVersion': version, - } - for iface in additional_interfaces - ) - agent_card_data['supportedInterfaces'] = supported_interfaces - - -def _map_legacy_security( - sec_list: list[dict[str, list[str]]], -) -> list[dict[str, Any]]: - """Convert a legacy security requirement list into the 1.0.0 Protobuf format.""" - return [ - { - 'schemes': { - scheme_name: {'list': scopes} - for scheme_name, scopes in sec_dict.items() - } - } - for sec_dict in sec_list - ] - - -def _handle_security_compatibility(agent_card_data: dict[str, Any]) -> None: - """Map legacy security requirements and schemas to their 1.0.0 Protobuf equivalents.""" - legacy_security = agent_card_data.pop('security', None) - if ( - 'securityRequirements' not in agent_card_data - and legacy_security is not None - ): - agent_card_data['securityRequirements'] = _map_legacy_security( - legacy_security - ) - - for skill in agent_card_data.get('skills', []): - legacy_skill_sec = skill.pop('security', None) - if 'securityRequirements' not in skill and legacy_skill_sec is not None: - skill['securityRequirements'] = _map_legacy_security( - legacy_skill_sec - ) - - security_schemes = agent_card_data.get('securitySchemes', {}) - if security_schemes: - type_mapping = { - 'apiKey': 'apiKeySecurityScheme', - 'http': 'httpAuthSecurityScheme', - 'oauth2': 'oauth2SecurityScheme', - 'openIdConnect': 'openIdConnectSecurityScheme', - 'mutualTLS': 'mtlsSecurityScheme', - } - for scheme in security_schemes.values(): - scheme_type = scheme.pop('type', None) - if scheme_type in type_mapping: - # Map legacy 'in' to modern 'location' - if scheme_type == 'apiKey' and 'in' in scheme: - scheme['location'] = scheme.pop('in') - - mapped_name = type_mapping[scheme_type] - new_scheme_wrapper = {mapped_name: scheme.copy()} - scheme.clear() - scheme.update(new_scheme_wrapper) - - -def create_text_message_object( - role: Role = Role.ROLE_USER, content: str = '' -) -> Message: - """Create a Message object containing a single text Part. - - Args: - role: The role of the message sender (user or agent). Defaults to Role.ROLE_USER. - content: The text content of the message. Defaults to an empty string. - - Returns: - A `Message` object with a new UUID message_id. - """ - return Message( - role=role, parts=[Part(text=content)], message_id=str(uuid4()) - ) diff --git a/src/a2a/client/service_parameters.py b/src/a2a/client/service_parameters.py index cef250807..39fe79ce1 100644 --- a/src/a2a/client/service_parameters.py +++ b/src/a2a/client/service_parameters.py @@ -1,7 +1,10 @@ from collections.abc import Callable from typing import TypeAlias -from a2a.extensions.common import HTTP_EXTENSION_HEADER +from a2a.extensions.common import ( + HTTP_EXTENSION_HEADER, + get_requested_extensions, +) ServiceParameters: TypeAlias = dict[str, str] @@ -44,17 +47,18 @@ def create_from( def with_a2a_extensions(extensions: list[str]) -> ServiceParametersUpdate: - """Create a ServiceParametersUpdate that adds A2A extensions. + """Create a ServiceParametersUpdate that merges A2A extension URIs. - Args: - extensions: List of extension strings. - - Returns: - A function that updates ServiceParameters with the extensions header. + Unions the supplied URIs with any already present in the A2A-Extensions + parameter, deduplicating and emitting them in sorted order. Repeated + calls accumulate rather than overwrite. """ def update(parameters: ServiceParameters) -> None: - if extensions: - parameters[HTTP_EXTENSION_HEADER] = ','.join(extensions) + if not extensions: + return + existing = parameters.get(HTTP_EXTENSION_HEADER, '') + merged = sorted(get_requested_extensions([existing, *extensions])) + parameters[HTTP_EXTENSION_HEADER] = ','.join(merged) return update diff --git a/src/a2a/client/transports/http_helpers.py b/src/a2a/client/transports/http_helpers.py index eca386bd4..0a73ed83c 100644 --- a/src/a2a/client/transports/http_helpers.py +++ b/src/a2a/client/transports/http_helpers.py @@ -12,6 +12,10 @@ from a2a.client.errors import A2AClientError, A2AClientTimeoutError +def _default_sse_error_handler(sse_data: str) -> NoReturn: + raise A2AClientError(f'SSE stream error event received: {sse_data}') + + @contextmanager def handle_http_exceptions( status_error_handler: Callable[[httpx.HTTPStatusError], NoReturn] @@ -71,9 +75,22 @@ async def send_http_stream_request( url: str, status_error_handler: Callable[[httpx.HTTPStatusError], NoReturn] | None = None, + sse_error_handler: Callable[[str], NoReturn] = _default_sse_error_handler, **kwargs: Any, ) -> AsyncGenerator[str]: - """Sends a streaming HTTP request, yielding SSE data strings and handling exceptions.""" + """Sends a streaming HTTP request, yielding SSE data strings and handling exceptions. + + Args: + httpx_client: The async HTTP client. + method: The HTTP method (e.g. 'POST', 'GET'). + url: The URL to send the request to. + status_error_handler: Handler for HTTP status errors. Should raise an + appropriate domain-specific exception. + sse_error_handler: Handler for SSE error events. Called with the + raw SSE data string when an ``event: error`` SSE event is received. + Should raise an appropriate domain-specific exception. + **kwargs: Additional keyword arguments forwarded to ``aconnect_sse``. + """ with handle_http_exceptions(status_error_handler): async with _SSEEventSource( httpx_client, method, url, **kwargs @@ -97,6 +114,8 @@ async def send_http_stream_request( async for sse in event_source.aiter_sse(): if not sse.data: continue + if sse.event == 'error': + sse_error_handler(sse.data) yield sse.data diff --git a/src/a2a/client/transports/jsonrpc.py b/src/a2a/client/transports/jsonrpc.py index eca6c4897..252ea439d 100644 --- a/src/a2a/client/transports/jsonrpc.py +++ b/src/a2a/client/transports/jsonrpc.py @@ -1,7 +1,7 @@ import logging from collections.abc import AsyncGenerator -from typing import Any +from typing import Any, NoReturn from uuid import uuid4 import httpx @@ -350,6 +350,7 @@ async def _send_stream_request( 'POST', self.url, None, + self._handle_sse_error, json=rpc_request_payload, **http_kwargs, ): @@ -360,3 +361,10 @@ async def _send_stream_request( json_rpc_response.result, StreamResponse() ) yield response + + def _handle_sse_error(self, sse_data: str) -> NoReturn: + """Handles SSE error events by parsing JSON-RPC error payload and raising the appropriate domain error.""" + json_rpc_response = JSONRPC20Response.from_json(sse_data) + if json_rpc_response.error: + raise self._create_jsonrpc_error(json_rpc_response.error) + raise A2AClientError(f'SSE stream error: {sse_data}') diff --git a/src/a2a/client/transports/rest.py b/src/a2a/client/transports/rest.py index ed40d31c7..3dfe95927 100644 --- a/src/a2a/client/transports/rest.py +++ b/src/a2a/client/transports/rest.py @@ -41,6 +41,47 @@ logger = logging.getLogger(__name__) +def _parse_rest_error( + error_payload: dict[str, Any], + fallback_message: str, +) -> Exception | None: + """Parses a REST error payload and returns the appropriate A2AError. + + Args: + error_payload: The parsed JSON error payload. + fallback_message: Message to use if the payload has no ``message``. + + Returns: + The mapped A2AError if a known reason was found, otherwise ``None``. + """ + error_data = error_payload.get('error', {}) + message = error_data.get('message', fallback_message) + details = error_data.get('details', []) + if not isinstance(details, list): + return None + + # The `details` array can contain multiple different error objects. + # We extract the first `ErrorInfo` object because it contains the + # specific `reason` code needed to map this back to a Python A2AError. + for d in details: + if ( + isinstance(d, dict) + and d.get('@type') == 'type.googleapis.com/google.rpc.ErrorInfo' + ): + reason = d.get('reason') + metadata = d.get('metadata') or {} + if isinstance(reason, str): + exception_cls = A2A_REASON_TO_ERROR.get(reason) + if exception_cls: + exc = exception_cls(message) + if metadata: + exc.data = metadata + return exc + break + + return None + + @trace_class(kind=SpanKind.CLIENT) class RestTransport(ClientTransport): """A REST transport for the A2A client.""" @@ -294,39 +335,12 @@ def _handle_http_error(self, e: httpx.HTTPStatusError) -> NoReturn: """Handles HTTP status errors and raises the appropriate A2AError.""" try: error_payload = e.response.json() - error_data = error_payload.get('error', {}) - - message = error_data.get('message', str(e)) - details = error_data.get('details', []) - if not isinstance(details, list): - details = [] - - # The `details` array can contain multiple different error objects. - # We extract the first `ErrorInfo` object because it contains the - # specific `reason` code needed to map this back to a Python A2AError. - error_info = {} - for d in details: - if ( - isinstance(d, dict) - and d.get('@type') - == 'type.googleapis.com/google.rpc.ErrorInfo' - ): - error_info = d - break - reason = error_info.get('reason') - metadata = error_info.get('metadata') or {} - - if isinstance(reason, str): - exception_cls = A2A_REASON_TO_ERROR.get(reason) - if exception_cls: - exc = exception_cls(message) - if metadata: - exc.data = metadata - raise exc from e + mapped = _parse_rest_error(error_payload, str(e)) + if mapped: + raise mapped from e except (json.JSONDecodeError, ValueError): pass - # Fallback mappings for status codes if 'type' is missing or unknown status_code = e.response.status_code if status_code == httpx.codes.NOT_FOUND: raise MethodNotFoundError( @@ -335,6 +349,14 @@ def _handle_http_error(self, e: httpx.HTTPStatusError) -> NoReturn: raise A2AClientError(f'HTTP Error {status_code}: {e}') from e + def _handle_sse_error(self, sse_data: str) -> NoReturn: + """Handles SSE error events by parsing the REST error payload and raising the appropriate A2AError.""" + error_payload = json.loads(sse_data) + mapped = _parse_rest_error(error_payload, sse_data) + if mapped: + raise mapped + raise A2AClientError(sse_data) + async def _send_stream_request( self, method: str, @@ -352,6 +374,7 @@ async def _send_stream_request( method, f'{self.url}{path}', self._handle_http_error, + self._handle_sse_error, json=json, **http_kwargs, ): diff --git a/src/a2a/compat/v0_3/context_builders.py b/src/a2a/compat/v0_3/context_builders.py new file mode 100644 index 000000000..1874853f5 --- /dev/null +++ b/src/a2a/compat/v0_3/context_builders.py @@ -0,0 +1,75 @@ +"""Context builders that add v0.3 backwards-compatibility for extensions. + +The current spec uses ``A2A-Extensions`` (RFC 6648, no ``X-`` prefix). v0.3 +clients still send the old ``X-A2A-Extensions`` name, so the v0.3 compat +adapters wrap the default builders with these classes to recognize both names. +""" + +from typing import TYPE_CHECKING + +from a2a.compat.v0_3.extension_headers import LEGACY_HTTP_EXTENSION_HEADER +from a2a.extensions.common import get_requested_extensions +from a2a.server.context import ServerCallContext + + +if TYPE_CHECKING: + import grpc + + from starlette.requests import Request + + from a2a.server.request_handlers.grpc_handler import ( + GrpcServerCallContextBuilder, + ) + from a2a.server.routes.common import ServerCallContextBuilder + + +def _get_legacy_grpc_extensions( + context: 'grpc.aio.ServicerContext', +) -> list[str]: + md = context.invocation_metadata() + if md is None: + return [] + lower_key = LEGACY_HTTP_EXTENSION_HEADER.lower() + return [ + e if isinstance(e, str) else e.decode('utf-8') + for k, e in md + if k.lower() == lower_key + ] + + +class V03ServerCallContextBuilder: + """Wraps a ServerCallContextBuilder to also accept the legacy header. + + Recognizes the v0.3 ``X-A2A-Extensions`` HTTP header in addition to the + spec ``A2A-Extensions``. + """ + + def __init__(self, inner: 'ServerCallContextBuilder') -> None: + self._inner = inner + + def build(self, request: 'Request') -> ServerCallContext: + """Builds a ServerCallContext, merging legacy extension headers.""" + context = self._inner.build(request) + context.requested_extensions |= get_requested_extensions( + request.headers.getlist(LEGACY_HTTP_EXTENSION_HEADER) + ) + return context + + +class V03GrpcServerCallContextBuilder: + """Wraps a GrpcServerCallContextBuilder to also accept the legacy metadata. + + Recognizes the v0.3 ``X-A2A-Extensions`` gRPC metadata key in addition to + the spec ``A2A-Extensions``. + """ + + def __init__(self, inner: 'GrpcServerCallContextBuilder') -> None: + self._inner = inner + + def build(self, context: 'grpc.aio.ServicerContext') -> ServerCallContext: + """Builds a ServerCallContext, merging legacy extension metadata.""" + server_context = self._inner.build(context) + server_context.requested_extensions |= get_requested_extensions( + _get_legacy_grpc_extensions(context) + ) + return server_context diff --git a/src/a2a/compat/v0_3/extension_headers.py b/src/a2a/compat/v0_3/extension_headers.py new file mode 100644 index 000000000..e1421a0b0 --- /dev/null +++ b/src/a2a/compat/v0_3/extension_headers.py @@ -0,0 +1,27 @@ +"""Shared header name constants for v0.3 extension compatibility. + +The current spec uses ``A2A-Extensions``. v0.3 used the ``X-`` prefixed +``X-A2A-Extensions`` form. v0.3 compat servers and clients accept/emit both +names so they can interoperate with peers that only know the legacy one. +""" + +from a2a.client.service_parameters import ServiceParameters +from a2a.extensions.common import HTTP_EXTENSION_HEADER + + +LEGACY_HTTP_EXTENSION_HEADER = f'X-{HTTP_EXTENSION_HEADER}' + + +def add_legacy_extension_header(parameters: ServiceParameters) -> None: + """Mirrors the ``A2A-Extensions`` parameter under its legacy name in-place. + + Used by v0.3 compat client transports so that requests can be understood + by older v0.3 servers that only recognize ``X-A2A-Extensions``. + """ + if ( + HTTP_EXTENSION_HEADER in parameters + and LEGACY_HTTP_EXTENSION_HEADER not in parameters + ): + parameters[LEGACY_HTTP_EXTENSION_HEADER] = parameters[ + HTTP_EXTENSION_HEADER + ] diff --git a/src/a2a/compat/v0_3/grpc_handler.py b/src/a2a/compat/v0_3/grpc_handler.py index c9db99557..b7bec26ea 100644 --- a/src/a2a/compat/v0_3/grpc_handler.py +++ b/src/a2a/compat/v0_3/grpc_handler.py @@ -12,14 +12,13 @@ from a2a.compat.v0_3 import ( a2a_v0_3_pb2, a2a_v0_3_pb2_grpc, - conversions, proto_utils, ) from a2a.compat.v0_3 import ( types as types_v03, ) +from a2a.compat.v0_3.context_builders import V03GrpcServerCallContextBuilder from a2a.compat.v0_3.request_handler import RequestHandler03 -from a2a.extensions.common import HTTP_EXTENSION_HEADER from a2a.server.context import ServerCallContext from a2a.server.request_handlers.grpc_handler import ( _ERROR_CODE_MAP, @@ -27,9 +26,7 @@ GrpcServerCallContextBuilder, ) from a2a.server.request_handlers.request_handler import RequestHandler -from a2a.types.a2a_pb2 import AgentCard from a2a.utils.errors import A2AError, InvalidParamsError -from a2a.utils.helpers import maybe_await, validate logger = logging.getLogger(__name__) @@ -42,29 +39,21 @@ class CompatGrpcHandler(a2a_v0_3_pb2_grpc.A2AServiceServicer): def __init__( self, - agent_card: AgentCard, request_handler: RequestHandler, context_builder: GrpcServerCallContextBuilder | None = None, - card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] - | None = None, ): """Initializes the CompatGrpcHandler. Args: - agent_card: The AgentCard describing the agent's capabilities (v1.0). request_handler: The underlying `RequestHandler` instance to delegate requests to. context_builder: The CallContextBuilder object. If none the DefaultCallContextBuilder is used. - card_modifier: An optional callback to dynamically modify the public - agent card before it is served. """ - self.agent_card = agent_card self.handler03 = RequestHandler03(request_handler=request_handler) - self._context_builder = ( + self._context_builder = V03GrpcServerCallContextBuilder( context_builder or DefaultGrpcServerCallContextBuilder() ) - self.card_modifier = card_modifier async def _handle_unary( self, @@ -76,7 +65,6 @@ async def _handle_unary( try: server_context = self._context_builder.build(context) result = await handler_func(server_context) - self._set_extension_metadata(context, server_context) except A2AError as e: await self.abort_context(e, context) else: @@ -93,7 +81,6 @@ async def _handle_stream( server_context = self._context_builder.build(context) async for item in handler_func(server_context): yield item - self._set_extension_metadata(context, server_context) except A2AError as e: await self.abort_context(e, context) @@ -131,19 +118,6 @@ async def abort_context( f'Unknown error type: {error}', ) - def _set_extension_metadata( - self, - context: grpc.aio.ServicerContext, - server_context: ServerCallContext, - ) -> None: - if server_context.activated_extensions: - context.set_trailing_metadata( - [ - (HTTP_EXTENSION_HEADER.lower(), e) - for e in sorted(server_context.activated_extensions) - ] - ) - async def SendMessage( self, request: a2a_v0_3_pb2.SendMessageRequest, @@ -179,10 +153,6 @@ async def SendStreamingMessage( ) -> AsyncIterable[a2a_v0_3_pb2.StreamResponse]: """Handles the 'SendStreamingMessage' gRPC method (v0.3).""" - @validate( - lambda _: self.agent_card.capabilities.streaming, - 'Streaming is not supported by the agent', - ) async def _handler( server_context: ServerCallContext, ) -> AsyncIterable[a2a_v0_3_pb2.StreamResponse]: @@ -242,10 +212,6 @@ async def TaskSubscription( ) -> AsyncIterable[a2a_v0_3_pb2.StreamResponse]: """Handles the 'TaskSubscription' gRPC method (v0.3).""" - @validate( - lambda _: self.agent_card.capabilities.streaming, - 'Streaming is not supported by the agent', - ) async def _handler( server_context: ServerCallContext, ) -> AsyncIterable[a2a_v0_3_pb2.StreamResponse]: @@ -269,10 +235,6 @@ async def CreateTaskPushNotificationConfig( ) -> a2a_v0_3_pb2.TaskPushNotificationConfig: """Handles the 'CreateTaskPushNotificationConfig' gRPC method (v0.3).""" - @validate( - lambda _: self.agent_card.capabilities.push_notifications, - 'Push notifications are not supported by the agent', - ) async def _handler( server_context: ServerCallContext, ) -> a2a_v0_3_pb2.TaskPushNotificationConfig: @@ -360,12 +322,19 @@ async def GetAgentCard( request: a2a_v0_3_pb2.GetAgentCardRequest, context: grpc.aio.ServicerContext, ) -> a2a_v0_3_pb2.AgentCard: - """Get the agent card for the agent served (v0.3).""" - card_to_serve = self.agent_card - if self.card_modifier: - card_to_serve = await maybe_await(self.card_modifier(card_to_serve)) - return proto_utils.ToProto.agent_card( - conversions.to_compat_agent_card(card_to_serve) + """Get the extended agent card for the agent served (v0.3).""" + + async def _handler( + server_context: ServerCallContext, + ) -> a2a_v0_3_pb2.AgentCard: + req_v03 = types_v03.GetAuthenticatedExtendedCardRequest(id=0) + res_v03 = await self.handler03.on_get_extended_agent_card( + req_v03, server_context + ) + return proto_utils.ToProto.agent_card(res_v03) + + return await self._handle_unary( + context, _handler, a2a_v0_3_pb2.AgentCard() ) async def DeleteTaskPushNotificationConfig( diff --git a/src/a2a/compat/v0_3/grpc_transport.py b/src/a2a/compat/v0_3/grpc_transport.py index 32ce7f27b..95314e3f1 100644 --- a/src/a2a/compat/v0_3/grpc_transport.py +++ b/src/a2a/compat/v0_3/grpc_transport.py @@ -30,6 +30,7 @@ from a2a.compat.v0_3 import ( types as types_v03, ) +from a2a.compat.v0_3.extension_headers import add_legacy_extension_header from a2a.types import a2a_pb2 from a2a.utils.constants import PROTOCOL_VERSION_0_3, VERSION_HEADER from a2a.utils.telemetry import SpanKind, trace_class @@ -361,7 +362,9 @@ def _get_grpc_metadata( metadata = [(VERSION_HEADER.lower(), PROTOCOL_VERSION_0_3)] if context and context.service_parameters: - for key, value in context.service_parameters.items(): + params = dict(context.service_parameters) + add_legacy_extension_header(params) + for key, value in params.items(): metadata.append((key.lower(), value)) return metadata diff --git a/src/a2a/compat/v0_3/jsonrpc_adapter.py b/src/a2a/compat/v0_3/jsonrpc_adapter.py index d01a7e11c..580034e9b 100644 --- a/src/a2a/compat/v0_3/jsonrpc_adapter.py +++ b/src/a2a/compat/v0_3/jsonrpc_adapter.py @@ -1,6 +1,6 @@ import logging -from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable +from collections.abc import AsyncIterable, AsyncIterator from typing import TYPE_CHECKING, Any from sse_starlette.sse import EventSourceResponse @@ -11,7 +11,6 @@ from starlette.requests import Request from a2a.server.request_handlers.request_handler import RequestHandler - from a2a.types.a2a_pb2 import AgentCard _package_starlette_installed = True else: @@ -24,8 +23,8 @@ _package_starlette_installed = False -from a2a.compat.v0_3 import conversions from a2a.compat.v0_3 import types as types_v03 +from a2a.compat.v0_3.context_builders import V03ServerCallContextBuilder from a2a.compat.v0_3.request_handler import RequestHandler03 from a2a.server.context import ServerCallContext from a2a.server.jsonrpc_models import ( @@ -42,8 +41,7 @@ ServerCallContextBuilder, ) from a2a.utils import constants -from a2a.utils.errors import ExtendedAgentCardNotConfiguredError -from a2a.utils.helpers import maybe_await, validate_version +from a2a.utils.version_validator import validate_version logger = logging.getLogger(__name__) @@ -65,23 +63,15 @@ class JSONRPC03Adapter: 'agent/getAuthenticatedExtendedCard': types_v03.GetAuthenticatedExtendedCardRequest, } - def __init__( # noqa: PLR0913 + def __init__( self, - agent_card: 'AgentCard', http_handler: 'RequestHandler', - extended_agent_card: 'AgentCard | None' = None, context_builder: 'ServerCallContextBuilder | None' = None, - card_modifier: 'Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] | None' = None, - extended_card_modifier: 'Callable[[AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard] | None' = None, ): - self.agent_card = agent_card - self.extended_agent_card = extended_agent_card - self.card_modifier = card_modifier - self.extended_card_modifier = extended_card_modifier self.handler = RequestHandler03( request_handler=http_handler, ) - self._context_builder = ( + self._context_builder = V03ServerCallContextBuilder( context_builder or DefaultServerCallContextBuilder() ) @@ -227,7 +217,7 @@ async def _process_non_streaming_request( ) ) elif method == 'agent/getAuthenticatedExtendedCard': - res_card = await self.get_authenticated_extended_card( + res_card = await self.handler.on_get_extended_agent_card( request_obj, context ) result = types_v03.GetAuthenticatedExtendedCardResponse( @@ -244,31 +234,6 @@ async def _process_non_streaming_request( ) ) - async def get_authenticated_extended_card( - self, - request: types_v03.GetAuthenticatedExtendedCardRequest, - context: ServerCallContext, - ) -> types_v03.AgentCard: - """Handles the 'agent/getAuthenticatedExtendedCard' JSON-RPC method.""" - if not self.agent_card.capabilities.extended_agent_card: - raise ExtendedAgentCardNotConfiguredError( - message='Authenticated card not supported' - ) - - base_card = self.extended_agent_card - if base_card is None: - base_card = self.agent_card - - card_to_serve = base_card - if self.extended_card_modifier and context: - card_to_serve = await maybe_await( - self.extended_card_modifier(base_card, context) - ) - elif self.card_modifier: - card_to_serve = await maybe_await(self.card_modifier(base_card)) - - return conversions.to_compat_agent_card(card_to_serve) - @validate_version(constants.PROTOCOL_VERSION_0_3) async def _process_streaming_request( self, diff --git a/src/a2a/compat/v0_3/jsonrpc_transport.py b/src/a2a/compat/v0_3/jsonrpc_transport.py index 557a63a16..caccd2811 100644 --- a/src/a2a/compat/v0_3/jsonrpc_transport.py +++ b/src/a2a/compat/v0_3/jsonrpc_transport.py @@ -19,6 +19,7 @@ ) from a2a.compat.v0_3 import conversions from a2a.compat.v0_3 import types as types_v03 +from a2a.compat.v0_3.extension_headers import add_legacy_extension_header from a2a.types.a2a_pb2 import ( AgentCard, CancelTaskRequest, @@ -424,6 +425,7 @@ async def _send_stream_request( http_kwargs = get_http_args(context) http_kwargs.setdefault('headers', {}) http_kwargs['headers'][VERSION_HEADER.lower()] = PROTOCOL_VERSION_0_3 + add_legacy_extension_header(http_kwargs['headers']) async for sse_data in send_http_stream_request( self.httpx_client, @@ -485,6 +487,7 @@ async def _send_request( http_kwargs = get_http_args(context) http_kwargs.setdefault('headers', {}) http_kwargs['headers'][VERSION_HEADER.lower()] = PROTOCOL_VERSION_0_3 + add_legacy_extension_header(http_kwargs['headers']) request = self.httpx_client.build_request( 'POST', diff --git a/src/a2a/compat/v0_3/request_handler.py b/src/a2a/compat/v0_3/request_handler.py index 6ec675312..d79a5cc5d 100644 --- a/src/a2a/compat/v0_3/request_handler.py +++ b/src/a2a/compat/v0_3/request_handler.py @@ -9,9 +9,7 @@ from a2a.server.request_handlers.request_handler import RequestHandler from a2a.types.a2a_pb2 import Task from a2a.utils import proto_utils as core_proto_utils -from a2a.utils.errors import ( - TaskNotFoundError, -) +from a2a.utils.errors import TaskNotFoundError logger = logging.getLogger(__name__) @@ -170,3 +168,15 @@ async def on_delete_task_push_notification_config( await self.request_handler.on_delete_task_push_notification_config( v10_req, context ) + + async def on_get_extended_agent_card( + self, + request: types_v03.GetAuthenticatedExtendedCardRequest, + context: ServerCallContext, + ) -> types_v03.AgentCard: + """Gets the authenticated extended agent card using v0.3 protocol types.""" + v10_req = conversions.to_core_get_extended_agent_card_request(request) + v10_card = await self.request_handler.on_get_extended_agent_card( + v10_req, context + ) + return conversions.to_compat_agent_card(v10_card) diff --git a/src/a2a/compat/v0_3/rest_adapter.py b/src/a2a/compat/v0_3/rest_adapter.py index 27aba2aad..38687054f 100644 --- a/src/a2a/compat/v0_3/rest_adapter.py +++ b/src/a2a/compat/v0_3/rest_adapter.py @@ -11,8 +11,8 @@ from starlette.requests import Request from starlette.responses import JSONResponse, Response + from a2a.server.context import ServerCallContext from a2a.server.request_handlers.request_handler import RequestHandler - from a2a.types.a2a_pb2 import AgentCard _package_starlette_installed = True else: @@ -31,9 +31,8 @@ _package_starlette_installed = False -from a2a.compat.v0_3 import conversions +from a2a.compat.v0_3.context_builders import V03ServerCallContextBuilder from a2a.compat.v0_3.rest_handler import REST03Handler -from a2a.server.context import ServerCallContext from a2a.server.routes.common import ( DefaultServerCallContextBuilder, ServerCallContextBuilder, @@ -43,10 +42,8 @@ rest_stream_error_handler, ) from a2a.utils.errors import ( - ExtendedAgentCardNotConfiguredError, InvalidRequestError, ) -from a2a.utils.helpers import maybe_await logger = logging.getLogger(__name__) @@ -58,23 +55,13 @@ class REST03Adapter: Defines v0.3 REST request processors and their routes, as well as managing response generation including Server-Sent Events (SSE). """ - def __init__( # noqa: PLR0913 + def __init__( self, - agent_card: 'AgentCard', http_handler: 'RequestHandler', - extended_agent_card: 'AgentCard | None' = None, context_builder: 'ServerCallContextBuilder | None' = None, - card_modifier: 'Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] | None' = None, - extended_card_modifier: 'Callable[[AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard] | None' = None, ): - self.agent_card = agent_card - self.extended_agent_card = extended_agent_card - self.card_modifier = card_modifier - self.extended_card_modifier = extended_card_modifier - self.handler = REST03Handler( - agent_card=agent_card, request_handler=http_handler - ) - self._context_builder = ( + self.handler = REST03Handler(request_handler=http_handler) + self._context_builder = V03ServerCallContextBuilder( context_builder or DefaultServerCallContextBuilder() ) @@ -113,39 +100,6 @@ async def event_generator( event_generator(method(request, call_context)) ) - async def handle_get_agent_card( - self, request: Request, call_context: ServerCallContext - ) -> dict[str, Any]: - """Handles GET requests for the agent card endpoint.""" - card_to_serve = self.agent_card - if self.card_modifier: - card_to_serve = await maybe_await(self.card_modifier(card_to_serve)) - v03_card = conversions.to_compat_agent_card(card_to_serve) - return v03_card.model_dump(mode='json', exclude_none=True) - - async def handle_authenticated_agent_card( - self, request: Request, call_context: ServerCallContext - ) -> dict[str, Any]: - """Hook for per credential agent card response.""" - if not self.agent_card.capabilities.extended_agent_card: - raise ExtendedAgentCardNotConfiguredError( - message='Authenticated card not supported' - ) - card_to_serve = self.extended_agent_card - - if not card_to_serve: - card_to_serve = self.agent_card - - if self.extended_card_modifier: - card_to_serve = await maybe_await( - self.extended_card_modifier(card_to_serve, call_context) - ) - elif self.card_modifier: - card_to_serve = await maybe_await(self.card_modifier(card_to_serve)) - - v03_card = conversions.to_compat_agent_card(card_to_serve) - return v03_card.model_dump(mode='json', exclude_none=True) - def routes(self) -> dict[tuple[str, str], Callable[[Request], Any]]: """Constructs a dictionary of API routes and their corresponding handlers.""" routes: dict[tuple[str, str], Callable[[Request], Any]] = { @@ -191,10 +145,9 @@ def routes(self) -> dict[tuple[str, str], Callable[[Request], Any]]: ('/v1/tasks', 'GET'): functools.partial( self._handle_request, self.handler.list_tasks ), + ('/v1/card', 'GET'): functools.partial( + self._handle_request, self.handler.on_get_extended_agent_card + ), } - if self.agent_card.capabilities.extended_agent_card: - routes[('/v1/card', 'GET')] = functools.partial( - self._handle_request, self.handle_authenticated_agent_card - ) return routes diff --git a/src/a2a/compat/v0_3/rest_handler.py b/src/a2a/compat/v0_3/rest_handler.py index 470f94b3e..bd5fcd2e6 100644 --- a/src/a2a/compat/v0_3/rest_handler.py +++ b/src/a2a/compat/v0_3/rest_handler.py @@ -10,7 +10,6 @@ from starlette.requests import Request from a2a.server.request_handlers.request_handler import RequestHandler - from a2a.types.a2a_pb2 import AgentCard _package_starlette_installed = True else: @@ -29,11 +28,8 @@ from a2a.compat.v0_3.request_handler import RequestHandler03 from a2a.server.context import ServerCallContext from a2a.utils import constants -from a2a.utils.helpers import ( - validate, - validate_version, -) from a2a.utils.telemetry import SpanKind, trace_class +from a2a.utils.version_validator import validate_version logger = logging.getLogger(__name__) @@ -45,16 +41,13 @@ class REST03Handler: def __init__( self, - agent_card: 'AgentCard', request_handler: 'RequestHandler', ): """Initializes the REST03Handler. Args: - agent_card: The AgentCard describing the agent's capabilities (v1.0). request_handler: The underlying `RequestHandler` instance to delegate requests to (v1.0). """ - self.agent_card = agent_card self.handler03 = RequestHandler03(request_handler=request_handler) @validate_version(constants.PROTOCOL_VERSION_0_3) @@ -84,10 +77,6 @@ async def on_message_send( return MessageToDict(pb2_v03_resp) @validate_version(constants.PROTOCOL_VERSION_0_3) - @validate( - lambda self: self.agent_card.capabilities.streaming, - 'Streaming is not supported by the agent', - ) async def on_message_send_stream( self, request: Request, @@ -142,10 +131,6 @@ async def on_cancel_task( return MessageToDict(pb2_v03_task) @validate_version(constants.PROTOCOL_VERSION_0_3) - @validate( - lambda self: self.agent_card.capabilities.streaming, - 'Streaming is not supported by the agent', - ) async def on_subscribe_to_task( self, request: Request, @@ -208,10 +193,6 @@ async def get_push_notification( return MessageToDict(pb2_v03_config) @validate_version(constants.PROTOCOL_VERSION_0_3) - @validate( - lambda self: self.agent_card.capabilities.push_notifications, - 'Push notifications are not supported by the agent', - ) async def set_push_notification( self, request: Request, @@ -317,3 +298,16 @@ async def list_tasks( ) -> dict[str, Any]: """Handles the 'tasks/list' REST method.""" raise NotImplementedError('list tasks not implemented') + + @validate_version(constants.PROTOCOL_VERSION_0_3) + async def on_get_extended_agent_card( + self, + request: Request, + context: ServerCallContext, + ) -> dict[str, Any]: + """Handles the 'v1/agent/authenticatedExtendedAgentCard' REST method.""" + rpc_req = types_v03.GetAuthenticatedExtendedCardRequest(id=0) + v03_resp = await self.handler03.on_get_extended_agent_card( + rpc_req, context + ) + return v03_resp.model_dump(mode='json', exclude_none=True) diff --git a/src/a2a/compat/v0_3/rest_transport.py b/src/a2a/compat/v0_3/rest_transport.py index 0ba38538d..bcaed2949 100644 --- a/src/a2a/compat/v0_3/rest_transport.py +++ b/src/a2a/compat/v0_3/rest_transport.py @@ -25,6 +25,7 @@ from a2a.compat.v0_3 import ( types as types_v03, ) +from a2a.compat.v0_3.extension_headers import add_legacy_extension_header from a2a.types.a2a_pb2 import ( AgentCard, CancelTaskRequest, @@ -380,6 +381,7 @@ async def _send_stream_request( http_kwargs = get_http_args(context) http_kwargs.setdefault('headers', {}) http_kwargs['headers'][VERSION_HEADER.lower()] = PROTOCOL_VERSION_0_3 + add_legacy_extension_header(http_kwargs['headers']) async for sse_data in send_http_stream_request( self.httpx_client, @@ -414,6 +416,7 @@ async def _execute_request( http_kwargs = get_http_args(context) http_kwargs.setdefault('headers', {}) http_kwargs['headers'][VERSION_HEADER.lower()] = PROTOCOL_VERSION_0_3 + add_legacy_extension_header(http_kwargs['headers']) request = self.httpx_client.build_request( method, diff --git a/src/a2a/contrib/tasks/__init__.py b/src/a2a/contrib/tasks/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/a2a/contrib/tasks/vertex_task_converter.py b/src/a2a/contrib/tasks/vertex_task_converter.py deleted file mode 100644 index 6f23dad2e..000000000 --- a/src/a2a/contrib/tasks/vertex_task_converter.py +++ /dev/null @@ -1,159 +0,0 @@ -try: - from google.genai import types as genai_types - from vertexai import types as vertexai_types -except ImportError as e: - raise ImportError( - 'vertex_task_converter requires vertexai. ' - 'Install with: ' - "'pip install a2a-sdk[vertex]'" - ) from e - -import base64 -import json - -from a2a.compat.v0_3.types import ( - Artifact, - DataPart, - FilePart, - FileWithBytes, - FileWithUri, - Part, - Task, - TaskState, - TaskStatus, - TextPart, -) - - -_TO_SDK_TASK_STATE = { - vertexai_types.A2aTaskState.STATE_UNSPECIFIED: TaskState.unknown, - vertexai_types.A2aTaskState.SUBMITTED: TaskState.submitted, - vertexai_types.A2aTaskState.WORKING: TaskState.working, - vertexai_types.A2aTaskState.COMPLETED: TaskState.completed, - vertexai_types.A2aTaskState.CANCELLED: TaskState.canceled, - vertexai_types.A2aTaskState.FAILED: TaskState.failed, - vertexai_types.A2aTaskState.REJECTED: TaskState.rejected, - vertexai_types.A2aTaskState.INPUT_REQUIRED: TaskState.input_required, - vertexai_types.A2aTaskState.AUTH_REQUIRED: TaskState.auth_required, -} - -_SDK_TO_STORED_TASK_STATE = {v: k for k, v in _TO_SDK_TASK_STATE.items()} - - -def to_sdk_task_state(stored_state: vertexai_types.A2aTaskState) -> TaskState: - """Converts a proto A2aTask.State to a TaskState enum.""" - return _TO_SDK_TASK_STATE.get(stored_state, TaskState.unknown) - - -def to_stored_task_state(task_state: TaskState) -> vertexai_types.A2aTaskState: - """Converts a TaskState enum to a proto A2aTask.State enum value.""" - return _SDK_TO_STORED_TASK_STATE.get( - task_state, vertexai_types.A2aTaskState.STATE_UNSPECIFIED - ) - - -def to_stored_part(part: Part) -> genai_types.Part: - """Converts a SDK Part to a proto Part.""" - if isinstance(part.root, TextPart): - return genai_types.Part(text=part.root.text) - if isinstance(part.root, DataPart): - data_bytes = json.dumps(part.root.data).encode('utf-8') - return genai_types.Part( - inline_data=genai_types.Blob( - mime_type='application/json', data=data_bytes - ) - ) - if isinstance(part.root, FilePart): - file_content = part.root.file - if isinstance(file_content, FileWithBytes): - decoded_bytes = base64.b64decode(file_content.bytes) - return genai_types.Part( - inline_data=genai_types.Blob( - mime_type=file_content.mime_type or '', data=decoded_bytes - ) - ) - if isinstance(file_content, FileWithUri): - return genai_types.Part( - file_data=genai_types.FileData( - mime_type=file_content.mime_type or '', - file_uri=file_content.uri, - ) - ) - raise ValueError(f'Unsupported part type: {type(part.root)}') - - -def to_sdk_part(stored_part: genai_types.Part) -> Part: - """Converts a proto Part to a SDK Part.""" - if stored_part.text: - return Part(root=TextPart(text=stored_part.text)) - if stored_part.inline_data: - encoded_bytes = base64.b64encode( - stored_part.inline_data.data or b'' - ).decode('utf-8') - return Part( - root=FilePart( - file=FileWithBytes( - mime_type=stored_part.inline_data.mime_type, - bytes=encoded_bytes, - ) - ) - ) - if stored_part.file_data and stored_part.file_data.file_uri: - return Part( - root=FilePart( - file=FileWithUri( - mime_type=stored_part.file_data.mime_type, - uri=stored_part.file_data.file_uri, - ) - ) - ) - - raise ValueError(f'Unsupported part: {stored_part}') - - -def to_stored_artifact(artifact: Artifact) -> vertexai_types.TaskArtifact: - """Converts a SDK Artifact to a proto TaskArtifact.""" - return vertexai_types.TaskArtifact( - artifact_id=artifact.artifact_id, - parts=[to_stored_part(part) for part in artifact.parts], - ) - - -def to_sdk_artifact(stored_artifact: vertexai_types.TaskArtifact) -> Artifact: - """Converts a proto TaskArtifact to a SDK Artifact.""" - return Artifact( - artifact_id=stored_artifact.artifact_id, - parts=[to_sdk_part(part) for part in stored_artifact.parts], - ) - - -def to_stored_task(task: Task) -> vertexai_types.A2aTask: - """Converts a SDK Task to a proto A2aTask.""" - return vertexai_types.A2aTask( - context_id=task.context_id, - metadata=task.metadata, - state=to_stored_task_state(task.status.state), - output=vertexai_types.TaskOutput( - artifacts=[ - to_stored_artifact(artifact) - for artifact in task.artifacts or [] - ] - ), - ) - - -def to_sdk_task(a2a_task: vertexai_types.A2aTask) -> Task: - """Converts a proto A2aTask to a SDK Task.""" - return Task( - id=a2a_task.name.split('/')[-1], - context_id=a2a_task.context_id, - status=TaskStatus(state=to_sdk_task_state(a2a_task.state)), - metadata=a2a_task.metadata or {}, - artifacts=[ - to_sdk_artifact(artifact) - for artifact in a2a_task.output.artifacts or [] - ] - if a2a_task.output - else [], - history=[], - ) diff --git a/src/a2a/contrib/tasks/vertex_task_store.py b/src/a2a/contrib/tasks/vertex_task_store.py deleted file mode 100644 index ccd9fffba..000000000 --- a/src/a2a/contrib/tasks/vertex_task_store.py +++ /dev/null @@ -1,225 +0,0 @@ -import logging - - -try: - import vertexai - - from google.genai import errors as genai_errors - from vertexai import types as vertexai_types -except ImportError as e: - raise ImportError( - 'VertexTaskStore requires vertexai. ' - 'Install with: ' - "'pip install a2a-sdk[vertex]'" - ) from e - -from a2a.compat.v0_3.conversions import to_compat_task, to_core_task -from a2a.compat.v0_3.types import Task as CompatTask -from a2a.contrib.tasks import vertex_task_converter -from a2a.server.context import ServerCallContext -from a2a.server.tasks.task_store import TaskStore -from a2a.types.a2a_pb2 import ListTasksRequest, ListTasksResponse, Task - - -logger = logging.getLogger(__name__) - - -class VertexTaskStore(TaskStore): - """Implementation of TaskStore using Vertex AI Agent Engine Task Store. - - Stores task objects in Vertex AI Agent Engine Task Store. - """ - - def __init__( - self, - client: vertexai.Client, # type: ignore - agent_engine_resource_id: str, - ) -> None: - """Initializes the VertexTaskStore. - - Args: - client: The Vertex AI client. - agent_engine_resource_id: The resource ID of the agent engine. - """ - self._client = client - self._agent_engine_resource_id = agent_engine_resource_id - - async def save(self, task: Task, context: ServerCallContext) -> None: - """Saves or updates a task in the store.""" - compat_task = to_compat_task(task) - previous_task = await self._get_stored_task(compat_task.id) - if previous_task is None: - await self._create(compat_task) - else: - await self._update(previous_task, compat_task) - - async def _create(self, sdk_task: CompatTask) -> None: - stored_task = vertex_task_converter.to_stored_task(sdk_task) - await self._client.aio.agent_engines.a2a_tasks.create( - name=self._agent_engine_resource_id, - a2a_task_id=sdk_task.id, - config=vertexai_types.CreateAgentEngineTaskConfig( - context_id=stored_task.context_id, - metadata=stored_task.metadata, - output=stored_task.output, - ), - ) - - def _get_status_change_event( - self, - previous_task: CompatTask, - task: CompatTask, - event_sequence_number: int, - ) -> vertexai_types.TaskEvent | None: - if task.status.state != previous_task.status.state: - return vertexai_types.TaskEvent( - event_data=vertexai_types.TaskEventData( - state_change=vertexai_types.TaskStateChange( - new_state=vertex_task_converter.to_stored_task_state( - task.status.state - ), - ), - ), - event_sequence_number=event_sequence_number, - ) - return None - - def _get_metadata_change_event( - self, - previous_task: CompatTask, - task: CompatTask, - event_sequence_number: int, - ) -> vertexai_types.TaskEvent | None: - if task.metadata != previous_task.metadata: - return vertexai_types.TaskEvent( - event_data=vertexai_types.TaskEventData( - metadata_change=vertexai_types.TaskMetadataChange( - new_metadata=task.metadata, - ) - ), - event_sequence_number=event_sequence_number, - ) - return None - - def _get_artifacts_change_event( - self, - previous_task: CompatTask, - task: CompatTask, - event_sequence_number: int, - ) -> vertexai_types.TaskEvent | None: - if task.artifacts != previous_task.artifacts: - task_artifact_change = vertexai_types.TaskArtifactChange() - event = vertexai_types.TaskEvent( - event_data=vertexai_types.TaskEventData( - output_change=vertexai_types.TaskOutputChange( - task_artifact_change=task_artifact_change - ) - ), - event_sequence_number=event_sequence_number, - ) - task_artifacts = ( - {artifact.artifact_id: artifact for artifact in task.artifacts} - if task.artifacts - else {} - ) - previous_task_artifacts = ( - { - artifact.artifact_id: artifact - for artifact in previous_task.artifacts - } - if previous_task.artifacts - else {} - ) - for artifact in previous_task_artifacts.values(): - if artifact.artifact_id not in task_artifacts: - if not task_artifact_change.deleted_artifact_ids: - task_artifact_change.deleted_artifact_ids = [] - task_artifact_change.deleted_artifact_ids.append( - artifact.artifact_id - ) - for artifact in task_artifacts.values(): - if artifact.artifact_id not in previous_task_artifacts: - if not task_artifact_change.added_artifacts: - task_artifact_change.added_artifacts = [] - task_artifact_change.added_artifacts.append( - vertex_task_converter.to_stored_artifact(artifact) - ) - elif artifact != previous_task_artifacts[artifact.artifact_id]: - if not task_artifact_change.updated_artifacts: - task_artifact_change.updated_artifacts = [] - task_artifact_change.updated_artifacts.append( - vertex_task_converter.to_stored_artifact(artifact) - ) - if task_artifact_change != vertexai_types.TaskArtifactChange(): - return event - return None - - async def _update( - self, previous_stored_task: vertexai_types.A2aTask, task: CompatTask - ) -> None: - previous_task = vertex_task_converter.to_sdk_task(previous_stored_task) - events = [] - event_sequence_number = previous_stored_task.next_event_sequence_number - - status_event = self._get_status_change_event( - previous_task, task, event_sequence_number - ) - if status_event: - events.append(status_event) - event_sequence_number += 1 - - metadata_event = self._get_metadata_change_event( - previous_task, task, event_sequence_number - ) - if metadata_event: - events.append(metadata_event) - event_sequence_number += 1 - - artifacts_event = self._get_artifacts_change_event( - previous_task, task, event_sequence_number - ) - if artifacts_event: - events.append(artifacts_event) - event_sequence_number += 1 - - if not events: - return - await self._client.aio.agent_engines.a2a_tasks.events.append( - name=self._agent_engine_resource_id + '/a2aTasks/' + task.id, - task_events=events, - ) - - async def _get_stored_task( - self, task_id: str - ) -> vertexai_types.A2aTask | None: - try: - a2a_task = await self._client.aio.agent_engines.a2a_tasks.get( - name=self._agent_engine_resource_id + '/a2aTasks/' + task_id, - ) - except genai_errors.APIError as e: - if e.status == 'NOT_FOUND': - logger.debug('Task %s not found in store.', task_id) - return None - raise - return a2a_task - - async def get( - self, task_id: str, context: ServerCallContext - ) -> Task | None: - """Retrieves a task from the database by ID.""" - a2a_task = await self._get_stored_task(task_id) - if a2a_task is None: - return None - return to_core_task(vertex_task_converter.to_sdk_task(a2a_task)) - - async def list( - self, - params: ListTasksRequest, - context: ServerCallContext, - ) -> ListTasksResponse: - """Retrieves a list of tasks from the store.""" - raise NotImplementedError - - async def delete(self, task_id: str, context: ServerCallContext) -> None: - """The backend doesn't support deleting tasks, so this is not implemented.""" - raise NotImplementedError diff --git a/src/a2a/extensions/common.py b/src/a2a/extensions/common.py index 0595216ed..06ccf8f40 100644 --- a/src/a2a/extensions/common.py +++ b/src/a2a/extensions/common.py @@ -1,7 +1,7 @@ from a2a.types.a2a_pb2 import AgentCard, AgentExtension -HTTP_EXTENSION_HEADER = 'X-A2A-Extensions' +HTTP_EXTENSION_HEADER = 'A2A-Extensions' def get_requested_extensions(values: list[str]) -> set[str]: diff --git a/src/a2a/helpers/__init__.py b/src/a2a/helpers/__init__.py new file mode 100644 index 000000000..c42429d43 --- /dev/null +++ b/src/a2a/helpers/__init__.py @@ -0,0 +1,34 @@ +"""Helper functions for the A2A Python SDK.""" + +from a2a.helpers.agent_card import display_agent_card +from a2a.helpers.proto_helpers import ( + get_artifact_text, + get_message_text, + get_stream_response_text, + get_text_parts, + new_artifact, + new_message, + new_task, + new_task_from_user_message, + new_text_artifact, + new_text_artifact_update_event, + new_text_message, + new_text_status_update_event, +) + + +__all__ = [ + 'display_agent_card', + 'get_artifact_text', + 'get_message_text', + 'get_stream_response_text', + 'get_text_parts', + 'new_artifact', + 'new_message', + 'new_task', + 'new_task_from_user_message', + 'new_text_artifact', + 'new_text_artifact_update_event', + 'new_text_message', + 'new_text_status_update_event', +] diff --git a/src/a2a/helpers/agent_card.py b/src/a2a/helpers/agent_card.py new file mode 100644 index 000000000..0962e67fb --- /dev/null +++ b/src/a2a/helpers/agent_card.py @@ -0,0 +1,76 @@ +"""Utility functions for inspecting AgentCard instances.""" + +from a2a.types.a2a_pb2 import AgentCard + + +def display_agent_card(card: AgentCard) -> None: + """Print a human-readable summary of an AgentCard to stdout. + + Args: + card: The AgentCard proto message to display. + """ + width = 52 + sep = '=' * width + thin = '-' * width + + lines: list[str] = [sep, 'AgentCard'.center(width), sep] + + lines += [ + '--- General ---', + f'Name : {card.name}', + f'Description : {card.description}', + f'Version : {card.version}', + ] + if card.documentation_url: + lines.append(f'Docs URL : {card.documentation_url}') + if card.icon_url: + lines.append(f'Icon URL : {card.icon_url}') + if card.HasField('provider'): + url_suffix = f' ({card.provider.url})' if card.provider.url else '' + lines.append(f'Provider : {card.provider.organization}{url_suffix}') + + lines += ['', '--- Interfaces ---'] + for i, iface in enumerate(card.supported_interfaces): + binding = f'{iface.protocol_binding} {iface.protocol_version}'.strip() + parts = [ + p + for p in [binding, f'tenant={iface.tenant}' if iface.tenant else ''] + if p + ] + suffix = f' ({", ".join(parts)})' if parts else '' + line = f' [{i}] {iface.url}{suffix}' + lines.append(line) + + lines += [ + '', + '--- Capabilities ---', + f'Streaming : {card.capabilities.streaming}', + f'Push notifications : {card.capabilities.push_notifications}', + f'Extended agent card : {card.capabilities.extended_agent_card}', + ] + + lines += [ + '', + '--- I/O Modes ---', + f'Input : {", ".join(card.default_input_modes) or "(none)"}', + f'Output : {", ".join(card.default_output_modes) or "(none)"}', + ] + + lines += ['', '--- Skills ---'] + if card.skills: + for skill in card.skills: + lines += [ + thin, + f' ID : {skill.id}', + f' Name : {skill.name}', + f' Description : {skill.description}', + f' Tags : {", ".join(skill.tags) or "(none)"}', + ] + if skill.examples: + for ex in skill.examples: + lines.append(f' Example : {ex}') + else: + lines.append(' (none)') + + lines.append(sep) + print('\n'.join(lines)) diff --git a/src/a2a/helpers/proto_helpers.py b/src/a2a/helpers/proto_helpers.py new file mode 100644 index 000000000..6cc6350b6 --- /dev/null +++ b/src/a2a/helpers/proto_helpers.py @@ -0,0 +1,469 @@ +"""Unified helper functions for creating and handling A2A types.""" + +import uuid + +from collections.abc import Sequence +from typing import Any + +from google.protobuf import struct_pb2 +from google.protobuf.json_format import ParseDict + +from a2a.types.a2a_pb2 import ( + Artifact, + Message, + Part, + Role, + StreamResponse, + Task, + TaskArtifactUpdateEvent, + TaskState, + TaskStatus, + TaskStatusUpdateEvent, +) + + +# --- Message Helpers --- + + +def new_message( + parts: list[Part], + context_id: str | None = None, + task_id: str | None = None, + role: Role = Role.ROLE_AGENT, +) -> Message: + """Creates a new message containing a list of Parts.""" + return Message( + role=role, + parts=parts, + message_id=str(uuid.uuid4()), + task_id=task_id, + context_id=context_id, + ) + + +def new_text_message( + text: str, + media_type: str | None = None, + context_id: str | None = None, + task_id: str | None = None, + role: Role = Role.ROLE_AGENT, +) -> Message: + """Creates a new message containing a single text Part.""" + return new_message( + parts=[new_text_part(text, media_type=media_type)], + context_id=context_id, + task_id=task_id, + role=role, + ) + + +def get_message_text(message: Message, delimiter: str = '\n') -> str: + """Extracts and joins all text content from a Message's parts.""" + return delimiter.join(get_text_parts(message.parts)) + + +def new_data_message( + data: Any, + media_type: str | None = None, + context_id: str | None = None, + task_id: str | None = None, + role: Role = Role.ROLE_AGENT, +) -> Message: + """Creates a new message containing a single data Part. + + Args: + data: JSON-serializable data to embed (dict, list, str, etc.). + media_type: Optional MIME type of the part content (e.g., "text/plain", "application/json", "image/png"). + context_id: Optional context ID. + task_id: Optional task ID. + role: The role of the message sender (default: ROLE_AGENT). + + Returns: + A Message with a single data Part. + """ + return new_message( + parts=[new_data_part(data, media_type=media_type)], + context_id=context_id, + task_id=task_id, + role=role, + ) + + +def new_raw_message( # noqa: PLR0913 + raw: bytes, + media_type: str | None = None, + filename: str | None = None, + context_id: str | None = None, + task_id: str | None = None, + role: Role = Role.ROLE_AGENT, +) -> Message: + """Creates a new message containing a single raw bytes Part. + + Args: + raw: The raw bytes content. + media_type: Optional MIME type (e.g. 'image/png'). + filename: Optional filename. + context_id: Optional context ID. + task_id: Optional task ID. + role: The role of the message sender (default: ROLE_AGENT). + + Returns: + A Message with a single raw Part. + """ + return new_message( + parts=[new_raw_part(raw, media_type=media_type, filename=filename)], + context_id=context_id, + task_id=task_id, + role=role, + ) + + +def new_url_message( # noqa: PLR0913 + url: str, + media_type: str | None = None, + filename: str | None = None, + context_id: str | None = None, + task_id: str | None = None, + role: Role = Role.ROLE_AGENT, +) -> Message: + """Creates a new message containing a single URL Part. + + Args: + url: The URL pointing to the file content. + media_type: Optional MIME type (e.g. 'image/png'). + filename: Optional filename. + context_id: Optional context ID. + task_id: Optional task ID. + role: The role of the message sender (default: ROLE_AGENT). + + Returns: + A Message with a single URL Part. + """ + return new_message( + parts=[new_url_part(url, media_type=media_type, filename=filename)], + context_id=context_id, + task_id=task_id, + role=role, + ) + + +# --- Artifact Helpers --- + + +def new_artifact( + parts: list[Part], + name: str, + description: str | None = None, + artifact_id: str | None = None, +) -> Artifact: + """Creates a new Artifact object.""" + return Artifact( + artifact_id=artifact_id or str(uuid.uuid4()), + parts=parts, + name=name, + description=description, + ) + + +def new_text_artifact( + name: str, + text: str, + media_type: str | None = None, + description: str | None = None, + artifact_id: str | None = None, +) -> Artifact: + """Creates a new Artifact object containing only a single text Part.""" + return new_artifact( + [new_text_part(text, media_type=media_type)], + name, + description, + artifact_id=artifact_id, + ) + + +def new_data_artifact( + name: str, + data: Any, + media_type: str | None = None, + description: str | None = None, + artifact_id: str | None = None, +) -> Artifact: + """Creates a new Artifact object containing only a single data Part. + + Args: + name: The name of the artifact. + data: JSON-serializable data to embed (dict, list, str, etc.). + media_type: Optional MIME type of the part content (e.g., "text/plain", "application/json", "image/png"). + description: Optional description. + artifact_id: Optional artifact ID (auto-generated if not provided). + + Returns: + An Artifact with a single data Part. + """ + return new_artifact( + [new_data_part(data, media_type=media_type)], + name, + description, + artifact_id=artifact_id, + ) + + +def new_raw_artifact( # noqa: PLR0913 + name: str, + raw: bytes, + media_type: str | None = None, + filename: str | None = None, + description: str | None = None, + artifact_id: str | None = None, +) -> Artifact: + """Creates a new Artifact object containing only a single raw bytes Part. + + Args: + name: The name of the artifact. + raw: The raw bytes content. + media_type: Optional MIME type (e.g. 'image/png'). + filename: Optional filename. + description: Optional description. + artifact_id: Optional artifact ID (auto-generated if not provided). + + Returns: + An Artifact with a single raw Part. + """ + return new_artifact( + [new_raw_part(raw, media_type=media_type, filename=filename)], + name, + description, + artifact_id=artifact_id, + ) + + +def new_url_artifact( # noqa: PLR0913 + name: str, + url: str, + media_type: str | None = None, + filename: str | None = None, + description: str | None = None, + artifact_id: str | None = None, +) -> Artifact: + """Creates a new Artifact object containing only a single URL Part. + + Args: + name: The name of the artifact. + url: The URL pointing to the file content. + media_type: Optional MIME type (e.g. 'image/png'). + filename: Optional filename. + description: Optional description. + artifact_id: Optional artifact ID (auto-generated if not provided). + + Returns: + An Artifact with a single URL Part. + """ + return new_artifact( + [new_url_part(url, media_type=media_type, filename=filename)], + name, + description, + artifact_id=artifact_id, + ) + + +def get_artifact_text(artifact: Artifact, delimiter: str = '\n') -> str: + """Extracts and joins all text content from an Artifact's parts.""" + return delimiter.join(get_text_parts(artifact.parts)) + + +# --- Task Helpers --- + + +def new_task_from_user_message(user_message: Message) -> Task: + """Creates a new Task object from an initial user message.""" + if user_message.role != Role.ROLE_USER: + raise ValueError('Message must be from a user') + if not user_message.parts: + raise ValueError('Message parts cannot be empty') + for part in user_message.parts: + if part.HasField('text') and not part.text: + raise ValueError('Message.text cannot be empty') + + return Task( + status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED), + id=user_message.task_id or str(uuid.uuid4()), + context_id=user_message.context_id or str(uuid.uuid4()), + history=[user_message], + ) + + +def new_task( + task_id: str, + context_id: str, + state: TaskState, + artifacts: list[Artifact] | None = None, + history: list[Message] | None = None, +) -> Task: + """Creates a Task object with a specified status.""" + if history is None: + history = [] + if artifacts is None: + artifacts = [] + + return Task( + status=TaskStatus(state=state), + id=task_id, + context_id=context_id, + artifacts=artifacts, + history=history, + ) + + +# --- Part Helpers --- + + +def new_text_part( + text: str, + media_type: str | None = None, +) -> Part: + """Creates a Part with text content. + + Args: + text: The text content. + media_type: Optional MIME type (e.g. 'text/plain', 'text/markdown'). + + Returns: + A Part with the text field set. + """ + return Part(text=text, media_type=media_type or '') + + +def new_data_part( + data: Any, + media_type: str | None = None, +) -> Part: + """Creates a Part with structured data (google.protobuf.Value). + + Args: + data: JSON-serializable data to embed (dict, list, str, etc.). + media_type: Optional MIME type of the part content (e.g., "text/plain", "application/json", "image/png"). + + Returns: + A Part with the data field set. + """ + return Part( + data=ParseDict(data, struct_pb2.Value()), + media_type=media_type or '', + ) + + +def new_raw_part( + raw: bytes, + media_type: str | None = None, + filename: str | None = None, +) -> Part: + """Creates a Part with raw bytes content. + + Args: + raw: The raw bytes content. + media_type: Optional MIME type (e.g. 'image/png'). + filename: Optional filename. + + Returns: + A Part with the raw field set. + """ + return Part( + raw=raw, + media_type=media_type or '', + filename=filename or '', + ) + + +def new_url_part( + url: str, + media_type: str | None = None, + filename: str | None = None, +) -> Part: + """Creates a Part with a URL pointing to file content. + + Args: + url: The URL to the file content. + media_type: Optional MIME type (e.g. 'image/png'). + filename: Optional filename. + + Returns: + A Part with the url field set. + """ + return Part( + url=url, + media_type=media_type or '', + filename=filename or '', + ) + + +def get_text_parts(parts: Sequence[Part]) -> list[str]: + """Extracts text content from all text Parts.""" + return [part.text for part in parts if part.HasField('text')] + + +# --- Event & Stream Helpers --- + + +def new_text_status_update_event( + task_id: str, + context_id: str, + state: TaskState, + text: str, +) -> TaskStatusUpdateEvent: + """Creates a TaskStatusUpdateEvent with a single text message.""" + return TaskStatusUpdateEvent( + task_id=task_id, + context_id=context_id, + status=TaskStatus( + state=state, + message=new_text_message( + text=text, + role=Role.ROLE_AGENT, + context_id=context_id, + task_id=task_id, + ), + ), + ) + + +def new_text_artifact_update_event( # noqa: PLR0913 + task_id: str, + context_id: str, + name: str, + text: str, + append: bool = False, + last_chunk: bool = False, + artifact_id: str | None = None, +) -> TaskArtifactUpdateEvent: + """Creates a TaskArtifactUpdateEvent with a single text artifact.""" + return TaskArtifactUpdateEvent( + task_id=task_id, + context_id=context_id, + artifact=new_text_artifact( + name=name, text=text, artifact_id=artifact_id + ), + append=append, + last_chunk=last_chunk, + ) + + +def get_stream_response_text( + response: StreamResponse, delimiter: str = '\n' +) -> str: + """Extracts text content from a StreamResponse.""" + if response.HasField('message'): + return get_message_text(response.message, delimiter) + if response.HasField('task'): + texts = [ + get_artifact_text(a, delimiter) for a in response.task.artifacts + ] + return delimiter.join(t for t in texts if t) + if response.HasField('status_update'): + if response.status_update.status.HasField('message'): + return get_message_text( + response.status_update.status.message, delimiter + ) + return '' + if response.HasField('artifact_update'): + return get_artifact_text(response.artifact_update.artifact, delimiter) + return '' diff --git a/src/a2a/server/agent_execution/active_task.py b/src/a2a/server/agent_execution/active_task.py new file mode 100644 index 000000000..5479a38c1 --- /dev/null +++ b/src/a2a/server/agent_execution/active_task.py @@ -0,0 +1,754 @@ +# ruff: noqa: TRY301, SLF001 +from __future__ import annotations + +import asyncio +import logging +import uuid + +from typing import TYPE_CHECKING, Any, cast + +from a2a.server.agent_execution.context import RequestContext + + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator, Callable + + from a2a.server.agent_execution.agent_executor import AgentExecutor + from a2a.server.context import ServerCallContext + from a2a.server.tasks.push_notification_sender import ( + PushNotificationSender, + ) + from a2a.server.tasks.task_manager import TaskManager + +from a2a.server.events.event_queue_v2 import ( + AsyncQueue, + Event, + EventQueueSource, + QueueShutDown, + _create_async_queue, +) +from a2a.server.tasks import PushNotificationEvent +from a2a.types.a2a_pb2 import ( + Message, + Task, + TaskState, + TaskStatus, + TaskStatusUpdateEvent, +) +from a2a.utils.errors import ( + InvalidAgentResponseError, + InvalidParamsError, + TaskNotFoundError, +) + + +logger = logging.getLogger(__name__) + + +TERMINAL_TASK_STATES = { + TaskState.TASK_STATE_COMPLETED, + TaskState.TASK_STATE_CANCELED, + TaskState.TASK_STATE_FAILED, + TaskState.TASK_STATE_REJECTED, +} +INTERRUPTED_TASK_STATES = { + TaskState.TASK_STATE_AUTH_REQUIRED, + TaskState.TASK_STATE_INPUT_REQUIRED, +} + + +class _RequestStarted: + def __init__(self, request_id: uuid.UUID, request_context: RequestContext): + self.request_id = request_id + self.request_context = request_context + + +class _RequestCompleted: + def __init__(self, request_id: uuid.UUID): + self.request_id = request_id + + +class ActiveTask: + """Manages the lifecycle and execution of an active A2A task. + + It coordinates between the agent's execution (the producer), the + persistence and state management (the TaskManager), and the event + distribution to subscribers (the consumer). + + Concurrency Guarantees: + - This class is designed to be highly concurrent. It manages an internal + producer-consumer model using `asyncio.Task`s. + - `self._lock` (asyncio.Lock) ensures mutually exclusive access for critical + lifecycle state changes, such as starting the task, subscribing, and + determining if cleanup is safe to trigger. + + mutation to the observable result state (like `_exception`, + or `_is_finished`) notifies waiting coroutines (like `wait()`). + - `self._is_finished` (asyncio.Event) provides a thread-safe, non-blocking way + for external observers and internal loops to check if the ActiveTask has + permanently ceased execution and closed its queues. + """ + + def __init__( + self, + agent_executor: AgentExecutor, + task_id: str, + task_manager: TaskManager, + push_sender: PushNotificationSender | None = None, + on_cleanup: Callable[[ActiveTask], None] | None = None, + ) -> None: + """Initializes the ActiveTask. + + Args: + agent_executor: The executor to run the agent logic (producer). + task_id: The unique identifier of the task being managed. + task_manager: The manager for task state and database persistence. + push_sender: Optional sender for out-of-band push notifications. + on_cleanup: Optional callback triggered when the task is fully finished + and the last subscriber has disconnected. Used to prune + the task from the ActiveTaskRegistry. + """ + # --- Core Dependencies --- + self._agent_executor = agent_executor + self._task_id = task_id + self._event_queue_agent = EventQueueSource() + self._event_queue_subscribers = EventQueueSource( + create_default_sink=False + ) + self._task_manager = task_manager + self._push_sender = push_sender + self._on_cleanup = on_cleanup + + # --- Synchronization Primitives --- + # `_lock` protects structural lifecycle changes: start(), subscribe() counting, + # and _maybe_cleanup() race conditions. + self._lock = asyncio.Lock() + + # `_request_lock` protects parallel request processing. + self._request_lock = asyncio.Lock() + + # _task_created is set when initial version of task is stored in DB. + self._task_created = asyncio.Event() + + # `_is_finished` is set EXACTLY ONCE when the consumer loop exits, signifying + # the absolute end of the task's active lifecycle. + self._is_finished = asyncio.Event() + + # --- Lifecycle State --- + # The background task executing the agent logic. + self._producer_task: asyncio.Task[None] | None = None + # The background task reading from _event_queue and updating the DB. + self._consumer_task: asyncio.Task[None] | None = None + + # Tracks how many active SSE/gRPC streams are currently tailing this task. + # Protected by `_lock`. + self._reference_count = 0 + + # Holds any fatal exception that crashed the producer or consumer. + # TODO: Synchronize exception handling (ideally mix it in the queue). + self._exception: Exception | None = None + + # Queue for incoming requests + self._request_queue: AsyncQueue[tuple[RequestContext, uuid.UUID]] = ( + _create_async_queue() + ) + + @property + def task_id(self) -> str: + """The ID of the task.""" + return self._task_id + + async def enqueue_request( + self, request_context: RequestContext + ) -> uuid.UUID: + """Enqueues a request for the active task to process.""" + request_id = uuid.uuid4() + await self._request_queue.put((request_context, request_id)) + return request_id + + async def start( + self, + call_context: ServerCallContext, + create_task_if_missing: bool = False, + ) -> None: + """Starts the active task background processes. + + Concurrency Guarantee: + Uses `self._lock` to ensure the producer and consumer tasks are strictly + singleton instances for the lifetime of this ActiveTask. + """ + logger.debug('ActiveTask[%s]: Starting', self._task_id) + async with self._lock: + if self._is_finished.is_set(): + raise InvalidParamsError( + f'Task {self._task_id} is already completed. Cannot start it again.' + ) + + if ( + self._producer_task is not None + and self._consumer_task is not None + ): + logger.debug( + 'ActiveTask[%s]: Already started, ignoring start request', + self._task_id, + ) + return + + logger.debug( + 'ActiveTask[%s]: Executing setup (call_context: %s, create_task_if_missing: %s)', + self._task_id, + call_context, + create_task_if_missing, + ) + try: + self._task_manager._call_context = call_context + task = await self._task_manager.get_task() + logger.debug('TASK (start): %s', task) + + if task: + self._task_created.set() + if task.status.state in TERMINAL_TASK_STATES: + raise InvalidParamsError( + message=f'Task {task.id} is in terminal state: {task.status.state}' + ) + elif not create_task_if_missing: + raise TaskNotFoundError + + except Exception: + logger.debug( + 'ActiveTask[%s]: Setup failed, cleaning up', + self._task_id, + ) + self._is_finished.set() + if self._reference_count == 0 and self._on_cleanup: + self._on_cleanup(self) + raise + + # Spawn the background tasks that drive the lifecycle. + self._reference_count += 1 + self._producer_task = asyncio.create_task( + self._run_producer(), name=f'producer:{self._task_id}' + ) + self._consumer_task = asyncio.create_task( + self._run_consumer(), name=f'consumer:{self._task_id}' + ) + logger.debug( + 'ActiveTask[%s]: Background tasks created', self._task_id + ) + + async def _run_producer(self) -> None: + """Executes the agent logic. + + This method encapsulates the external `AgentExecutor.execute` call. It ensures + that regardless of how the agent finishes (success, unhandled exception, or + cancellation), the underlying `_event_queue` is safely closed, which signals + the consumer to wind down. + + Concurrency Guarantee: + Runs as a detached asyncio.Task. Safe to cancel. + """ + logger.debug('Producer[%s]: Started', self._task_id) + request_context = None + try: + while True: + ( + request_context, + request_id, + ) = await self._request_queue.get() + await self._request_lock.acquire() + # TODO: Should we create task manager every time? + self._task_manager._call_context = request_context.call_context + + request_context.current_task = ( + await self._task_manager.get_task() + ) + + logger.debug( + 'Producer[%s]: Executing agent task %s', + self._task_id, + request_context.current_task, + ) + + try: + await self._event_queue_agent.enqueue_event( + cast( + 'Event', + _RequestStarted(request_id, request_context), + ) + ) + + await self._agent_executor.execute( + request_context, self._event_queue_agent + ) + logger.debug( + 'Producer[%s]: Execution finished successfully', + self._task_id, + ) + finally: + logger.debug( + 'Producer[%s]: Enqueuing request completed event', + self._task_id, + ) + await self._event_queue_agent.enqueue_event( + cast('Event', _RequestCompleted(request_id)) + ) + self._request_queue.task_done() + except asyncio.CancelledError: + logger.debug('Producer[%s]: Cancelled', self._task_id) + + except QueueShutDown: + logger.debug('Producer[%s]: Queue shut down', self._task_id) + + except Exception as e: + logger.exception( + 'Producer[%s]: Execution failed', + self._task_id, + ) + # Create task and mark as failed. + if request_context: + await self._task_manager.ensure_task_id( + self._task_id, + request_context.context_id or '', + ) + self._task_created.set() + async with self._lock: + await self._mark_task_as_failed(e) + + finally: + self._request_queue.shutdown(immediate=True) + await self._event_queue_agent.close(immediate=False) + await self._event_queue_subscribers.close(immediate=False) + logger.debug('Producer[%s]: Completed', self._task_id) + + async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912 + """Consumes events from the agent and updates system state. + + This continuous loop dequeues events emitted by the producer, updates the + database via `TaskManager`, and intercepts critical task states (e.g., + INPUT_REQUIRED, COMPLETED, FAILED) to cache the final result. + + Concurrency Guarantee: + Runs as a detached asyncio.Task. The loop ends gracefully when the producer + closes the queue (raising `QueueShutDown`). Upon termination, it formally sets + `_is_finished`, unblocking all global subscribers and wait() calls. + """ + logger.debug('Consumer[%s]: Started', self._task_id) + task_mode = None + message_to_save = None + # TODO: Make helper methods + # TODO: Support Task enqueue + try: + try: + try: + while True: + # Dequeue event. This raises QueueShutDown when finished. + logger.debug( + 'Consumer[%s]: Waiting for event', + self._task_id, + ) + new_task = None + event = await self._event_queue_agent.dequeue_event() + logger.debug( + 'Consumer[%s]: Dequeued event %s', + self._task_id, + type(event).__name__, + ) + + try: + if isinstance(event, _RequestCompleted): + logger.debug( + 'Consumer[%s]: Request completed', + self._task_id, + ) + self._request_lock.release() + elif isinstance(event, _RequestStarted): + logger.debug( + 'Consumer[%s]: Request started', + self._task_id, + ) + message_to_save = event.request_context.message + + elif isinstance(event, Message): + if task_mode is not None: + if task_mode: + raise InvalidAgentResponseError( + 'Received Message object in task mode. Use TaskStatusUpdateEvent or TaskArtifactUpdateEvent instead.' + ) + raise InvalidAgentResponseError( + 'Multiple Message objects received.' + ) + task_mode = False + logger.debug( + 'Consumer[%s]: Setting result to Message: %s', + self._task_id, + event, + ) + else: + if task_mode is False: + raise InvalidAgentResponseError( + f'Received {type(event).__name__} in message mode. Use Task with TaskStatusUpdateEvent and TaskArtifactUpdateEvent instead.' + ) + + if isinstance(event, Task): + existing_task = ( + await self._task_manager.get_task() + ) + if existing_task: + logger.error( + 'Task %s already exists. Ignoring task replacement.', + self._task_id, + ) + else: + await ( + self._task_manager.save_task_event( + event + ) + ) + # Initial task should already contain the message. + message_to_save = None + else: + if ( + isinstance(event, TaskStatusUpdateEvent) + and not self._task_created.is_set() + ): + task = ( + await self._task_manager.get_task() + ) + if task is None: + raise InvalidAgentResponseError( + f'Agent should enqueue Task before {type(event).__name__} event' + ) + + new_task = ( + await self._task_manager.ensure_task_id( + self._task_id, + event.context_id, + ) + ) + + if message_to_save is not None: + new_task = self._task_manager.update_with_message( + message_to_save, + new_task, + ) + await ( + self._task_manager.save_task_event( + new_task + ) + ) + message_to_save = None + + task_mode = True + # Save structural events (like TaskStatusUpdate) to DB. + + self._task_manager.context_id = event.context_id + if not isinstance(event, Task): + await self._task_manager.process(event) + + # Check for AUTH_REQUIRED or INPUT_REQUIRED or TERMINAL states + new_task = await self._task_manager.get_task() + if new_task is None: + raise RuntimeError( + f'Task {self.task_id} not found' + ) + if isinstance(event, Task): + event = new_task + is_interrupted = ( + new_task.status.state + in INTERRUPTED_TASK_STATES + ) + is_terminal = ( + new_task.status.state + in TERMINAL_TASK_STATES + ) + + # If we hit a breakpoint or terminal state, lock in the result. + if is_interrupted or is_terminal: + logger.debug( + 'Consumer[%s]: Setting first result as Task (state=%s)', + self._task_id, + new_task.status.state, + ) + + if is_terminal: + logger.debug( + 'Consumer[%s]: Reached terminal state %s', + self._task_id, + new_task.status.state, + ) + if not self._is_finished.is_set(): + async with self._lock: + # TODO: what about _reference_count when task is failing? + self._reference_count -= 1 + # _maybe_cleanup() is called in finally block. + + # Terminate the ActiveTask globally. + self._is_finished.set() + self._request_queue.shutdown(immediate=True) + + if is_interrupted: + logger.debug( + 'Consumer[%s]: Interrupted with state %s', + self._task_id, + new_task.status.state, + ) + + if ( + self._push_sender + and self._task_id + and isinstance(event, PushNotificationEvent) + ): + logger.debug( + 'Consumer[%s]: Sending push notification', + self._task_id, + ) + await self._push_sender.send_notification( + self._task_id, event + ) + + self._task_created.set() + + finally: + if new_task is not None: + new_task_copy = Task() + new_task_copy.CopyFrom(new_task) + new_task = new_task_copy + if isinstance(event, Task): + new_task_copy = Task() + new_task_copy.CopyFrom(event) + event = new_task_copy + + logger.debug( + 'Consumer[%s]: Enqueuing\nEvent: %s\nNew Task: %s\n', + self._task_id, + event, + new_task, + ) + await self._event_queue_subscribers.enqueue_event( + cast('Any', (event, new_task)) + ) + self._event_queue_agent.task_done() + except QueueShutDown: + logger.debug( + 'Consumer[%s]: Event queue shut down', self._task_id + ) + except Exception as e: + logger.exception('Consumer[%s]: Failed', self._task_id) + # TODO: Make the task in database as failed. + async with self._lock: + await self._mark_task_as_failed(e) + finally: + # The consumer is dead. The ActiveTask is permanently finished. + self._is_finished.set() + self._request_queue.shutdown(immediate=True) + await self._event_queue_agent.close(immediate=True) + + logger.debug('Consumer[%s]: Finishing', self._task_id) + await self._maybe_cleanup() + finally: + logger.debug('Consumer[%s]: Completed', self._task_id) + + async def subscribe( # noqa: PLR0912, PLR0915 + self, + *, + request: RequestContext | None = None, + include_initial_task: bool = False, + replace_status_update_with_task: bool = False, + ) -> AsyncGenerator[Event, None]: + """Creates a queue tap and yields events as they are produced. + + Concurrency Guarantee: + Uses `_lock` to safely increment and decrement `_reference_count`. + Safely detaches its queue tap when the client disconnects or the task finishes, + triggering `_maybe_cleanup()` to potentially garbage collect the ActiveTask. + """ + logger.debug('Subscribe[%s]: New subscriber', self._task_id) + + async with self._lock: + if self._exception: + logger.debug( + 'Subscribe[%s]: Failed, exception already set', + self._task_id, + ) + raise self._exception + if self._is_finished.is_set(): + raise InvalidParamsError( + f'Task {self._task_id} is already completed.' + ) + self._reference_count += 1 + logger.debug( + 'Subscribe[%s]: Subscribers count: %d', + self._task_id, + self._reference_count, + ) + + tapped_queue = await self._event_queue_subscribers.tap() + request_id = await self.enqueue_request(request) if request else None + + try: + if include_initial_task: + logger.debug( + 'Subscribe[%s]: Including initial task', + self._task_id, + ) + task = await self.get_task() + yield task + + while True: + try: + if self._exception: + raise self._exception + + dequeued = await tapped_queue.dequeue_event() + event, updated_task = cast('Any', dequeued) + logger.debug( + 'Subscriber[%s]\nDequeued event %s\nUpdated task %s\n', + self._task_id, + event, + updated_task, + ) + if replace_status_update_with_task and isinstance( + event, TaskStatusUpdateEvent + ): + logger.debug( + 'Subscriber[%s]: Replacing TaskStatusUpdateEvent with Task: %s', + self._task_id, + updated_task, + ) + event = updated_task + if self._exception: + raise self._exception from None + if isinstance(event, _RequestCompleted): + if ( + request_id is not None + and event.request_id == request_id + ): + logger.debug( + 'Subscriber[%s]: Request completed', + self._task_id, + ) + return + continue + elif isinstance(event, _RequestStarted): + logger.debug( + 'Subscriber[%s]: Request started', + self._task_id, + ) + continue + try: + yield event + finally: + tapped_queue.task_done() + except (QueueShutDown, asyncio.CancelledError): + if self._exception: + raise self._exception from None + break + finally: + logger.debug('Subscribe[%s]: Unsubscribing', self._task_id) + await tapped_queue.close(immediate=True) + async with self._lock: + self._reference_count -= 1 + # Evaluate if this was the last subscriber on a finished task. + await self._maybe_cleanup() + + async def cancel(self, call_context: ServerCallContext) -> Task: + """Cancels the running active task. + + Concurrency Guarantee: + Uses `_lock` to ensure we don't attempt to cancel a producer that is + already winding down or hasn't started. It fires the cancellation signal + and blocks until the consumer processes the cancellation events. + """ + logger.debug('Cancel[%s]: Cancelling task', self._task_id) + + # TODO: Conflicts with call_context on the pending request. + self._task_manager._call_context = call_context + + task = await self._task_manager.get_task() + request_context = RequestContext( + call_context=call_context, + task_id=self._task_id, + context_id=task.context_id if task else None, + task=task, + ) + + async with self._lock: + if not self._is_finished.is_set() and self._producer_task: + logger.debug( + 'Cancel[%s]: Cancelling producer task', self._task_id + ) + self._producer_task.cancel() + try: + await self._agent_executor.cancel( + request_context, self._event_queue_agent + ) + except Exception as e: + logger.exception( + 'Cancel[%s]: Agent cancel failed', self._task_id + ) + await self._mark_task_as_failed(e) + raise + else: + logger.debug( + 'Cancel[%s]: Task already finished [%s] or producer not started [%s], not cancelling', + self._task_id, + self._is_finished.is_set(), + self._producer_task, + ) + + await self._is_finished.wait() + task = await self._task_manager.get_task() + if not task: + raise RuntimeError('Task should have been created') + return task + + async def _maybe_cleanup(self) -> None: + """Triggers cleanup if task is finished and has no subscribers. + + Concurrency Guarantee: + Protected by `_lock` to prevent race conditions where a new subscriber + attaches at the exact moment the task decides to garbage collect itself. + """ + async with self._lock: + logger.debug( + 'Cleanup[%s]: Subscribers count: %d is_finished: %s', + self._task_id, + self._reference_count, + self._is_finished.is_set(), + ) + + if ( + self._is_finished.is_set() + and self._reference_count == 0 + and self._on_cleanup + ): + logger.debug('Cleanup[%s]: Triggering cleanup', self._task_id) + self._on_cleanup(self) + + async def _mark_task_as_failed(self, exception: Exception) -> None: + if self._exception is None: + self._exception = exception + if self._task_created.is_set(): + try: + task = await self._task_manager.get_task() + if task is not None: + await self._event_queue_agent.enqueue_event( + TaskStatusUpdateEvent( + task_id=task.id, + context_id=task.context_id, + status=TaskStatus( + state=TaskState.TASK_STATE_FAILED, + ), + ) + ) + except QueueShutDown: + pass + + async def get_task(self) -> Task: + """Get task from db.""" + # TODO: THERE IS ZERO CONCURRENCY SAFETY HERE (Except inital task creation). + await self._task_created.wait() + task = await self._task_manager.get_task() + if not task: + raise RuntimeError('Task should have been created') + return task diff --git a/src/a2a/server/agent_execution/active_task_registry.py b/src/a2a/server/agent_execution/active_task_registry.py new file mode 100644 index 000000000..9c1299ab3 --- /dev/null +++ b/src/a2a/server/agent_execution/active_task_registry.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +import asyncio +import logging + +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + from a2a.server.agent_execution.agent_executor import AgentExecutor + from a2a.server.context import ServerCallContext + from a2a.server.tasks.push_notification_sender import PushNotificationSender + from a2a.server.tasks.task_store import TaskStore + +from a2a.server.agent_execution.active_task import ActiveTask +from a2a.server.tasks.task_manager import TaskManager + + +logger = logging.getLogger(__name__) + + +class ActiveTaskRegistry: + """A registry for active ActiveTask instances.""" + + def __init__( + self, + agent_executor: AgentExecutor, + task_store: TaskStore, + push_sender: PushNotificationSender | None = None, + ): + self._agent_executor = agent_executor + self._task_store = task_store + self._push_sender = push_sender + self._active_tasks: dict[str, ActiveTask] = {} + self._lock = asyncio.Lock() + self._cleanup_tasks: set[asyncio.Task[None]] = set() + + async def get_or_create( + self, + task_id: str, + call_context: ServerCallContext, + context_id: str | None = None, + create_task_if_missing: bool = False, + ) -> ActiveTask: + """Retrieves an existing ActiveTask or creates a new one.""" + async with self._lock: + if task_id in self._active_tasks: + return self._active_tasks[task_id] + + task_manager = TaskManager( + task_id=task_id, + context_id=context_id, + task_store=self._task_store, + initial_message=None, + context=call_context, + ) + + active_task = ActiveTask( + agent_executor=self._agent_executor, + task_id=task_id, + task_manager=task_manager, + push_sender=self._push_sender, + on_cleanup=self._on_active_task_cleanup, + ) + self._active_tasks[task_id] = active_task + + await active_task.start( + call_context=call_context, + create_task_if_missing=create_task_if_missing, + ) + return active_task + + def _on_active_task_cleanup(self, active_task: ActiveTask) -> None: + """Called by ActiveTask when it's finished and has no subscribers.""" + logger.debug('Active task %s cleanup scheduled', active_task.task_id) + task = asyncio.create_task(self._remove_task(active_task.task_id)) + self._cleanup_tasks.add(task) + task.add_done_callback(self._cleanup_tasks.discard) + + async def _remove_task(self, task_id: str) -> None: + async with self._lock: + self._active_tasks.pop(task_id, None) + logger.debug('Removed active task for %s from registry', task_id) + + async def get(self, task_id: str) -> ActiveTask | None: + """Retrieves an existing task.""" + async with self._lock: + return self._active_tasks.get(task_id) diff --git a/src/a2a/server/agent_execution/agent_executor.py b/src/a2a/server/agent_execution/agent_executor.py index e03232b35..1c3866047 100644 --- a/src/a2a/server/agent_execution/agent_executor.py +++ b/src/a2a/server/agent_execution/agent_executor.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from a2a.server.agent_execution.context import RequestContext -from a2a.server.events.event_queue import EventQueue +from a2a.server.events.event_queue_v2 import EventQueue class AgentExecutor(ABC): @@ -23,6 +23,44 @@ async def execute( return once the agent's execution for this request is complete or yields control (e.g., enters an input-required state). + Request Lifecycle & AgentExecutor Responsibilities: + - **Concurrency**: The framework guarantees single execution per request; + `execute()` will not be called concurrently for the same request context. + - **Exception Handling**: Unhandled exceptions raised by `execute()` will be + caught by the framework and result in the task transitioning to + `TaskState.TASK_STATE_ERROR`. + - **Post-Completion**: Once `execute()` completes (returns or raises), the + executor must not access the `context` or `event_queue` anymore. + - **Terminal States**: Before completing the call normally, the executor + SHOULD publish a `TaskStatusUpdateEvent` to transition the task to a + terminal state (e.g., `TASK_STATE_COMPLETED`) or an interrupted state + (`TASK_STATE_INPUT_REQUIRED` or `TASK_STATE_AUTH_REQUIRED`). + - **Interrupted Workflows**: + - `TASK_STATE_INPUT_REQUIRED`: The executor publishes a `TaskStatusUpdateEvent` with + `TaskState.TASK_STATE_INPUT_REQUIRED` and returns to yield control. + The request will resume once user input is provided. + - `TASK_STATE_AUTH_REQUIRED`: There are in-bound and out-of-bound auth models. + In both scenarios, the agent publishes a `TaskStatusUpdateEvent` with + `TaskState.TASK_STATE_AUTH_REQUIRED`. + - In-bound: The agent should return from `execute()`. The framework will + call `execute()` again once the user response is received. + - Out-of-bound: The agent should not return from `execute()`. It should wait + for the out-of-band auth provider to complete the authentication and then + continue execution. + + - **Cancellation Workflow**: When a cancellation request is received, the + async task running `execute()` is cancelled (raising an `asyncio.CancelledError`), + and `cancel()` is explicitly called by the framework. + + Allowed Workflows: + - Immediate response: Enqueue a SINGLE `Message` object. + - Asynchronous/Long-running: Enqueue a `Task` object, perform work, and emit + multiple `TaskStatusUpdateEvent` / `TaskArtifactUpdateEvent` objects over time. + + Note that the framework waits with response to the send_message request with + `return_immediately=True` parameter until the first event (Message or Task) + is enqueued by AgentExecutor. + Args: context: The request context containing the message, task ID, etc. event_queue: The queue to publish events to. diff --git a/src/a2a/server/agent_execution/context.py b/src/a2a/server/agent_execution/context.py index 91284f37c..5fcdf8697 100644 --- a/src/a2a/server/agent_execution/context.py +++ b/src/a2a/server/agent_execution/context.py @@ -1,5 +1,6 @@ from typing import Any +from a2a.helpers.proto_helpers import get_message_text from a2a.server.context import ServerCallContext from a2a.server.id_generator import ( IDGenerator, @@ -12,7 +13,6 @@ SendMessageRequest, Task, ) -from a2a.utils import get_message_text from a2a.utils.errors import InvalidParamsError @@ -120,7 +120,7 @@ def current_task(self) -> Task | None: return self._current_task @current_task.setter - def current_task(self, task: Task) -> None: + def current_task(self, task: Task | None) -> None: """Sets the current task object.""" self._current_task = task @@ -151,14 +151,6 @@ def metadata(self) -> dict[str, Any]: return dict(self._params.metadata) return {} - def add_activated_extension(self, uri: str) -> None: - """Add an extension to the set of activated extensions for this request. - - This causes the extension to be indicated back to the client in the - response. - """ - self._call_context.activated_extensions.add(uri) - @property def tenant(self) -> str: """The tenant associated with this request.""" @@ -166,7 +158,7 @@ def tenant(self) -> str: @property def requested_extensions(self) -> set[str]: - """Extensions that the client requested to activate.""" + """Extensions that the client requested for this interaction.""" return self._call_context.requested_extensions def _check_or_generate_task_id(self) -> None: diff --git a/src/a2a/server/context.py b/src/a2a/server/context.py index 6196a69d6..833ca44c4 100644 --- a/src/a2a/server/context.py +++ b/src/a2a/server/context.py @@ -23,4 +23,3 @@ class ServerCallContext(BaseModel): user: User = Field(default_factory=UnauthenticatedUser) tenant: str = Field(default='') requested_extensions: set[str] = Field(default_factory=set) - activated_extensions: set[str] = Field(default_factory=set) diff --git a/src/a2a/server/events/event_consumer.py b/src/a2a/server/events/event_consumer.py index a29394795..8414e2d17 100644 --- a/src/a2a/server/events/event_consumer.py +++ b/src/a2a/server/events/event_consumer.py @@ -5,7 +5,7 @@ from pydantic import ValidationError -from a2a.server.events.event_queue import Event, EventQueue, QueueShutDown +from a2a.server.events.event_queue import Event, EventQueueLegacy, QueueShutDown from a2a.types.a2a_pb2 import ( Message, Task, @@ -22,7 +22,7 @@ class EventConsumer: """Consumer to read events from the agent event queue.""" - def __init__(self, queue: EventQueue): + def __init__(self, queue: EventQueueLegacy): """Initializes the EventConsumer. Args: diff --git a/src/a2a/server/events/event_queue.py b/src/a2a/server/events/event_queue.py index 25598d15b..bb4d7b9b4 100644 --- a/src/a2a/server/events/event_queue.py +++ b/src/a2a/server/events/event_queue.py @@ -92,73 +92,6 @@ async def enqueue_event(self, event: Event) -> None: Only main queue can enqueue events. Child queues can only dequeue events. """ - @abstractmethod - async def dequeue_event(self) -> Event: - """Pulls an event from the queue.""" - - @abstractmethod - def task_done(self) -> None: - """Signals that a work on dequeued event is complete.""" - - @abstractmethod - async def tap( - self, max_queue_size: int = DEFAULT_MAX_QUEUE_SIZE - ) -> 'EventQueue': - """Creates a child queue that receives future events. - - Note: The tapped queue may receive some old events if the incoming event - queue is lagging behind and hasn't dispatched them yet. - """ - - @abstractmethod - async def close(self, immediate: bool = False) -> None: - """Closes the queue. - - For parent queue: it closes the main queue and all its child queues. - For child queue: it closes only child queue. - - It is safe to call it multiple times. - If immediate is True, the queue will be closed without waiting for all events to be processed. - If immediate is False, the queue will be closed after all events are processed (and confirmed with task_done() calls). - - WARNING: Closing the parent queue with immediate=False is a deadlock risk if there are unconsumed events - in any of the child sinks and the consumer has crashed without draining its queue. - It is highly recommended to wrap graceful shutdowns with a timeout, e.g., - `asyncio.wait_for(queue.close(immediate=False), timeout=...)`. - """ - - @abstractmethod - def is_closed(self) -> bool: - """[DEPRECATED] Checks if the queue is closed. - - NOTE: Relying on this for enqueue logic introduces race conditions. - It is maintained primarily for backwards compatibility, workarounds for - Python 3.10/3.12 async queues in consumers, and for the test suite. - """ - - @abstractmethod - async def __aenter__(self) -> Self: - """Enters the async context manager, returning the queue itself. - - WARNING: See `__aexit__` for important deadlock risks associated with - exiting this context manager if unconsumed events remain. - """ - - @abstractmethod - async def __aexit__( - self, - exc_type: type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: TracebackType | None, - ) -> None: - """Exits the async context manager, ensuring close() is called. - - WARNING: The context manager calls `close(immediate=False)` by default. - If a consumer exits the `async with` block early (e.g., due to an exception - or an explicit `break`) while unconsumed events remain in the queue, - `__aexit__` will deadlock waiting for `task_done()` to be called on those events. - """ - @trace_class(kind=SpanKind.SERVER) class EventQueueLegacy(EventQueue): @@ -180,7 +113,7 @@ def __init__(self, max_queue_size: int = DEFAULT_MAX_QUEUE_SIZE) -> None: self._queue: AsyncQueue[Event] = _create_async_queue( maxsize=max_queue_size ) - self._children: list[EventQueue] = [] + self._children: list[EventQueueLegacy] = [] self._is_closed = False self._lock = asyncio.Lock() logger.debug('EventQueue initialized.') diff --git a/src/a2a/server/events/event_queue_v2.py b/src/a2a/server/events/event_queue_v2.py index 5642bfbc6..224cb8e56 100644 --- a/src/a2a/server/events/event_queue_v2.py +++ b/src/a2a/server/events/event_queue_v2.py @@ -28,7 +28,11 @@ class EventQueueSource(EventQueue): in `_incoming_queue` and distributed to all child Sinks by a background dispatcher task. """ - def __init__(self, max_queue_size: int = DEFAULT_MAX_QUEUE_SIZE) -> None: + def __init__( + self, + max_queue_size: int = DEFAULT_MAX_QUEUE_SIZE, + create_default_sink: bool = True, + ) -> None: """Initializes the EventQueueSource.""" if max_queue_size <= 0: raise ValueError('max_queue_size must be greater than 0') @@ -41,10 +45,15 @@ def __init__(self, max_queue_size: int = DEFAULT_MAX_QUEUE_SIZE) -> None: self._is_closed = False # Internal sink for backward compatibility - self._default_sink = EventQueueSink( - parent=self, max_queue_size=max_queue_size - ) - self._sinks.add(self._default_sink) + self._default_sink: EventQueueSink | None + if create_default_sink: + self._default_sink = EventQueueSink( + parent=self, max_queue_size=max_queue_size + ) + self._sinks.add(self._default_sink) + else: + self._default_sink = None + self._dispatcher_task = asyncio.create_task(self._dispatch_loop()) self._dispatcher_task_expected_to_cancel = False @@ -54,6 +63,8 @@ def __init__(self, max_queue_size: int = DEFAULT_MAX_QUEUE_SIZE) -> None: @property def queue(self) -> AsyncQueue[Event]: """Returns the underlying asyncio.Queue of the default sink.""" + if self._default_sink is None: + raise ValueError('No default sink available.') return self._default_sink.queue async def _dispatch_loop(self) -> None: @@ -182,15 +193,29 @@ async def enqueue_event(self, event: Event) -> None: return async def dequeue_event(self) -> Event: - """Dequeues an event from the default internal sink queue.""" + """Pulls an event from the default internal sink queue.""" + if self._default_sink is None: + raise ValueError('No default sink available.') return await self._default_sink.dequeue_event() def task_done(self) -> None: - """Signals that a formerly enqueued task is complete via the default internal sink queue.""" + """Signals that a work on dequeued event is complete via the default internal sink queue.""" + if self._default_sink is None: + raise ValueError('No default sink available.') self._default_sink.task_done() async def close(self, immediate: bool = False) -> None: - """Closes the queue for future push events and also closes all child sinks.""" + """Closes the queue and all its child sinks. + + It is safe to call it multiple times. + If immediate is True, the queue will be closed without waiting for all events to be processed. + If immediate is False, the queue will be closed after all events are processed (and confirmed with task_done() calls). + + WARNING: Closing the parent queue with immediate=False is a deadlock risk if there are unconsumed events + in any of the child sinks and the consumer has crashed without draining its queue. + It is highly recommended to wrap graceful shutdowns with a timeout, e.g., + `asyncio.wait_for(queue.close(immediate=False), timeout=...)`. + """ logger.debug('Closing EventQueueSource: immediate=%s', immediate) async with self._lock: # No more tap() allowed. @@ -215,7 +240,12 @@ async def close(self, immediate: bool = False) -> None: ) def is_closed(self) -> bool: - """Checks if the queue is closed.""" + """[DEPRECATED] Checks if the queue is closed. + + NOTE: Relying on this for enqueue logic introduces race conditions. + It is maintained primarily for backwards compatibility, workarounds for + Python 3.10/3.12 async queues in consumers, and for the test suite. + """ return self._is_closed async def test_only_join_incoming_queue(self) -> None: @@ -223,7 +253,11 @@ async def test_only_join_incoming_queue(self) -> None: await self._join_incoming_queue() async def __aenter__(self) -> Self: - """Enters the async context manager, returning the queue itself.""" + """Enters the async context manager, returning the queue itself. + + WARNING: See `__aexit__` for important deadlock risks associated with + exiting this context manager if unconsumed events remain. + """ return self async def __aexit__( @@ -232,7 +266,13 @@ async def __aexit__( exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> None: - """Exits the async context manager, ensuring close() is called.""" + """Exits the async context manager, ensuring close() is called. + + WARNING: The context manager calls `close(immediate=False)` by default. + If a consumer exits the `async with` block early (e.g., due to an exception + or an explicit `break`) while unconsumed events remain in the queue, + `__aexit__` will deadlock waiting for `task_done()` to be called on those events. + """ await self.close() @@ -275,26 +315,35 @@ async def enqueue_event(self, event: Event) -> None: raise RuntimeError('Cannot enqueue to a sink-only queue') async def dequeue_event(self) -> Event: - """Dequeues an event from the sink queue.""" + """Pulls an event from the sink queue.""" logger.debug('Attempting to dequeue event (waiting).') event = await self._queue.get() logger.debug('Dequeued event: %s', event) return event def task_done(self) -> None: - """Signals that a formerly enqueued task is complete in this sink queue.""" + """Signals that a work on dequeued event is complete in this sink queue.""" logger.debug('Marking task as done in EventQueueSink.') self._queue.task_done() async def tap( self, max_queue_size: int = DEFAULT_MAX_QUEUE_SIZE ) -> 'EventQueueSink': - """Taps the event queue to create a new child queue that receives future events.""" + """Creates a child queue that receives future events. + + Note: The tapped queue may receive some old events if the incoming event + queue is lagging behind and hasn't dispatched them yet. + """ # Delegate tap to the parent source so all sinks are flat under the source return await self._parent.tap(max_queue_size=max_queue_size) async def close(self, immediate: bool = False) -> None: - """Closes the child sink queue.""" + """Closes the child sink queue. + + It is safe to call it multiple times. + If immediate is True, the queue will be closed without waiting for all events to be processed. + If immediate is False, the queue will be closed after all events are processed (and confirmed with task_done() calls). + """ logger.debug('Closing EventQueueSink.') async with self._lock: self._is_closed = True @@ -308,11 +357,20 @@ async def close(self, immediate: bool = False) -> None: await self._queue.join() def is_closed(self) -> bool: - """Checks if the sink queue is closed.""" + """[DEPRECATED] Checks if the queue is closed. + + NOTE: Relying on this for enqueue logic introduces race conditions. + It is maintained primarily for backwards compatibility, workarounds for + Python 3.10/3.12 async queues in consumers, and for the test suite. + """ return self._is_closed async def __aenter__(self) -> Self: - """Enters the async context manager, returning the queue itself.""" + """Enters the async context manager, returning the queue itself. + + WARNING: See `__aexit__` for important deadlock risks associated with + exiting this context manager if unconsumed events remain. + """ return self async def __aexit__( @@ -321,5 +379,11 @@ async def __aexit__( exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> None: - """Exits the async context manager, ensuring close() is called.""" + """Exits the async context manager, ensuring close() is called. + + WARNING: The context manager calls `close(immediate=False)` by default. + If a consumer exits the `async with` block early (e.g., due to an exception + or an explicit `break`) while unconsumed events remain in the queue, + `__aexit__` will deadlock waiting for `task_done()` to be called on those events. + """ await self.close() diff --git a/src/a2a/server/events/in_memory_queue_manager.py b/src/a2a/server/events/in_memory_queue_manager.py index ddff52419..0beb354f9 100644 --- a/src/a2a/server/events/in_memory_queue_manager.py +++ b/src/a2a/server/events/in_memory_queue_manager.py @@ -1,6 +1,6 @@ import asyncio -from a2a.server.events.event_queue import EventQueue, EventQueueLegacy +from a2a.server.events.event_queue import EventQueueLegacy from a2a.server.events.queue_manager import ( NoTaskQueue, QueueManager, @@ -23,10 +23,10 @@ class InMemoryQueueManager(QueueManager): def __init__(self) -> None: """Initializes the InMemoryQueueManager.""" - self._task_queue: dict[str, EventQueue] = {} + self._task_queue: dict[str, EventQueueLegacy] = {} self._lock = asyncio.Lock() - async def add(self, task_id: str, queue: EventQueue) -> None: + async def add(self, task_id: str, queue: EventQueueLegacy) -> None: """Adds a new event queue for a task ID. Raises: @@ -37,22 +37,22 @@ async def add(self, task_id: str, queue: EventQueue) -> None: raise TaskQueueExists self._task_queue[task_id] = queue - async def get(self, task_id: str) -> EventQueue | None: + async def get(self, task_id: str) -> EventQueueLegacy | None: """Retrieves the event queue for a task ID. Returns: - The `EventQueue` instance for the `task_id`, or `None` if not found. + The `EventQueueLegacy` instance for the `task_id`, or `None` if not found. """ async with self._lock: if task_id not in self._task_queue: return None return self._task_queue[task_id] - async def tap(self, task_id: str) -> EventQueue | None: + async def tap(self, task_id: str) -> EventQueueLegacy | None: """Taps the event queue for a task ID to create a child queue. Returns: - A new child `EventQueue` instance, or `None` if the task ID is not found. + A new child `EventQueueLegacy` instance, or `None` if the task ID is not found. """ async with self._lock: if task_id not in self._task_queue: @@ -71,11 +71,11 @@ async def close(self, task_id: str) -> None: queue = self._task_queue.pop(task_id) await queue.close() - async def create_or_tap(self, task_id: str) -> EventQueue: + async def create_or_tap(self, task_id: str) -> EventQueueLegacy: """Creates a new event queue for a task ID if one doesn't exist, otherwise taps the existing one. Returns: - A new or child `EventQueue` instance for the `task_id`. + A new or child `EventQueueLegacy` instance for the `task_id`. """ async with self._lock: if task_id not in self._task_queue: diff --git a/src/a2a/server/events/queue_manager.py b/src/a2a/server/events/queue_manager.py index ed69aae68..b3ec204a5 100644 --- a/src/a2a/server/events/queue_manager.py +++ b/src/a2a/server/events/queue_manager.py @@ -1,21 +1,21 @@ from abc import ABC, abstractmethod -from a2a.server.events.event_queue import EventQueue +from a2a.server.events.event_queue import EventQueueLegacy class QueueManager(ABC): """Interface for managing the event queue lifecycles per task.""" @abstractmethod - async def add(self, task_id: str, queue: EventQueue) -> None: + async def add(self, task_id: str, queue: EventQueueLegacy) -> None: """Adds a new event queue associated with a task ID.""" @abstractmethod - async def get(self, task_id: str) -> EventQueue | None: + async def get(self, task_id: str) -> EventQueueLegacy | None: """Retrieves the event queue for a task ID.""" @abstractmethod - async def tap(self, task_id: str) -> EventQueue | None: + async def tap(self, task_id: str) -> EventQueueLegacy | None: """Creates a child event queue (tap) for an existing task ID.""" @abstractmethod @@ -23,7 +23,7 @@ async def close(self, task_id: str) -> None: """Closes and removes the event queue for a task ID.""" @abstractmethod - async def create_or_tap(self, task_id: str) -> EventQueue: + async def create_or_tap(self, task_id: str) -> EventQueueLegacy: """Creates a queue if one doesn't exist, otherwise taps the existing one.""" diff --git a/src/a2a/server/request_handlers/__init__.py b/src/a2a/server/request_handlers/__init__.py index 194e81a45..34654cb58 100644 --- a/src/a2a/server/request_handlers/__init__.py +++ b/src/a2a/server/request_handlers/__init__.py @@ -3,7 +3,10 @@ import logging from a2a.server.request_handlers.default_request_handler import ( - DefaultRequestHandler, + LegacyRequestHandler, +) +from a2a.server.request_handlers.default_request_handler_v2 import ( + DefaultRequestHandlerV2, ) from a2a.server.request_handlers.request_handler import ( RequestHandler, @@ -40,11 +43,15 @@ def __init__(self, *args, **kwargs): ) from _original_error +DefaultRequestHandler = DefaultRequestHandlerV2 + __all__ = [ 'DefaultGrpcServerCallContextBuilder', 'DefaultRequestHandler', + 'DefaultRequestHandlerV2', 'GrpcHandler', 'GrpcServerCallContextBuilder', + 'LegacyRequestHandler', 'RequestHandler', 'build_error_response', 'prepare_response_object', diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index 67b51e248..e803b567f 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -1,7 +1,7 @@ import asyncio import logging -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, Awaitable, Callable from typing import cast from a2a.server.agent_execution import ( @@ -14,13 +14,13 @@ from a2a.server.events import ( Event, EventConsumer, - EventQueue, EventQueueLegacy, InMemoryQueueManager, QueueManager, ) from a2a.server.request_handlers.request_handler import ( RequestHandler, + validate, validate_request_params, ) from a2a.server.tasks import ( @@ -32,8 +32,10 @@ TaskStore, ) from a2a.types.a2a_pb2 import ( + AgentCard, CancelTaskRequest, DeleteTaskPushNotificationConfigRequest, + GetExtendedAgentCardRequest, GetTaskPushNotificationConfigRequest, GetTaskRequest, ListTaskPushNotificationConfigsRequest, @@ -48,6 +50,7 @@ TaskState, ) from a2a.utils.errors import ( + ExtendedAgentCardNotConfiguredError, InternalError, InvalidParamsError, PushNotificationNotSupportedError, @@ -74,7 +77,7 @@ @trace_class(kind=SpanKind.SERVER) -class DefaultRequestHandler(RequestHandler): +class LegacyRequestHandler(RequestHandler): """Default request handler for all incoming requests. This handler provides default implementations for all A2A JSON-RPC methods, @@ -89,27 +92,39 @@ def __init__( # noqa: PLR0913 self, agent_executor: AgentExecutor, task_store: TaskStore, + agent_card: AgentCard, queue_manager: QueueManager | None = None, push_config_store: PushNotificationConfigStore | None = None, push_sender: PushNotificationSender | None = None, request_context_builder: RequestContextBuilder | None = None, + extended_agent_card: AgentCard | None = None, + extended_card_modifier: Callable[ + [AgentCard, ServerCallContext], Awaitable[AgentCard] + ] + | None = None, ) -> None: """Initializes the DefaultRequestHandler. Args: agent_executor: The `AgentExecutor` instance to run agent logic. task_store: The `TaskStore` instance to manage task persistence. + agent_card: The `AgentCard` describing the agent's capabilities. queue_manager: The `QueueManager` instance to manage event queues. Defaults to `InMemoryQueueManager`. push_config_store: The `PushNotificationConfigStore` instance for managing push notification configurations. Defaults to None. push_sender: The `PushNotificationSender` instance for sending push notifications. Defaults to None. request_context_builder: The `RequestContextBuilder` instance used to build request contexts. Defaults to `SimpleRequestContextBuilder`. + extended_agent_card: An optional, distinct `AgentCard` to be served at the extended card endpoint. + extended_card_modifier: An optional callback to dynamically modify the extended `AgentCard` before it is served. """ self.agent_executor = agent_executor self.task_store = task_store + self._agent_card = agent_card self._queue_manager = queue_manager or InMemoryQueueManager() self._push_config_store = push_config_store self._push_sender = push_sender + self.extended_agent_card = extended_agent_card + self.extended_card_modifier = extended_card_modifier self._request_context_builder = ( request_context_builder or SimpleRequestContextBuilder( @@ -224,7 +239,7 @@ async def on_cancel_task( return result async def _run_event_stream( - self, request: RequestContext, queue: EventQueue + self, request: RequestContext, queue: EventQueueLegacy ) -> None: """Runs the agent's `execute` method and closes the queue afterwards. @@ -239,7 +254,9 @@ async def _setup_message_execution( self, params: SendMessageRequest, context: ServerCallContext, - ) -> tuple[TaskManager, str, EventQueue, ResultAggregator, asyncio.Task]: + ) -> tuple[ + TaskManager, str, EventQueueLegacy, ResultAggregator, asyncio.Task + ]: """Common setup logic for both streaming and non-streaming message handling. Returns: @@ -397,6 +414,10 @@ async def push_notification_callback(event: Event) -> None: return result @validate_request_params + @validate( + lambda self: self._agent_card.capabilities.streaming, + 'Streaming is not supported by the agent', + ) async def on_message_send_stream( self, params: SendMessageRequest, @@ -486,6 +507,11 @@ async def _cleanup_producer( self._running_agents.pop(task_id, None) @validate_request_params + @validate( + lambda self: self._agent_card.capabilities.push_notifications, + error_message='Push notifications are not supported by the agent', + error_type=PushNotificationNotSupportedError, + ) async def on_create_task_push_notification_config( self, params: TaskPushNotificationConfig, @@ -512,6 +538,11 @@ async def on_create_task_push_notification_config( return params @validate_request_params + @validate( + lambda self: self._agent_card.capabilities.push_notifications, + error_message='Push notifications are not supported by the agent', + error_type=PushNotificationNotSupportedError, + ) async def on_get_task_push_notification_config( self, params: GetTaskPushNotificationConfigRequest, @@ -538,9 +569,13 @@ async def on_get_task_push_notification_config( if config.id == config_id: return config - raise InternalError(message='Push notification config not found') + raise TaskNotFoundError @validate_request_params + @validate( + lambda self: self._agent_card.capabilities.streaming, + 'Streaming is not supported by the agent', + ) async def on_subscribe_to_task( self, params: SubscribeToTaskRequest, @@ -584,6 +619,11 @@ async def on_subscribe_to_task( yield event @validate_request_params + @validate( + lambda self: self._agent_card.capabilities.push_notifications, + error_message='Push notifications are not supported by the agent', + error_type=PushNotificationNotSupportedError, + ) async def on_list_task_push_notification_configs( self, params: ListTaskPushNotificationConfigsRequest, @@ -610,6 +650,11 @@ async def on_list_task_push_notification_configs( ) @validate_request_params + @validate( + lambda self: self._agent_card.capabilities.push_notifications, + error_message='Push notifications are not supported by the agent', + error_type=PushNotificationNotSupportedError, + ) async def on_delete_task_push_notification_config( self, params: DeleteTaskPushNotificationConfigRequest, @@ -629,3 +674,28 @@ async def on_delete_task_push_notification_config( raise TaskNotFoundError await self._push_config_store.delete_info(task_id, context, config_id) + + @validate_request_params + @validate( + lambda self: self._agent_card.capabilities.extended_agent_card, + error_message='The agent does not support authenticated extended cards', + ) + async def on_get_extended_agent_card( + self, + params: GetExtendedAgentCardRequest, + context: ServerCallContext, + ) -> AgentCard: + """Default handler for 'GetExtendedAgentCard'. + + Requires `capabilities.extended_agent_card` to be true. + """ + extended_card = self.extended_agent_card + if not extended_card: + raise ExtendedAgentCardNotConfiguredError + + if self.extended_card_modifier: + extended_card = await self.extended_card_modifier( + extended_card, context + ) + + return extended_card diff --git a/src/a2a/server/request_handlers/default_request_handler_v2.py b/src/a2a/server/request_handlers/default_request_handler_v2.py new file mode 100644 index 000000000..ecdc0cfef --- /dev/null +++ b/src/a2a/server/request_handlers/default_request_handler_v2.py @@ -0,0 +1,482 @@ +from __future__ import annotations + +import asyncio # noqa: TC003 +import logging + +from typing import TYPE_CHECKING, Any, cast + +from a2a.server.agent_execution import ( + AgentExecutor, + RequestContext, + RequestContextBuilder, + SimpleRequestContextBuilder, +) +from a2a.server.agent_execution.active_task import ( + INTERRUPTED_TASK_STATES, + TERMINAL_TASK_STATES, +) +from a2a.server.agent_execution.active_task_registry import ActiveTaskRegistry +from a2a.server.request_handlers.request_handler import ( + RequestHandler, + validate, + validate_request_params, +) +from a2a.types.a2a_pb2 import ( + AgentCard, + CancelTaskRequest, + DeleteTaskPushNotificationConfigRequest, + GetExtendedAgentCardRequest, + GetTaskPushNotificationConfigRequest, + GetTaskRequest, + ListTaskPushNotificationConfigsRequest, + ListTaskPushNotificationConfigsResponse, + ListTasksRequest, + ListTasksResponse, + Message, + SendMessageRequest, + SubscribeToTaskRequest, + Task, + TaskPushNotificationConfig, + TaskStatusUpdateEvent, +) +from a2a.utils.errors import ( + ExtendedAgentCardNotConfiguredError, + InternalError, + InvalidParamsError, + PushNotificationNotSupportedError, + TaskNotCancelableError, + TaskNotFoundError, +) +from a2a.utils.task import ( + apply_history_length, + validate_history_length, + validate_page_size, +) +from a2a.utils.telemetry import SpanKind, trace_class + + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator, Awaitable, Callable + + from a2a.server.agent_execution.active_task import ActiveTask + from a2a.server.context import ServerCallContext + from a2a.server.events import Event + from a2a.server.tasks import ( + PushNotificationConfigStore, + PushNotificationSender, + TaskStore, + ) + + +logger = logging.getLogger(__name__) + + +# TODO: cleanup context_id management + + +@trace_class(kind=SpanKind.SERVER) +class DefaultRequestHandlerV2(RequestHandler): + """Default request handler for all incoming requests.""" + + _background_tasks: set[asyncio.Task] + + def __init__( # noqa: PLR0913 + self, + agent_executor: AgentExecutor, + task_store: TaskStore, + agent_card: AgentCard, + queue_manager: Any + | None = None, # Kept for backward compat in signature + push_config_store: PushNotificationConfigStore | None = None, + push_sender: PushNotificationSender | None = None, + request_context_builder: RequestContextBuilder | None = None, + extended_agent_card: AgentCard | None = None, + extended_card_modifier: Callable[ + [AgentCard, ServerCallContext], Awaitable[AgentCard] + ] + | None = None, + ) -> None: + self.agent_executor = agent_executor + self.task_store = task_store + self._agent_card = agent_card + self._push_config_store = push_config_store + self._push_sender = push_sender + self.extended_agent_card = extended_agent_card + self.extended_card_modifier = extended_card_modifier + self._request_context_builder = ( + request_context_builder + or SimpleRequestContextBuilder( + should_populate_referred_tasks=False, task_store=self.task_store + ) + ) + self._active_task_registry = ActiveTaskRegistry( + agent_executor=self.agent_executor, + task_store=self.task_store, + push_sender=self._push_sender, + ) + self._background_tasks = set() + + @validate_request_params + async def on_get_task( # noqa: D102 + self, + params: GetTaskRequest, + context: ServerCallContext, + ) -> Task | None: + validate_history_length(params) + + task_id = params.id + task: Task | None = await self.task_store.get(task_id, context) + if not task: + raise TaskNotFoundError + + return apply_history_length(task, params) + + @validate_request_params + async def on_list_tasks( # noqa: D102 + self, + params: ListTasksRequest, + context: ServerCallContext, + ) -> ListTasksResponse: + validate_history_length(params) + if params.HasField('page_size'): + validate_page_size(params.page_size) + + page = await self.task_store.list(params, context) + for task in page.tasks: + if not params.include_artifacts: + task.ClearField('artifacts') + + updated_task = apply_history_length(task, params) + if updated_task is not task: + task.CopyFrom(updated_task) + + return page + + @validate_request_params + async def on_cancel_task( # noqa: D102 + self, + params: CancelTaskRequest, + context: ServerCallContext, + ) -> Task | None: + task_id = params.id + + try: + active_task = await self._active_task_registry.get_or_create( + task_id, call_context=context, create_task_if_missing=False + ) + result = await active_task.cancel(context) + except InvalidParamsError as e: + raise TaskNotCancelableError from e + + if isinstance(result, Message): + raise InternalError( + message='Cancellation returned a message instead of a task.' + ) + + return result + + def _validate_task_id_match(self, task_id: str, event_task_id: str) -> None: + if task_id != event_task_id: + logger.error( + 'Agent generated task_id=%s does not match the RequestContext task_id=%s.', + event_task_id, + task_id, + ) + raise InternalError(message='Task ID mismatch in agent response') + + async def _setup_active_task( + self, + params: SendMessageRequest, + call_context: ServerCallContext, + ) -> tuple[ActiveTask, RequestContext]: + validate_history_length(params.configuration) + + original_task_id = params.message.task_id or None + original_context_id = params.message.context_id or None + + if original_task_id: + task = await self.task_store.get(original_task_id, call_context) + if not task: + raise TaskNotFoundError(f'Task {original_task_id} not found') + + # Build context to resolve or generate missing IDs + request_context = await self._request_context_builder.build( + params=params, + task_id=original_task_id, + context_id=original_context_id, + # We will get the task when we have to process the request to avoid concurrent read/write issues. + task=None, + context=call_context, + ) + + task_id = cast('str', request_context.task_id) + context_id = cast('str', request_context.context_id) + + if ( + self._push_config_store + and params.configuration + and params.configuration.task_push_notification_config + ): + await self._push_config_store.set_info( + task_id, + params.configuration.task_push_notification_config, + call_context, + ) + + active_task = await self._active_task_registry.get_or_create( + task_id, + context_id=context_id, + call_context=call_context, + create_task_if_missing=True, + ) + + return active_task, request_context + + @validate_request_params + async def on_message_send( # noqa: D102 + self, + params: SendMessageRequest, + context: ServerCallContext, + ) -> Message | Task: + active_task, request_context = await self._setup_active_task( + params, context + ) + task_id = cast('str', request_context.task_id) + + result: Message | Task | None = None + + async for raw_event in active_task.subscribe( + request=request_context, + include_initial_task=False, + replace_status_update_with_task=True, + ): + event = raw_event + logger.debug( + 'Processing[%s] event [%s] %s', + params.message.task_id, + type(event).__name__, + event, + ) + if isinstance(event, TaskStatusUpdateEvent): + self._validate_task_id_match(task_id, event.task_id) + event = await active_task.get_task() + logger.debug( + 'Replaced TaskStatusUpdateEvent with Task: %s', event + ) + + if isinstance(event, Task) and ( + params.configuration.return_immediately + or event.status.state + in (TERMINAL_TASK_STATES | INTERRUPTED_TASK_STATES) + ): + self._validate_task_id_match(task_id, event.id) + result = event + # DO break here as it's "return_immediately". + # AgentExecutor will continue to run in the background. + break + + if isinstance(event, Message): + result = event + # Do NOT break here as Message is supposed to be the only + # event in "Message-only" interaction. + # ActiveTask consumer (see active_task.py) validates the event + # stream and raises InvalidAgentResponseError if more events are + # pushed after a Message. + + if result is None: + logger.debug('Missing result for task %s', request_context.task_id) + result = await active_task.get_task() + + if isinstance(result, Task): + result = apply_history_length(result, params.configuration) + + logger.debug( + 'Returning result for task %s: %s', + request_context.task_id, + result, + ) + return result + + @validate_request_params + @validate( + lambda self: self._agent_card.capabilities.streaming, + 'Streaming is not supported by the agent', + ) + async def on_message_send_stream( # noqa: D102 + self, + params: SendMessageRequest, + context: ServerCallContext, + ) -> AsyncGenerator[Event, None]: + active_task, request_context = await self._setup_active_task( + params, context + ) + + task_id = cast('str', request_context.task_id) + + async for event in active_task.subscribe( + request=request_context, + include_initial_task=False, + ): + # Do NOT break here as we rely on AgentExecutor to yield control. + # ActiveTask consumer (see active_task.py) validates the event + # stream and raises InvalidAgentResponseError on misbehaving agents: + # - an event after a Message + # - Message after entering task mode + # - an event after a terminal state + if isinstance(event, Task): + self._validate_task_id_match(task_id, event.id) + yield apply_history_length(event, params.configuration) + else: + yield event + + @validate_request_params + @validate( + lambda self: self._agent_card.capabilities.push_notifications, + error_message='Push notifications are not supported by the agent', + error_type=PushNotificationNotSupportedError, + ) + async def on_create_task_push_notification_config( # noqa: D102 + self, + params: TaskPushNotificationConfig, + context: ServerCallContext, + ) -> TaskPushNotificationConfig: + if not self._push_config_store: + raise PushNotificationNotSupportedError + + task_id = params.task_id + task: Task | None = await self.task_store.get(task_id, context) + if not task: + raise TaskNotFoundError + + await self._push_config_store.set_info( + task_id, + params, + context, + ) + + return params + + @validate_request_params + @validate( + lambda self: self._agent_card.capabilities.push_notifications, + error_message='Push notifications are not supported by the agent', + error_type=PushNotificationNotSupportedError, + ) + async def on_get_task_push_notification_config( # noqa: D102 + self, + params: GetTaskPushNotificationConfigRequest, + context: ServerCallContext, + ) -> TaskPushNotificationConfig: + if not self._push_config_store: + raise PushNotificationNotSupportedError + + task_id = params.task_id + config_id = params.id + task: Task | None = await self.task_store.get(task_id, context) + if not task: + raise TaskNotFoundError + + push_notification_configs: list[TaskPushNotificationConfig] = ( + await self._push_config_store.get_info(task_id, context) or [] + ) + + for config in push_notification_configs: + if config.id == config_id: + return config + + raise TaskNotFoundError + + @validate_request_params + @validate( + lambda self: self._agent_card.capabilities.streaming, + 'Streaming is not supported by the agent', + ) + async def on_subscribe_to_task( # noqa: D102 + self, + params: SubscribeToTaskRequest, + context: ServerCallContext, + ) -> AsyncGenerator[Event, None]: + task_id = params.id + + active_task = await self._active_task_registry.get_or_create( + task_id, + call_context=context, + create_task_if_missing=False, + ) + + async for event in active_task.subscribe(include_initial_task=True): + yield event + + @validate_request_params + @validate( + lambda self: self._agent_card.capabilities.push_notifications, + error_message='Push notifications are not supported by the agent', + error_type=PushNotificationNotSupportedError, + ) + async def on_list_task_push_notification_configs( # noqa: D102 + self, + params: ListTaskPushNotificationConfigsRequest, + context: ServerCallContext, + ) -> ListTaskPushNotificationConfigsResponse: + if not self._push_config_store: + raise PushNotificationNotSupportedError + + task_id = params.task_id + task: Task | None = await self.task_store.get(task_id, context) + if not task: + raise TaskNotFoundError + + push_notification_config_list = await self._push_config_store.get_info( + task_id, context + ) + + return ListTaskPushNotificationConfigsResponse( + configs=push_notification_config_list + ) + + @validate_request_params + @validate( + lambda self: self._agent_card.capabilities.push_notifications, + error_message='Push notifications are not supported by the agent', + error_type=PushNotificationNotSupportedError, + ) + async def on_delete_task_push_notification_config( # noqa: D102 + self, + params: DeleteTaskPushNotificationConfigRequest, + context: ServerCallContext, + ) -> None: + if not self._push_config_store: + raise PushNotificationNotSupportedError + + task_id = params.task_id + config_id = params.id + task: Task | None = await self.task_store.get(task_id, context) + if not task: + raise TaskNotFoundError + + await self._push_config_store.delete_info(task_id, context, config_id) + + @validate_request_params + @validate( + lambda self: self._agent_card.capabilities.extended_agent_card, + error_message='The agent does not support authenticated extended cards', + ) + async def on_get_extended_agent_card( + self, + params: GetExtendedAgentCardRequest, + context: ServerCallContext, + ) -> AgentCard: + """Default handler for 'GetExtendedAgentCard'. + + Requires `capabilities.extended_agent_card` to be true. + """ + extended_card = self.extended_agent_card + if not extended_card: + raise ExtendedAgentCardNotConfiguredError + + if self.extended_card_modifier: + extended_card = await self.extended_card_modifier( + extended_card, context + ) + + return extended_card diff --git a/src/a2a/server/request_handlers/grpc_handler.py b/src/a2a/server/request_handlers/grpc_handler.py index 60aa41d22..8cd421e93 100644 --- a/src/a2a/server/request_handlers/grpc_handler.py +++ b/src/a2a/server/request_handlers/grpc_handler.py @@ -32,10 +32,8 @@ from a2a.server.context import ServerCallContext from a2a.server.request_handlers.request_handler import RequestHandler from a2a.types import a2a_pb2 -from a2a.types.a2a_pb2 import AgentCard from a2a.utils import proto_utils from a2a.utils.errors import A2A_ERROR_REASONS, A2AError, TaskNotFoundError -from a2a.utils.helpers import maybe_await, validate from a2a.utils.proto_utils import validation_errors_to_bad_request @@ -109,30 +107,22 @@ class GrpcHandler(a2a_grpc.A2AServiceServicer): def __init__( self, - agent_card: AgentCard, request_handler: RequestHandler, context_builder: GrpcServerCallContextBuilder | None = None, - card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] - | None = None, ): """Initializes the GrpcHandler. Args: - agent_card: The AgentCard describing the agent's capabilities. request_handler: The underlying `RequestHandler` instance to delegate requests to. context_builder: The GrpcContextBuilder used to construct the ServerCallContext passed to the request_handler. If None the DefaultGrpcContextBuilder is used. - card_modifier: An optional callback to dynamically modify the public - agent card before it is served. """ - self.agent_card = agent_card self.request_handler = request_handler self._context_builder = ( context_builder or DefaultGrpcServerCallContextBuilder() ) - self.card_modifier = card_modifier async def _handle_unary( self, @@ -145,7 +135,6 @@ async def _handle_unary( try: server_context = self._build_call_context(context, request) result = await handler_func(server_context) - self._set_extension_metadata(context, server_context) except A2AError as e: await self.abort_context(e, context) else: @@ -163,7 +152,6 @@ async def _handle_stream( server_context = self._build_call_context(context, request) async for item in handler_func(server_context): yield item - self._set_extension_metadata(context, server_context) except A2AError as e: await self.abort_context(e, context) @@ -195,10 +183,6 @@ async def SendStreamingMessage( ) -> AsyncIterable[a2a_pb2.StreamResponse]: """Handles the 'StreamMessage' gRPC method.""" - @validate( - lambda _: self.agent_card.capabilities.streaming, - 'Streaming is not supported by the agent', - ) async def _handler( server_context: ServerCallContext, ) -> AsyncIterable[a2a_pb2.StreamResponse]: @@ -236,10 +220,6 @@ async def SubscribeToTask( ) -> AsyncIterable[a2a_pb2.StreamResponse]: """Handles the 'SubscribeToTask' gRPC method.""" - @validate( - lambda _: self.agent_card.capabilities.streaming, - 'Streaming is not supported by the agent', - ) async def _handler( server_context: ServerCallContext, ) -> AsyncIterable[a2a_pb2.StreamResponse]: @@ -278,10 +258,6 @@ async def CreateTaskPushNotificationConfig( ) -> a2a_pb2.TaskPushNotificationConfig: """Handles the 'CreateTaskPushNotificationConfig' gRPC method.""" - @validate( - lambda _: self.agent_card.capabilities.push_notifications, - 'Push notifications are not supported by the agent', - ) async def _handler( server_context: ServerCallContext, ) -> a2a_pb2.TaskPushNotificationConfig: @@ -376,10 +352,17 @@ async def GetExtendedAgentCard( context: grpc.aio.ServicerContext, ) -> a2a_pb2.AgentCard: """Get the extended agent card for the agent served.""" - card_to_serve = self.agent_card - if self.card_modifier: - card_to_serve = await maybe_await(self.card_modifier(card_to_serve)) - return card_to_serve + + async def _handler( + server_context: ServerCallContext, + ) -> a2a_pb2.AgentCard: + return await self.request_handler.on_get_extended_agent_card( + request, server_context + ) + + return await self._handle_unary( + request, context, _handler, a2a_pb2.AgentCard() + ) async def abort_context( self, error: A2AError, context: grpc.aio.ServicerContext @@ -437,19 +420,6 @@ async def abort_context( f'Unknown error type: {error}', ) - def _set_extension_metadata( - self, - context: grpc.aio.ServicerContext, - server_context: ServerCallContext, - ) -> None: - if server_context.activated_extensions: - context.set_trailing_metadata( - [ - (HTTP_EXTENSION_HEADER.lower(), e) - for e in sorted(server_context.activated_extensions) - ] - ) - def _build_call_context( self, context: grpc.aio.ServicerContext, diff --git a/src/a2a/server/request_handlers/request_handler.py b/src/a2a/server/request_handlers/request_handler.py index 23b0f2b95..6fb42098f 100644 --- a/src/a2a/server/request_handlers/request_handler.py +++ b/src/a2a/server/request_handlers/request_handler.py @@ -1,5 +1,6 @@ import functools import inspect +import logging from abc import ABC, abstractmethod from collections.abc import AsyncGenerator, Callable @@ -10,8 +11,10 @@ from a2a.server.context import ServerCallContext from a2a.server.events.event_queue import Event from a2a.types.a2a_pb2 import ( + AgentCard, CancelTaskRequest, DeleteTaskPushNotificationConfigRequest, + GetExtendedAgentCardRequest, GetTaskPushNotificationConfigRequest, GetTaskRequest, ListTaskPushNotificationConfigsRequest, @@ -32,7 +35,7 @@ class RequestHandler(ABC): """A2A request handler interface. This interface defines the methods that an A2A server implementation must - provide to handle incoming JSON-RPC requests. + provide to handle incoming A2A requests from any transport (gRPC, REST, JSON-RPC). """ @abstractmethod @@ -59,7 +62,7 @@ async def on_list_tasks( ) -> ListTasksResponse: """Handles the tasks/list method. - Retrieves all task for an agent. Supports filtering, pagination, + Retrieves all tasks for an agent. Supports filtering, pagination, ordering, limiting the history length, excluding artifacts, etc. Args: @@ -124,10 +127,8 @@ async def on_message_send_stream( Yields: `Event` objects from the agent's execution. - - Raises: - UnsupportedOperationError: By default, if not implemented. """ + # This is needed for typechecker to recognise this method as an async generator. raise UnsupportedOperationError yield @@ -183,9 +184,6 @@ async def on_subscribe_to_task( Yields: `Event` objects from the agent's ongoing execution for the specified task. - - Raises: - UnsupportedOperationError: By default, if not implemented. """ raise UnsupportedOperationError yield @@ -226,6 +224,25 @@ async def on_delete_task_push_notification_config( None """ + @abstractmethod + async def on_get_extended_agent_card( + self, + params: GetExtendedAgentCardRequest, + context: ServerCallContext, + ) -> AgentCard: + """Handles the 'GetExtendedAgentCard' method. + + Retrieves the extended agent card for the agent. + + Args: + params: Parameters for the request. + context: Context provided by the server. + + Returns: + The `AgentCard` object representing the extended properties of the agent. + + """ + def validate_request_params(method: Callable) -> Callable: """Decorator for RequestHandler methods to validate required fields on incoming requests.""" @@ -268,3 +285,128 @@ async def async_wrapper( return await method(self, params, context, *args, **kwargs) return async_wrapper + + +def validate( + expression: Callable[[Any], bool], + error_message: str | None = None, + error_type: type[Exception] = UnsupportedOperationError, +) -> Callable: + """Decorator that validates if a given expression evaluates to True. + + Typically used on class methods to check capabilities or configuration + before executing the method's logic. If the expression is False, + the specified `error_type` (defaults to `UnsupportedOperationError`) is raised. + + Args: + expression: A callable that takes the instance (`self`) as its argument + and returns a boolean. + error_message: An optional custom error message for the error raised. + If None, the string representation of the expression will be used. + error_type: The exception class to raise on validation failure. + Must take a `message` keyword argument (inherited from A2AError). + + Examples: + Demonstrating with an async method: + >>> import asyncio + >>> from a2a.utils.errors import UnsupportedOperationError + >>> + >>> class MyAgent: + ... def __init__(self, streaming_enabled: bool): + ... self.streaming_enabled = streaming_enabled + ... + ... @validate( + ... lambda self: self.streaming_enabled, + ... 'Streaming is not enabled for this agent', + ... ) + ... async def stream_response(self, message: str): + ... return f'Streaming: {message}' + >>> + >>> async def run_async_test(): + ... # Successful call + ... agent_ok = MyAgent(streaming_enabled=True) + ... result = await agent_ok.stream_response('hello') + ... print(result) + ... + ... # Call that fails validation + ... agent_fail = MyAgent(streaming_enabled=False) + ... try: + ... await agent_fail.stream_response('world') + ... except UnsupportedOperationError as e: + ... print(e.message) + >>> + >>> asyncio.run(run_async_test()) + Streaming: hello + Streaming is not enabled for this agent + + Demonstrating with a sync method: + >>> class SecureAgent: + ... def __init__(self): + ... self.auth_enabled = False + ... + ... @validate( + ... lambda self: self.auth_enabled, + ... 'Authentication must be enabled for this operation', + ... ) + ... def secure_operation(self, data: str): + ... return f'Processing secure data: {data}' + >>> + >>> # Error case example + >>> agent = SecureAgent() + >>> try: + ... agent.secure_operation('secret') + ... except UnsupportedOperationError as e: + ... print(e.message) + Authentication must be enabled for this operation + + Note: + This decorator works with both sync and async methods automatically. + """ + + def decorator(function: Callable) -> Callable: + if inspect.isasyncgenfunction(function): + + @functools.wraps(function) + async def async_gen_wrapper(self: Any, *args, **kwargs) -> Any: + if not expression(self): + final_message = error_message or str(expression) + logging.getLogger(__name__).error( + 'Validation failure: %s', final_message + ) + raise error_type(final_message) + inner = function(self, *args, **kwargs) + try: + async for item in inner: + yield item + finally: + await inner.aclose() + + return async_gen_wrapper + + if inspect.iscoroutinefunction(function): + + @functools.wraps(function) + async def async_wrapper(self: Any, *args, **kwargs) -> Any: + if not expression(self): + final_message = error_message or str(expression) + logging.getLogger(__name__).error( + 'Validation failure: %s', final_message + ) + raise error_type(final_message) + return await function(self, *args, **kwargs) + + return async_wrapper + + @functools.wraps(function) + def sync_wrapper(self: Any, *args, **kwargs) -> Any: + if not expression(self): + final_message = error_message or str(expression) + logging.getLogger(__name__).error( + 'Validation failure: %s', final_message + ) + raise error_type(final_message) + return function(self, *args, **kwargs) + + return sync_wrapper + + return decorator diff --git a/src/a2a/server/routes/agent_card_routes.py b/src/a2a/server/routes/agent_card_routes.py index 9b850ff4f..924a3d9dc 100644 --- a/src/a2a/server/routes/agent_card_routes.py +++ b/src/a2a/server/routes/agent_card_routes.py @@ -26,13 +26,11 @@ from a2a.server.request_handlers.response_helpers import agent_card_to_dict from a2a.types.a2a_pb2 import AgentCard from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH -from a2a.utils.helpers import maybe_await def create_agent_card_routes( agent_card: AgentCard, - card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] - | None = None, + card_modifier: Callable[[AgentCard], Awaitable[AgentCard]] | None = None, card_url: str = AGENT_CARD_WELL_KNOWN_PATH, ) -> list['Route']: """Creates the Starlette Route for the A2A protocol agent card endpoint.""" @@ -45,7 +43,7 @@ def create_agent_card_routes( async def _get_agent_card(request: Request) -> Response: card_to_serve = agent_card if card_modifier: - card_to_serve = await maybe_await(card_modifier(card_to_serve)) + card_to_serve = await card_modifier(card_to_serve) return JSONResponse(agent_card_to_dict(card_to_serve)) return [ diff --git a/src/a2a/server/routes/jsonrpc_dispatcher.py b/src/a2a/server/routes/jsonrpc_dispatcher.py index e0f0042b0..cb4e93bf1 100644 --- a/src/a2a/server/routes/jsonrpc_dispatcher.py +++ b/src/a2a/server/routes/jsonrpc_dispatcher.py @@ -4,17 +4,15 @@ import logging import traceback -from collections.abc import AsyncGenerator, Awaitable, Callable +from collections.abc import AsyncGenerator from typing import TYPE_CHECKING, Any from google.protobuf.json_format import MessageToDict, ParseDict from jsonrpc.jsonrpc2 import JSONRPC20Request, JSONRPC20Response from a2a.compat.v0_3.jsonrpc_adapter import JSONRPC03Adapter -from a2a.extensions.common import ( - HTTP_EXTENSION_HEADER, -) from a2a.server.context import ServerCallContext +from a2a.server.events import Event from a2a.server.jsonrpc_models import ( InternalError, InvalidParamsError, @@ -32,7 +30,6 @@ ServerCallContextBuilder, ) from a2a.types.a2a_pb2 import ( - AgentCard, CancelTaskRequest, DeleteTaskPushNotificationConfigRequest, GetExtendedAgentCardRequest, @@ -49,12 +46,11 @@ from a2a.utils import constants, proto_utils from a2a.utils.errors import ( A2AError, - ExtendedAgentCardNotConfiguredError, TaskNotFoundError, UnsupportedOperationError, ) -from a2a.utils.helpers import maybe_await, validate, validate_version from a2a.utils.telemetry import SpanKind, trace_class +from a2a.utils.version_validator import validate_version INTERNAL_ERROR_CODE = -32603 @@ -130,36 +126,20 @@ class JsonRpcDispatcher: 'GetExtendedAgentCard': GetExtendedAgentCardRequest, } - def __init__( # noqa: PLR0913 + def __init__( self, - agent_card: AgentCard, request_handler: RequestHandler, - extended_agent_card: AgentCard | None = None, context_builder: ServerCallContextBuilder | None = None, - card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] - | None = None, - extended_card_modifier: Callable[ - [AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard - ] - | None = None, enable_v0_3_compat: bool = False, ) -> None: """Initializes the JsonRpcDispatcher. Args: - agent_card: The AgentCard describing the agent's capabilities. request_handler: The handler instance responsible for processing A2A requests via http. - extended_agent_card: An optional, distinct AgentCard to be served - at the authenticated extended card endpoint. context_builder: The ServerCallContextBuilder used to construct the ServerCallContext passed to the request_handler. If None the DefaultServerCallContextBuilder is used. - card_modifier: An optional callback to dynamically modify the public - agent card before it is served. - extended_card_modifier: An optional callback to dynamically modify - the extended agent card before it is served. It receives the - call context. enable_v0_3_compat: Whether to enable v0.3 backward compatibility on the same endpoint. """ if not _package_starlette_installed: @@ -169,11 +149,7 @@ def __init__( # noqa: PLR0913 ' optional dependencies, `a2a-sdk[http-server]`.' ) - self.agent_card = agent_card self.request_handler = request_handler - self.extended_agent_card = extended_agent_card - self.card_modifier = card_modifier - self.extended_card_modifier = extended_card_modifier self._context_builder = ( context_builder or DefaultServerCallContextBuilder() ) @@ -182,12 +158,8 @@ def __init__( # noqa: PLR0913 if self.enable_v0_3_compat: self._v03_adapter = JSONRPC03Adapter( - agent_card=agent_card, http_handler=request_handler, - extended_agent_card=extended_agent_card, context_builder=self._context_builder, - card_modifier=card_modifier, - extended_card_modifier=extended_card_modifier, ) def _generate_error_response( @@ -333,6 +305,9 @@ async def handle_requests(self, request: Request) -> Response: # noqa: PLR0911, call_context.state['request_id'] = request_id # Route streaming requests by method name + handler_result: ( + AsyncGenerator[dict[str, Any], None] | dict[str, Any] + ) if method in ('SendStreamingMessage', 'SubscribeToTask'): handler_result = await self._process_streaming_request( request_id, specific_request, call_context @@ -369,10 +344,6 @@ async def handle_requests(self, request: Request) -> Response: # noqa: PLR0911, ) @validate_version(constants.PROTOCOL_VERSION_1_0) - @validate( - lambda self: self.agent_card.capabilities.streaming, - 'Streaming is not supported by the agent', - ) async def _process_streaming_request( self, request_id: str | int | None, @@ -403,20 +374,32 @@ async def _process_streaming_request( if stream is None: raise UnsupportedOperationError(message='Stream not supported') + # Eagerly fetch the first event to trigger validation/upfront errors + try: + first_event = await anext(stream) + except StopAsyncIteration: + first_event = None + async def _wrap_stream( - st: AsyncGenerator, + st: AsyncGenerator, first_evt: Event | None ) -> AsyncGenerator[dict[str, Any], None]: + def _map_event(evt: Event) -> dict[str, Any]: + stream_response = proto_utils.to_stream_response(evt) + result = MessageToDict( + stream_response, preserving_proto_field_name=False + ) + return JSONRPC20Response(result=result, _id=request_id).data + try: + if first_evt is not None: + yield _map_event(first_evt) + async for event in st: - stream_response = proto_utils.to_stream_response(event) - result = MessageToDict( - stream_response, preserving_proto_field_name=False - ) - yield JSONRPC20Response(result=result, _id=request_id).data + yield _map_event(event) except A2AError as e: yield build_error_response(request_id, e) - return _wrap_stream(stream) + return _wrap_stream(stream, first_event) async def _handle_send_message( self, request_obj: SendMessageRequest, context: ServerCallContext @@ -456,10 +439,6 @@ async def _handle_list_tasks( always_print_fields_with_no_presence=True, ) - @validate( - lambda self: self.agent_card.capabilities.push_notifications, - 'Push notifications are not supported by the agent', - ) async def _handle_create_task_push_notification_config( self, request_obj: TaskPushNotificationConfig, @@ -512,20 +491,10 @@ async def _handle_get_extended_agent_card( request_obj: GetExtendedAgentCardRequest, context: ServerCallContext, ) -> dict[str, Any]: - if not self.agent_card.capabilities.extended_agent_card: - raise ExtendedAgentCardNotConfiguredError( - message='The agent does not have an extended agent card configured' - ) - base_card = self.extended_agent_card or self.agent_card - card_to_serve = base_card - if self.extended_card_modifier and context: - card_to_serve = await maybe_await( - self.extended_card_modifier(base_card, context) - ) - elif self.card_modifier: - card_to_serve = await maybe_await(self.card_modifier(base_card)) - - return MessageToDict(card_to_serve, preserving_proto_field_name=False) + card = await self.request_handler.on_get_extended_agent_card( + request_obj, context + ) + return MessageToDict(card, preserving_proto_field_name=False) @validate_version(constants.PROTOCOL_VERSION_1_0) async def _process_non_streaming_request( # noqa: PLR0911 @@ -598,20 +567,37 @@ def _create_response( Returns: A Starlette JSONResponse or EventSourceResponse. """ - headers = {} - if exts := context.activated_extensions: - headers[HTTP_EXTENSION_HEADER] = ', '.join(sorted(exts)) if isinstance(handler_result, AsyncGenerator): # Result is a stream of dict objects async def event_generator( stream: AsyncGenerator[dict[str, Any]], ) -> AsyncGenerator[dict[str, str]]: - async for item in stream: - yield {'data': json.dumps(item)} + try: + async for item in stream: + event: dict[str, str] = { + 'data': json.dumps(item), + } + if 'error' in item: + event['event'] = 'error' + yield event + except Exception as e: + logger.exception( + 'Unhandled error during JSON-RPC SSE stream' + ) + rpc_error: A2AError | JSONRPCError = ( + e + if isinstance(e, A2AError | JSONRPCError) + else InternalError(message=str(e)) + ) + error_response = build_error_response( + context.state.get('request_id'), rpc_error + ) + yield { + 'event': 'error', + 'data': json.dumps(error_response), + } - return EventSourceResponse( - event_generator(handler_result), headers=headers - ) + return EventSourceResponse(event_generator(handler_result)) # handler_result is a dict (JSON-RPC response) - return JSONResponse(handler_result, headers=headers) + return JSONResponse(handler_result) diff --git a/src/a2a/server/routes/jsonrpc_routes.py b/src/a2a/server/routes/jsonrpc_routes.py index f19625379..a94d513ae 100644 --- a/src/a2a/server/routes/jsonrpc_routes.py +++ b/src/a2a/server/routes/jsonrpc_routes.py @@ -1,4 +1,5 @@ -from collections.abc import Awaitable, Callable +import logging + from typing import TYPE_CHECKING, Any @@ -16,26 +17,18 @@ _package_starlette_installed = False - -from a2a.server.context import ServerCallContext from a2a.server.request_handlers.request_handler import RequestHandler from a2a.server.routes.common import ServerCallContextBuilder from a2a.server.routes.jsonrpc_dispatcher import JsonRpcDispatcher -from a2a.types.a2a_pb2 import AgentCard -def create_jsonrpc_routes( # noqa: PLR0913 - agent_card: AgentCard, +logger = logging.getLogger(__name__) + + +def create_jsonrpc_routes( request_handler: RequestHandler, rpc_url: str, - extended_agent_card: AgentCard | None = None, context_builder: ServerCallContextBuilder | None = None, - card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] - | None = None, - extended_card_modifier: Callable[ - [AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard - ] - | None = None, enable_v0_3_compat: bool = False, ) -> list['Route']: """Creates the Starlette Route for the A2A protocol JSON-RPC endpoint. @@ -45,20 +38,12 @@ def create_jsonrpc_routes( # noqa: PLR0913 (SSE). Args: - agent_card: The AgentCard describing the agent's capabilities. request_handler: The handler instance responsible for processing A2A requests via http. - rpc_url: The URL prefix for the RPC endpoints. - extended_agent_card: An optional, distinct AgentCard to be served - at the authenticated extended card endpoint. + rpc_url: The URL prefix for the RPC endpoints. Should start with a leading slash '/'. context_builder: The ServerCallContextBuilder used to construct the ServerCallContext passed to the request_handler. If None the DefaultServerCallContextBuilder is used. - card_modifier: An optional callback to dynamically modify the public - agent card before it is served. - extended_card_modifier: An optional callback to dynamically modify - the extended agent card before it is served. It receives the - call context. enable_v0_3_compat: Whether to enable v0.3 backward compatibility on the same endpoint. """ if not _package_starlette_installed: @@ -69,12 +54,8 @@ def create_jsonrpc_routes( # noqa: PLR0913 ) dispatcher = JsonRpcDispatcher( - agent_card=agent_card, request_handler=request_handler, - extended_agent_card=extended_agent_card, context_builder=context_builder, - card_modifier=card_modifier, - extended_card_modifier=extended_card_modifier, enable_v0_3_compat=enable_v0_3_compat, ) diff --git a/src/a2a/server/routes/rest_dispatcher.py b/src/a2a/server/routes/rest_dispatcher.py index 1f91dd573..adbdba96e 100644 --- a/src/a2a/server/routes/rest_dispatcher.py +++ b/src/a2a/server/routes/rest_dispatcher.py @@ -14,26 +14,26 @@ ) from a2a.types import a2a_pb2 from a2a.types.a2a_pb2 import ( - AgentCard, CancelTaskRequest, GetTaskPushNotificationConfigRequest, SubscribeToTaskRequest, ) from a2a.utils import constants, proto_utils from a2a.utils.error_handlers import ( + build_rest_error_payload, rest_error_handler, rest_stream_error_handler, ) from a2a.utils.errors import ( - ExtendedAgentCardNotConfiguredError, InvalidRequestError, TaskNotFoundError, ) -from a2a.utils.helpers import maybe_await, validate, validate_version from a2a.utils.telemetry import SpanKind, trace_class +from a2a.utils.version_validator import validate_version if TYPE_CHECKING: + from sse_starlette.event import ServerSentEvent from sse_starlette.sse import EventSourceResponse from starlette.requests import Request from starlette.responses import JSONResponse, Response @@ -41,6 +41,7 @@ _package_starlette_installed = True else: try: + from sse_starlette.event import ServerSentEvent from sse_starlette.sse import EventSourceResponse from starlette.requests import Request from starlette.responses import JSONResponse, Response @@ -48,6 +49,7 @@ _package_starlette_installed = True except ImportError: EventSourceResponse = Any + ServerSentEvent = Any Request = Any JSONResponse = Any Response = Any @@ -66,34 +68,18 @@ class RestDispatcher: Handles context building, routing to RequestHandler directly, and response formatting (JSON/SSE). """ - def __init__( # noqa: PLR0913 + def __init__( self, - agent_card: AgentCard, request_handler: RequestHandler, - extended_agent_card: AgentCard | None = None, context_builder: ServerCallContextBuilder | None = None, - card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] - | None = None, - extended_card_modifier: Callable[ - [AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard - ] - | None = None, ) -> None: """Initializes the RestDispatcher. Args: - agent_card: The AgentCard describing the agent's capabilities. request_handler: The underlying `RequestHandler` instance to delegate requests to. - extended_agent_card: An optional, distinct AgentCard to be served - at the authenticated extended card endpoint. context_builder: The ServerCallContextBuilder used to construct the ServerCallContext passed to the request_handler. If None the DefaultServerCallContextBuilder is used. - card_modifier: An optional callback to dynamically modify the public - agent card before it is served. - extended_card_modifier: An optional callback to dynamically modify - the extended agent card before it is served. It receives the - call context. """ if not _package_starlette_installed: raise ImportError( @@ -102,10 +88,6 @@ def __init__( # noqa: PLR0913 'optional dependencies, `a2a-sdk[http-server]`.' ) - self.agent_card = agent_card - self.extended_agent_card = extended_agent_card - self.card_modifier = card_modifier - self.extended_card_modifier = extended_card_modifier self._context_builder = ( context_builder or DefaultServerCallContextBuilder() ) @@ -157,10 +139,17 @@ async def _handle_streaming( except StopAsyncIteration: return EventSourceResponse(iter([])) - async def event_generator() -> AsyncIterator[str]: - yield json.dumps(first_item) - async for item in stream: - yield json.dumps(item) + async def event_generator() -> AsyncIterator[ServerSentEvent]: + yield ServerSentEvent(data=json.dumps(first_item)) + try: + async for item in stream: + yield ServerSentEvent(data=json.dumps(item)) + except Exception as e: + logger.exception('Error during REST SSE stream') + yield ServerSentEvent( + data=json.dumps(build_rest_error_payload(e)), + event='error', + ) return EventSourceResponse(event_generator()) @@ -192,10 +181,6 @@ async def on_message_send_stream( """Handles the 'message/stream' REST method.""" @validate_version(constants.PROTOCOL_VERSION_1_0) - @validate( - lambda _: self.agent_card.capabilities.streaming, - 'Streaming is not supported by the agent', - ) async def _handler( context: ServerCallContext, ) -> AsyncIterator[dict[str, Any]]: @@ -235,10 +220,6 @@ async def on_subscribe_to_task( task_id = request.path_params['id'] @validate_version(constants.PROTOCOL_VERSION_1_0) - @validate( - lambda _: self.agent_card.capabilities.streaming, - 'Streaming is not supported by the agent', - ) async def _handler( context: ServerCallContext, ) -> AsyncIterator[dict[str, Any]]: @@ -312,10 +293,6 @@ async def set_push_notification(self, request: Request) -> Response: """Handles the 'tasks/pushNotificationConfig/set' REST method.""" @validate_version(constants.PROTOCOL_VERSION_1_0) - @validate( - lambda _: self.agent_card.capabilities.push_notifications, - 'Push notifications are not supported by the agent', - ) async def _handler( context: ServerCallContext, ) -> a2a_pb2.TaskPushNotificationConfig: @@ -371,23 +348,16 @@ async def _handler( async def handle_authenticated_agent_card( self, request: Request ) -> Response: - """Handles the 'extendedAgentCard' REST method.""" - if not self.agent_card.capabilities.extended_agent_card: - raise ExtendedAgentCardNotConfiguredError( - message='Authenticated card not supported' - ) - card_to_serve = self.extended_agent_card or self.agent_card + """Handles the 'agentCard' REST method.""" - if self.extended_card_modifier: - context = self._build_call_context(request) - card_to_serve = await maybe_await( - self.extended_card_modifier(card_to_serve, context) + @validate_version(constants.PROTOCOL_VERSION_1_0) + async def _handler( + context: ServerCallContext, + ) -> a2a_pb2.AgentCard: + params = a2a_pb2.GetExtendedAgentCardRequest() + return await self.request_handler.on_get_extended_agent_card( + params, context ) - elif self.card_modifier: - card_to_serve = await maybe_await(self.card_modifier(card_to_serve)) - return JSONResponse( - content=MessageToDict( - card_to_serve, preserving_proto_field_name=True - ) - ) + response = await self._handle_non_streaming(request, _handler) + return JSONResponse(content=MessageToDict(response)) diff --git a/src/a2a/server/routes/rest_routes.py b/src/a2a/server/routes/rest_routes.py index 20a899ca4..2ba8cecfc 100644 --- a/src/a2a/server/routes/rest_routes.py +++ b/src/a2a/server/routes/rest_routes.py @@ -1,16 +1,11 @@ import logging -from collections.abc import Awaitable, Callable from typing import TYPE_CHECKING, Any from a2a.compat.v0_3.rest_adapter import REST03Adapter -from a2a.server.context import ServerCallContext from a2a.server.request_handlers.request_handler import RequestHandler from a2a.server.routes.common import ServerCallContextBuilder from a2a.server.routes.rest_dispatcher import RestDispatcher -from a2a.types.a2a_pb2 import ( - AgentCard, -) if TYPE_CHECKING: @@ -32,36 +27,20 @@ logger = logging.getLogger(__name__) -def create_rest_routes( # noqa: PLR0913 - agent_card: AgentCard, +def create_rest_routes( request_handler: RequestHandler, - extended_agent_card: AgentCard | None = None, context_builder: ServerCallContextBuilder | None = None, - card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] - | None = None, - extended_card_modifier: Callable[ - [AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard - ] - | None = None, enable_v0_3_compat: bool = False, path_prefix: str = '', ) -> list['BaseRoute']: """Creates the Starlette Routes for the A2A protocol REST endpoint. Args: - agent_card: The AgentCard describing the agent's capabilities. request_handler: The handler instance responsible for processing A2A requests via http. - extended_agent_card: An optional, distinct AgentCard to be served - at the authenticated extended card endpoint. context_builder: The ServerCallContextBuilder used to construct the ServerCallContext passed to the request_handler. If None the DefaultServerCallContextBuilder is used. - card_modifier: An optional callback to dynamically modify the public - agent card before it is served. - extended_card_modifier: An optional callback to dynamically modify - the extended agent card before it is served. It receives the - call context. enable_v0_3_compat: If True, mounts backward-compatible v0.3 protocol endpoints using REST03Adapter. path_prefix: The URL prefix for the REST endpoints. @@ -74,23 +53,15 @@ def create_rest_routes( # noqa: PLR0913 ) dispatcher = RestDispatcher( - agent_card=agent_card, request_handler=request_handler, - extended_agent_card=extended_agent_card, context_builder=context_builder, - card_modifier=card_modifier, - extended_card_modifier=extended_card_modifier, ) routes: list[BaseRoute] = [] if enable_v0_3_compat: v03_adapter = REST03Adapter( - agent_card=agent_card, http_handler=request_handler, - extended_agent_card=extended_agent_card, context_builder=context_builder, - card_modifier=card_modifier, - extended_card_modifier=extended_card_modifier, ) v03_routes = v03_adapter.routes() for (path, method), endpoint in v03_routes.items(): diff --git a/src/a2a/server/tasks/base_push_notification_sender.py b/src/a2a/server/tasks/base_push_notification_sender.py index 4a4929e8f..ff9ca3ce5 100644 --- a/src/a2a/server/tasks/base_push_notification_sender.py +++ b/src/a2a/server/tasks/base_push_notification_sender.py @@ -27,26 +27,39 @@ def __init__( self, httpx_client: httpx.AsyncClient, config_store: PushNotificationConfigStore, - context: ServerCallContext, + context: ServerCallContext | None = None, ) -> None: """Initializes the BasePushNotificationSender. Args: httpx_client: An async HTTP client instance to send notifications. - config_store: A PushNotificationConfigStore instance to retrieve configurations. - context: The `ServerCallContext` that this push notification is produced under. + config_store: A PushNotificationConfigStore instance to + retrieve configurations. + context: Deprecated and ignored. Accepted only for + backward compatibility with 1.0 callers that constructed + the sender with a (typically dummy) ServerCallContext. + Pass None (the default) in new code. A non-None + value logs a deprecation warning and is otherwise + ignored. """ + if context is not None: + logger.warning( + 'BasePushNotificationSender no longer uses the context ' + 'parameter; it is accepted only for backward compatibility ' + 'with 1.0 and will be removed in a future major version. ' + 'Push notifications now fan out across all owners via ' + 'PushNotificationConfigStore.get_info_for_dispatch; the ' + 'caller identity is not carried into dispatch. Drop the ' + 'context argument from the constructor call.' + ) self._client = httpx_client self._config_store = config_store - self._call_context: ServerCallContext = context async def send_notification( self, task_id: str, event: PushNotificationEvent ) -> None: """Sends a push notification for an event if configuration exists.""" - push_configs = await self._config_store.get_info( - task_id, self._call_context - ) + push_configs = await self._config_store.get_info_for_dispatch(task_id) if not push_configs: return diff --git a/src/a2a/server/tasks/database_push_notification_config_store.py b/src/a2a/server/tasks/database_push_notification_config_store.py index 31cd676c8..d050de7cc 100644 --- a/src/a2a/server/tasks/database_push_notification_config_store.py +++ b/src/a2a/server/tasks/database_push_notification_config_store.py @@ -7,7 +7,7 @@ try: - from sqlalchemy import Table, and_, delete, select + from sqlalchemy import ColumnElement, Table, and_, delete, select from sqlalchemy.ext.asyncio import ( AsyncEngine, AsyncSession, @@ -304,21 +304,14 @@ async def set_info( owner, ) - async def get_info( + async def _select_configs( self, - task_id: str, - context: ServerCallContext, + *predicates: 'ColumnElement[bool]', ) -> list[TaskPushNotificationConfig]: - """Retrieves all push notification configurations for a task, for the given owner.""" + """Loads configs matching the given predicates and decodes them.""" await self._ensure_initialized() - owner = self.owner_resolver(context) async with self.async_session_maker() as session: - stmt = select(self.config_model).where( - and_( - self.config_model.task_id == task_id, - self.config_model.owner == owner, - ) - ) + stmt = select(self.config_model).where(and_(*predicates)) result = await session.execute(stmt) models = result.scalars().all() @@ -331,10 +324,37 @@ async def get_info( 'Could not deserialize push notification config for task %s, config %s, owner %s', model.task_id, model.config_id, - owner, + model.owner, ) return configs + async def get_info( + self, + task_id: str, + context: ServerCallContext, + ) -> list[TaskPushNotificationConfig]: + """Retrieves all push notification configurations for a task, for the given owner. + + Used by the user-callable read endpoints. + """ + owner = self.owner_resolver(context) + return await self._select_configs( + self.config_model.task_id == task_id, + self.config_model.owner == owner, + ) + + async def get_info_for_dispatch( + self, + task_id: str, + ) -> list[TaskPushNotificationConfig]: + """Retrieves all push notification configurations for a task, across all owners. + + Used by the push-notification dispatch path. + """ + return await self._select_configs( + self.config_model.task_id == task_id, + ) + async def delete_info( self, task_id: str, diff --git a/src/a2a/server/tasks/inmemory_push_notification_config_store.py b/src/a2a/server/tasks/inmemory_push_notification_config_store.py index d5b0a5b1f..19e35074a 100644 --- a/src/a2a/server/tasks/inmemory_push_notification_config_store.py +++ b/src/a2a/server/tasks/inmemory_push_notification_config_store.py @@ -72,12 +72,29 @@ async def get_info( task_id: str, context: ServerCallContext, ) -> list[TaskPushNotificationConfig]: - """Retrieves all push notification configurations for a task from memory, for the given owner.""" + """Retrieves all push notification configurations for a task from memory, for the given owner. + + Used by the user-callable read endpoints. + """ owner = self.owner_resolver(context) async with self.lock: owner_infos = self._get_owner_push_notification_infos(owner) return list(owner_infos.get(task_id, [])) + async def get_info_for_dispatch( + self, + task_id: str, + ) -> list[TaskPushNotificationConfig]: + """Retrieves all push notification configurations for a task across all owners. + + Used by the push-notification dispatch path. + """ + async with self.lock: + results: list[TaskPushNotificationConfig] = [] + for all_configs in self._push_notification_infos.values(): + results.extend(all_configs.get(task_id, [])) + return results + async def delete_info( self, task_id: str, diff --git a/src/a2a/server/tasks/push_notification_config_store.py b/src/a2a/server/tasks/push_notification_config_store.py index 6b5b35245..e1e65c3fb 100644 --- a/src/a2a/server/tasks/push_notification_config_store.py +++ b/src/a2a/server/tasks/push_notification_config_store.py @@ -1,9 +1,14 @@ +import logging + from abc import ABC, abstractmethod from a2a.server.context import ServerCallContext from a2a.types.a2a_pb2 import TaskPushNotificationConfig +logger = logging.getLogger(__name__) + + class PushNotificationConfigStore(ABC): """Interface for storing and retrieving push notification configurations for tasks.""" @@ -22,7 +27,46 @@ async def get_info( task_id: str, context: ServerCallContext, ) -> list[TaskPushNotificationConfig]: - """Retrieves the push notification configuration for a task.""" + """Retrieves push notification configurations for a task, scoped to the caller. + + This is the user-callable read path. Implementations MUST return + only configurations owned by the caller (as resolved from + context). + """ + + async def get_info_for_dispatch( + self, + task_id: str, + ) -> list[TaskPushNotificationConfig]: + """Retrieves all push notification configurations for a task, across all owners. + + This is the internal read path used by the push-notification + dispatch loop. Implementations SHOULD override this method to + return every configuration registered for task_id regardless of + which user registered it. Authorization already happened at + registration time and the dispatch path fires every registered + webhook for the task. + + The default implementation falls back to calling get_info with + a synthetic empty ServerCallContext. This preserves 1.0 + behavior for subclasses that have not implemented the override + but is INCORRECT for any deployment with multiple owners: the + empty context resolves to the empty-string owner partition and + returns no configs (silently dropping every notification). A + warning is logged on every call to flag the misconfiguration. + Custom subclasses MUST override this method to deliver + notifications correctly in multi-owner deployments. + """ + logger.warning( + '%s does not override ' + 'PushNotificationConfigStore.get_info_for_dispatch; falling back ' + 'to a context-less get_info call which silently drops ' + 'notifications in any deployment with multiple owners. Override ' + 'get_info_for_dispatch to return all configs for task_id across ' + 'every owner.', + type(self).__name__, + ) + return await self.get_info(task_id, ServerCallContext()) @abstractmethod async def delete_info( diff --git a/src/a2a/server/tasks/task_manager.py b/src/a2a/server/tasks/task_manager.py index 905b11af3..e5d899c1e 100644 --- a/src/a2a/server/tasks/task_manager.py +++ b/src/a2a/server/tasks/task_manager.py @@ -4,6 +4,7 @@ from a2a.server.events.event_queue import Event from a2a.server.tasks.task_store import TaskStore from a2a.types.a2a_pb2 import ( + Artifact, Message, Task, TaskArtifactUpdateEvent, @@ -11,13 +12,77 @@ TaskStatus, TaskStatusUpdateEvent, ) -from a2a.utils import append_artifact_to_task from a2a.utils.errors import InvalidParamsError +from a2a.utils.telemetry import trace_function logger = logging.getLogger(__name__) +@trace_function() +def append_artifact_to_task(task: Task, event: TaskArtifactUpdateEvent) -> None: + """Helper method for updating a Task object with new artifact data from an event. + + Handles creating the artifacts list if it doesn't exist, adding new artifacts, + and appending parts to existing artifacts based on the `append` flag in the event. + + Args: + task: The `Task` object to modify. + event: The `TaskArtifactUpdateEvent` containing the artifact data. + """ + new_artifact_data: Artifact = event.artifact + artifact_id: str = new_artifact_data.artifact_id + append_parts: bool = event.append + + existing_artifact: Artifact | None = None + existing_artifact_list_index: int | None = None + + # Find existing artifact by its id + for i, art in enumerate(task.artifacts): + if art.artifact_id == artifact_id: + existing_artifact = art + existing_artifact_list_index = i + break + + if not append_parts: + # This represents the first chunk for this artifact index. + if existing_artifact_list_index is not None: + # Replace the existing artifact entirely with the new data + logger.debug( + 'Replacing artifact at id %s for task %s', artifact_id, task.id + ) + task.artifacts[existing_artifact_list_index].CopyFrom( + new_artifact_data + ) + else: + # Append the new artifact since no artifact with this index exists yet + logger.debug( + 'Adding new artifact with id %s for task %s', + artifact_id, + task.id, + ) + task.artifacts.append(new_artifact_data) + elif existing_artifact: + # Append new parts to the existing artifact's part list + logger.debug( + 'Appending parts to artifact id %s for task %s', + artifact_id, + task.id, + ) + existing_artifact.parts.extend(new_artifact_data.parts) + existing_artifact.metadata.update( + dict(new_artifact_data.metadata.items()) + ) + else: + # We received a chunk to append, but we don't have an existing artifact. + # we will ignore this chunk + logger.warning( + 'Received append=True for nonexistent artifact index %s in task %s. Ignoring chunk.', + artifact_id, + task.id, + ) + + class TaskManager: """Helps manage a task's lifecycle during execution of a request. @@ -147,13 +212,12 @@ async def save_task_event( await self._save_task(task) return task - async def ensure_task( - self, event: TaskStatusUpdateEvent | TaskArtifactUpdateEvent - ) -> Task: + async def ensure_task_id(self, task_id: str, context_id: str) -> Task: """Ensures a Task object exists in memory, loading from store or creating new if needed. Args: - event: The task-related event triggering the need for a Task object. + task_id: The ID for the new task. + context_id: The context ID for the new task. Returns: An existing or newly created `Task` object. @@ -168,16 +232,29 @@ async def ensure_task( if not task: logger.info( 'Task not found or task_id not set. Creating new task for event (task_id: %s, context_id: %s).', - event.task_id, - event.context_id, + task_id, + context_id, ) # streaming agent did not previously stream task object. # Create a task object with the available information and persist the event - task = self._init_task_obj(event.task_id, event.context_id) + task = self._init_task_obj(task_id, context_id) await self._save_task(task) return task + async def ensure_task( + self, event: TaskStatusUpdateEvent | TaskArtifactUpdateEvent + ) -> Task: + """Ensures a Task object exists in memory, loading from store or creating new if needed. + + Args: + event: The task-related event triggering the need for a Task object. + + Returns: + An existing or newly created `Task` object. + """ + return await self.ensure_task_id(event.task_id, event.context_id) + async def process(self, event: Event) -> Event: """Processes an event, updates the task state if applicable, stores it, and returns the event. diff --git a/src/a2a/utils/__init__.py b/src/a2a/utils/__init__.py index a502bfb62..04693dd0b 100644 --- a/src/a2a/utils/__init__.py +++ b/src/a2a/utils/__init__.py @@ -1,60 +1,18 @@ """Utility functions for the A2A Python SDK.""" from a2a.utils import proto_utils -from a2a.utils.artifact import ( - get_artifact_text, - new_artifact, - new_data_artifact, - new_text_artifact, -) from a2a.utils.constants import ( AGENT_CARD_WELL_KNOWN_PATH, DEFAULT_RPC_URL, TransportProtocol, ) -from a2a.utils.helpers import ( - append_artifact_to_task, - are_modalities_compatible, - build_text_artifact, - create_task_obj, -) -from a2a.utils.message import ( - get_message_text, - new_agent_parts_message, - new_agent_text_message, -) -from a2a.utils.parts import ( - get_data_parts, - get_file_parts, - get_text_parts, -) from a2a.utils.proto_utils import to_stream_response -from a2a.utils.task import ( - completed_task, - new_task, -) __all__ = [ 'AGENT_CARD_WELL_KNOWN_PATH', 'DEFAULT_RPC_URL', 'TransportProtocol', - 'append_artifact_to_task', - 'are_modalities_compatible', - 'build_text_artifact', - 'completed_task', - 'create_task_obj', - 'get_artifact_text', - 'get_data_parts', - 'get_file_parts', - 'get_message_text', - 'get_text_parts', - 'new_agent_parts_message', - 'new_agent_text_message', - 'new_artifact', - 'new_data_artifact', - 'new_task', - 'new_text_artifact', 'proto_utils', 'to_stream_response', ] diff --git a/src/a2a/utils/artifact.py b/src/a2a/utils/artifact.py deleted file mode 100644 index ac14087dc..000000000 --- a/src/a2a/utils/artifact.py +++ /dev/null @@ -1,92 +0,0 @@ -"""Utility functions for creating A2A Artifact objects.""" - -import uuid - -from typing import Any - -from google.protobuf.struct_pb2 import Struct, Value - -from a2a.types.a2a_pb2 import Artifact, Part -from a2a.utils.parts import get_text_parts - - -def new_artifact( - parts: list[Part], - name: str, - description: str | None = None, -) -> Artifact: - """Creates a new Artifact object. - - Args: - parts: The list of `Part` objects forming the artifact's content. - name: The human-readable name of the artifact. - description: An optional description of the artifact. - - Returns: - A new `Artifact` object with a generated artifact_id. - """ - return Artifact( - artifact_id=str(uuid.uuid4()), - parts=parts, - name=name, - description=description, - ) - - -def new_text_artifact( - name: str, - text: str, - description: str | None = None, -) -> Artifact: - """Creates a new Artifact object containing only a single text Part. - - Args: - name: The human-readable name of the artifact. - text: The text content of the artifact. - description: An optional description of the artifact. - - Returns: - A new `Artifact` object with a generated artifact_id. - """ - return new_artifact( - [Part(text=text)], - name, - description, - ) - - -def new_data_artifact( - name: str, - data: dict[str, Any], - description: str | None = None, -) -> Artifact: - """Creates a new Artifact object containing only a single data Part. - - Args: - name: The human-readable name of the artifact. - data: The structured data content of the artifact. - description: An optional description of the artifact. - - Returns: - A new `Artifact` object with a generated artifact_id. - """ - struct_data = Struct() - struct_data.update(data) - return new_artifact( - [Part(data=Value(struct_value=struct_data))], - name, - description, - ) - - -def get_artifact_text(artifact: Artifact, delimiter: str = '\n') -> str: - """Extracts and joins all text content from an Artifact's parts. - - Args: - artifact: The `Artifact` object. - delimiter: The string to use when joining text from multiple TextParts. - - Returns: - A single string containing all text content, or an empty string if no text parts are found. - """ - return delimiter.join(get_text_parts(artifact.parts)) diff --git a/src/a2a/utils/error_handlers.py b/src/a2a/utils/error_handlers.py index d21a9e24c..ea544d79d 100644 --- a/src/a2a/utils/error_handlers.py +++ b/src/a2a/utils/error_handlers.py @@ -54,16 +54,43 @@ def _build_error_payload( return {'error': payload} -def _create_error_response(error: Exception) -> Response: - """Helper function to create a JSONResponse for an error.""" +def build_rest_error_payload(error: Exception) -> dict[str, Any]: + """Build a REST error payload dict from an exception. + + Returns: + A dict with the error payload in the standard REST error format. + """ if isinstance(error, A2AError): mapping = A2A_REST_ERROR_MAPPING.get( type(error), RestErrorMap(500, 'INTERNAL', 'INTERNAL_ERROR') ) - http_code = mapping.http_code - grpc_status = mapping.grpc_status - reason = mapping.reason + # SECURITY WARNING: Data attached to A2AError.data is serialized unaltered and exposed publicly to the client in the REST API response. + metadata = getattr(error, 'data', None) or {} + return _build_error_payload( + code=mapping.http_code, + status=mapping.grpc_status, + message=getattr(error, 'message', str(error)), + reason=mapping.reason, + metadata=metadata, + ) + if isinstance(error, ParseError): + return _build_error_payload( + code=400, + status='INVALID_ARGUMENT', + message=str(error), + reason='INVALID_REQUEST', + metadata={}, + ) + return _build_error_payload( + code=500, + status='INTERNAL', + message='unknown exception', + ) + +def _create_error_response(error: Exception) -> Response: + """Helper function to create a JSONResponse for an error.""" + if isinstance(error, A2AError): log_level = ( logging.ERROR if isinstance(error, InternalError) @@ -76,42 +103,17 @@ def _create_error_response(error: Exception) -> Response: getattr(error, 'message', str(error)), f', Data={error.data}' if error.data else '', ) - - # SECURITY WARNING: Data attached to A2AError.data is serialized unaltered and exposed publicly to the client in the REST API response. - metadata = getattr(error, 'data', None) or {} - - return JSONResponse( - content=_build_error_payload( - code=http_code, - status=grpc_status, - message=getattr(error, 'message', str(error)), - reason=reason, - metadata=metadata, - ), - status_code=http_code, - media_type='application/json', - ) - if isinstance(error, ParseError): + elif isinstance(error, ParseError): logger.warning('Parse error: %s', str(error)) - return JSONResponse( - content=_build_error_payload( - code=400, - status='INVALID_ARGUMENT', - message=str(error), - reason='INVALID_REQUEST', - metadata={}, - ), - status_code=400, - media_type='application/json', - ) - logger.exception('Unknown error occurred') + else: + logger.exception('Unknown error occurred') + + payload = build_rest_error_payload(error) + # Extract HTTP status code from the payload + http_code = payload.get('error', {}).get('code', 500) return JSONResponse( - content=_build_error_payload( - code=500, - status='INTERNAL', - message='unknown exception', - ), - status_code=500, + content=payload, + status_code=http_code, media_type='application/json', ) diff --git a/src/a2a/utils/helpers.py b/src/a2a/utils/helpers.py deleted file mode 100644 index badfde180..000000000 --- a/src/a2a/utils/helpers.py +++ /dev/null @@ -1,409 +0,0 @@ -"""General utility functions for the A2A Python SDK.""" - -import functools -import inspect -import json -import logging - -from collections.abc import AsyncIterator, Awaitable, Callable -from typing import Any, TypeVar, cast -from uuid import uuid4 - -from google.protobuf.json_format import MessageToDict -from packaging.version import InvalidVersion, Version - -from a2a.server.context import ServerCallContext -from a2a.types.a2a_pb2 import ( - AgentCard, - Artifact, - Part, - SendMessageRequest, - Task, - TaskArtifactUpdateEvent, - TaskState, - TaskStatus, -) -from a2a.utils import constants -from a2a.utils.errors import UnsupportedOperationError, VersionNotSupportedError -from a2a.utils.telemetry import trace_function - - -T = TypeVar('T') -F = TypeVar('F', bound=Callable[..., Any]) - - -logger = logging.getLogger(__name__) - - -@trace_function() -def create_task_obj(message_send_params: SendMessageRequest) -> Task: - """Create a new task object from message send params. - - Generates UUIDs for task and context IDs if they are not already present in the message. - - Args: - message_send_params: The `SendMessageRequest` object containing the initial message. - - Returns: - A new `Task` object initialized with 'submitted' status and the input message in history. - """ - if not message_send_params.message.context_id: - message_send_params.message.context_id = str(uuid4()) - - task = Task( - id=str(uuid4()), - context_id=message_send_params.message.context_id, - status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED), - ) - task.history.append(message_send_params.message) - return task - - -@trace_function() -def append_artifact_to_task(task: Task, event: TaskArtifactUpdateEvent) -> None: - """Helper method for updating a Task object with new artifact data from an event. - - Handles creating the artifacts list if it doesn't exist, adding new artifacts, - and appending parts to existing artifacts based on the `append` flag in the event. - - Args: - task: The `Task` object to modify. - event: The `TaskArtifactUpdateEvent` containing the artifact data. - """ - new_artifact_data: Artifact = event.artifact - artifact_id: str = new_artifact_data.artifact_id - append_parts: bool = event.append - - existing_artifact: Artifact | None = None - existing_artifact_list_index: int | None = None - - # Find existing artifact by its id - for i, art in enumerate(task.artifacts): - if art.artifact_id == artifact_id: - existing_artifact = art - existing_artifact_list_index = i - break - - if not append_parts: - # This represents the first chunk for this artifact index. - if existing_artifact_list_index is not None: - # Replace the existing artifact entirely with the new data - logger.debug( - 'Replacing artifact at id %s for task %s', artifact_id, task.id - ) - task.artifacts[existing_artifact_list_index].CopyFrom( - new_artifact_data - ) - else: - # Append the new artifact since no artifact with this index exists yet - logger.debug( - 'Adding new artifact with id %s for task %s', - artifact_id, - task.id, - ) - task.artifacts.append(new_artifact_data) - elif existing_artifact: - # Append new parts to the existing artifact's part list - logger.debug( - 'Appending parts to artifact id %s for task %s', - artifact_id, - task.id, - ) - existing_artifact.parts.extend(new_artifact_data.parts) - else: - # We received a chunk to append, but we don't have an existing artifact. - # we will ignore this chunk - logger.warning( - 'Received append=True for nonexistent artifact index %s in task %s. Ignoring chunk.', - artifact_id, - task.id, - ) - - -def build_text_artifact(text: str, artifact_id: str) -> Artifact: - """Helper to create a text artifact. - - Args: - text: The text content for the artifact. - artifact_id: The ID for the artifact. - - Returns: - An `Artifact` object containing a single text Part. - """ - part = Part(text=text) - return Artifact(parts=[part], artifact_id=artifact_id) - - -def validate( - expression: Callable[[Any], bool], error_message: str | None = None -) -> Callable: - """Decorator that validates if a given expression evaluates to True. - - Typically used on class methods to check capabilities or configuration - before executing the method's logic. If the expression is False, - an `UnsupportedOperationError` is raised. - - Args: - expression: A callable that takes the instance (`self`) as its argument - and returns a boolean. - error_message: An optional custom error message for the `UnsupportedOperationError`. - If None, the string representation of the expression will be used. - - Examples: - Demonstrating with an async method: - >>> import asyncio - >>> from a2a.utils.errors import UnsupportedOperationError - >>> - >>> class MyAgent: - ... def __init__(self, streaming_enabled: bool): - ... self.streaming_enabled = streaming_enabled - ... - ... @validate( - ... lambda self: self.streaming_enabled, - ... 'Streaming is not enabled for this agent', - ... ) - ... async def stream_response(self, message: str): - ... return f'Streaming: {message}' - >>> - >>> async def run_async_test(): - ... # Successful call - ... agent_ok = MyAgent(streaming_enabled=True) - ... result = await agent_ok.stream_response('hello') - ... print(result) - ... - ... # Call that fails validation - ... agent_fail = MyAgent(streaming_enabled=False) - ... try: - ... await agent_fail.stream_response('world') - ... except UnsupportedOperationError as e: - ... print(e.message) - >>> - >>> asyncio.run(run_async_test()) - Streaming: hello - Streaming is not enabled for this agent - - Demonstrating with a sync method: - >>> class SecureAgent: - ... def __init__(self): - ... self.auth_enabled = False - ... - ... @validate( - ... lambda self: self.auth_enabled, - ... 'Authentication must be enabled for this operation', - ... ) - ... def secure_operation(self, data: str): - ... return f'Processing secure data: {data}' - >>> - >>> # Error case example - >>> agent = SecureAgent() - >>> try: - ... agent.secure_operation('secret') - ... except UnsupportedOperationError as e: - ... print(e.message) - Authentication must be enabled for this operation - - Note: - This decorator works with both sync and async methods automatically. - """ - - def decorator(function: Callable) -> Callable: - if inspect.iscoroutinefunction(function): - - @functools.wraps(function) - async def async_wrapper(self: Any, *args, **kwargs) -> Any: - if not expression(self): - final_message = error_message or str(expression) - logger.error('Unsupported Operation: %s', final_message) - raise UnsupportedOperationError(message=final_message) - return await function(self, *args, **kwargs) - - return async_wrapper - - @functools.wraps(function) - def sync_wrapper(self: Any, *args, **kwargs) -> Any: - if not expression(self): - final_message = error_message or str(expression) - logger.error('Unsupported Operation: %s', final_message) - raise UnsupportedOperationError(message=final_message) - return function(self, *args, **kwargs) - - return sync_wrapper - - return decorator - - -def are_modalities_compatible( - server_output_modes: list[str] | None, client_output_modes: list[str] | None -) -> bool: - """Checks if server and client output modalities (MIME types) are compatible. - - Modalities are compatible if: - 1. The client specifies no preferred output modes (client_output_modes is None or empty). - 2. The server specifies no supported output modes (server_output_modes is None or empty). - 3. There is at least one common modality between the server's supported list and the client's preferred list. - - Args: - server_output_modes: A list of MIME types supported by the server/agent for output. - Can be None or empty if the server doesn't specify. - client_output_modes: A list of MIME types preferred by the client for output. - Can be None or empty if the client accepts any. - - Returns: - True if the modalities are compatible, False otherwise. - """ - if client_output_modes is None or len(client_output_modes) == 0: - return True - - if server_output_modes is None or len(server_output_modes) == 0: - return True - - return any(x in server_output_modes for x in client_output_modes) - - -def _clean_empty(d: Any) -> Any: - """Recursively remove empty strings, lists and dicts from a dictionary.""" - if isinstance(d, dict): - cleaned_dict = { - k: cleaned_v - for k, v in d.items() - if (cleaned_v := _clean_empty(v)) is not None - } - return cleaned_dict or None - if isinstance(d, list): - cleaned_list = [ - cleaned_v for v in d if (cleaned_v := _clean_empty(v)) is not None - ] - return cleaned_list or None - if isinstance(d, str) and not d: - return None - return d - - -def canonicalize_agent_card(agent_card: AgentCard) -> str: - """Canonicalizes the Agent Card JSON according to RFC 8785 (JCS).""" - card_dict = MessageToDict( - agent_card, - ) - # Remove signatures field if present - card_dict.pop('signatures', None) - - # Recursively remove empty values - cleaned_dict = _clean_empty(card_dict) - return json.dumps(cleaned_dict, separators=(',', ':'), sort_keys=True) - - -async def maybe_await(value: T | Awaitable[T]) -> T: - """Awaits a value if it's awaitable, otherwise simply provides it back.""" - if inspect.isawaitable(value): - return await value - return value - - -def validate_version(expected_version: str) -> Callable[[F], F]: - """Decorator that validates the A2A-Version header in the request context. - - The header name is defined by `constants.VERSION_HEADER` ('A2A-Version'). - If the header is missing or empty, it is interpreted as `constants.PROTOCOL_VERSION_0_3` ('0.3'). - If the version in the header does not match the `expected_version` (major and minor parts), - a `VersionNotSupportedError` is raised. Patch version is ignored. - - This decorator supports both async methods and async generator methods. It - expects a `ServerCallContext` to be present either in the arguments or - keyword arguments of the decorated method. - - Args: - expected_version: The A2A protocol version string expected by the method. - - Returns: - The decorated function. - - Raises: - VersionNotSupportedError: If the version in the request does not match `expected_version`. - """ - try: - expected_v = Version(expected_version) - except InvalidVersion: - # If the expected version is not a valid semver, we can't do major/minor comparison. - # This shouldn't happen with our constants. - expected_v = None - - def decorator(func: F) -> F: - def _get_actual_version( - args: tuple[Any, ...], kwargs: dict[str, Any] - ) -> str: - context = kwargs.get('context') - if context is None: - for arg in args: - if isinstance(arg, ServerCallContext): - context = arg - break - - if context is None: - # If no context is found, we can't validate the version. - # In a real scenario, this shouldn't happen for properly routed requests. - # We default to the expected version to allow test call to proceed. - return expected_version - - headers = context.state.get('headers', {}) - # Header names are usually case-insensitive in most frameworks, but dict lookup is case-sensitive. - # We check both standard and lowercase versions. - actual_version = headers.get( - constants.VERSION_HEADER - ) or headers.get(constants.VERSION_HEADER.lower()) - - if not actual_version: - return constants.PROTOCOL_VERSION_0_3 - - return str(actual_version) - - def _is_version_compatible(actual: str) -> bool: - if actual == expected_version: - return True - if not expected_v: - return False - try: - actual_v = Version(actual) - except InvalidVersion: - return False - else: - return actual_v.major == expected_v.major - - if inspect.isasyncgenfunction(inspect.unwrap(func)): - - @functools.wraps(func) - def async_gen_wrapper( - *args: Any, **kwargs: Any - ) -> AsyncIterator[Any]: - actual_version = _get_actual_version(args, kwargs) - if not _is_version_compatible(actual_version): - logger.warning( - "Version mismatch: actual='%s', expected='%s'", - actual_version, - expected_version, - ) - raise VersionNotSupportedError( - message=f"A2A version '{actual_version}' is not supported by this handler. " - f"Expected version '{expected_version}'." - ) - return func(*args, **kwargs) - - return cast('F', async_gen_wrapper) - - @functools.wraps(func) - async def async_wrapper(*args: Any, **kwargs: Any) -> Any: - actual_version = _get_actual_version(args, kwargs) - if not _is_version_compatible(actual_version): - logger.warning( - "Version mismatch: actual='%s', expected='%s'", - actual_version, - expected_version, - ) - raise VersionNotSupportedError( - message=f"A2A version '{actual_version}' is not supported by this handler. " - f"Expected version '{expected_version}'." - ) - return await func(*args, **kwargs) - - return cast('F', async_wrapper) - - return decorator diff --git a/src/a2a/utils/message.py b/src/a2a/utils/message.py deleted file mode 100644 index 528d952f4..000000000 --- a/src/a2a/utils/message.py +++ /dev/null @@ -1,71 +0,0 @@ -"""Utility functions for creating and handling A2A Message objects.""" - -import uuid - -from a2a.types.a2a_pb2 import ( - Message, - Part, - Role, -) -from a2a.utils.parts import get_text_parts - - -def new_agent_text_message( - text: str, - context_id: str | None = None, - task_id: str | None = None, -) -> Message: - """Creates a new agent message containing a single text Part. - - Args: - text: The text content of the message. - context_id: The context ID for the message. - task_id: The task ID for the message. - - Returns: - A new `Message` object with role 'agent'. - """ - return Message( - role=Role.ROLE_AGENT, - parts=[Part(text=text)], - message_id=str(uuid.uuid4()), - task_id=task_id, - context_id=context_id, - ) - - -def new_agent_parts_message( - parts: list[Part], - context_id: str | None = None, - task_id: str | None = None, -) -> Message: - """Creates a new agent message containing a list of Parts. - - Args: - parts: The list of `Part` objects for the message content. - context_id: The context ID for the message. - task_id: The task ID for the message. - - Returns: - A new `Message` object with role 'agent'. - """ - return Message( - role=Role.ROLE_AGENT, - parts=parts, - message_id=str(uuid.uuid4()), - task_id=task_id, - context_id=context_id, - ) - - -def get_message_text(message: Message, delimiter: str = '\n') -> str: - """Extracts and joins all text content from a Message's parts. - - Args: - message: The `Message` object. - delimiter: The string to use when joining text from multiple text Parts. - - Returns: - A single string containing all text content, or an empty string if no text parts are found. - """ - return delimiter.join(get_text_parts(message.parts)) diff --git a/src/a2a/utils/parts.py b/src/a2a/utils/parts.py deleted file mode 100644 index c9b964540..000000000 --- a/src/a2a/utils/parts.py +++ /dev/null @@ -1,46 +0,0 @@ -"""Utility functions for creating and handling A2A Parts objects.""" - -from collections.abc import Sequence -from typing import Any - -from google.protobuf.json_format import MessageToDict - -from a2a.types.a2a_pb2 import ( - Part, -) - - -def get_text_parts(parts: Sequence[Part]) -> list[str]: - """Extracts text content from all text Parts. - - Args: - parts: A sequence of `Part` objects. - - Returns: - A list of strings containing the text content from any text Parts found. - """ - return [part.text for part in parts if part.HasField('text')] - - -def get_data_parts(parts: Sequence[Part]) -> list[Any]: - """Extracts data from all data Parts in a list of Parts. - - Args: - parts: A sequence of `Part` objects. - - Returns: - A list of values containing the data from any data Parts found. - """ - return [MessageToDict(part.data) for part in parts if part.HasField('data')] - - -def get_file_parts(parts: Sequence[Part]) -> list[Part]: - """Extracts file parts from a list of Parts. - - Args: - parts: A sequence of `Part` objects. - - Returns: - A list of `Part` objects containing file data (raw or url). - """ - return [part for part in parts if part.raw or part.url] diff --git a/src/a2a/utils/proto_utils.py b/src/a2a/utils/proto_utils.py index f77593297..b191f98e0 100644 --- a/src/a2a/utils/proto_utils.py +++ b/src/a2a/utils/proto_utils.py @@ -174,7 +174,10 @@ def parse_params(params: QueryParams, message: ProtobufMessage) -> None: field = fields[k] v_list = params.getlist(k) - if field.label == field.LABEL_REPEATED: + # TODO(https://github.com/a2aproject/a2a-python/issues/1011): Replace + # deprecated `field.label` with `field.is_repeated` once the minimum + # protobuf version requirement is bumped. + if field.label == FieldDescriptor.LABEL_REPEATED: accumulated: list[Any] = [] for v in v_list: if not v: @@ -208,7 +211,10 @@ def _check_required_field_violation( ) -> ValidationDetail | None: """Check if a required field is missing or invalid.""" val = getattr(msg, field.name) - if field.is_repeated: + # TODO(https://github.com/a2aproject/a2a-python/issues/1011): Replace + # deprecated `field.label` with `field.is_repeated` once the minimum + # protobuf version requirement is bumped. + if field.label == FieldDescriptor.LABEL_REPEATED: if not val: return ValidationDetail( field=field.name, @@ -249,7 +255,10 @@ def _recurse_validation( return errors val = getattr(msg, field.name) - if not field.is_repeated: + # TODO(https://github.com/a2aproject/a2a-python/issues/1011): Replace + # deprecated `field.label` with `field.is_repeated` once the minimum + # protobuf version requirement is bumped. + if field.label != FieldDescriptor.LABEL_REPEATED: if msg.HasField(field.name): sub_errs = _validate_proto_required_fields_internal(val) _append_nested_errors(errors, field.name, sub_errs) diff --git a/src/a2a/utils/signing.py b/src/a2a/utils/signing.py index 68924c8a0..aa720d159 100644 --- a/src/a2a/utils/signing.py +++ b/src/a2a/utils/signing.py @@ -3,7 +3,7 @@ from collections.abc import Callable from typing import Any, TypedDict -from a2a.utils.helpers import canonicalize_agent_card +from google.protobuf.json_format import MessageToDict try: @@ -68,7 +68,7 @@ def create_agent_card_signer( def agent_card_signer(agent_card: AgentCard) -> AgentCard: """Signs agent card.""" - canonical_payload = canonicalize_agent_card(agent_card) + canonical_payload = _canonicalize_agent_card(agent_card) payload_dict = json.loads(canonical_payload) jws_string = jwt.encode( @@ -128,7 +128,7 @@ def signature_verifier( jku = protected_header.get('jku') verification_key = key_provider(kid, jku) - canonical_payload = canonicalize_agent_card(agent_card) + canonical_payload = _canonicalize_agent_card(agent_card) encoded_payload = base64url_encode( canonical_payload.encode('utf-8') ).decode('utf-8') @@ -148,3 +148,35 @@ def signature_verifier( raise InvalidSignaturesError('No valid signature found') return signature_verifier + + +def _clean_empty(d: Any) -> Any: + """Recursively remove empty strings, lists and dicts from a dictionary.""" + if isinstance(d, dict): + cleaned_dict = { + k: cleaned_v + for k, v in d.items() + if (cleaned_v := _clean_empty(v)) is not None + } + return cleaned_dict or None + if isinstance(d, list): + cleaned_list = [ + cleaned_v for v in d if (cleaned_v := _clean_empty(v)) is not None + ] + return cleaned_list or None + if isinstance(d, str) and not d: + return None + return d + + +def _canonicalize_agent_card(agent_card: AgentCard) -> str: + """Canonicalizes the Agent Card JSON according to RFC 8785 (JCS).""" + card_dict = MessageToDict( + agent_card, + ) + # Remove signatures field if present + card_dict.pop('signatures', None) + + # Recursively remove empty values + cleaned_dict = _clean_empty(card_dict) + return json.dumps(cleaned_dict, separators=(',', ':'), sort_keys=True) diff --git a/src/a2a/utils/task.py b/src/a2a/utils/task.py index 6ff716a30..4acf54e46 100644 --- a/src/a2a/utils/task.py +++ b/src/a2a/utils/task.py @@ -1,89 +1,15 @@ """Utility functions for creating A2A Task objects.""" import binascii -import uuid from base64 import b64decode, b64encode from typing import Literal, Protocol, runtime_checkable -from a2a.types.a2a_pb2 import ( - Artifact, - Message, - Task, - TaskState, - TaskStatus, -) +from a2a.types.a2a_pb2 import Task from a2a.utils.constants import MAX_LIST_TASKS_PAGE_SIZE from a2a.utils.errors import InvalidParamsError -def new_task(request: Message) -> Task: - """Creates a new Task object from an initial user message. - - Generates task and context IDs if not provided in the message. - - Args: - request: The initial `Message` object from the user. - - Returns: - A new `Task` object initialized with 'submitted' status and the input message in history. - - Raises: - TypeError: If the message role is None. - ValueError: If the message parts are empty, if any part has empty content, or if the provided context_id is invalid. - """ - if not request.role: - raise TypeError('Message role cannot be None') - if not request.parts: - raise ValueError('Message parts cannot be empty') - for part in request.parts: - if part.HasField('text') and not part.text: - raise ValueError('Message.text cannot be empty') - - return Task( - status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED), - id=request.task_id or str(uuid.uuid4()), - context_id=request.context_id or str(uuid.uuid4()), - history=[request], - ) - - -def completed_task( - task_id: str, - context_id: str, - artifacts: list[Artifact], - history: list[Message] | None = None, -) -> Task: - """Creates a Task object in the 'completed' state. - - Useful for constructing a final Task representation when the agent - finishes and produces artifacts. - - Args: - task_id: The ID of the task. - context_id: The context ID of the task. - artifacts: A list of `Artifact` objects produced by the task. - history: An optional list of `Message` objects representing the task history. - - Returns: - A `Task` object with status set to 'completed'. - """ - if not artifacts or not all(isinstance(a, Artifact) for a in artifacts): - raise ValueError( - 'artifacts must be a non-empty list of Artifact objects' - ) - - if history is None: - history = [] - return Task( - status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), - id=task_id, - context_id=context_id, - artifacts=artifacts, - history=history, - ) - - @runtime_checkable class HistoryLengthConfig(Protocol): """Protocol for configuration arguments containing history_length field.""" diff --git a/src/a2a/utils/version_validator.py b/src/a2a/utils/version_validator.py new file mode 100644 index 000000000..4a776c27e --- /dev/null +++ b/src/a2a/utils/version_validator.py @@ -0,0 +1,130 @@ +"""General utility functions for the A2A Python SDK.""" + +import functools +import inspect +import logging + +from collections.abc import AsyncIterator, Callable +from typing import Any, TypeVar, cast + +from packaging.version import InvalidVersion, Version + +from a2a.server.context import ServerCallContext +from a2a.utils import constants +from a2a.utils.errors import VersionNotSupportedError + + +F = TypeVar('F', bound=Callable[..., Any]) + + +logger = logging.getLogger(__name__) + + +def validate_version(expected_version: str) -> Callable[[F], F]: + """Decorator that validates the A2A-Version header in the request context. + + The header name is defined by `constants.VERSION_HEADER` ('A2A-Version'). + If the header is missing or empty, it is interpreted as `constants.PROTOCOL_VERSION_0_3` ('0.3'). + If the version in the header does not match the `expected_version` (major and minor parts), + a `VersionNotSupportedError` is raised. Patch version is ignored. + + This decorator supports both async methods and async generator methods. It + expects a `ServerCallContext` to be present either in the arguments or + keyword arguments of the decorated method. + + Args: + expected_version: The A2A protocol version string expected by the method. + + Returns: + The decorated function. + + Raises: + VersionNotSupportedError: If the version in the request does not match `expected_version`. + """ + try: + expected_v = Version(expected_version) + except InvalidVersion: + # If the expected version is not a valid semver, we can't do major/minor comparison. + # This shouldn't happen with our constants. + expected_v = None + + def decorator(func: F) -> F: + def _get_actual_version( + args: tuple[Any, ...], kwargs: dict[str, Any] + ) -> str: + context = kwargs.get('context') + if context is None: + for arg in args: + if isinstance(arg, ServerCallContext): + context = arg + break + + if context is None: + # If no context is found, we can't validate the version. + # In a real scenario, this shouldn't happen for properly routed requests. + # We default to the expected version to allow test call to proceed. + return expected_version + + headers = context.state.get('headers', {}) + # Header names are usually case-insensitive in most frameworks, but dict lookup is case-sensitive. + # We check both standard and lowercase versions. + actual_version = headers.get( + constants.VERSION_HEADER + ) or headers.get(constants.VERSION_HEADER.lower()) + + if not actual_version: + return constants.PROTOCOL_VERSION_0_3 + + return str(actual_version) + + def _is_version_compatible(actual: str) -> bool: + if actual == expected_version: + return True + if not expected_v: + return False + try: + actual_v = Version(actual) + except InvalidVersion: + return False + else: + return actual_v.major == expected_v.major + + if inspect.isasyncgenfunction(inspect.unwrap(func)): + + @functools.wraps(func) + def async_gen_wrapper( + *args: Any, **kwargs: Any + ) -> AsyncIterator[Any]: + actual_version = _get_actual_version(args, kwargs) + if not _is_version_compatible(actual_version): + logger.warning( + "Version mismatch: actual='%s', expected='%s'", + actual_version, + expected_version, + ) + raise VersionNotSupportedError( + message=f"A2A version '{actual_version}' is not supported by this handler. " + f"Expected version '{expected_version}'." + ) + return func(*args, **kwargs) + + return cast('F', async_gen_wrapper) + + @functools.wraps(func) + async def async_wrapper(*args: Any, **kwargs: Any) -> Any: + actual_version = _get_actual_version(args, kwargs) + if not _is_version_compatible(actual_version): + logger.warning( + "Version mismatch: actual='%s', expected='%s'", + actual_version, + expected_version, + ) + raise VersionNotSupportedError( + message=f"A2A version '{actual_version}' is not supported by this handler. " + f"Expected version '{expected_version}'." + ) + return await func(*args, **kwargs) + + return cast('F', async_wrapper) + + return decorator diff --git a/tck/sut_agent.py b/tck/sut_agent.py index 259b16a5d..0ca3a1450 100644 --- a/tck/sut_agent.py +++ b/tck/sut_agent.py @@ -17,9 +17,7 @@ from a2a.server.agent_execution.agent_executor import AgentExecutor from a2a.server.agent_execution.context import RequestContext from a2a.server.events.event_queue import EventQueue -from a2a.server.request_handlers.default_request_handler import ( - DefaultRequestHandler, -) +from a2a.server.request_handlers import DefaultRequestHandler from a2a.server.request_handlers.grpc_handler import GrpcHandler from a2a.server.routes import ( create_agent_card_routes, @@ -193,13 +191,13 @@ def serve(task_store: TaskStore) -> None: ) request_handler = DefaultRequestHandler( + agent_card=agent_card, agent_executor=SUTAgentExecutor(), task_store=task_store, ) # JSONRPC jsonrpc_routes = create_jsonrpc_routes( - agent_card=agent_card, request_handler=request_handler, rpc_url=JSONRPC_URL, ) @@ -209,7 +207,6 @@ def serve(task_store: TaskStore) -> None: ) # REST rest_routes = create_rest_routes( - agent_card=agent_card, request_handler=request_handler, path_prefix=REST_URL, ) @@ -229,8 +226,8 @@ def serve(task_store: TaskStore) -> None: # GRPC grpc_server = grpc.aio.server() grpc_server.add_insecure_port(f'[::]:{grpc_port}') - servicer = GrpcHandler(agent_card, request_handler) - compat_servicer = CompatGrpcHandler(agent_card, request_handler) + servicer = GrpcHandler(request_handler) + compat_servicer = CompatGrpcHandler(request_handler) a2a_grpc.add_A2AServiceServicer_to_server(servicer, grpc_server) a2a_v0_3_grpc.add_A2AServiceServicer_to_server(compat_servicer, grpc_server) diff --git a/tests/client/test_auth_interceptor.py b/tests/client/test_auth_interceptor.py index 11d932090..560751fa8 100644 --- a/tests/client/test_auth_interceptor.py +++ b/tests/client/test_auth_interceptor.py @@ -240,7 +240,6 @@ class AuthTestCase: ) -@pytest.mark.skip(reason='Interceptors disabled by user request') @pytest.mark.asyncio @pytest.mark.parametrize( 'test_case', diff --git a/tests/client/test_card_resolver.py b/tests/client/test_card_resolver.py index 9a684a4ac..ff60632ad 100644 --- a/tests/client/test_card_resolver.py +++ b/tests/client/test_card_resolver.py @@ -1,13 +1,35 @@ +import copy +import difflib import json import logging - from unittest.mock import AsyncMock, MagicMock, Mock +from google.protobuf.json_format import MessageToDict import httpx import pytest from a2a.client import A2ACardResolver, AgentCardResolutionError +from a2a.client.card_resolver import parse_agent_card +from a2a.server.request_handlers.response_helpers import agent_card_to_dict from a2a.types import AgentCard +from a2a.types.a2a_pb2 import ( + APIKeySecurityScheme, + AgentCapabilities, + AgentCardSignature, + AgentInterface, + AgentProvider, + AgentSkill, + AuthorizationCodeOAuthFlow, + HTTPAuthSecurityScheme, + MutualTlsSecurityScheme, + OAuth2SecurityScheme, + OAuthFlows, + OpenIdConnectSecurityScheme, + Role, + SecurityRequirement, + SecurityScheme, + StringList, +) from a2a.utils import AGENT_CARD_WELL_KNOWN_PATH @@ -388,3 +410,680 @@ async def test_get_agent_card_with_signature_verifier( ) mock_verifier.assert_called_once_with(agent_card) + + +class TestParseAgentCard: + """Tests for parse_agent_card function.""" + + @staticmethod + def _assert_agent_card_diff( + original_data: dict, serialized_data: dict + ) -> None: + """Helper to assert that the re-serialized 1.0.0 JSON payload contains all original 0.3.0 data (no dropped fields).""" + original_json_str = json.dumps(original_data, indent=2, sort_keys=True) + serialized_json_str = json.dumps( + serialized_data, indent=2, sort_keys=True + ) + + diff_lines = list( + difflib.unified_diff( + original_json_str.splitlines(), + serialized_json_str.splitlines(), + lineterm='', + ) + ) + + removed_lines = [] + for line in diff_lines: + if line.startswith('-') and not line.startswith('---'): + removed_lines.append(line) + + if removed_lines: + error_msg = ( + 'Re-serialization dropped fields from the original payload:\n' + + '\n'.join(removed_lines) + ) + raise AssertionError(error_msg) + + def test_parse_agent_card_legacy_support(self) -> None: + data = { + 'name': 'Legacy Agent', + 'description': 'Legacy Description', + 'version': '1.0', + 'supportsAuthenticatedExtendedCard': True, + } + card = parse_agent_card(data) + assert card.name == 'Legacy Agent' + assert card.capabilities.extended_agent_card is True + # Ensure it's popped from the dict + assert 'supportsAuthenticatedExtendedCard' not in data + + def test_parse_agent_card_new_support(self) -> None: + data = { + 'name': 'New Agent', + 'description': 'New Description', + 'version': '1.0', + 'capabilities': {'extendedAgentCard': True}, + } + card = parse_agent_card(data) + assert card.name == 'New Agent' + assert card.capabilities.extended_agent_card is True + + def test_parse_agent_card_no_support(self) -> None: + data = { + 'name': 'No Support Agent', + 'description': 'No Support Description', + 'version': '1.0', + 'capabilities': {'extendedAgentCard': False}, + } + card = parse_agent_card(data) + assert card.name == 'No Support Agent' + assert card.capabilities.extended_agent_card is False + + def test_parse_agent_card_both_legacy_and_new(self) -> None: + data = { + 'name': 'Mixed Agent', + 'description': 'Mixed Description', + 'version': '1.0', + 'supportsAuthenticatedExtendedCard': True, + 'capabilities': {'streaming': True}, + } + card = parse_agent_card(data) + assert card.name == 'Mixed Agent' + assert card.capabilities.streaming is True + assert card.capabilities.extended_agent_card is True + + def test_parse_typical_030_agent_card(self) -> None: + data = { + 'additionalInterfaces': [ + { + 'transport': 'GRPC', + 'url': 'http://agent.example.com/api/grpc', + } + ], + 'capabilities': {'streaming': True}, + 'defaultInputModes': ['text/plain'], + 'defaultOutputModes': ['application/json'], + 'description': 'A typical agent from 0.3.0', + 'name': 'Typical Agent 0.3', + 'preferredTransport': 'JSONRPC', + 'protocolVersion': '0.3.0', + 'security': [{'test_oauth': ['read', 'write']}], + 'securitySchemes': { + 'test_oauth': { + 'description': 'OAuth2 authentication', + 'flows': { + 'authorizationCode': { + 'authorizationUrl': 'http://auth.example.com', + 'scopes': { + 'read': 'Read access', + 'write': 'Write access', + }, + 'tokenUrl': 'http://token.example.com', + } + }, + 'type': 'oauth2', + } + }, + 'skills': [ + { + 'description': 'The first skill', + 'id': 'skill-1', + 'name': 'Skill 1', + 'security': [{'test_oauth': ['read']}], + 'tags': ['example'], + } + ], + 'supportsAuthenticatedExtendedCard': True, + 'url': 'http://agent.example.com/api', + 'version': '1.0', + } + original_data = copy.deepcopy(data) + card = parse_agent_card(data) + + expected_card = AgentCard( + name='Typical Agent 0.3', + description='A typical agent from 0.3.0', + version='1.0', + capabilities=AgentCapabilities( + extended_agent_card=True, streaming=True + ), + default_input_modes=['text/plain'], + default_output_modes=['application/json'], + supported_interfaces=[ + AgentInterface( + url='http://agent.example.com/api', + protocol_binding='JSONRPC', + protocol_version='0.3.0', + ), + AgentInterface( + url='http://agent.example.com/api/grpc', + protocol_binding='GRPC', + protocol_version='0.3.0', + ), + ], + security_requirements=[ + SecurityRequirement( + schemes={'test_oauth': StringList(list=['read', 'write'])} + ) + ], + security_schemes={ + 'test_oauth': SecurityScheme( + oauth2_security_scheme=OAuth2SecurityScheme( + description='OAuth2 authentication', + flows=OAuthFlows( + authorization_code=AuthorizationCodeOAuthFlow( + authorization_url='http://auth.example.com', + token_url='http://token.example.com', + scopes={ + 'read': 'Read access', + 'write': 'Write access', + }, + ) + ), + ) + ) + }, + skills=[ + AgentSkill( + id='skill-1', + name='Skill 1', + description='The first skill', + tags=['example'], + security_requirements=[ + SecurityRequirement( + schemes={'test_oauth': StringList(list=['read'])} + ) + ], + ) + ], + ) + + assert card == expected_card + + # Serialize back to JSON and compare + serialized_data = agent_card_to_dict(card) + + self._assert_agent_card_diff(original_data, serialized_data) + assert 'preferredTransport' in serialized_data + + # Re-parse from the serialized payload and verify identical to original parsing + re_parsed_card = parse_agent_card(copy.deepcopy(serialized_data)) + assert re_parsed_card == card + + def test_parse_agent_card_security_scheme_without_in(self) -> None: + data = { + 'name': 'API Key Agent', + 'description': 'API Key without in param', + 'version': '1.0', + 'securitySchemes': { + 'test_api_key': {'type': 'apiKey', 'name': 'X-API-KEY'} + }, + } + card = parse_agent_card(data) + assert 'test_api_key' in card.security_schemes + assert ( + card.security_schemes['test_api_key'].api_key_security_scheme.name + == 'X-API-KEY' + ) + assert ( + card.security_schemes[ + 'test_api_key' + ].api_key_security_scheme.location + == '' + ) + + def test_parse_agent_card_security_scheme_unknown_type(self) -> None: + data = { + 'name': 'Unknown Scheme Agent', + 'description': 'Has unknown scheme type', + 'version': '1.0', + 'securitySchemes': { + 'test_unknown': { + 'type': 'someFutureType', + 'future_prop': 'value', + }, + 'test_missing_type': {'prop': 'value'}, + }, + } + card = parse_agent_card(data) + assert 'test_unknown' in card.security_schemes + assert not card.security_schemes['test_unknown'].WhichOneof('scheme') + + assert 'test_missing_type' in card.security_schemes + assert not card.security_schemes['test_missing_type'].WhichOneof( + 'scheme' + ) + + def test_parse_030_agent_card_route_planner(self) -> None: + data = { + 'protocolVersion': '0.3', + 'name': 'GeoSpatial Route Planner Agent', + 'description': 'Provides advanced route planning.', + 'url': 'https://georoute-agent.example.com/a2a/v1', + 'preferredTransport': 'JSONRPC', + 'additionalInterfaces': [ + { + 'url': 'https://georoute-agent.example.com/a2a/v1', + 'transport': 'JSONRPC', + }, + { + 'url': 'https://georoute-agent.example.com/a2a/grpc', + 'transport': 'GRPC', + }, + { + 'url': 'https://georoute-agent.example.com/a2a/json', + 'transport': 'HTTP+JSON', + }, + ], + 'provider': { + 'organization': 'Example Geo Services Inc.', + 'url': 'https://www.examplegeoservices.com', + }, + 'iconUrl': 'https://georoute-agent.example.com/icon.png', + 'version': '1.2.0', + 'documentationUrl': 'https://docs.examplegeoservices.com/georoute-agent/api', + 'supportsAuthenticatedExtendedCard': True, + 'capabilities': { + 'streaming': True, + 'pushNotifications': True, + 'stateTransitionHistory': False, + }, + 'securitySchemes': { + 'google': { + 'type': 'openIdConnect', + 'openIdConnectUrl': 'https://accounts.google.com/.well-known/openid-configuration', + } + }, + 'security': [{'google': ['openid', 'profile', 'email']}], + 'defaultInputModes': ['application/json', 'text/plain'], + 'defaultOutputModes': ['application/json', 'image/png'], + 'skills': [ + { + 'id': 'route-optimizer-traffic', + 'name': 'Traffic-Aware Route Optimizer', + 'description': 'Calculates the optimal driving route between two or more locations, taking into account real-time traffic conditions, road closures, and user preferences (e.g., avoid tolls, prefer highways).', + 'tags': [ + 'maps', + 'routing', + 'navigation', + 'directions', + 'traffic', + ], + 'examples': [ + "Plan a route from '1600 Amphitheatre Parkway, Mountain View, CA' to 'San Francisco International Airport' avoiding tolls.", + '{"origin": {"lat": 37.422, "lng": -122.084}, "destination": {"lat": 37.7749, "lng": -122.4194}, "preferences": ["avoid_ferries"]}', + ], + 'inputModes': ['application/json', 'text/plain'], + 'outputModes': [ + 'application/json', + 'application/vnd.geo+json', + 'text/html', + ], + 'security': [ + {'example': []}, + {'google': ['openid', 'profile', 'email']}, + ], + }, + { + 'id': 'custom-map-generator', + 'name': 'Personalized Map Generator', + 'description': 'Creates custom map images or interactive map views based on user-defined points of interest, routes, and style preferences. Can overlay data layers.', + 'tags': [ + 'maps', + 'customization', + 'visualization', + 'cartography', + ], + 'examples': [ + 'Generate a map of my upcoming road trip with all planned stops highlighted.', + 'Show me a map visualizing all coffee shops within a 1-mile radius of my current location.', + ], + 'inputModes': ['application/json'], + 'outputModes': [ + 'image/png', + 'image/jpeg', + 'application/json', + 'text/html', + ], + }, + ], + 'signatures': [ + { + 'protected': 'eyJhbGciOiJFUzI1NiIsInR5cCI6IkpPU0UiLCJraWQiOiJrZXktMSIsImprdSI6Imh0dHBzOi8vZXhhbXBsZS5jb20vYWdlbnQvandrcy5qc29uIn0', + 'signature': 'QFdkNLNszlGj3z3u0YQGt_T9LixY3qtdQpZmsTdDHDe3fXV9y9-B3m2-XgCpzuhiLt8E0tV6HXoZKHv4GtHgKQ', + } + ], + } + + original_data = copy.deepcopy(data) + card = parse_agent_card(data) + + expected_card = AgentCard( + name='GeoSpatial Route Planner Agent', + description='Provides advanced route planning.', + version='1.2.0', + documentation_url='https://docs.examplegeoservices.com/georoute-agent/api', + icon_url='https://georoute-agent.example.com/icon.png', + provider=AgentProvider( + organization='Example Geo Services Inc.', + url='https://www.examplegeoservices.com', + ), + capabilities=AgentCapabilities( + extended_agent_card=True, + streaming=True, + push_notifications=True, + ), + default_input_modes=['application/json', 'text/plain'], + default_output_modes=['application/json', 'image/png'], + supported_interfaces=[ + AgentInterface( + url='https://georoute-agent.example.com/a2a/v1', + protocol_binding='JSONRPC', + protocol_version='0.3', + ), + AgentInterface( + url='https://georoute-agent.example.com/a2a/v1', + protocol_binding='JSONRPC', + protocol_version='0.3', + ), + AgentInterface( + url='https://georoute-agent.example.com/a2a/grpc', + protocol_binding='GRPC', + protocol_version='0.3', + ), + AgentInterface( + url='https://georoute-agent.example.com/a2a/json', + protocol_binding='HTTP+JSON', + protocol_version='0.3', + ), + ], + security_requirements=[ + SecurityRequirement( + schemes={ + 'google': StringList( + list=['openid', 'profile', 'email'] + ) + } + ) + ], + security_schemes={ + 'google': SecurityScheme( + open_id_connect_security_scheme=OpenIdConnectSecurityScheme( + open_id_connect_url='https://accounts.google.com/.well-known/openid-configuration' + ) + ) + }, + skills=[ + AgentSkill( + id='route-optimizer-traffic', + name='Traffic-Aware Route Optimizer', + description='Calculates the optimal driving route between two or more locations, taking into account real-time traffic conditions, road closures, and user preferences (e.g., avoid tolls, prefer highways).', + tags=[ + 'maps', + 'routing', + 'navigation', + 'directions', + 'traffic', + ], + examples=[ + "Plan a route from '1600 Amphitheatre Parkway, Mountain View, CA' to 'San Francisco International Airport' avoiding tolls.", + '{"origin": {"lat": 37.422, "lng": -122.084}, "destination": {"lat": 37.7749, "lng": -122.4194}, "preferences": ["avoid_ferries"]}', + ], + input_modes=['application/json', 'text/plain'], + output_modes=[ + 'application/json', + 'application/vnd.geo+json', + 'text/html', + ], + security_requirements=[ + SecurityRequirement(schemes={'example': StringList()}), + SecurityRequirement( + schemes={ + 'google': StringList( + list=['openid', 'profile', 'email'] + ) + } + ), + ], + ), + AgentSkill( + id='custom-map-generator', + name='Personalized Map Generator', + description='Creates custom map images or interactive map views based on user-defined points of interest, routes, and style preferences. Can overlay data layers.', + tags=[ + 'maps', + 'customization', + 'visualization', + 'cartography', + ], + examples=[ + 'Generate a map of my upcoming road trip with all planned stops highlighted.', + 'Show me a map visualizing all coffee shops within a 1-mile radius of my current location.', + ], + input_modes=['application/json'], + output_modes=[ + 'image/png', + 'image/jpeg', + 'application/json', + 'text/html', + ], + ), + ], + signatures=[ + AgentCardSignature( + protected='eyJhbGciOiJFUzI1NiIsInR5cCI6IkpPU0UiLCJraWQiOiJrZXktMSIsImprdSI6Imh0dHBzOi8vZXhhbXBsZS5jb20vYWdlbnQvandrcy5qc29uIn0', + signature='QFdkNLNszlGj3z3u0YQGt_T9LixY3qtdQpZmsTdDHDe3fXV9y9-B3m2-XgCpzuhiLt8E0tV6HXoZKHv4GtHgKQ', + ) + ], + ) + + assert card == expected_card + serialized_data = agent_card_to_dict(card) + del original_data['capabilities']['stateTransitionHistory'] + self._assert_agent_card_diff(original_data, serialized_data) + re_parsed_card = parse_agent_card(copy.deepcopy(serialized_data)) + assert re_parsed_card == card + + def test_parse_complex_030_agent_card(self) -> None: + data = { + 'additionalInterfaces': [ + { + 'transport': 'GRPC', + 'url': 'http://complex.agent.example.com/grpc', + }, + { + 'transport': 'JSONRPC', + 'url': 'http://complex.agent.example.com/jsonrpc', + }, + ], + 'capabilities': {'pushNotifications': True, 'streaming': True}, + 'defaultInputModes': ['text/plain', 'application/json'], + 'defaultOutputModes': ['application/json', 'image/png'], + 'description': 'A very complex agent from 0.3.0', + 'name': 'Complex Agent 0.3', + 'preferredTransport': 'HTTP+JSON', + 'protocolVersion': '0.3.0', + 'security': [ + {'test_oauth': ['read', 'write'], 'test_api_key': []}, + {'test_http': []}, + {'test_oidc': ['openid', 'profile']}, + {'test_mtls': []}, + ], + 'securitySchemes': { + 'test_oauth': { + 'description': 'OAuth2 authentication', + 'flows': { + 'authorizationCode': { + 'authorizationUrl': 'http://auth.example.com', + 'scopes': { + 'read': 'Read access', + 'write': 'Write access', + }, + 'tokenUrl': 'http://token.example.com', + } + }, + 'type': 'oauth2', + }, + 'test_api_key': { + 'description': 'API Key auth', + 'in': 'header', + 'name': 'X-API-KEY', + 'type': 'apiKey', + }, + 'test_http': { + 'bearerFormat': 'JWT', + 'description': 'HTTP Basic auth', + 'scheme': 'basic', + 'type': 'http', + }, + 'test_oidc': { + 'description': 'OIDC Auth', + 'openIdConnectUrl': 'https://example.com/.well-known/openid-configuration', + 'type': 'openIdConnect', + }, + 'test_mtls': {'description': 'mTLS Auth', 'type': 'mutualTLS'}, + }, + 'skills': [ + { + 'description': 'The first complex skill', + 'id': 'skill-1', + 'inputModes': ['application/json'], + 'name': 'Complex Skill 1', + 'outputModes': ['application/json'], + 'security': [{'test_api_key': []}], + 'tags': ['example', 'complex'], + }, + { + 'description': 'The second complex skill', + 'id': 'skill-2', + 'name': 'Complex Skill 2', + 'security': [{'test_oidc': ['openid']}], + 'tags': ['example2'], + }, + ], + 'supportsAuthenticatedExtendedCard': True, + 'url': 'http://complex.agent.example.com/api', + 'version': '1.5.2', + } + original_data = copy.deepcopy(data) + card = parse_agent_card(data) + + expected_card = AgentCard( + name='Complex Agent 0.3', + description='A very complex agent from 0.3.0', + version='1.5.2', + capabilities=AgentCapabilities( + extended_agent_card=True, + streaming=True, + push_notifications=True, + ), + default_input_modes=['text/plain', 'application/json'], + default_output_modes=['application/json', 'image/png'], + supported_interfaces=[ + AgentInterface( + url='http://complex.agent.example.com/api', + protocol_binding='HTTP+JSON', + protocol_version='0.3.0', + ), + AgentInterface( + url='http://complex.agent.example.com/grpc', + protocol_binding='GRPC', + protocol_version='0.3.0', + ), + AgentInterface( + url='http://complex.agent.example.com/jsonrpc', + protocol_binding='JSONRPC', + protocol_version='0.3.0', + ), + ], + security_requirements=[ + SecurityRequirement( + schemes={ + 'test_oauth': StringList(list=['read', 'write']), + 'test_api_key': StringList(), + } + ), + SecurityRequirement(schemes={'test_http': StringList()}), + SecurityRequirement( + schemes={ + 'test_oidc': StringList(list=['openid', 'profile']) + } + ), + SecurityRequirement(schemes={'test_mtls': StringList()}), + ], + security_schemes={ + 'test_oauth': SecurityScheme( + oauth2_security_scheme=OAuth2SecurityScheme( + description='OAuth2 authentication', + flows=OAuthFlows( + authorization_code=AuthorizationCodeOAuthFlow( + authorization_url='http://auth.example.com', + token_url='http://token.example.com', + scopes={ + 'read': 'Read access', + 'write': 'Write access', + }, + ) + ), + ) + ), + 'test_api_key': SecurityScheme( + api_key_security_scheme=APIKeySecurityScheme( + description='API Key auth', + location='header', + name='X-API-KEY', + ) + ), + 'test_http': SecurityScheme( + http_auth_security_scheme=HTTPAuthSecurityScheme( + description='HTTP Basic auth', + scheme='basic', + bearer_format='JWT', + ) + ), + 'test_oidc': SecurityScheme( + open_id_connect_security_scheme=OpenIdConnectSecurityScheme( + description='OIDC Auth', + open_id_connect_url='https://example.com/.well-known/openid-configuration', + ) + ), + 'test_mtls': SecurityScheme( + mtls_security_scheme=MutualTlsSecurityScheme( + description='mTLS Auth' + ) + ), + }, + skills=[ + AgentSkill( + id='skill-1', + name='Complex Skill 1', + description='The first complex skill', + tags=['example', 'complex'], + input_modes=['application/json'], + output_modes=['application/json'], + security_requirements=[ + SecurityRequirement( + schemes={'test_api_key': StringList()} + ) + ], + ), + AgentSkill( + id='skill-2', + name='Complex Skill 2', + description='The second complex skill', + tags=['example2'], + security_requirements=[ + SecurityRequirement( + schemes={'test_oidc': StringList(list=['openid'])} + ) + ], + ), + ], + ) + + assert card == expected_card + serialized_data = agent_card_to_dict(card) + self._assert_agent_card_diff(original_data, serialized_data) + re_parsed_card = parse_agent_card(copy.deepcopy(serialized_data)) + assert re_parsed_card == card diff --git a/tests/client/test_client_factory.py b/tests/client/test_client_factory.py index a5366e0d3..b30d57d12 100644 --- a/tests/client/test_client_factory.py +++ b/tests/client/test_client_factory.py @@ -1,18 +1,16 @@ """Tests for the ClientFactory.""" -from collections.abc import AsyncGenerator from unittest.mock import AsyncMock, MagicMock, patch import typing import httpx import pytest -from a2a.client import ClientConfig, ClientFactory +from a2a.client import ClientConfig, ClientFactory, create_client from a2a.client.client_factory import TransportProducer from a2a.client.transports import ( JsonRpcTransport, RestTransport, - ClientTransport, ) from a2a.client.transports.tenant_decorator import TenantTransportDecorator from a2a.types.a2a_pb2 import ( @@ -127,26 +125,27 @@ def test_client_factory_no_compatible_transport(base_agent_card: AgentCard): factory.create(base_agent_card) -@pytest.mark.asyncio -async def test_client_factory_connect_with_agent_card( +def test_client_factory_create_with_default_config( base_agent_card: AgentCard, ): - """Verify that connect works correctly when provided with an AgentCard.""" - client = await ClientFactory.connect(base_agent_card) + """Verify that create works correctly with a default ClientConfig.""" + factory = ClientFactory() + client = factory.create(base_agent_card) assert isinstance(client._transport, JsonRpcTransport) # type: ignore[attr-defined] assert client._transport.url == 'http://primary-url.com' # type: ignore[attr-defined] @pytest.mark.asyncio -async def test_client_factory_connect_with_url(base_agent_card: AgentCard): - """Verify that connect works correctly when provided with a URL.""" +async def test_client_factory_create_from_url(base_agent_card: AgentCard): + """Verify that create_from_url resolves the card and creates a client.""" with patch('a2a.client.client_factory.A2ACardResolver') as mock_resolver: mock_resolver.return_value.get_agent_card = AsyncMock( return_value=base_agent_card ) agent_url = 'http://example.com' - client = await ClientFactory.connect(agent_url) + factory = ClientFactory() + client = await factory.create_from_url(agent_url) mock_resolver.assert_called_once() assert mock_resolver.call_args[0][1] == agent_url @@ -157,10 +156,10 @@ async def test_client_factory_connect_with_url(base_agent_card: AgentCard): @pytest.mark.asyncio -async def test_client_factory_connect_with_url_and_client_config( +async def test_client_factory_create_from_url_uses_factory_httpx_client( base_agent_card: AgentCard, ): - """Verify connect with a URL and a pre-configured httpx client.""" + """Verify create_from_url uses the factory's configured httpx client.""" with patch('a2a.client.client_factory.A2ACardResolver') as mock_resolver: mock_resolver.return_value.get_agent_card = AsyncMock( return_value=base_agent_card @@ -170,7 +169,8 @@ async def test_client_factory_connect_with_url_and_client_config( mock_httpx_client = httpx.AsyncClient() config = ClientConfig(httpx_client=mock_httpx_client) - client = await ClientFactory.connect(agent_url, client_config=config) + factory = ClientFactory(config) + client = await factory.create_from_url(agent_url) mock_resolver.assert_called_once_with(mock_httpx_client, agent_url) mock_resolver.return_value.get_agent_card.assert_awaited_once() @@ -180,10 +180,10 @@ async def test_client_factory_connect_with_url_and_client_config( @pytest.mark.asyncio -async def test_client_factory_connect_with_resolver_args( +async def test_client_factory_create_from_url_passes_resolver_args( base_agent_card: AgentCard, ): - """Verify connect passes resolver arguments correctly.""" + """Verify create_from_url passes resolver arguments correctly.""" with patch('a2a.client.client_factory.A2ACardResolver') as mock_resolver: mock_resolver.return_value.get_agent_card = AsyncMock( return_value=base_agent_card @@ -193,12 +193,11 @@ async def test_client_factory_connect_with_resolver_args( relative_path = '/extendedAgentCard' http_kwargs = {'headers': {'X-Test': 'true'}} - # The resolver args are only passed if an httpx_client is provided in config config = ClientConfig(httpx_client=httpx.AsyncClient()) + factory = ClientFactory(config) - await ClientFactory.connect( + await factory.create_from_url( agent_url, - client_config=config, relative_card_path=relative_path, resolver_http_kwargs=http_kwargs, ) @@ -211,10 +210,10 @@ async def test_client_factory_connect_with_resolver_args( @pytest.mark.asyncio -async def test_client_factory_connect_resolver_args_without_client( +async def test_client_factory_create_from_url_with_default_config( base_agent_card: AgentCard, ): - """Verify resolver args are ignored if no httpx_client is provided.""" + """Verify create_from_url works with a default ClientConfig.""" with patch('a2a.client.client_factory.A2ACardResolver') as mock_resolver: mock_resolver.return_value.get_agent_card = AsyncMock( return_value=base_agent_card @@ -224,12 +223,16 @@ async def test_client_factory_connect_resolver_args_without_client( relative_path = '/extendedAgentCard' http_kwargs = {'headers': {'X-Test': 'true'}} - await ClientFactory.connect( + factory = ClientFactory() + + await factory.create_from_url( agent_url, relative_card_path=relative_path, resolver_http_kwargs=http_kwargs, ) + # Factory always creates an httpx client, so resolver gets it + mock_resolver.assert_called_once() mock_resolver.return_value.get_agent_card.assert_awaited_once_with( relative_card_path=relative_path, http_kwargs=http_kwargs, @@ -237,16 +240,17 @@ async def test_client_factory_connect_resolver_args_without_client( ) -@pytest.mark.asyncio -async def test_client_factory_connect_with_extra_transports( +def test_client_factory_register_and_create_custom_transport( base_agent_card: AgentCard, ): - """Verify that connect can register and use extra transports.""" + """Verify that register() + create() uses custom transports.""" class CustomTransport: pass - def custom_transport_producer(*args, **kwargs): + def custom_transport_producer( + *args: typing.Any, **kwargs: typing.Any + ) -> CustomTransport: return CustomTransport() base_agent_card.supported_interfaces.insert( @@ -255,27 +259,60 @@ def custom_transport_producer(*args, **kwargs): ) config = ClientConfig(supported_protocol_bindings=['custom']) - - client = await ClientFactory.connect( - base_agent_card, - client_config=config, - extra_transports=typing.cast( - dict[str, TransportProducer], {'custom': custom_transport_producer} - ), + factory = ClientFactory(config) + factory.register( + 'custom', + typing.cast(TransportProducer, custom_transport_producer), ) + client = factory.create(base_agent_card) assert isinstance(client._transport, CustomTransport) # type: ignore[attr-defined] @pytest.mark.asyncio -async def test_client_factory_connect_with_interceptors( +async def test_client_factory_create_from_url_uses_registered_transports( + base_agent_card: AgentCard, +): + """Verify that create_from_url() respects custom transports from register().""" + + class CustomTransport: + pass + + def custom_transport_producer( + *args: typing.Any, **kwargs: typing.Any + ) -> CustomTransport: + return CustomTransport() + + base_agent_card.supported_interfaces.insert( + 0, + AgentInterface(protocol_binding='custom', url='custom://foo'), + ) + + with patch('a2a.client.client_factory.A2ACardResolver') as mock_resolver: + mock_resolver.return_value.get_agent_card = AsyncMock( + return_value=base_agent_card + ) + + config = ClientConfig(supported_protocol_bindings=['custom']) + factory = ClientFactory(config) + factory.register( + 'custom', + typing.cast(TransportProducer, custom_transport_producer), + ) + + client = await factory.create_from_url('http://example.com') + assert isinstance(client._transport, CustomTransport) # type: ignore[attr-defined] + + +def test_client_factory_create_with_interceptors( base_agent_card: AgentCard, ): """Verify interceptors are passed through correctly.""" interceptor1 = MagicMock() with patch('a2a.client.client_factory.BaseClient') as mock_base_client: - await ClientFactory.connect( + factory = ClientFactory() + factory.create( base_agent_card, interceptors=[interceptor1], ) @@ -298,3 +335,44 @@ def test_client_factory_applies_tenant_decorator(base_agent_card: AgentCard): assert isinstance(client._transport, TenantTransportDecorator) # type: ignore[attr-defined] assert client._transport._tenant == 'my-tenant' # type: ignore[attr-defined] assert isinstance(client._transport._base, JsonRpcTransport) # type: ignore[attr-defined] + + +@pytest.mark.asyncio +async def test_create_client_with_agent_card(base_agent_card: AgentCard): + """Verify create_client works when given an AgentCard directly.""" + client = await create_client(base_agent_card) + assert isinstance(client._transport, JsonRpcTransport) # type: ignore[attr-defined] + assert client._transport.url == 'http://primary-url.com' # type: ignore[attr-defined] + + +@pytest.mark.asyncio +async def test_create_client_with_url(base_agent_card: AgentCard): + """Verify create_client resolves a URL and creates a client.""" + with patch('a2a.client.client_factory.A2ACardResolver') as mock_resolver: + mock_resolver.return_value.get_agent_card = AsyncMock( + return_value=base_agent_card + ) + + client = await create_client('http://example.com') + + mock_resolver.assert_called_once() + assert mock_resolver.call_args[0][1] == 'http://example.com' + assert isinstance(client._transport, JsonRpcTransport) # type: ignore[attr-defined] + + +@pytest.mark.asyncio +async def test_create_client_with_url_and_config(base_agent_card: AgentCard): + """Verify create_client passes client_config to the factory.""" + with patch('a2a.client.client_factory.A2ACardResolver') as mock_resolver: + mock_resolver.return_value.get_agent_card = AsyncMock( + return_value=base_agent_card + ) + + mock_httpx_client = httpx.AsyncClient() + config = ClientConfig(httpx_client=mock_httpx_client) + + await create_client('http://example.com', client_config=config) + + mock_resolver.assert_called_once_with( + mock_httpx_client, 'http://example.com' + ) diff --git a/tests/client/test_client_helpers.py b/tests/client/test_client_helpers.py deleted file mode 100644 index 8963eefce..000000000 --- a/tests/client/test_client_helpers.py +++ /dev/null @@ -1,695 +0,0 @@ -import copy -import difflib -import json -from google.protobuf.json_format import MessageToDict - -from a2a.client.helpers import create_text_message_object, parse_agent_card -from a2a.server.request_handlers.response_helpers import agent_card_to_dict -from a2a.types.a2a_pb2 import ( - APIKeySecurityScheme, - AgentCapabilities, - AgentCard, - AgentCardSignature, - AgentInterface, - AgentProvider, - AgentSkill, - AuthorizationCodeOAuthFlow, - HTTPAuthSecurityScheme, - MutualTlsSecurityScheme, - OAuth2SecurityScheme, - OAuthFlows, - OpenIdConnectSecurityScheme, - Role, - SecurityRequirement, - SecurityScheme, - StringList, -) - - -def test_parse_agent_card_legacy_support() -> None: - data = { - 'name': 'Legacy Agent', - 'description': 'Legacy Description', - 'version': '1.0', - 'supportsAuthenticatedExtendedCard': True, - } - card = parse_agent_card(data) - assert card.name == 'Legacy Agent' - assert card.capabilities.extended_agent_card is True - # Ensure it's popped from the dict - assert 'supportsAuthenticatedExtendedCard' not in data - - -def test_parse_agent_card_new_support() -> None: - data = { - 'name': 'New Agent', - 'description': 'New Description', - 'version': '1.0', - 'capabilities': {'extendedAgentCard': True}, - } - card = parse_agent_card(data) - assert card.name == 'New Agent' - assert card.capabilities.extended_agent_card is True - - -def test_parse_agent_card_no_support() -> None: - data = { - 'name': 'No Support Agent', - 'description': 'No Support Description', - 'version': '1.0', - 'capabilities': {'extendedAgentCard': False}, - } - card = parse_agent_card(data) - assert card.name == 'No Support Agent' - assert card.capabilities.extended_agent_card is False - - -def test_parse_agent_card_both_legacy_and_new() -> None: - data = { - 'name': 'Mixed Agent', - 'description': 'Mixed Description', - 'version': '1.0', - 'supportsAuthenticatedExtendedCard': True, - 'capabilities': {'streaming': True}, - } - card = parse_agent_card(data) - assert card.name == 'Mixed Agent' - assert card.capabilities.streaming is True - assert card.capabilities.extended_agent_card is True - - -def _assert_agent_card_diff(original_data: dict, serialized_data: dict) -> None: - """Helper to assert that the re-serialized 1.0.0 JSON payload contains all original 0.3.0 data (no dropped fields).""" - original_json_str = json.dumps(original_data, indent=2, sort_keys=True) - serialized_json_str = json.dumps(serialized_data, indent=2, sort_keys=True) - - diff_lines = list( - difflib.unified_diff( - original_json_str.splitlines(), - serialized_json_str.splitlines(), - lineterm='', - ) - ) - - removed_lines = [] - for line in diff_lines: - if line.startswith('-') and not line.startswith('---'): - removed_lines.append(line) - - if removed_lines: - error_msg = ( - 'Re-serialization dropped fields from the original payload:\n' - + '\n'.join(removed_lines) - ) - raise AssertionError(error_msg) - - -def test_parse_typical_030_agent_card() -> None: - data = { - 'additionalInterfaces': [ - {'transport': 'GRPC', 'url': 'http://agent.example.com/api/grpc'} - ], - 'capabilities': {'streaming': True}, - 'defaultInputModes': ['text/plain'], - 'defaultOutputModes': ['application/json'], - 'description': 'A typical agent from 0.3.0', - 'name': 'Typical Agent 0.3', - 'preferredTransport': 'JSONRPC', - 'protocolVersion': '0.3.0', - 'security': [{'test_oauth': ['read', 'write']}], - 'securitySchemes': { - 'test_oauth': { - 'description': 'OAuth2 authentication', - 'flows': { - 'authorizationCode': { - 'authorizationUrl': 'http://auth.example.com', - 'scopes': { - 'read': 'Read access', - 'write': 'Write access', - }, - 'tokenUrl': 'http://token.example.com', - } - }, - 'type': 'oauth2', - } - }, - 'skills': [ - { - 'description': 'The first skill', - 'id': 'skill-1', - 'name': 'Skill 1', - 'security': [{'test_oauth': ['read']}], - 'tags': ['example'], - } - ], - 'supportsAuthenticatedExtendedCard': True, - 'url': 'http://agent.example.com/api', - 'version': '1.0', - } - original_data = copy.deepcopy(data) - card = parse_agent_card(data) - - expected_card = AgentCard( - name='Typical Agent 0.3', - description='A typical agent from 0.3.0', - version='1.0', - capabilities=AgentCapabilities( - extended_agent_card=True, streaming=True - ), - default_input_modes=['text/plain'], - default_output_modes=['application/json'], - supported_interfaces=[ - AgentInterface( - url='http://agent.example.com/api', - protocol_binding='JSONRPC', - protocol_version='0.3.0', - ), - AgentInterface( - url='http://agent.example.com/api/grpc', - protocol_binding='GRPC', - protocol_version='0.3.0', - ), - ], - security_requirements=[ - SecurityRequirement( - schemes={'test_oauth': StringList(list=['read', 'write'])} - ) - ], - security_schemes={ - 'test_oauth': SecurityScheme( - oauth2_security_scheme=OAuth2SecurityScheme( - description='OAuth2 authentication', - flows=OAuthFlows( - authorization_code=AuthorizationCodeOAuthFlow( - authorization_url='http://auth.example.com', - token_url='http://token.example.com', - scopes={ - 'read': 'Read access', - 'write': 'Write access', - }, - ) - ), - ) - ) - }, - skills=[ - AgentSkill( - id='skill-1', - name='Skill 1', - description='The first skill', - tags=['example'], - security_requirements=[ - SecurityRequirement( - schemes={'test_oauth': StringList(list=['read'])} - ) - ], - ) - ], - ) - - assert card == expected_card - - # Serialize back to JSON and compare - serialized_data = agent_card_to_dict(card) - - _assert_agent_card_diff(original_data, serialized_data) - assert 'preferredTransport' in serialized_data - - # Re-parse from the serialized payload and verify identical to original parsing - re_parsed_card = parse_agent_card(copy.deepcopy(serialized_data)) - assert re_parsed_card == card - - -def test_parse_agent_card_security_scheme_without_in() -> None: - data = { - 'name': 'API Key Agent', - 'description': 'API Key without in param', - 'version': '1.0', - 'securitySchemes': { - 'test_api_key': {'type': 'apiKey', 'name': 'X-API-KEY'} - }, - } - card = parse_agent_card(data) - assert 'test_api_key' in card.security_schemes - assert ( - card.security_schemes['test_api_key'].api_key_security_scheme.name - == 'X-API-KEY' - ) - assert ( - card.security_schemes['test_api_key'].api_key_security_scheme.location - == '' - ) - - -def test_parse_agent_card_security_scheme_unknown_type() -> None: - data = { - 'name': 'Unknown Scheme Agent', - 'description': 'Has unknown scheme type', - 'version': '1.0', - 'securitySchemes': { - 'test_unknown': {'type': 'someFutureType', 'future_prop': 'value'}, - 'test_missing_type': {'prop': 'value'}, - }, - } - card = parse_agent_card(data) - # the ParseDict ignore_unknown_fields=True handles the unknown fields. - # Because there is no mapping logic for 'someFutureType', the Protobuf - # creates an empty SecurityScheme message under those keys. - assert 'test_unknown' in card.security_schemes - assert not card.security_schemes['test_unknown'].WhichOneof('scheme') - - assert 'test_missing_type' in card.security_schemes - assert not card.security_schemes['test_missing_type'].WhichOneof('scheme') - - -def test_create_text_message_object() -> None: - msg = create_text_message_object(role=Role.ROLE_AGENT, content='Hello') - assert msg.role == Role.ROLE_AGENT - assert len(msg.parts) == 1 - assert msg.parts[0].text == 'Hello' - assert msg.message_id != '' - - -def test_parse_030_agent_card_route_planner() -> None: - data = { - 'protocolVersion': '0.3', - 'name': 'GeoSpatial Route Planner Agent', - 'description': 'Provides advanced route planning.', - 'url': 'https://georoute-agent.example.com/a2a/v1', - 'preferredTransport': 'JSONRPC', - 'additionalInterfaces': [ - { - 'url': 'https://georoute-agent.example.com/a2a/v1', - 'transport': 'JSONRPC', - }, - { - 'url': 'https://georoute-agent.example.com/a2a/grpc', - 'transport': 'GRPC', - }, - { - 'url': 'https://georoute-agent.example.com/a2a/json', - 'transport': 'HTTP+JSON', - }, - ], - 'provider': { - 'organization': 'Example Geo Services Inc.', - 'url': 'https://www.examplegeoservices.com', - }, - 'iconUrl': 'https://georoute-agent.example.com/icon.png', - 'version': '1.2.0', - 'documentationUrl': 'https://docs.examplegeoservices.com/georoute-agent/api', - 'supportsAuthenticatedExtendedCard': True, - 'capabilities': { - 'streaming': True, - 'pushNotifications': True, - 'stateTransitionHistory': False, - }, - 'securitySchemes': { - 'google': { - 'type': 'openIdConnect', - 'openIdConnectUrl': 'https://accounts.google.com/.well-known/openid-configuration', - } - }, - 'security': [{'google': ['openid', 'profile', 'email']}], - 'defaultInputModes': ['application/json', 'text/plain'], - 'defaultOutputModes': ['application/json', 'image/png'], - 'skills': [ - { - 'id': 'route-optimizer-traffic', - 'name': 'Traffic-Aware Route Optimizer', - 'description': 'Calculates the optimal driving route between two or more locations, taking into account real-time traffic conditions, road closures, and user preferences (e.g., avoid tolls, prefer highways).', - 'tags': [ - 'maps', - 'routing', - 'navigation', - 'directions', - 'traffic', - ], - 'examples': [ - "Plan a route from '1600 Amphitheatre Parkway, Mountain View, CA' to 'San Francisco International Airport' avoiding tolls.", - '{"origin": {"lat": 37.422, "lng": -122.084}, "destination": {"lat": 37.7749, "lng": -122.4194}, "preferences": ["avoid_ferries"]}', - ], - 'inputModes': ['application/json', 'text/plain'], - 'outputModes': [ - 'application/json', - 'application/vnd.geo+json', - 'text/html', - ], - 'security': [ - {'example': []}, - {'google': ['openid', 'profile', 'email']}, - ], - }, - { - 'id': 'custom-map-generator', - 'name': 'Personalized Map Generator', - 'description': 'Creates custom map images or interactive map views based on user-defined points of interest, routes, and style preferences. Can overlay data layers.', - 'tags': [ - 'maps', - 'customization', - 'visualization', - 'cartography', - ], - 'examples': [ - 'Generate a map of my upcoming road trip with all planned stops highlighted.', - 'Show me a map visualizing all coffee shops within a 1-mile radius of my current location.', - ], - 'inputModes': ['application/json'], - 'outputModes': [ - 'image/png', - 'image/jpeg', - 'application/json', - 'text/html', - ], - }, - ], - 'signatures': [ - { - 'protected': 'eyJhbGciOiJFUzI1NiIsInR5cCI6IkpPU0UiLCJraWQiOiJrZXktMSIsImprdSI6Imh0dHBzOi8vZXhhbXBsZS5jb20vYWdlbnQvandrcy5qc29uIn0', - 'signature': 'QFdkNLNszlGj3z3u0YQGt_T9LixY3qtdQpZmsTdDHDe3fXV9y9-B3m2-XgCpzuhiLt8E0tV6HXoZKHv4GtHgKQ', - } - ], - } - - original_data = copy.deepcopy(data) - card = parse_agent_card(data) - - expected_card = AgentCard( - name='GeoSpatial Route Planner Agent', - description='Provides advanced route planning.', - version='1.2.0', - documentation_url='https://docs.examplegeoservices.com/georoute-agent/api', - icon_url='https://georoute-agent.example.com/icon.png', - provider=AgentProvider( - organization='Example Geo Services Inc.', - url='https://www.examplegeoservices.com', - ), - capabilities=AgentCapabilities( - extended_agent_card=True, streaming=True, push_notifications=True - ), - default_input_modes=['application/json', 'text/plain'], - default_output_modes=['application/json', 'image/png'], - supported_interfaces=[ - AgentInterface( - url='https://georoute-agent.example.com/a2a/v1', - protocol_binding='JSONRPC', - protocol_version='0.3', - ), - AgentInterface( - url='https://georoute-agent.example.com/a2a/v1', - protocol_binding='JSONRPC', - protocol_version='0.3', - ), - AgentInterface( - url='https://georoute-agent.example.com/a2a/grpc', - protocol_binding='GRPC', - protocol_version='0.3', - ), - AgentInterface( - url='https://georoute-agent.example.com/a2a/json', - protocol_binding='HTTP+JSON', - protocol_version='0.3', - ), - ], - security_requirements=[ - SecurityRequirement( - schemes={ - 'google': StringList(list=['openid', 'profile', 'email']) - } - ) - ], - security_schemes={ - 'google': SecurityScheme( - open_id_connect_security_scheme=OpenIdConnectSecurityScheme( - open_id_connect_url='https://accounts.google.com/.well-known/openid-configuration' - ) - ) - }, - skills=[ - AgentSkill( - id='route-optimizer-traffic', - name='Traffic-Aware Route Optimizer', - description='Calculates the optimal driving route between two or more locations, taking into account real-time traffic conditions, road closures, and user preferences (e.g., avoid tolls, prefer highways).', - tags=['maps', 'routing', 'navigation', 'directions', 'traffic'], - examples=[ - "Plan a route from '1600 Amphitheatre Parkway, Mountain View, CA' to 'San Francisco International Airport' avoiding tolls.", - '{"origin": {"lat": 37.422, "lng": -122.084}, "destination": {"lat": 37.7749, "lng": -122.4194}, "preferences": ["avoid_ferries"]}', - ], - input_modes=['application/json', 'text/plain'], - output_modes=[ - 'application/json', - 'application/vnd.geo+json', - 'text/html', - ], - security_requirements=[ - SecurityRequirement(schemes={'example': StringList()}), - SecurityRequirement( - schemes={ - 'google': StringList( - list=['openid', 'profile', 'email'] - ) - } - ), - ], - ), - AgentSkill( - id='custom-map-generator', - name='Personalized Map Generator', - description='Creates custom map images or interactive map views based on user-defined points of interest, routes, and style preferences. Can overlay data layers.', - tags=['maps', 'customization', 'visualization', 'cartography'], - examples=[ - 'Generate a map of my upcoming road trip with all planned stops highlighted.', - 'Show me a map visualizing all coffee shops within a 1-mile radius of my current location.', - ], - input_modes=['application/json'], - output_modes=[ - 'image/png', - 'image/jpeg', - 'application/json', - 'text/html', - ], - ), - ], - signatures=[ - AgentCardSignature( - protected='eyJhbGciOiJFUzI1NiIsInR5cCI6IkpPU0UiLCJraWQiOiJrZXktMSIsImprdSI6Imh0dHBzOi8vZXhhbXBsZS5jb20vYWdlbnQvandrcy5qc29uIn0', - signature='QFdkNLNszlGj3z3u0YQGt_T9LixY3qtdQpZmsTdDHDe3fXV9y9-B3m2-XgCpzuhiLt8E0tV6HXoZKHv4GtHgKQ', - ) - ], - ) - - assert card == expected_card - - # Serialize back to JSON and compare - serialized_data = agent_card_to_dict(card) - - # Remove deprecated stateTransitionHistory before diffing - del original_data['capabilities']['stateTransitionHistory'] - - _assert_agent_card_diff(original_data, serialized_data) - - # Re-parse from the serialized payload and verify identical to original parsing - re_parsed_card = parse_agent_card(copy.deepcopy(serialized_data)) - assert re_parsed_card == card - - -def test_parse_complex_030_agent_card() -> None: - data = { - 'additionalInterfaces': [ - { - 'transport': 'GRPC', - 'url': 'http://complex.agent.example.com/grpc', - }, - { - 'transport': 'JSONRPC', - 'url': 'http://complex.agent.example.com/jsonrpc', - }, - ], - 'capabilities': {'pushNotifications': True, 'streaming': True}, - 'defaultInputModes': ['text/plain', 'application/json'], - 'defaultOutputModes': ['application/json', 'image/png'], - 'description': 'A very complex agent from 0.3.0', - 'name': 'Complex Agent 0.3', - 'preferredTransport': 'HTTP+JSON', - 'protocolVersion': '0.3.0', - 'security': [ - {'test_oauth': ['read', 'write'], 'test_api_key': []}, - {'test_http': []}, - {'test_oidc': ['openid', 'profile']}, - {'test_mtls': []}, - ], - 'securitySchemes': { - 'test_oauth': { - 'description': 'OAuth2 authentication', - 'flows': { - 'authorizationCode': { - 'authorizationUrl': 'http://auth.example.com', - 'scopes': { - 'read': 'Read access', - 'write': 'Write access', - }, - 'tokenUrl': 'http://token.example.com', - } - }, - 'type': 'oauth2', - }, - 'test_api_key': { - 'description': 'API Key auth', - 'in': 'header', - 'name': 'X-API-KEY', - 'type': 'apiKey', - }, - 'test_http': { - 'bearerFormat': 'JWT', - 'description': 'HTTP Basic auth', - 'scheme': 'basic', - 'type': 'http', - }, - 'test_oidc': { - 'description': 'OIDC Auth', - 'openIdConnectUrl': 'https://example.com/.well-known/openid-configuration', - 'type': 'openIdConnect', - }, - 'test_mtls': {'description': 'mTLS Auth', 'type': 'mutualTLS'}, - }, - 'skills': [ - { - 'description': 'The first complex skill', - 'id': 'skill-1', - 'inputModes': ['application/json'], - 'name': 'Complex Skill 1', - 'outputModes': ['application/json'], - 'security': [{'test_api_key': []}], - 'tags': ['example', 'complex'], - }, - { - 'description': 'The second complex skill', - 'id': 'skill-2', - 'name': 'Complex Skill 2', - 'security': [{'test_oidc': ['openid']}], - 'tags': ['example2'], - }, - ], - 'supportsAuthenticatedExtendedCard': True, - 'url': 'http://complex.agent.example.com/api', - 'version': '1.5.2', - } - original_data = copy.deepcopy(data) - card = parse_agent_card(data) - - expected_card = AgentCard( - name='Complex Agent 0.3', - description='A very complex agent from 0.3.0', - version='1.5.2', - capabilities=AgentCapabilities( - extended_agent_card=True, streaming=True, push_notifications=True - ), - default_input_modes=['text/plain', 'application/json'], - default_output_modes=['application/json', 'image/png'], - supported_interfaces=[ - AgentInterface( - url='http://complex.agent.example.com/api', - protocol_binding='HTTP+JSON', - protocol_version='0.3.0', - ), - AgentInterface( - url='http://complex.agent.example.com/grpc', - protocol_binding='GRPC', - protocol_version='0.3.0', - ), - AgentInterface( - url='http://complex.agent.example.com/jsonrpc', - protocol_binding='JSONRPC', - protocol_version='0.3.0', - ), - ], - security_requirements=[ - SecurityRequirement( - schemes={ - 'test_oauth': StringList(list=['read', 'write']), - 'test_api_key': StringList(), - } - ), - SecurityRequirement(schemes={'test_http': StringList()}), - SecurityRequirement( - schemes={'test_oidc': StringList(list=['openid', 'profile'])} - ), - SecurityRequirement(schemes={'test_mtls': StringList()}), - ], - security_schemes={ - 'test_oauth': SecurityScheme( - oauth2_security_scheme=OAuth2SecurityScheme( - description='OAuth2 authentication', - flows=OAuthFlows( - authorization_code=AuthorizationCodeOAuthFlow( - authorization_url='http://auth.example.com', - token_url='http://token.example.com', - scopes={ - 'read': 'Read access', - 'write': 'Write access', - }, - ) - ), - ) - ), - 'test_api_key': SecurityScheme( - api_key_security_scheme=APIKeySecurityScheme( - description='API Key auth', - location='header', - name='X-API-KEY', - ) - ), - 'test_http': SecurityScheme( - http_auth_security_scheme=HTTPAuthSecurityScheme( - description='HTTP Basic auth', - scheme='basic', - bearer_format='JWT', - ) - ), - 'test_oidc': SecurityScheme( - open_id_connect_security_scheme=OpenIdConnectSecurityScheme( - description='OIDC Auth', - open_id_connect_url='https://example.com/.well-known/openid-configuration', - ) - ), - 'test_mtls': SecurityScheme( - mtls_security_scheme=MutualTlsSecurityScheme( - description='mTLS Auth' - ) - ), - }, - skills=[ - AgentSkill( - id='skill-1', - name='Complex Skill 1', - description='The first complex skill', - tags=['example', 'complex'], - input_modes=['application/json'], - output_modes=['application/json'], - security_requirements=[ - SecurityRequirement(schemes={'test_api_key': StringList()}) - ], - ), - AgentSkill( - id='skill-2', - name='Complex Skill 2', - description='The second complex skill', - tags=['example2'], - security_requirements=[ - SecurityRequirement( - schemes={'test_oidc': StringList(list=['openid'])} - ) - ], - ), - ], - ) - - assert card == expected_card - - # Serialize back to JSON and compare - serialized_data = agent_card_to_dict(card) - _assert_agent_card_diff(original_data, serialized_data) - - # Re-parse from the serialized payload and verify identical to original parsing - re_parsed_card = parse_agent_card(copy.deepcopy(serialized_data)) - assert re_parsed_card == card diff --git a/tests/client/test_service_parameters.py b/tests/client/test_service_parameters.py new file mode 100644 index 000000000..fbabd9719 --- /dev/null +++ b/tests/client/test_service_parameters.py @@ -0,0 +1,53 @@ +"""Tests for a2a.client.service_parameters module.""" + +from a2a.client.service_parameters import ( + ServiceParametersFactory, + with_a2a_extensions, +) +from a2a.extensions.common import HTTP_EXTENSION_HEADER + + +def test_with_a2a_extensions_merges_dedupes_and_sorts(): + """Repeated calls accumulate; duplicates collapse; output is sorted.""" + parameters = ServiceParametersFactory.create( + [ + with_a2a_extensions(['ext-c', 'ext-a']), + with_a2a_extensions(['ext-b', 'ext-a']), + ] + ) + + assert parameters[HTTP_EXTENSION_HEADER] == 'ext-a,ext-b,ext-c' + + +def test_with_a2a_extensions_merges_existing_header_value(): + """Pre-existing comma-separated header values are parsed and merged.""" + parameters = ServiceParametersFactory.create_from( + {HTTP_EXTENSION_HEADER: 'ext-a, ext-b'}, + [with_a2a_extensions(['ext-c'])], + ) + + assert parameters[HTTP_EXTENSION_HEADER] == 'ext-a,ext-b,ext-c' + + +def test_with_a2a_extensions_empty_is_noop(): + """An empty extensions list leaves the header untouched / absent.""" + parameters = ServiceParametersFactory.create( + [ + with_a2a_extensions(['ext-a']), + with_a2a_extensions([]), + ] + ) + + assert parameters[HTTP_EXTENSION_HEADER] == 'ext-a' + assert HTTP_EXTENSION_HEADER not in ServiceParametersFactory.create( + [with_a2a_extensions([])] + ) + + +def test_with_a2a_extensions_normalizes_input_strings(): + """Input strings are split on commas and stripped, like header values.""" + parameters = ServiceParametersFactory.create( + [with_a2a_extensions(['ext-a, ext-b', ' ext-c '])] + ) + + assert parameters[HTTP_EXTENSION_HEADER] == 'ext-a,ext-b,ext-c' diff --git a/tests/client/transports/test_grpc_client.py b/tests/client/transports/test_grpc_client.py index 9e81bd71e..95cca9189 100644 --- a/tests/client/transports/test_grpc_client.py +++ b/tests/client/transports/test_grpc_client.py @@ -35,7 +35,7 @@ TaskStatus, TaskStatusUpdateEvent, ) -from a2a.utils import get_text_parts +from a2a.helpers.proto_helpers import get_text_parts @pytest.fixture diff --git a/tests/client/transports/test_jsonrpc_client.py b/tests/client/transports/test_jsonrpc_client.py index 1339bb8af..b005c2e05 100644 --- a/tests/client/transports/test_jsonrpc_client.py +++ b/tests/client/transports/test_jsonrpc_client.py @@ -545,7 +545,7 @@ async def test_extensions_added_to_request( from a2a.client.client import ClientCallContext context = ClientCallContext( - service_parameters={'X-A2A-Extensions': 'https://example.com/ext1'} + service_parameters={'A2A-Extensions': 'https://example.com/ext1'} ) await transport.send_message(request, context=context) @@ -555,7 +555,7 @@ async def test_extensions_added_to_request( call_args = mock_httpx_client.build_request.call_args # Extensions should be in the kwargs assert ( - call_args[1].get('headers', {}).get('X-A2A-Extensions') + call_args[1].get('headers', {}).get('A2A-Extensions') == 'https://example.com/ext1' ) diff --git a/tests/client/transports/test_rest_client.py b/tests/client/transports/test_rest_client.py index e7912566e..1e9398181 100644 --- a/tests/client/transports/test_rest_client.py +++ b/tests/client/transports/test_rest_client.py @@ -8,7 +8,7 @@ from google.protobuf.timestamp_pb2 import Timestamp from httpx_sse import EventSource, ServerSentEvent -from a2a.client import create_text_message_object +from a2a.helpers.proto_helpers import new_text_message from a2a.client.client import ClientCallContext from a2a.client.errors import A2AClientError from a2a.client.transports.rest import RestTransport @@ -83,7 +83,7 @@ async def test_send_message_streaming_timeout( url='http://agent.example.com/api', ) params = SendMessageRequest( - message=create_text_message_object(content='Hello stream') + message=new_text_message(text='Hello stream') ) mock_event_source = AsyncMock(spec=EventSource) mock_event_source.response = MagicMock(spec=httpx.Response) @@ -120,9 +120,7 @@ async def test_rest_mapped_errors( agent_card=mock_agent_card, url='http://agent.example.com/api', ) - params = SendMessageRequest( - message=create_text_message_object(content='Hello') - ) + params = SendMessageRequest(message=new_text_message(text='Hello')) mock_build_request = MagicMock( return_value=AsyncMock(spec=httpx.Request) @@ -172,9 +170,7 @@ async def test_send_message_with_timeout_context( agent_card=mock_agent_card, url='http://agent.example.com/api', ) - params = SendMessageRequest( - message=create_text_message_object(content='Hello') - ) + params = SendMessageRequest(message=new_text_message(text='Hello')) context = ClientCallContext(timeout=10.0) mock_build_request = MagicMock( @@ -246,9 +242,7 @@ async def test_send_message_with_default_extensions( agent_card=mock_agent_card, url='http://agent.example.com/api', ) - params = SendMessageRequest( - message=create_text_message_object(content='Hello') - ) + params = SendMessageRequest(message=new_text_message(text='Hello')) # Mock the build_request method to capture its inputs mock_build_request = MagicMock( @@ -263,7 +257,7 @@ async def test_send_message_with_default_extensions( context = ClientCallContext( service_parameters={ - 'X-A2A-Extensions': 'https://example.com/test-ext/v1,https://example.com/test-ext/v2' + 'A2A-Extensions': 'https://example.com/test-ext/v1,https://example.com/test-ext/v2' } ) await client.send_message(request=params, context=context) @@ -287,14 +281,14 @@ async def test_send_message_streaming_with_new_extensions( mock_httpx_client: AsyncMock, mock_agent_card: MagicMock, ): - """Test X-A2A-Extensions header in send_message_streaming.""" + """Test A2A-Extensions header in send_message_streaming.""" client = RestTransport( httpx_client=mock_httpx_client, agent_card=mock_agent_card, url='http://agent.example.com/api', ) params = SendMessageRequest( - message=create_text_message_object(content='Hello stream') + message=new_text_message(text='Hello stream') ) mock_event_source = AsyncMock(spec=EventSource) @@ -309,7 +303,7 @@ async def test_send_message_streaming_with_new_extensions( context = ClientCallContext( service_parameters={ - 'X-A2A-Extensions': 'https://example.com/test-ext/v2' + 'A2A-Extensions': 'https://example.com/test-ext/v2' } ) @@ -343,7 +337,7 @@ async def test_send_message_streaming_server_error_propagates( url='http://agent.example.com/api', ) request = SendMessageRequest( - message=create_text_message_object(content='Error stream') + message=new_text_message(text='Error stream') ) mock_event_source = AsyncMock(spec=EventSource) @@ -524,7 +518,7 @@ class TestRestTransportTenant: 'send_message', SendMessageRequest( tenant='my-tenant', - message=create_text_message_object(content='hi'), + message=new_text_message(text='hi'), ), '/my-tenant/message:send', ), @@ -686,7 +680,7 @@ async def test_rest_get_task_prepend_empty_tenant( 'send_message_streaming', SendMessageRequest( tenant='my-tenant', - message=create_text_message_object(content='hi'), + message=new_text_message(text='hi'), ), '/my-tenant/message:stream', ), diff --git a/tests/compat/v0_3/test_context_builders.py b/tests/compat/v0_3/test_context_builders.py new file mode 100644 index 000000000..1b711f52f --- /dev/null +++ b/tests/compat/v0_3/test_context_builders.py @@ -0,0 +1,159 @@ +from unittest.mock import AsyncMock, MagicMock + +import grpc + +from starlette.datastructures import Headers + +from a2a.compat.v0_3.context_builders import ( + V03GrpcServerCallContextBuilder, + V03ServerCallContextBuilder, +) +from a2a.compat.v0_3.extension_headers import LEGACY_HTTP_EXTENSION_HEADER +from a2a.extensions.common import HTTP_EXTENSION_HEADER +from a2a.server.context import ServerCallContext +from a2a.server.request_handlers.grpc_handler import ( + DefaultGrpcServerCallContextBuilder, +) +from a2a.server.routes.common import DefaultServerCallContextBuilder + + +def _make_mock_request(headers=None): + request = MagicMock() + request.scope = {} + request.headers = Headers(headers or {}) + return request + + +def _make_mock_grpc_context(metadata: list[tuple[str, str]]) -> AsyncMock: + context = AsyncMock(spec=grpc.aio.ServicerContext) + context.invocation_metadata.return_value = grpc.aio.Metadata(*metadata) + return context + + +class TestV03ServerCallContextBuilder: + def test_legacy_header_only(self): + request = _make_mock_request( + headers={LEGACY_HTTP_EXTENSION_HEADER: 'legacy-ext'} + ) + builder = V03ServerCallContextBuilder(DefaultServerCallContextBuilder()) + + ctx = builder.build(request) + + assert isinstance(ctx, ServerCallContext) + assert ctx.requested_extensions == {'legacy-ext'} + + def test_spec_header_only(self): + request = _make_mock_request( + headers={HTTP_EXTENSION_HEADER: 'spec-ext'} + ) + builder = V03ServerCallContextBuilder(DefaultServerCallContextBuilder()) + + ctx = builder.build(request) + + assert ctx.requested_extensions == {'spec-ext'} + + def test_both_headers_merged(self): + request = _make_mock_request( + headers={ + HTTP_EXTENSION_HEADER: 'spec-ext', + LEGACY_HTTP_EXTENSION_HEADER: 'legacy-ext', + } + ) + builder = V03ServerCallContextBuilder(DefaultServerCallContextBuilder()) + + ctx = builder.build(request) + + assert ctx.requested_extensions == {'spec-ext', 'legacy-ext'} + + def test_legacy_header_comma_separated(self): + request = _make_mock_request( + headers={LEGACY_HTTP_EXTENSION_HEADER: 'foo, bar'} + ) + builder = V03ServerCallContextBuilder(DefaultServerCallContextBuilder()) + + ctx = builder.build(request) + + assert ctx.requested_extensions == {'foo', 'bar'} + + def test_no_extensions(self): + request = _make_mock_request() + builder = V03ServerCallContextBuilder(DefaultServerCallContextBuilder()) + + ctx = builder.build(request) + + assert ctx.requested_extensions == set() + + +class TestV03GrpcServerCallContextBuilder: + def test_legacy_metadata_only(self): + context = _make_mock_grpc_context( + [(LEGACY_HTTP_EXTENSION_HEADER.lower(), 'legacy-ext')] + ) + builder = V03GrpcServerCallContextBuilder( + DefaultGrpcServerCallContextBuilder() + ) + + ctx = builder.build(context) + + assert isinstance(ctx, ServerCallContext) + assert ctx.requested_extensions == {'legacy-ext'} + + def test_spec_metadata_only(self): + context = _make_mock_grpc_context( + [(HTTP_EXTENSION_HEADER.lower(), 'spec-ext')] + ) + builder = V03GrpcServerCallContextBuilder( + DefaultGrpcServerCallContextBuilder() + ) + + ctx = builder.build(context) + + assert ctx.requested_extensions == {'spec-ext'} + + def test_both_metadata_merged(self): + context = _make_mock_grpc_context( + [ + (HTTP_EXTENSION_HEADER.lower(), 'spec-ext'), + (LEGACY_HTTP_EXTENSION_HEADER.lower(), 'legacy-ext'), + ] + ) + builder = V03GrpcServerCallContextBuilder( + DefaultGrpcServerCallContextBuilder() + ) + + ctx = builder.build(context) + + assert ctx.requested_extensions == {'spec-ext', 'legacy-ext'} + + def test_legacy_metadata_comma_separated(self): + context = _make_mock_grpc_context( + [(LEGACY_HTTP_EXTENSION_HEADER.lower(), 'foo, bar')] + ) + builder = V03GrpcServerCallContextBuilder( + DefaultGrpcServerCallContextBuilder() + ) + + ctx = builder.build(context) + + assert ctx.requested_extensions == {'foo', 'bar'} + + def test_no_extensions(self): + context = _make_mock_grpc_context([]) + builder = V03GrpcServerCallContextBuilder( + DefaultGrpcServerCallContextBuilder() + ) + + ctx = builder.build(context) + + assert ctx.requested_extensions == set() + + def test_no_metadata(self): + context = AsyncMock(spec=grpc.aio.ServicerContext) + context.invocation_metadata.return_value = None + builder = V03GrpcServerCallContextBuilder( + DefaultGrpcServerCallContextBuilder() + ) + + ctx = builder.build(context) + + assert ctx.requested_extensions == set() diff --git a/tests/compat/v0_3/test_extension_headers.py b/tests/compat/v0_3/test_extension_headers.py new file mode 100644 index 000000000..d5abbdfcc --- /dev/null +++ b/tests/compat/v0_3/test_extension_headers.py @@ -0,0 +1,39 @@ +from a2a.compat.v0_3.extension_headers import ( + LEGACY_HTTP_EXTENSION_HEADER, + add_legacy_extension_header, +) +from a2a.extensions.common import HTTP_EXTENSION_HEADER + + +def test_legacy_header_constant_value(): + assert LEGACY_HTTP_EXTENSION_HEADER == 'X-A2A-Extensions' + + +def test_mirrors_spec_header_under_legacy_name(): + params = {HTTP_EXTENSION_HEADER: 'foo,bar'} + + add_legacy_extension_header(params) + + assert params == { + HTTP_EXTENSION_HEADER: 'foo,bar', + LEGACY_HTTP_EXTENSION_HEADER: 'foo,bar', + } + + +def test_no_op_when_spec_header_absent(): + params = {'Other': 'value'} + + add_legacy_extension_header(params) + + assert params == {'Other': 'value'} + + +def test_does_not_overwrite_existing_legacy_header(): + params = { + HTTP_EXTENSION_HEADER: 'spec', + LEGACY_HTTP_EXTENSION_HEADER: 'legacy-original', + } + + add_legacy_extension_header(params) + + assert params[LEGACY_HTTP_EXTENSION_HEADER] == 'legacy-original' diff --git a/tests/compat/v0_3/test_grpc_handler.py b/tests/compat/v0_3/test_grpc_handler.py index 9040388e2..fbd74f29f 100644 --- a/tests/compat/v0_3/test_grpc_handler.py +++ b/tests/compat/v0_3/test_grpc_handler.py @@ -7,8 +7,6 @@ a2a_v0_3_pb2, grpc_handler as compat_grpc_handler, ) -from a2a.extensions.common import HTTP_EXTENSION_HEADER -from a2a.server.context import ServerCallContext from a2a.server.request_handlers import RequestHandler from a2a.types import a2a_pb2 from a2a.utils.errors import TaskNotFoundError, InvalidParamsError @@ -37,6 +35,7 @@ def sample_agent_card() -> a2a_pb2.AgentCard: capabilities=a2a_pb2.AgentCapabilities( streaming=True, push_notifications=True, + extended_agent_card=True, ), supported_interfaces=[ a2a_pb2.AgentInterface( @@ -53,7 +52,7 @@ def handler( mock_request_handler: AsyncMock, sample_agent_card: a2a_pb2.AgentCard ) -> compat_grpc_handler.CompatGrpcHandler: return compat_grpc_handler.CompatGrpcHandler( - agent_card=sample_agent_card, request_handler=mock_request_handler + request_handler=mock_request_handler, ) @@ -437,9 +436,15 @@ async def test_list_push_config_success( @pytest.mark.asyncio async def test_get_agent_card_success( handler: compat_grpc_handler.CompatGrpcHandler, + mock_request_handler: AsyncMock, mock_grpc_context: AsyncMock, + sample_agent_card: a2a_pb2.AgentCard, ) -> None: request = a2a_v0_3_pb2.GetAgentCardRequest() + mock_request_handler.on_get_extended_agent_card.return_value = ( + sample_agent_card + ) + response = await handler.GetAgentCard(request, mock_grpc_context) expected_res = a2a_v0_3_pb2.AgentCard( @@ -448,6 +453,7 @@ async def test_get_agent_card_success( url='http://jsonrpc.v03.com', version='1.0.0', protocol_version='0.3', + supports_authenticated_extended_card=True, preferred_transport='JSONRPC', capabilities=a2a_v0_3_pb2.AgentCapabilities( streaming=True, @@ -498,21 +504,3 @@ async def test_extract_task_and_config_id_invalid( ): with pytest.raises(InvalidParamsError): handler._extract_task_and_config_id('invalid-name') - - -@pytest.mark.asyncio -async def test_handle_unary_extension_metadata( - handler: compat_grpc_handler.CompatGrpcHandler, - mock_request_handler: AsyncMock, - mock_grpc_context: AsyncMock, -) -> None: - async def mock_func(server_context: ServerCallContext): - server_context.activated_extensions.add('ext-1') - return a2a_pb2.Task() - - await handler._handle_unary(mock_grpc_context, mock_func, a2a_pb2.Task()) - - expected_metadata = [(HTTP_EXTENSION_HEADER.lower(), 'ext-1')] - mock_grpc_context.set_trailing_metadata.assert_called_once_with( - expected_metadata - ) diff --git a/tests/compat/v0_3/test_grpc_transport.py b/tests/compat/v0_3/test_grpc_transport.py index ba1e6af3d..402a57000 100644 --- a/tests/compat/v0_3/test_grpc_transport.py +++ b/tests/compat/v0_3/test_grpc_transport.py @@ -2,6 +2,7 @@ import pytest +from a2a.client.client import ClientCallContext from a2a.client.optionals import Channel from a2a.compat.v0_3 import a2a_v0_3_pb2 from a2a.compat.v0_3.grpc_transport import CompatGrpcTransport @@ -38,3 +39,30 @@ async def test_compat_grpc_transport_send_message_response_msg_parsing(): assert isinstance(response, SendMessageResponse) assert response.HasField('message') assert response.message.message_id == 'msg-123' + + +def test_compat_grpc_transport_mirrors_extension_metadata(): + """Compat gRPC client must also emit the legacy x-a2a-extensions metadata + so that v0.3 servers (which only know that name) understand the request.""" + transport = CompatGrpcTransport( + channel=AsyncMock(spec=Channel), agent_card=None + ) + context = ClientCallContext( + service_parameters={'A2A-Extensions': 'foo,bar'} + ) + + metadata = dict(transport._get_grpc_metadata(context)) + + assert metadata['a2a-extensions'] == 'foo,bar' + assert metadata['x-a2a-extensions'] == 'foo,bar' + + +def test_compat_grpc_transport_no_extension_metadata(): + transport = CompatGrpcTransport( + channel=AsyncMock(spec=Channel), agent_card=None + ) + + metadata = dict(transport._get_grpc_metadata(None)) + + assert 'a2a-extensions' not in metadata + assert 'x-a2a-extensions' not in metadata diff --git a/tests/compat/v0_3/test_jsonrpc_app_compat.py b/tests/compat/v0_3/test_jsonrpc_app_compat.py index 1417b5dac..6658097dc 100644 --- a/tests/compat/v0_3/test_jsonrpc_app_compat.py +++ b/tests/compat/v0_3/test_jsonrpc_app_compat.py @@ -46,8 +46,8 @@ def mock_handler(): @pytest.fixture -def test_app(mock_handler): - agent_card = AgentCard( +def agent_card(): + card = AgentCard( name='TestAgent', description='Test Description', version='1.0.0', @@ -55,13 +55,17 @@ def test_app(mock_handler): streaming=False, push_notifications=True, extended_agent_card=True ), ) - interface = agent_card.supported_interfaces.add() + interface = card.supported_interfaces.add() interface.url = 'http://mockurl.com' interface.protocol_binding = 'jsonrpc' interface.protocol_version = '0.3' + return card + +@pytest.fixture +def test_app(mock_handler, agent_card): + mock_handler._agent_card = agent_card jsonrpc_routes = create_jsonrpc_routes( - agent_card=agent_card, request_handler=mock_handler, enable_v0_3_compat=True, rpc_url='/', @@ -123,9 +127,10 @@ def test_get_task_v03_compat( def test_get_extended_agent_card_v03_compat( - client: TestClient, + client: TestClient, mock_handler: AsyncMock, agent_card: AgentCard ) -> None: """Test that the v0.3 method name 'agent/getAuthenticatedExtendedCard' is correctly routed.""" + mock_handler.on_get_extended_agent_card.return_value = agent_card request_payload = { 'jsonrpc': '2.0', 'id': '3', diff --git a/tests/compat/v0_3/test_jsonrpc_transport.py b/tests/compat/v0_3/test_jsonrpc_transport.py index 50b33e162..70291f005 100644 --- a/tests/compat/v0_3/test_jsonrpc_transport.py +++ b/tests/compat/v0_3/test_jsonrpc_transport.py @@ -539,3 +539,29 @@ async def test_compat_jsonrpc_transport_send_request( mock_send_http_request.assert_called_once_with( transport.httpx_client, mock_request, transport._handle_http_error ) + + +@pytest.mark.asyncio +@patch('a2a.compat.v0_3.jsonrpc_transport.send_http_request') +async def test_compat_jsonrpc_transport_mirrors_extension_header( + mock_send_http_request, transport +): + """Compat client must also emit the legacy X-A2A-Extensions header so + that v0.3 servers (which only know that name) understand the request.""" + from a2a.client.client import ClientCallContext + + mock_send_http_request.return_value = {'result': {'ok': True}} + transport.httpx_client.build_request.return_value = httpx.Request( + 'POST', 'http://example.com' + ) + + context = ClientCallContext( + service_parameters={'A2A-Extensions': 'foo,bar'} + ) + + await transport._send_request({'some': 'data'}, context=context) + + _, kwargs = transport.httpx_client.build_request.call_args + headers = kwargs['headers'] + assert headers['A2A-Extensions'] == 'foo,bar' + assert headers['X-A2A-Extensions'] == 'foo,bar' diff --git a/tests/compat/v0_3/test_request_handler.py b/tests/compat/v0_3/test_request_handler.py index 55b0d2cab..26ad74264 100644 --- a/tests/compat/v0_3/test_request_handler.py +++ b/tests/compat/v0_3/test_request_handler.py @@ -7,24 +7,15 @@ from a2a.server.context import ServerCallContext from a2a.server.request_handlers.request_handler import RequestHandler from a2a.types.a2a_pb2 import ( + AgentCapabilities, + AgentCard, + AgentInterface, ListTaskPushNotificationConfigsResponse as V10ListPushConfigsResp, -) -from a2a.types.a2a_pb2 import ( Message as V10Message, -) -from a2a.types.a2a_pb2 import ( Part as V10Part, -) -from a2a.types.a2a_pb2 import ( Task as V10Task, -) -from a2a.types.a2a_pb2 import ( TaskPushNotificationConfig as V10PushConfig, -) -from a2a.types.a2a_pb2 import ( TaskState as V10TaskState, -) -from a2a.types.a2a_pb2 import ( TaskStatus as V10TaskStatus, ) from a2a.utils.errors import TaskNotFoundError @@ -32,7 +23,16 @@ @pytest.fixture def mock_core_handler(): - return AsyncMock(spec=RequestHandler) + handler = AsyncMock(spec=RequestHandler) + + handler.agent_card = AgentCard( + capabilities=AgentCapabilities( + streaming=True, + push_notifications=True, + extended_agent_card=True, + ) + ) + return handler @pytest.fixture @@ -355,3 +355,35 @@ async def test_on_delete_task_push_notification_config( assert result is None mock_core_handler.on_delete_task_push_notification_config.assert_called_once() + + +@pytest.mark.anyio +async def test_on_get_extended_agent_card_success( + v03_handler, mock_core_handler, mock_context +): + v03_req = types_v03.GetAuthenticatedExtendedCardRequest(id=0) + + mock_core_handler.on_get_extended_agent_card.return_value = AgentCard( + name='Extended Agent', + description='An extended test agent', + version='1.0.0', + supported_interfaces=[ + AgentInterface( + url='http://jsonrpc.v03.com', + protocol_version='0.3', + ) + ], + capabilities=AgentCapabilities( + streaming=True, + push_notifications=True, + extended_agent_card=True, + ), + ) + + result = await v03_handler.on_get_extended_agent_card(v03_req, mock_context) + + assert isinstance(result, types_v03.AgentCard) + assert result.name == 'Extended Agent' + assert result.capabilities.streaming is True + assert result.capabilities.push_notifications is True + mock_core_handler.on_get_extended_agent_card.assert_called_once() diff --git a/tests/compat/v0_3/test_rest_handler.py b/tests/compat/v0_3/test_rest_handler.py index f864b7037..6ff44abb1 100644 --- a/tests/compat/v0_3/test_rest_handler.py +++ b/tests/compat/v0_3/test_rest_handler.py @@ -27,9 +27,7 @@ def agent_card(): @pytest.fixture def rest_handler(agent_card, mock_core_handler): - handler = REST03Handler( - agent_card=agent_card, request_handler=mock_core_handler - ) + handler = REST03Handler(request_handler=mock_core_handler) # Mock the internal handler03 for easier testing of translations handler.handler03 = AsyncMock() return handler @@ -363,3 +361,39 @@ async def test_list_push_notifications( async def test_list_tasks(rest_handler, mock_request, mock_context): with pytest.raises(NotImplementedError): await rest_handler.list_tasks(mock_request, mock_context) + + +# Add our new translation method test +@pytest.mark.anyio +async def test_on_get_extended_agent_card_success( + rest_handler, mock_request, mock_context +): + rest_handler.handler03.on_get_extended_agent_card.return_value = ( + types_v03.AgentCard( + name='Extended Agent', + description='An extended test agent', + version='1.0.0', + url='http://jsonrpc.v03.com', + preferred_transport='JSONRPC', + protocol_version='0.3', + default_input_modes=[], + default_output_modes=[], + skills=[], + capabilities=types_v03.AgentCapabilities( + streaming=True, + push_notifications=True, + ), + ) + ) + + result = await rest_handler.on_get_extended_agent_card( + mock_request, mock_context + ) + + # on_get_extended_agent_card returns a JSON-friendly dict via model_dump + assert isinstance(result, dict) + assert result['name'] == 'Extended Agent' + assert result['capabilities']['streaming'] is True + assert result['capabilities']['pushNotifications'] is True + + rest_handler.handler03.on_get_extended_agent_card.assert_called_once() diff --git a/tests/compat/v0_3/test_rest_routes_compat.py b/tests/compat/v0_3/test_rest_routes_compat.py index 5ee0f60ca..b3b9e70b3 100644 --- a/tests/compat/v0_3/test_rest_routes_compat.py +++ b/tests/compat/v0_3/test_rest_routes_compat.py @@ -53,8 +53,9 @@ async def app( request_handler: RequestHandler, ) -> Starlette: """Builds the Starlette application for testing.""" + request_handler._agent_card = agent_card rest_routes = create_rest_routes( - agent_card, request_handler, enable_v0_3_compat=True + request_handler=request_handler, enable_v0_3_compat=True ) agent_card_routes = create_agent_card_routes( agent_card=agent_card, card_url='/well-known/agent.json' diff --git a/tests/contrib/__init__.py b/tests/contrib/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/contrib/tasks/__init__.py b/tests/contrib/tasks/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/contrib/tasks/fake_vertex_client.py b/tests/contrib/tasks/fake_vertex_client.py deleted file mode 100644 index 86d14ede0..000000000 --- a/tests/contrib/tasks/fake_vertex_client.py +++ /dev/null @@ -1,137 +0,0 @@ -"""Fake Vertex AI Client implementations for testing.""" - -import copy - -from google.genai import errors as genai_errors -from vertexai import types as vertexai_types - - -class FakeAgentEnginesA2aTasksEventsClient: - def __init__(self, parent_client): - self.parent_client = parent_client - - async def append( - self, name: str, task_events: list[vertexai_types.TaskEvent] - ) -> None: - task = self.parent_client.tasks.get(name) - if not task: - raise genai_errors.APIError( - code=404, - response_json={ - 'error': { - 'status': 'NOT_FOUND', - 'message': 'Task not found', - } - }, - ) - - task = copy.deepcopy(task) - if ( - not hasattr(task, 'next_event_sequence_number') - or not task.next_event_sequence_number - ): - task.next_event_sequence_number = 0 - - for event in task_events: - data = event.event_data - if getattr(data, 'state_change', None): - task.state = getattr(data.state_change, 'new_state', task.state) - if getattr(data, 'metadata_change', None): - task.metadata = getattr( - data.metadata_change, 'new_metadata', task.metadata - ) - if getattr(data, 'output_change', None): - change = getattr( - data.output_change, 'task_artifact_change', None - ) - if not change: - continue - if not getattr(task, 'output', None): - task.output = vertexai_types.TaskOutput() - - current_artifacts = ( - list(task.output.artifacts) - if getattr(task.output, 'artifacts', None) - else [] - ) - - deleted_ids = getattr(change, 'deleted_artifact_ids', []) or [] - if deleted_ids: - current_artifacts = [ - a - for a in current_artifacts - if a.artifact_id not in deleted_ids - ] - - added = getattr(change, 'added_artifacts', []) or [] - if added: - current_artifacts.extend(added) - - updated = getattr(change, 'updated_artifacts', []) or [] - if updated: - updated_map = {a.artifact_id: a for a in updated} - current_artifacts = [ - updated_map.get(a.artifact_id, a) - for a in current_artifacts - ] - - try: - del task.output.artifacts[:] - task.output.artifacts.extend(current_artifacts) - except Exception: - task.output.artifacts = current_artifacts - task.next_event_sequence_number += 1 - - self.parent_client.tasks[name] = task - - -class FakeAgentEnginesA2aTasksClient: - def __init__(self): - self.tasks: dict[str, vertexai_types.A2aTask] = {} - self.events = FakeAgentEnginesA2aTasksEventsClient(self) - - async def create( - self, - name: str, - a2a_task_id: str, - config: vertexai_types.CreateAgentEngineTaskConfig, - ) -> vertexai_types.A2aTask: - full_name = f'{name}/a2aTasks/{a2a_task_id}' - task = vertexai_types.A2aTask( - name=full_name, - context_id=config.context_id, - metadata=config.metadata, - output=config.output, - state=vertexai_types.State.SUBMITTED, - ) - task.next_event_sequence_number = 1 - self.tasks[full_name] = task - return task - - async def get(self, name: str) -> vertexai_types.A2aTask: - if name not in self.tasks: - raise genai_errors.APIError( - code=404, - response_json={ - 'error': { - 'status': 'NOT_FOUND', - 'message': 'Task not found', - } - }, - ) - return copy.deepcopy(self.tasks[name]) - - -class FakeAgentEnginesClient: - def __init__(self): - self.a2a_tasks = FakeAgentEnginesA2aTasksClient() - - -class FakeAioClient: - def __init__(self): - self.agent_engines = FakeAgentEnginesClient() - - -class FakeVertexClient: - def __init__(self): - self.aio = FakeAioClient() diff --git a/tests/contrib/tasks/run_vertex_tests.sh b/tests/contrib/tasks/run_vertex_tests.sh deleted file mode 100755 index 12c0395d2..000000000 --- a/tests/contrib/tasks/run_vertex_tests.sh +++ /dev/null @@ -1,17 +0,0 @@ -#!/bin/bash -set -e - -for var in VERTEX_PROJECT VERTEX_LOCATION VERTEX_BASE_URL VERTEX_API_VERSION; do - if [ -z "${!var}" ]; then - echo "Error: Environment variable $var is undefined or empty." >&2 - exit 1 - fi -done - -PYTEST_ARGS=("$@") - -echo "Running Vertex tests..." - -cd $(git rev-parse --show-toplevel) - -uv run pytest -v "${PYTEST_ARGS[@]}" tests/contrib/tasks/test_vertex_task_store.py tests/contrib/tasks/test_vertex_task_converter.py diff --git a/tests/contrib/tasks/test_vertex_task_converter.py b/tests/contrib/tasks/test_vertex_task_converter.py deleted file mode 100644 index a060bc451..000000000 --- a/tests/contrib/tasks/test_vertex_task_converter.py +++ /dev/null @@ -1,405 +0,0 @@ -import base64 - -import pytest - - -pytest.importorskip( - 'vertexai', reason='Vertex Task Converter tests require vertexai' -) -from vertexai import types as vertexai_types -from google.genai import types as genai_types -from a2a.contrib.tasks.vertex_task_converter import ( - to_sdk_artifact, - to_sdk_part, - to_sdk_task, - to_sdk_task_state, - to_stored_artifact, - to_stored_part, - to_stored_task, - to_stored_task_state, -) -from a2a.compat.v0_3.types import ( - Artifact, - DataPart, - FilePart, - FileWithBytes, - FileWithUri, - Part, - Task, - TaskState, - TaskStatus, - TextPart, -) - - -def test_to_sdk_task_state() -> None: - assert ( - to_sdk_task_state(vertexai_types.A2aTaskState.STATE_UNSPECIFIED) - == TaskState.unknown - ) - assert ( - to_sdk_task_state(vertexai_types.A2aTaskState.SUBMITTED) - == TaskState.submitted - ) - assert ( - to_sdk_task_state(vertexai_types.A2aTaskState.WORKING) - == TaskState.working - ) - assert ( - to_sdk_task_state(vertexai_types.A2aTaskState.COMPLETED) - == TaskState.completed - ) - assert ( - to_sdk_task_state(vertexai_types.A2aTaskState.CANCELLED) - == TaskState.canceled - ) - assert ( - to_sdk_task_state(vertexai_types.A2aTaskState.FAILED) - == TaskState.failed - ) - assert ( - to_sdk_task_state(vertexai_types.A2aTaskState.REJECTED) - == TaskState.rejected - ) - assert ( - to_sdk_task_state(vertexai_types.A2aTaskState.INPUT_REQUIRED) - == TaskState.input_required - ) - assert ( - to_sdk_task_state(vertexai_types.A2aTaskState.AUTH_REQUIRED) - == TaskState.auth_required - ) - assert to_sdk_task_state(999) == TaskState.unknown # type: ignore - - -def test_to_stored_task_state() -> None: - assert ( - to_stored_task_state(TaskState.unknown) - == vertexai_types.A2aTaskState.STATE_UNSPECIFIED - ) - assert ( - to_stored_task_state(TaskState.submitted) - == vertexai_types.A2aTaskState.SUBMITTED - ) - assert ( - to_stored_task_state(TaskState.working) - == vertexai_types.A2aTaskState.WORKING - ) - assert ( - to_stored_task_state(TaskState.completed) - == vertexai_types.A2aTaskState.COMPLETED - ) - assert ( - to_stored_task_state(TaskState.canceled) - == vertexai_types.A2aTaskState.CANCELLED - ) - assert ( - to_stored_task_state(TaskState.failed) - == vertexai_types.A2aTaskState.FAILED - ) - assert ( - to_stored_task_state(TaskState.rejected) - == vertexai_types.A2aTaskState.REJECTED - ) - assert ( - to_stored_task_state(TaskState.input_required) - == vertexai_types.A2aTaskState.INPUT_REQUIRED - ) - assert ( - to_stored_task_state(TaskState.auth_required) - == vertexai_types.A2aTaskState.AUTH_REQUIRED - ) - - -def test_to_stored_part_text() -> None: - sdk_part = Part(root=TextPart(text='hello world')) - stored_part = to_stored_part(sdk_part) - assert stored_part.text == 'hello world' - assert not stored_part.inline_data - assert not stored_part.file_data - - -def test_to_stored_part_data() -> None: - sdk_part = Part(root=DataPart(data={'key': 'value'})) - stored_part = to_stored_part(sdk_part) - assert stored_part.inline_data is not None - assert stored_part.inline_data.mime_type == 'application/json' - assert stored_part.inline_data.data == b'{"key": "value"}' - - -def test_to_stored_part_file_bytes() -> None: - encoded_b64 = base64.b64encode(b'test data').decode('utf-8') - sdk_part = Part( - root=FilePart( - file=FileWithBytes( - bytes=encoded_b64, - mime_type='text/plain', - ) - ) - ) - stored_part = to_stored_part(sdk_part) - assert stored_part.inline_data is not None - assert stored_part.inline_data.mime_type == 'text/plain' - assert stored_part.inline_data.data == b'test data' - - -def test_to_stored_part_file_uri() -> None: - sdk_part = Part( - root=FilePart( - file=FileWithUri( - uri='gs://test-bucket/file.txt', - mime_type='text/plain', - ) - ) - ) - stored_part = to_stored_part(sdk_part) - assert stored_part.file_data is not None - assert stored_part.file_data.mime_type == 'text/plain' - assert stored_part.file_data.file_uri == 'gs://test-bucket/file.txt' - - -def test_to_stored_part_unsupported() -> None: - class BadPart: - pass - - part = Part(root=TextPart(text='t')) - part.root = BadPart() # type: ignore - with pytest.raises(ValueError, match='Unsupported part type'): - to_stored_part(part) - - -def test_to_sdk_part_text() -> None: - stored_part = genai_types.Part(text='hello back') - sdk_part = to_sdk_part(stored_part) - assert isinstance(sdk_part.root, TextPart) - assert sdk_part.root.text == 'hello back' - - -def test_to_sdk_part_inline_data() -> None: - stored_part = genai_types.Part( - inline_data=genai_types.Blob( - mime_type='application/json', - data=b'{"key": "val"}', - ) - ) - sdk_part = to_sdk_part(stored_part) - assert isinstance(sdk_part.root, FilePart) - assert isinstance(sdk_part.root.file, FileWithBytes) - expected_b64 = base64.b64encode(b'{"key": "val"}').decode('utf-8') - assert sdk_part.root.file.mime_type == 'application/json' - assert sdk_part.root.file.bytes == expected_b64 - - -def test_to_sdk_part_file_data() -> None: - stored_part = genai_types.Part( - file_data=genai_types.FileData( - mime_type='image/jpeg', - file_uri='gs://bucket/image.jpg', - ) - ) - sdk_part = to_sdk_part(stored_part) - assert isinstance(sdk_part.root, FilePart) - assert isinstance(sdk_part.root.file, FileWithUri) - assert sdk_part.root.file.mime_type == 'image/jpeg' - assert sdk_part.root.file.uri == 'gs://bucket/image.jpg' - - -def test_to_sdk_part_unsupported() -> None: - stored_part = genai_types.Part() - with pytest.raises(ValueError, match='Unsupported part:'): - to_sdk_part(stored_part) - - -def test_to_stored_artifact() -> None: - sdk_artifact = Artifact( - artifact_id='art-123', - parts=[Part(root=TextPart(text='part_1'))], - ) - stored_artifact = to_stored_artifact(sdk_artifact) - assert stored_artifact.artifact_id == 'art-123' - assert len(stored_artifact.parts) == 1 - assert stored_artifact.parts[0].text == 'part_1' - - -def test_to_sdk_artifact() -> None: - stored_artifact = vertexai_types.TaskArtifact( - artifact_id='art-456', - parts=[genai_types.Part(text='part_2')], - ) - sdk_artifact = to_sdk_artifact(stored_artifact) - assert sdk_artifact.artifact_id == 'art-456' - assert len(sdk_artifact.parts) == 1 - assert isinstance(sdk_artifact.parts[0].root, TextPart) - assert sdk_artifact.parts[0].root.text == 'part_2' - - -def test_to_stored_task() -> None: - sdk_task = Task( - id='task-1', - context_id='ctx-1', - status=TaskStatus(state=TaskState.working), - metadata={'foo': 'bar'}, - artifacts=[ - Artifact( - artifact_id='art-1', - parts=[Part(root=TextPart(text='stuff'))], - ) - ], - history=[], - ) - stored_task = to_stored_task(sdk_task) - assert stored_task.context_id == 'ctx-1' - assert stored_task.metadata == {'foo': 'bar'} - assert stored_task.state == vertexai_types.A2aTaskState.WORKING - assert stored_task.output is not None - assert stored_task.output.artifacts is not None - assert len(stored_task.output.artifacts) == 1 - assert stored_task.output.artifacts[0].artifact_id == 'art-1' - - -def test_to_sdk_task() -> None: - stored_task = vertexai_types.A2aTask( - name='projects/123/locations/us-central1/agentEngines/456/tasks/task-2', - context_id='ctx-2', - state=vertexai_types.A2aTaskState.COMPLETED, - metadata={'a': 'b'}, - output=vertexai_types.TaskOutput( - artifacts=[ - vertexai_types.TaskArtifact( - artifact_id='art-2', - parts=[genai_types.Part(text='result')], - ) - ] - ), - ) - sdk_task = to_sdk_task(stored_task) - assert sdk_task.id == 'task-2' - assert sdk_task.context_id == 'ctx-2' - assert sdk_task.status.state == TaskState.completed - assert sdk_task.metadata == {'a': 'b'} - assert sdk_task.history == [] - assert sdk_task.artifacts is not None - assert len(sdk_task.artifacts) == 1 - assert sdk_task.artifacts[0].artifact_id == 'art-2' - assert isinstance(sdk_task.artifacts[0].parts[0].root, TextPart) - assert sdk_task.artifacts[0].parts[0].root.text == 'result' - - -def test_to_sdk_task_no_output() -> None: - stored_task = vertexai_types.A2aTask( - name='tasks/task-3', - context_id='ctx-3', - state=vertexai_types.A2aTaskState.SUBMITTED, - metadata=None, - ) - sdk_task = to_sdk_task(stored_task) - assert sdk_task.id == 'task-3' - assert sdk_task.metadata == {} - assert sdk_task.artifacts == [] - - -def test_sdk_task_state_conversion_round_trip() -> None: - for state in TaskState: - stored_state = to_stored_task_state(state) - round_trip_state = to_sdk_task_state(stored_state) - assert round_trip_state == state - - -def test_sdk_part_text_conversion_round_trip() -> None: - sdk_part = Part(root=TextPart(text='hello world')) - stored_part = to_stored_part(sdk_part) - round_trip_sdk_part = to_sdk_part(stored_part) - assert round_trip_sdk_part == sdk_part - - -def test_sdk_part_data_conversion_round_trip() -> None: - # A DataPart is converted to `inline_data` in Vertex AI, which lacks the original - # `DataPart` vs `FilePart` distinction. When reading it back from the stored - # protocol format, it becomes a `FilePart` with base64-encoded `FileWithBytes` - # and `mime_type="application/json"`. - sdk_part = Part(root=DataPart(data={'key': 'value'})) - stored_part = to_stored_part(sdk_part) - round_trip_sdk_part = to_sdk_part(stored_part) - - expected_b64 = base64.b64encode(b'{"key": "value"}').decode('utf-8') - assert round_trip_sdk_part == Part( - root=FilePart( - file=FileWithBytes( - bytes=expected_b64, - mime_type='application/json', - ) - ) - ) - - -def test_sdk_part_file_bytes_conversion_round_trip() -> None: - encoded_b64 = base64.b64encode(b'test data').decode('utf-8') - sdk_part = Part( - root=FilePart( - file=FileWithBytes( - bytes=encoded_b64, - mime_type='text/plain', - ) - ) - ) - stored_part = to_stored_part(sdk_part) - round_trip_sdk_part = to_sdk_part(stored_part) - assert round_trip_sdk_part == sdk_part - - -def test_sdk_part_file_uri_conversion_round_trip() -> None: - sdk_part = Part( - root=FilePart( - file=FileWithUri( - uri='gs://test-bucket/file.txt', - mime_type='text/plain', - ) - ) - ) - stored_part = to_stored_part(sdk_part) - round_trip_sdk_part = to_sdk_part(stored_part) - assert round_trip_sdk_part == sdk_part - - -def test_sdk_artifact_conversion_round_trip() -> None: - sdk_artifact = Artifact( - artifact_id='art-123', - parts=[Part(root=TextPart(text='part_1'))], - ) - stored_artifact = to_stored_artifact(sdk_artifact) - round_trip_sdk_artifact = to_sdk_artifact(stored_artifact) - assert round_trip_sdk_artifact == sdk_artifact - - -def test_sdk_task_conversion_round_trip() -> None: - sdk_task = Task( - id='task-1', - context_id='ctx-1', - status=TaskStatus(state=TaskState.working), - metadata={'foo': 'bar'}, - artifacts=[ - Artifact( - artifact_id='art-1', - parts=[Part(root=TextPart(text='stuff'))], - ) - ], - history=[ - # History is not yet implemented and later will be supported - # via events. - ], - ) - stored_task = to_stored_task(sdk_task) - # Simulate Vertex storing the ID in the fully qualified resource name. - # The task ID during creation gets appended to the parent name. - stored_task.name = ( - f'projects/p/locations/l/agentEngines/e/tasks/{sdk_task.id}' - ) - - round_trip_sdk_task = to_sdk_task(stored_task) - - assert round_trip_sdk_task.id == sdk_task.id - assert round_trip_sdk_task.context_id == sdk_task.context_id - assert round_trip_sdk_task.status == sdk_task.status - assert round_trip_sdk_task.metadata == sdk_task.metadata - assert round_trip_sdk_task.artifacts == sdk_task.artifacts - assert round_trip_sdk_task.history == [] diff --git a/tests/contrib/tasks/test_vertex_task_store.py b/tests/contrib/tasks/test_vertex_task_store.py deleted file mode 100644 index 75e3bdf08..000000000 --- a/tests/contrib/tasks/test_vertex_task_store.py +++ /dev/null @@ -1,532 +0,0 @@ -""" -Tests for the VertexTaskStore. - -These tests can be run with a real or fake Vertex AI Agent Engine as a backend. -The real ones are skipped by default unless the necessary environment variables\ -are set, which prevents them from failing in GitHub Actions. - -To run these tests locally, you can use the provided script: - ./run_vertex_tests.sh - -The following environment variables are required for the real backend: - VERTEX_PROJECT="your-project" \ - VERTEX_LOCATION="your-location" \ - VERTEX_BASE_URL="your-base-url" \ - VERTEX_API_VERSION="your-api-version" \ -""" - -import os - -from collections.abc import AsyncGenerator - -import pytest -import pytest_asyncio - -from .fake_vertex_client import FakeVertexClient - - -# Skip the entire test module if vertexai is not installed -pytest.importorskip( - 'vertexai', reason='Vertex Task Store tests require vertexai' -) -import vertexai - - -# Skip the real backend tests if required environment variables are not set -missing_env_vars = not all( - os.environ.get(var) - for var in [ - 'VERTEX_PROJECT', - 'VERTEX_LOCATION', - 'VERTEX_BASE_URL', - 'VERTEX_API_VERSION', - ] -) - - -@pytest.fixture( - scope='module', - params=[ - 'fake', - pytest.param( - 'real', - marks=pytest.mark.skipif( - missing_env_vars, - reason='Missing required environment variables for real Vertex Task Store.', - ), - ), - ], -) -def backend_type(request) -> str: - return request.param - - -from a2a.contrib.tasks.vertex_task_store import VertexTaskStore -from a2a.server.context import ServerCallContext -from a2a.types.a2a_pb2 import ( - Artifact, - Part, - Task, - TaskState, - TaskStatus, -) - - -# Minimal Task object for testing -MINIMAL_TASK_OBJ = Task( - id='task-abc', - context_id='session-xyz', - status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED), -) -MINIMAL_TASK_OBJ.metadata['test_key'] = 'test_value' - - -from collections.abc import Generator - - -@pytest.fixture(scope='module') -def agent_engine_resource_id(backend_type: str) -> Generator[str, None, None]: - """ - Module-scoped fixture that creates and deletes a single Agent Engine - for all the tests. For fake backend, it yields a mock resource. - """ - if backend_type == 'fake': - yield 'projects/mock-project/locations/mock-location/agentEngines/mock-engine' - return - - project = os.environ.get('VERTEX_PROJECT') - location = os.environ.get('VERTEX_LOCATION') - base_url = os.environ.get('VERTEX_BASE_URL') - - client = vertexai.Client(project=project, location=location) - client._api_client._http_options.base_url = base_url - - agent_engine = client.agent_engines.create() - yield agent_engine.api_resource.name - agent_engine.delete() - - -@pytest_asyncio.fixture -async def vertex_store( - backend_type: str, - agent_engine_resource_id: str, -) -> AsyncGenerator[VertexTaskStore, None]: - """ - Function-scoped fixture providing a fresh VertexTaskStore per test, - reusing the module-scoped engine. Uses fake client for 'fake' backend. - """ - if backend_type == 'fake': - client = FakeVertexClient() - else: - project = os.environ.get('VERTEX_PROJECT') - location = os.environ.get('VERTEX_LOCATION') - base_url = os.environ.get('VERTEX_BASE_URL') - api_version = os.environ.get('VERTEX_API_VERSION') - - client = vertexai.Client(project=project, location=location) - client._api_client._http_options.base_url = base_url - client._api_client._http_options.api_version = api_version - - store = VertexTaskStore( - client=client, # type: ignore - agent_engine_resource_id=agent_engine_resource_id, - ) - yield store - - -@pytest.mark.asyncio -async def test_save_task(vertex_store: VertexTaskStore) -> None: - """Test saving a task to the VertexTaskStore.""" - # Ensure unique ID for parameterized tests if needed, or rely on table isolation - task_to_save = Task() - task_to_save.CopyFrom(MINIMAL_TASK_OBJ) - task_to_save.id = 'save-test-task-2' - await vertex_store.save(task_to_save, ServerCallContext()) - - retrieved_task = await vertex_store.get( - task_to_save.id, ServerCallContext() - ) - assert retrieved_task is not None - assert retrieved_task.id == task_to_save.id - - assert retrieved_task == task_to_save - - -@pytest.mark.asyncio -async def test_get_task(vertex_store: VertexTaskStore) -> None: - """Test retrieving a task from the VertexTaskStore.""" - task_id = 'get-test-task-1' - task_to_save = Task() - task_to_save.CopyFrom(MINIMAL_TASK_OBJ) - task_to_save.id = task_id - await vertex_store.save(task_to_save, ServerCallContext()) - - retrieved_task = await vertex_store.get( - task_to_save.id, ServerCallContext() - ) - assert retrieved_task is not None - assert retrieved_task.id == task_to_save.id - assert retrieved_task.context_id == task_to_save.context_id - assert retrieved_task.status.state == TaskState.TASK_STATE_SUBMITTED - - -@pytest.mark.asyncio -async def test_get_nonexistent_task( - vertex_store: VertexTaskStore, -) -> None: - """Test retrieving a nonexistent task.""" - retrieved_task = await vertex_store.get( - 'nonexistent-task-id', ServerCallContext() - ) - assert retrieved_task is None - - -@pytest.mark.asyncio -async def test_save_and_get_detailed_task( - vertex_store: VertexTaskStore, -) -> None: - """Test saving and retrieving a task with more fields populated.""" - task_id = 'detailed-task-test-vertex' - test_task = Task( - id=task_id, - context_id='test-session-1', - status=TaskStatus( - state=TaskState.TASK_STATE_SUBMITTED, - ), - artifacts=[ - Artifact( - artifact_id='artifact-1', - parts=[Part(text='hello')], - ) - ], - ) - test_task.metadata['key1'] = 'value1' - test_task.metadata['key2'] = 123 - - await vertex_store.save(test_task, ServerCallContext()) - retrieved_task = await vertex_store.get(test_task.id, ServerCallContext()) - - assert retrieved_task is not None - assert retrieved_task.id == test_task.id - assert retrieved_task.context_id == test_task.context_id - assert retrieved_task.status.state == TaskState.TASK_STATE_SUBMITTED - assert retrieved_task.metadata['key1'] == 'value1' - assert retrieved_task.metadata['key2'] == 123 - assert retrieved_task.artifacts == test_task.artifacts - - -@pytest.mark.asyncio -async def test_update_task_status_and_metadata( - vertex_store: VertexTaskStore, -) -> None: - """Test updating an existing task.""" - task_id = 'update-test-task-1' - original_task = Task( - id=task_id, - context_id='session-update', - status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED), - artifacts=[], - history=[], - ) - await vertex_store.save(original_task, ServerCallContext()) - - retrieved_before_update = await vertex_store.get( - task_id, ServerCallContext() - ) - assert retrieved_before_update is not None - assert ( - retrieved_before_update.status.state == TaskState.TASK_STATE_SUBMITTED - ) - assert retrieved_before_update.metadata == {} - - updated_task = Task() - updated_task.CopyFrom(original_task) - updated_task.status.state = TaskState.TASK_STATE_COMPLETED - updated_task.status.timestamp.FromJsonString('2023-01-02T11:00:00Z') - updated_task.metadata.update({'update_key': 'update_value'}) - - await vertex_store.save(updated_task, ServerCallContext()) - - retrieved_after_update = await vertex_store.get( - task_id, ServerCallContext() - ) - assert retrieved_after_update is not None - assert retrieved_after_update.status.state == TaskState.TASK_STATE_COMPLETED - assert retrieved_after_update.metadata == {'update_key': 'update_value'} - - -@pytest.mark.asyncio -async def test_update_task_add_artifact(vertex_store: VertexTaskStore) -> None: - """Test updating an existing task by adding an artifact.""" - task_id = 'update-test-task-2' - original_task = Task( - id=task_id, - context_id='session-update', - status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED), - artifacts=[ - Artifact( - artifact_id='artifact-1', - parts=[Part(text='hello')], - ) - ], - history=[], - ) - await vertex_store.save(original_task, ServerCallContext()) - - retrieved_before_update = await vertex_store.get( - task_id, ServerCallContext() - ) - assert retrieved_before_update is not None - assert ( - retrieved_before_update.status.state == TaskState.TASK_STATE_SUBMITTED - ) - assert retrieved_before_update.metadata == {} - - updated_task = Task() - updated_task.CopyFrom(original_task) - updated_task.status.state = TaskState.TASK_STATE_WORKING - updated_task.status.timestamp.FromJsonString('2023-01-02T11:00:00Z') - - updated_task.artifacts.append( - Artifact( - artifact_id='artifact-2', - parts=[Part(text='world')], - ) - ) - - await vertex_store.save(updated_task, ServerCallContext()) - - retrieved_after_update = await vertex_store.get( - task_id, ServerCallContext() - ) - assert retrieved_after_update is not None - assert retrieved_after_update.status.state == TaskState.TASK_STATE_WORKING - - assert retrieved_after_update.artifacts == [ - Artifact( - artifact_id='artifact-1', - parts=[Part(text='hello')], - ), - Artifact( - artifact_id='artifact-2', - parts=[Part(text='world')], - ), - ] - - -@pytest.mark.asyncio -async def test_update_task_update_artifact( - vertex_store: VertexTaskStore, -) -> None: - """Test updating an existing task by changing an artifact.""" - task_id = 'update-test-task-3' - original_task = Task( - id=task_id, - context_id='session-update', - status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED), - artifacts=[ - Artifact( - artifact_id='artifact-1', - parts=[Part(text='hello')], - ), - Artifact( - artifact_id='artifact-2', - parts=[Part(text='world')], - ), - ], - history=[], - ) - await vertex_store.save(original_task, ServerCallContext()) - - retrieved_before_update = await vertex_store.get( - task_id, ServerCallContext() - ) - assert retrieved_before_update is not None - assert ( - retrieved_before_update.status.state == TaskState.TASK_STATE_SUBMITTED - ) - assert retrieved_before_update.metadata == {} - - updated_task = Task() - updated_task.CopyFrom(original_task) - updated_task.status.state = TaskState.TASK_STATE_WORKING - updated_task.status.timestamp.FromJsonString('2023-01-02T11:00:00Z') - - updated_task.artifacts[0].parts[0].text = 'ahoy' - - await vertex_store.save(updated_task, ServerCallContext()) - - retrieved_after_update = await vertex_store.get( - task_id, ServerCallContext() - ) - assert retrieved_after_update is not None - assert retrieved_after_update.status.state == TaskState.TASK_STATE_WORKING - - assert retrieved_after_update.artifacts == [ - Artifact( - artifact_id='artifact-1', - parts=[Part(text='ahoy')], - ), - Artifact( - artifact_id='artifact-2', - parts=[Part(text='world')], - ), - ] - - -@pytest.mark.asyncio -async def test_update_task_delete_artifact( - vertex_store: VertexTaskStore, -) -> None: - """Test updating an existing task by deleting an artifact.""" - task_id = 'update-test-task-4' - original_task = Task( - id=task_id, - context_id='session-update', - status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED), - artifacts=[ - Artifact( - artifact_id='artifact-1', - parts=[Part(text='hello')], - ), - Artifact( - artifact_id='artifact-2', - parts=[Part(text='world')], - ), - ], - history=[], - ) - await vertex_store.save(original_task, ServerCallContext()) - - retrieved_before_update = await vertex_store.get( - task_id, ServerCallContext() - ) - assert retrieved_before_update is not None - assert ( - retrieved_before_update.status.state == TaskState.TASK_STATE_SUBMITTED - ) - assert retrieved_before_update.metadata == {} - - updated_task = Task() - updated_task.CopyFrom(original_task) - updated_task.status.state = TaskState.TASK_STATE_WORKING - updated_task.status.timestamp.FromJsonString('2023-01-02T11:00:00Z') - - del updated_task.artifacts[1] - - await vertex_store.save(updated_task, ServerCallContext()) - - retrieved_after_update = await vertex_store.get( - task_id, ServerCallContext() - ) - assert retrieved_after_update is not None - assert retrieved_after_update.status.state == TaskState.TASK_STATE_WORKING - - assert retrieved_after_update.artifacts == [ - Artifact( - artifact_id='artifact-1', - parts=[Part(text='hello')], - ) - ] - - -@pytest.mark.asyncio -async def test_metadata_field_mapping( - vertex_store: VertexTaskStore, -) -> None: - """Test that metadata field is correctly mapped between the core types and vertex. - - This test verifies: - 1. Metadata can be None - 2. Metadata can be a simple dict - 3. Metadata can contain nested structures - 4. Metadata is correctly saved and retrieved - 5. The mapping between task.metadata and task_metadata column works - """ - # Test 1: Task with no metadata (None) - task_no_metadata = Task( - id='task-metadata-test-1', - context_id='session-meta-1', - status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED), - ) - await vertex_store.save(task_no_metadata, ServerCallContext()) - retrieved_no_metadata = await vertex_store.get( - 'task-metadata-test-1', ServerCallContext() - ) - assert retrieved_no_metadata is not None - assert retrieved_no_metadata.metadata == {} - - # Test 2: Task with simple metadata - simple_metadata = {'key': 'value', 'number': 42, 'boolean': True} - task_simple_metadata = Task( - id='task-metadata-test-2', - context_id='session-meta-2', - status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED), - metadata=simple_metadata, - ) - await vertex_store.save(task_simple_metadata, ServerCallContext()) - retrieved_simple = await vertex_store.get( - 'task-metadata-test-2', ServerCallContext() - ) - assert retrieved_simple is not None - assert retrieved_simple.metadata == simple_metadata - - # Test 3: Task with complex nested metadata - complex_metadata = { - 'level1': { - 'level2': { - 'level3': ['a', 'b', 'c'], - 'numeric': 3.14159, - }, - 'array': [1, 2, {'nested': 'value'}], - }, - 'special_chars': 'Hello\nWorld\t!', - 'unicode': '🚀 Unicode test 你好', - 'null_value': None, - } - task_complex_metadata = Task( - id='task-metadata-test-3', - context_id='session-meta-3', - status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED), - metadata=complex_metadata, - ) - await vertex_store.save(task_complex_metadata, ServerCallContext()) - retrieved_complex = await vertex_store.get( - 'task-metadata-test-3', ServerCallContext() - ) - assert retrieved_complex is not None - assert retrieved_complex.metadata == complex_metadata - - # Test 4: Update metadata from None to dict - task_update_metadata = Task( - id='task-metadata-test-4', - context_id='session-meta-4', - status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED), - ) - await vertex_store.save(task_update_metadata, ServerCallContext()) - - # Update metadata - task_update_metadata.metadata.Clear() - task_update_metadata.metadata.update( - {'updated': True, 'timestamp': '2024-01-01'} - ) - await vertex_store.save(task_update_metadata, ServerCallContext()) - - retrieved_updated = await vertex_store.get( - 'task-metadata-test-4', ServerCallContext() - ) - assert retrieved_updated is not None - assert retrieved_updated.metadata == { - 'updated': True, - 'timestamp': '2024-01-01', - } - - # Test 5: Update metadata from dict to None - task_update_metadata.metadata.Clear() - await vertex_store.save(task_update_metadata, ServerCallContext()) - - retrieved_none = await vertex_store.get( - 'task-metadata-test-4', ServerCallContext() - ) - assert retrieved_none is not None - assert retrieved_none.metadata == {} diff --git a/tests/e2e/push_notifications/agent_app.py b/tests/e2e/push_notifications/agent_app.py index 94ccae03a..9bb3a02fa 100644 --- a/tests/e2e/push_notifications/agent_app.py +++ b/tests/e2e/push_notifications/agent_app.py @@ -1,14 +1,17 @@ import httpx from fastapi import FastAPI +from starlette.applications import Starlette +from starlette.requests import Request +from a2a.auth.user import UnauthenticatedUser, User from a2a.server.agent_execution import AgentExecutor, RequestContext from a2a.server.context import ServerCallContext from a2a.server.events import EventQueue -from starlette.applications import Starlette -from a2a.server.routes.rest_routes import create_rest_routes -from a2a.server.routes import create_agent_card_routes from a2a.server.request_handlers import DefaultRequestHandler +from a2a.server.routes import create_agent_card_routes +from a2a.server.routes.common import DefaultServerCallContextBuilder +from a2a.server.routes.rest_routes import create_rest_routes from a2a.server.tasks import ( BasePushNotificationSender, InMemoryPushNotificationConfigStore, @@ -24,12 +27,15 @@ Message, Task, ) -from a2a.utils import ( - new_agent_text_message, - new_task, +from a2a.helpers.proto_helpers import ( + new_text_message, + new_task_from_user_message, ) +_TEST_USER_HEADER = 'x-test-user' + + def test_agent_card(url: str) -> AgentCard: """Returns an agent card for the test agent.""" return AgentCard( @@ -74,7 +80,7 @@ async def invoke( or not msg.parts[0].HasField('text') ): await updater.failed( - new_agent_text_message( + new_text_message( 'Unsupported message.', task.context_id, task.id ) ) @@ -84,25 +90,23 @@ async def invoke( # Simple request-response flow. if text_message == 'Hello Agent!': await updater.complete( - new_agent_text_message('Hello User!', task.context_id, task.id) + new_text_message('Hello User!', task.context_id, task.id) ) # Flow with user input required: "How are you?" -> "Good! How are you?" -> "Good" -> "Amazing". elif text_message == 'How are you?': await updater.requires_input( - new_agent_text_message( - 'Good! How are you?', task.context_id, task.id - ) + new_text_message('Good! How are you?', task.context_id, task.id) ) elif text_message == 'Good': await updater.complete( - new_agent_text_message('Amazing', task.context_id, task.id) + new_text_message('Amazing', task.context_id, task.id) ) # Fail for unsupported messages. else: await updater.failed( - new_agent_text_message( + new_text_message( 'Unsupported message.', task.context_id, task.id ) ) @@ -124,7 +128,7 @@ async def execute( task = context.current_task if not task: - task = new_task(context.message) + task = new_task_from_user_message(context.message) await event_queue.enqueue_event(task) updater = TaskUpdater(event_queue, task.id, task.context_id) @@ -142,17 +146,95 @@ def create_agent_app( """Creates a new HTTP+REST Starlette application for the test agent.""" push_config_store = InMemoryPushNotificationConfigStore() card = test_agent_card(url) + extended_card = test_agent_card(url) + extended_card.name = 'Test Agent Extended' handler = DefaultRequestHandler( agent_executor=TestAgentExecutor(), task_store=InMemoryTaskStore(), + agent_card=card, + extended_agent_card=extended_card, push_config_store=push_config_store, push_sender=BasePushNotificationSender( httpx_client=notification_client, config_store=push_config_store, - context=ServerCallContext(), ), ) - rest_routes = create_rest_routes(agent_card=card, request_handler=handler) + rest_routes = create_rest_routes(request_handler=handler) + agent_card_routes = create_agent_card_routes( + agent_card=card, card_url='/.well-known/agent-card.json' + ) + return Starlette(routes=[*rest_routes, *agent_card_routes]) + + +class _NamedTestUser(User): + """Authenticated test user identified by ``user_name``.""" + + def __init__(self, user_name: str) -> None: + self._user_name = user_name + + @property + def is_authenticated(self) -> bool: + return True + + @property + def user_name(self) -> str: + return self._user_name + + +class _HeaderUserContextBuilder(DefaultServerCallContextBuilder): + """Builds a ServerCallContext whose user is read from a request header.""" + + def build_user(self, request: Request) -> User: + user_name = request.headers.get(_TEST_USER_HEADER) + if user_name: + return _NamedTestUser(user_name) + return UnauthenticatedUser() + + +def create_multi_user_agent_app( + url: str, notification_client: httpx.AsyncClient +) -> Starlette: + """Creates a multi-user variant of the test agent app. + + Differences from create_agent_app: + + - Identity is read from the x-test-user header on each request + via _HeaderUserContextBuilder. Multiple authenticated + users (e.g. alice, bob) can therefore call the same + server. + - The InMemoryTaskStore uses a constant owner resolver, so + every authenticated user has access to every task. + - The InMemoryPushNotificationConfigStore keeps the default + per-user owner resolver, so each registrar's configs live in their + own owner partition; this exercises cross-owner aggregation in + get_info_for_dispatch. + """ + # Shared task visibility: any authenticated user can see any task. + task_store = InMemoryTaskStore(owner_resolver=lambda _ctx: 'shared') + + # Per-user push-config partitioning (the default). + push_config_store = InMemoryPushNotificationConfigStore() + + card = test_agent_card(url) + extended_card = test_agent_card(url) + extended_card.name = 'Test Agent Extended' + + handler = DefaultRequestHandler( + agent_executor=TestAgentExecutor(), + task_store=task_store, + agent_card=card, + extended_agent_card=extended_card, + push_config_store=push_config_store, + push_sender=BasePushNotificationSender( + httpx_client=notification_client, + config_store=push_config_store, + ), + ) + + rest_routes = create_rest_routes( + request_handler=handler, + context_builder=_HeaderUserContextBuilder(), + ) agent_card_routes = create_agent_card_routes( agent_card=card, card_url='/.well-known/agent-card.json' ) diff --git a/tests/e2e/push_notifications/test_default_push_notification_support.py b/tests/e2e/push_notifications/test_default_push_notification_support.py index 053707d62..84fd14c9a 100644 --- a/tests/e2e/push_notifications/test_default_push_notification_support.py +++ b/tests/e2e/push_notifications/test_default_push_notification_support.py @@ -6,7 +6,7 @@ import pytest import pytest_asyncio -from .agent_app import create_agent_app +from .agent_app import create_agent_app, create_multi_user_agent_app from .notifications_app import Notification, create_notifications_app from .utils import ( create_app_process, @@ -21,9 +21,9 @@ ) from a2a.utils.constants import TransportProtocol from a2a.types.a2a_pb2 import ( + ListTaskPushNotificationConfigsRequest, Message, Part, - TaskPushNotificationConfig, Role, SendMessageConfiguration, SendMessageRequest, @@ -33,6 +33,9 @@ ) +_TEST_USER_HEADER = 'x-test-user' + + @pytest.fixture(scope='module') def notifications_server(): """ @@ -75,7 +78,43 @@ def agent_server(notifications_client: httpx.AsyncClient): ) process.start() try: - wait_for_server_ready(f'{url}/extendedAgentCard') + wait_for_server_ready( + f'{url}/extendedAgentCard', headers={'A2A-Version': '1.0'} + ) + except TimeoutError as e: + process.terminate() + raise e + + yield url + + process.terminate() + process.join() + + +@pytest.fixture(scope='module') +def multi_user_agent_server(notifications_client: httpx.AsyncClient): + """Starts the multi-user variant of the test agent server. + + This variant reads identity from an x-test-user request header + and uses a TaskStore whose owner resolver returns a constant, so + every authenticated user can see every task. It runs on its own + port alongside the single-user agent_server fixture; the + notifications_server is shared (notifications include the + task_id and per-config token, so collisions are avoided). + """ + host = '127.0.0.1' + port = find_free_port() + url = f'http://{host}:{port}' + + process = create_app_process( + create_multi_user_agent_app(url, notifications_client), host, port + ) + process.start() + try: + wait_for_server_ready( + f'{url}/extendedAgentCard', + headers={'A2A-Version': '1.0', _TEST_USER_HEADER: 'health-check'}, + ) except TimeoutError as e: process.terminate() raise e @@ -107,13 +146,11 @@ async def test_notification_triggering_with_in_message_config_e2e( a2a_client = ClientFactory( ClientConfig( supported_protocol_bindings=[TransportProtocol.HTTP_JSON], - push_notification_configs=[ - TaskPushNotificationConfig( - id='in-message-config', - url=f'{notifications_server}/notifications', - token=token, - ) - ], + push_notification_config=TaskPushNotificationConfig( + id='in-message-config', + url=f'{notifications_server}/notifications', + token=token, + ), ) ).create(minimal_agent_card(agent_server, [TransportProtocol.HTTP_JSON])) @@ -238,6 +275,272 @@ async def test_notification_triggering_after_config_change_e2e( assert notifications[0].token == token +@pytest.mark.asyncio +async def test_multi_registrar_fan_out_e2e( + notifications_server: str, + agent_server: str, + http_client: httpx.AsyncClient, +): + """Two pushNotificationConfigs registered for the same task both fire end-to-end. + + Exercises the dispatch fan-out across multiple registered configs + over the real wire: each registered URL must receive a POST with + its own token in the X-A2A-Notification-Token header. + """ + # Configure an A2A client without a per-message push notification config + # (we'll register configs explicitly after the task is created). + a2a_client = ClientFactory( + ClientConfig( + supported_protocol_bindings=[TransportProtocol.HTTP_JSON], + ) + ).create(minimal_agent_card(agent_server, [TransportProtocol.HTTP_JSON])) + + # Send an initial message that requires more input, so the task lingers + # long enough for us to register multiple push configs against it. + responses = [ + response + async for response in a2a_client.send_message( + SendMessageRequest( + message=Message( + message_id='multi-fanout-init', + parts=[Part(text='How are you?')], + role=Role.ROLE_USER, + ), + configuration=SendMessageConfiguration(), + ) + ) + ] + assert len(responses) == 1 + stream_response = responses[0] + assert stream_response.HasField('task') + task = stream_response.task + assert task.status.state == TaskState.TASK_STATE_INPUT_REQUIRED + + # Register two distinct push configs for the same task. Both share the + # same registrar (this client), but use different config ids, URLs, and + # tokens. Both must fire when the next event is dispatched. + token_a = uuid.uuid4().hex + token_b = uuid.uuid4().hex + await a2a_client.create_task_push_notification_config( + TaskPushNotificationConfig( + task_id=task.id, + id='registrar-a', + url=f'{notifications_server}/notifications', + token=token_a, + ) + ) + await a2a_client.create_task_push_notification_config( + TaskPushNotificationConfig( + task_id=task.id, + id='registrar-b', + url=f'{notifications_server}/notifications', + token=token_b, + ) + ) + + # Sanity: no notifications have fired yet. + response = await http_client.get( + f'{notifications_server}/{task.id}/notifications' + ) + assert response.status_code == 200 + assert len(response.json().get('notifications', [])) == 0 + + # Send a follow-up message that completes the task and triggers + # dispatch. Both registered configs must receive a POST. + responses = [ + response + async for response in a2a_client.send_message( + SendMessageRequest( + message=Message( + task_id=task.id, + message_id='multi-fanout-complete', + parts=[Part(text='Good')], + role=Role.ROLE_USER, + ), + configuration=SendMessageConfiguration(), + ) + ) + ] + assert len(responses) == 1 + + # Expect 2 notifications: one COMPLETED event, fanned out to 2 configs. + notifications = await wait_for_n_notifications( + http_client, + f'{notifications_server}/{task.id}/notifications', + n=2, + ) + + # Both tokens must appear exactly once. + received_tokens = sorted(n.token for n in notifications) + assert received_tokens == sorted([token_a, token_b]) + + # Both notifications must carry the same COMPLETED event payload. + for notification in notifications: + state = ( + notification.event.get('status_update', {}) + .get('status', {}) + .get('state') + ) + assert state == 'TASK_STATE_COMPLETED' + + +def _make_user_a2a_client(agent_server: str, user_name: str): + """Builds an A2A client that identifies as user_name on every request. + + Identity is conveyed via a default header on the underlying + httpx.AsyncClient; the multi-user agent app's context builder + reads that header to populate ServerCallContext.user. + """ + httpx_client = httpx.AsyncClient(headers={_TEST_USER_HEADER: user_name}) + return ClientFactory( + ClientConfig( + httpx_client=httpx_client, + supported_protocol_bindings=[TransportProtocol.HTTP_JSON], + ) + ).create( + minimal_agent_card(agent_server, [TransportProtocol.HTTP_JSON]) + ), httpx_client + + +@pytest.mark.asyncio +async def test_alice_and_bob_both_receive_notifications_on_shared_task_e2e( + notifications_server: str, + multi_user_agent_server: str, + http_client: httpx.AsyncClient, +): + """Alice registers a webhook; Bob registers a webhook; both fire end-to-end. + + 1. Alice creates a task (it lingers in INPUT_REQUIRED). + 2. Alice registers her own push config on the task. + 3. Bob (a different authenticated user, sharing access to the task) + registers his own push config on the same task. + 4. Bob (the dispatcher, *not* the registrar of Alice's webhook) + sends a follow-up message that completes the task. + 5. Both Alice's webhook and Bob's webhook receive a POST with their + own respective tokens. + + Regression guard for the design's central guarantee: subscriptions + fire on the registrar's behalf regardless of which user's action + triggered the event. A regression that re-introduced + dispatcher-context filtering on the dispatch path would drop one of + the two notifications. + """ + alice_client, alice_http = _make_user_a2a_client( + multi_user_agent_server, 'alice' + ) + bob_client, bob_http = _make_user_a2a_client(multi_user_agent_server, 'bob') + + try: + responses = [ + response + async for response in alice_client.send_message( + SendMessageRequest( + message=Message( + message_id='shared-task-init', + parts=[Part(text='How are you?')], + role=Role.ROLE_USER, + ), + ) + ) + ] + assert len(responses) == 1 + assert responses[0].HasField('task') + task = responses[0].task + assert task.status.state == TaskState.TASK_STATE_INPUT_REQUIRED + + # 2. Alice registers her push config. + alice_token = uuid.uuid4().hex + await alice_client.create_task_push_notification_config( + TaskPushNotificationConfig( + task_id=task.id, + id='alice-cfg', + url=f'{notifications_server}/notifications', + token=alice_token, + ) + ) + + # 3. Bob registers his push config on the same task. + bob_token = uuid.uuid4().hex + await bob_client.create_task_push_notification_config( + TaskPushNotificationConfig( + task_id=task.id, + id='bob-cfg', + url=f'{notifications_server}/notifications', + token=bob_token, + ) + ) + + # Sanity: the per-user listing endpoints are owner-scoped -- + # Alice does not see Bob's config and vice-versa, even though + # both can see the underlying task. + # + # The auto-registered empty config (see step 1 quirk note) lives + # in Alice's partition under ``id == task_id``, so Alice's + # listing contains ``{'alice-cfg', task.id}``; the key invariant + # is that neither listing contains the other user's id or + # token. + alice_configs = await alice_client.list_task_push_notification_configs( + ListTaskPushNotificationConfigsRequest(task_id=task.id) + ) + alice_ids = {c.id for c in alice_configs.configs} + assert 'alice-cfg' in alice_ids + assert 'bob-cfg' not in alice_ids + assert all(c.token != bob_token for c in alice_configs.configs) + + bob_configs = await bob_client.list_task_push_notification_configs( + ListTaskPushNotificationConfigsRequest(task_id=task.id) + ) + bob_ids = {c.id for c in bob_configs.configs} + assert 'bob-cfg' in bob_ids + assert 'alice-cfg' not in bob_ids + assert all(c.token != alice_token for c in bob_configs.configs) + + # Sanity: no notifications have fired yet. + response = await http_client.get( + f'{notifications_server}/{task.id}/notifications' + ) + assert response.status_code == 200 + assert len(response.json().get('notifications', [])) == 0 + + # 4. Bob sends the follow-up message that completes the task. + # Omit ``configuration`` for the same reason as step 1. + responses = [ + response + async for response in bob_client.send_message( + SendMessageRequest( + message=Message( + task_id=task.id, + message_id='shared-task-complete', + parts=[Part(text='Good')], + role=Role.ROLE_USER, + ), + ) + ) + ] + assert len(responses) == 1 + + # 5. Both Alice's and Bob's webhooks receive the COMPLETED event. + notifications = await wait_for_n_notifications( + http_client, + f'{notifications_server}/{task.id}/notifications', + n=2, + ) + + received_tokens = sorted(n.token for n in notifications) + assert received_tokens == sorted([alice_token, bob_token]) + + for notification in notifications: + state = ( + notification.event.get('status_update', {}) + .get('status', {}) + .get('state') + ) + assert state == 'TASK_STATE_COMPLETED' + finally: + await alice_http.aclose() + await bob_http.aclose() + + async def wait_for_n_notifications( http_client: httpx.AsyncClient, url: str, diff --git a/tests/e2e/push_notifications/utils.py b/tests/e2e/push_notifications/utils.py index 2934ecc58..a7317f1b2 100644 --- a/tests/e2e/push_notifications/utils.py +++ b/tests/e2e/push_notifications/utils.py @@ -20,12 +20,14 @@ def run_server(app, host, port) -> None: uvicorn.run(app, host=host, port=port, log_level='warning') -def wait_for_server_ready(url: str, timeout: int = 10) -> None: +def wait_for_server_ready( + url: str, timeout: int = 10, headers: dict | None = None +) -> None: """Polls the provided URL endpoint until the server is up.""" start_time = time.time() while True: with contextlib.suppress(httpx.ConnectError): - with httpx.Client() as client: + with httpx.Client(headers=headers) as client: response = client.get(url) if response.status_code == 200: return diff --git a/tests/helpers/test_agent_card_display.py b/tests/helpers/test_agent_card_display.py new file mode 100644 index 000000000..e252a52fe --- /dev/null +++ b/tests/helpers/test_agent_card_display.py @@ -0,0 +1,194 @@ +"""Tests for display_agent_card utility.""" + +import pytest + +from a2a.types.a2a_pb2 import ( + AgentCapabilities, + AgentCard, + AgentInterface, + AgentProvider, + AgentSkill, +) +from a2a.helpers.agent_card import display_agent_card + + +@pytest.fixture +def full_agent_card() -> AgentCard: + return AgentCard( + name='Sample Agent', + description='A sample agent.', + version='1.0.0', + documentation_url='https://docs.example.com', + icon_url='https://example.com/icon.png', + provider=AgentProvider( + organization='Example Org', url='https://example.com' + ), + supported_interfaces=[ + AgentInterface( + url='http://localhost:9999/a2a/jsonrpc', + protocol_binding='JSONRPC', + protocol_version='1.0', + ), + AgentInterface( + url='http://localhost:9999/a2a/rest', + protocol_binding='HTTP+JSON', + protocol_version='1.0', + tenant='tenant-a', + ), + ], + capabilities=AgentCapabilities( + streaming=True, + push_notifications=False, + extended_agent_card=True, + ), + default_input_modes=['text'], + default_output_modes=['text', 'task-status'], + skills=[ + AgentSkill( + id='skill-1', + name='My Skill', + description='Does something useful.', + tags=['foo', 'bar'], + examples=['Do the thing', 'Another example'], + ), + AgentSkill( + id='skill-2', + name='Other Skill', + description='Does something else.', + tags=['baz'], + ), + ], + ) + + +class TestDisplayAgentCard: + def test_full_card_output( + self, full_agent_card: AgentCard, capsys: pytest.CaptureFixture[str] + ) -> None: + """Golden test: exact output for a fully-populated card.""" + display_agent_card(full_agent_card) + assert capsys.readouterr().out == ( + '====================================================\n' + ' AgentCard \n' + '====================================================\n' + '--- General ---\n' + 'Name : Sample Agent\n' + 'Description : A sample agent.\n' + 'Version : 1.0.0\n' + 'Docs URL : https://docs.example.com\n' + 'Icon URL : https://example.com/icon.png\n' + 'Provider : Example Org (https://example.com)\n' + '\n' + '--- Interfaces ---\n' + ' [0] http://localhost:9999/a2a/jsonrpc (JSONRPC 1.0)\n' + ' [1] http://localhost:9999/a2a/rest (HTTP+JSON 1.0, tenant=tenant-a)\n' + '\n' + '--- Capabilities ---\n' + 'Streaming : True\n' + 'Push notifications : False\n' + 'Extended agent card : True\n' + '\n' + '--- I/O Modes ---\n' + 'Input : text\n' + 'Output : text, task-status\n' + '\n' + '--- Skills ---\n' + '----------------------------------------------------\n' + ' ID : skill-1\n' + ' Name : My Skill\n' + ' Description : Does something useful.\n' + ' Tags : foo, bar\n' + ' Example : Do the thing\n' + ' Example : Another example\n' + '----------------------------------------------------\n' + ' ID : skill-2\n' + ' Name : Other Skill\n' + ' Description : Does something else.\n' + ' Tags : baz\n' + '====================================================\n' + ) + + def test_empty_card_output( + self, capsys: pytest.CaptureFixture[str] + ) -> None: + """Golden test: exact output for a card with only default/empty fields. + + An empty supported_interfaces section signals a malformed card — + the bare header with no entries is intentional and visible to the user. + """ + display_agent_card(AgentCard()) + assert capsys.readouterr().out == ( + '====================================================\n' + ' AgentCard \n' + '====================================================\n' + '--- General ---\n' + 'Name : \n' + 'Description : \n' + 'Version : \n' + '\n' + '--- Interfaces ---\n' + '\n' + '--- Capabilities ---\n' + 'Streaming : False\n' + 'Push notifications : False\n' + 'Extended agent card : False\n' + '\n' + '--- I/O Modes ---\n' + 'Input : (none)\n' + 'Output : (none)\n' + '\n' + '--- Skills ---\n' + ' (none)\n' + '====================================================\n' + ) + + def test_interface_without_protocol_version_has_no_trailing_space( + self, capsys: pytest.CaptureFixture[str] + ) -> None: + """No trailing space in the binding field when protocol_version is not set.""" + card = AgentCard( + supported_interfaces=[ + AgentInterface( + url='127.0.0.1:50051', + protocol_binding='GRPC', + ) + ] + ) + display_agent_card(card) + assert ' [0] 127.0.0.1:50051 (GRPC)' in capsys.readouterr().out + + def test_interface_without_binding_or_version_has_no_parentheses( + self, capsys: pytest.CaptureFixture[str] + ) -> None: + """No parentheses when neither protocol_binding nor protocol_version are set.""" + card = AgentCard( + supported_interfaces=[AgentInterface(url='127.0.0.1:50051')] + ) + display_agent_card(card) + assert ' [0] 127.0.0.1:50051\n' in capsys.readouterr().out + + def test_provider_with_url( + self, capsys: pytest.CaptureFixture[str] + ) -> None: + """Provider shows organization and URL in parentheses when both are set.""" + card = AgentCard( + provider=AgentProvider( + organization='Example Org', + url='https://example.com', + ), + ) + display_agent_card(card) + assert ( + 'Provider : Example Org (https://example.com)' + in capsys.readouterr().out + ) + + def test_provider_without_url_has_no_empty_parentheses( + self, capsys: pytest.CaptureFixture[str] + ) -> None: + """No empty parentheses when provider URL is not set.""" + card = AgentCard(provider=AgentProvider(organization='Example Org')) + display_agent_card(card) + out = capsys.readouterr().out + assert 'Provider : Example Org' in out + assert '()' not in out diff --git a/tests/helpers/test_proto_helpers.py b/tests/helpers/test_proto_helpers.py new file mode 100644 index 000000000..8fb68dbc2 --- /dev/null +++ b/tests/helpers/test_proto_helpers.py @@ -0,0 +1,445 @@ +"""Tests for proto helpers.""" + +import pytest + +from a2a.helpers.proto_helpers import ( + get_artifact_text, + get_message_text, + get_stream_response_text, + get_text_parts, + new_artifact, + new_data_artifact, + new_data_message, + new_data_part, + new_message, + new_raw_artifact, + new_raw_message, + new_raw_part, + new_task, + new_task_from_user_message, + new_text_artifact, + new_text_artifact_update_event, + new_text_message, + new_text_part, + new_text_status_update_event, + new_url_artifact, + new_url_message, + new_url_part, +) +from a2a.types.a2a_pb2 import ( + Artifact, + Message, + Part, + Role, + StreamResponse, + Task, + TaskState, +) + + +# --- Message Helpers Tests --- + + +def test_new_message() -> None: + parts = [Part(text='hello')] + msg = new_message( + parts, context_id='ctx1', task_id='task1', role=Role.ROLE_USER + ) + assert msg.role == Role.ROLE_USER + assert msg.parts == parts + assert msg.context_id == 'ctx1' + assert msg.task_id == 'task1' + assert msg.message_id != '' + + +def test_new_text_message() -> None: + msg = new_text_message( + 'hello', + media_type='text/plain', + context_id='ctx1', + task_id='task1', + role=Role.ROLE_USER, + ) + assert msg.role == Role.ROLE_USER + assert len(msg.parts) == 1 + assert msg.parts[0].text == 'hello' + assert msg.parts[0].media_type == 'text/plain' + assert msg.context_id == 'ctx1' + assert msg.task_id == 'task1' + assert msg.message_id != '' + + +def test_new_data_message() -> None: + msg = new_data_message( + data={'key': 'value'}, + media_type='application/json', + context_id='ctx1', + task_id='task1', + role=Role.ROLE_USER, + ) + assert msg.role == Role.ROLE_USER + assert len(msg.parts) == 1 + assert msg.parts[0].HasField('data') + assert msg.parts[0].data.struct_value.fields['key'].string_value == 'value' + assert msg.parts[0].media_type == 'application/json' + assert msg.context_id == 'ctx1' + assert msg.task_id == 'task1' + assert msg.message_id != '' + + +def test_new_raw_message() -> None: + msg = new_raw_message( + b'\x89PNG', + media_type='image/png', + filename='img.png', + context_id='ctx1', + task_id='task1', + role=Role.ROLE_USER, + ) + assert msg.role == Role.ROLE_USER + assert len(msg.parts) == 1 + assert msg.parts[0].HasField('raw') + assert msg.parts[0].raw == b'\x89PNG' + assert msg.parts[0].media_type == 'image/png' + assert msg.parts[0].filename == 'img.png' + assert msg.context_id == 'ctx1' + assert msg.task_id == 'task1' + assert msg.message_id != '' + + +def test_new_url_message() -> None: + msg = new_url_message( + 'https://example.com/file.pdf', + media_type='application/pdf', + filename='file.pdf', + context_id='ctx1', + task_id='task1', + role=Role.ROLE_USER, + ) + assert msg.role == Role.ROLE_USER + assert len(msg.parts) == 1 + assert msg.parts[0].HasField('url') + assert msg.parts[0].url == 'https://example.com/file.pdf' + assert msg.parts[0].media_type == 'application/pdf' + assert msg.parts[0].filename == 'file.pdf' + assert msg.context_id == 'ctx1' + assert msg.task_id == 'task1' + assert msg.message_id != '' + + +def test_get_message_text() -> None: + msg = Message(parts=[Part(text='hello'), Part(text='world')]) + assert get_message_text(msg) == 'hello\nworld' + assert get_message_text(msg, delimiter=' ') == 'hello world' + + +# --- Artifact Helpers Tests --- + + +def test_new_artifact() -> None: + parts = [Part(text='content')] + art = new_artifact(parts=parts, name='test', description='desc') + assert art.name == 'test' + assert art.description == 'desc' + assert art.parts == parts + assert art.artifact_id != '' + + +def test_new_text_artifact() -> None: + art = new_text_artifact(name='test', text='content', description='desc') + assert art.name == 'test' + assert art.description == 'desc' + assert len(art.parts) == 1 + assert art.parts[0].text == 'content' + assert art.artifact_id != '' + + +def test_new_text_artifact_with_id() -> None: + art = new_text_artifact( + name='test', text='content', description='desc', artifact_id='art1' + ) + assert art.name == 'test' + assert art.description == 'desc' + assert len(art.parts) == 1 + assert art.parts[0].text == 'content' + assert art.artifact_id == 'art1' + + +def test_new_data_artifact() -> None: + art = new_data_artifact( + name='result', data={'score': 1.0}, description='desc' + ) + assert art.name == 'result' + assert art.description == 'desc' + assert len(art.parts) == 1 + assert art.parts[0].HasField('data') + assert art.parts[0].data.struct_value.fields['score'].number_value == 1.0 + assert art.artifact_id != '' + + +def test_new_data_artifact_with_id() -> None: + art = new_data_artifact(name='result', data={'x': 'y'}, artifact_id='art1') + assert art.artifact_id == 'art1' + assert art.parts[0].data.struct_value.fields['x'].string_value == 'y' + + +def test_new_raw_artifact() -> None: + art = new_raw_artifact( + name='screenshot', + raw=b'\x89PNG', + media_type='image/png', + filename='screen.png', + description='desc', + artifact_id='art1', + ) + assert art.name == 'screenshot' + assert art.description == 'desc' + assert art.artifact_id == 'art1' + assert len(art.parts) == 1 + assert art.parts[0].HasField('raw') + assert art.parts[0].raw == b'\x89PNG' + assert art.parts[0].media_type == 'image/png' + assert art.parts[0].filename == 'screen.png' + + +def test_new_raw_artifact_minimal() -> None: + art = new_raw_artifact(name='file', raw=b'data') + assert art.parts[0].raw == b'data' + assert art.artifact_id != '' + + +def test_new_url_artifact() -> None: + art = new_url_artifact( + name='report', + url='https://example.com/report.pdf', + media_type='application/pdf', + filename='report.pdf', + description='desc', + artifact_id='art1', + ) + assert art.name == 'report' + assert art.description == 'desc' + assert art.artifact_id == 'art1' + assert len(art.parts) == 1 + assert art.parts[0].HasField('url') + assert art.parts[0].url == 'https://example.com/report.pdf' + assert art.parts[0].media_type == 'application/pdf' + assert art.parts[0].filename == 'report.pdf' + + +def test_new_url_artifact_minimal() -> None: + art = new_url_artifact(name='img', url='https://example.com/img.png') + assert art.parts[0].url == 'https://example.com/img.png' + assert art.artifact_id != '' + + +def test_get_artifact_text() -> None: + art = Artifact(parts=[Part(text='hello'), Part(text='world')]) + assert get_artifact_text(art) == 'hello\nworld' + assert get_artifact_text(art, delimiter=' ') == 'hello world' + + +# --- Task Helpers Tests --- + + +def test_new_task_from_user_message() -> None: + msg = Message( + role=Role.ROLE_USER, + parts=[Part(text='hello')], + task_id='task1', + context_id='ctx1', + ) + task = new_task_from_user_message(msg) + assert task.id == 'task1' + assert task.context_id == 'ctx1' + assert task.status.state == TaskState.TASK_STATE_SUBMITTED + assert len(task.history) == 1 + assert task.history[0] == msg + + +def test_new_task_from_user_message_empty_parts() -> None: + msg = Message(role=Role.ROLE_USER, parts=[]) + with pytest.raises(ValueError, match='Message parts cannot be empty'): + new_task_from_user_message(msg) + + +def test_new_task_from_user_message_empty_text() -> None: + msg = Message(role=Role.ROLE_USER, parts=[Part(text='')]) + with pytest.raises(ValueError, match='Message.text cannot be empty'): + new_task_from_user_message(msg) + + +def test_new_task() -> None: + task = new_task( + task_id='task1', context_id='ctx1', state=TaskState.TASK_STATE_WORKING + ) + assert task.id == 'task1' + assert task.context_id == 'ctx1' + assert task.status.state == TaskState.TASK_STATE_WORKING + assert len(task.history) == 0 + assert len(task.artifacts) == 0 + + +# --- Part Helpers Tests --- + + +def test_get_text_parts() -> None: + parts = [ + Part(text='hello'), + Part(url='http://example.com'), + Part(text='world'), + ] + assert get_text_parts(parts) == ['hello', 'world'] + + +def test_new_text_part() -> None: + part = new_text_part('hello') + assert part.HasField('text') + assert part.text == 'hello' + assert part.media_type == '' + + +def test_new_text_part_with_media_type() -> None: + part = new_text_part('# Hello', media_type='text/markdown') + assert part.HasField('text') + assert part.text == '# Hello' + assert part.media_type == 'text/markdown' + + +def test_new_data_part_from_dict() -> None: + part = new_data_part({'key': 'value', 'count': 42}) + assert part.HasField('data') + assert part.data.struct_value.fields['key'].string_value == 'value' + assert part.data.struct_value.fields['count'].number_value == 42 + assert part.media_type == '' + + +def test_new_data_part_with_media_type() -> None: + part = new_data_part({'key': 'value'}, media_type='application/json') + assert part.HasField('data') + assert part.media_type == 'application/json' + + +def test_new_data_part_from_list() -> None: + part = new_data_part([1, 2, 3]) + assert part.HasField('data') + assert part.data.list_value.values[0].number_value == 1 + assert part.data.list_value.values[1].number_value == 2 + assert part.data.list_value.values[2].number_value == 3 + + +def test_new_raw_part() -> None: + part = new_raw_part(b'\x89PNG', media_type='image/png', filename='img.png') + assert part.HasField('raw') + assert part.raw == b'\x89PNG' + assert part.media_type == 'image/png' + assert part.filename == 'img.png' + + +def test_new_raw_part_minimal() -> None: + part = new_raw_part(b'data') + assert part.HasField('raw') + assert part.raw == b'data' + assert part.media_type == '' + assert part.filename == '' + + +def test_new_url_part() -> None: + part = new_url_part( + 'https://example.com/file.pdf', + media_type='application/pdf', + filename='file.pdf', + ) + assert part.HasField('url') + assert part.url == 'https://example.com/file.pdf' + assert part.media_type == 'application/pdf' + assert part.filename == 'file.pdf' + + +def test_new_url_part_minimal() -> None: + part = new_url_part('https://example.com/img.png') + assert part.HasField('url') + assert part.url == 'https://example.com/img.png' + assert part.media_type == '' + assert part.filename == '' + + +# --- Event & Stream Helpers Tests --- + + +def test_new_text_status_update_event() -> None: + event = new_text_status_update_event( + task_id='task1', + context_id='ctx1', + state=TaskState.TASK_STATE_WORKING, + text='progress', + ) + assert event.task_id == 'task1' + assert event.context_id == 'ctx1' + assert event.status.state == TaskState.TASK_STATE_WORKING + assert event.status.message.parts[0].text == 'progress' + + +def test_new_text_artifact_update_event() -> None: + event = new_text_artifact_update_event( + task_id='task1', + context_id='ctx1', + name='test', + text='content', + append=True, + last_chunk=True, + ) + assert event.task_id == 'task1' + assert event.context_id == 'ctx1' + assert event.artifact.name == 'test' + assert event.artifact.parts[0].text == 'content' + assert event.append is True + assert event.last_chunk is True + + +def test_new_text_artifact_update_event_with_id() -> None: + event = new_text_artifact_update_event( + task_id='task1', + context_id='ctx1', + name='test', + text='content', + artifact_id='art1', + ) + assert event.task_id == 'task1' + assert event.context_id == 'ctx1' + assert event.artifact.name == 'test' + assert event.artifact.parts[0].text == 'content' + assert event.artifact.artifact_id == 'art1' + + +def test_get_stream_response_text_message() -> None: + resp = StreamResponse(message=Message(parts=[Part(text='hello')])) + assert get_stream_response_text(resp) == 'hello' + + +def test_get_stream_response_text_task() -> None: + resp = StreamResponse( + task=Task(artifacts=[Artifact(parts=[Part(text='hello')])]) + ) + assert get_stream_response_text(resp) == 'hello' + + +def test_get_stream_response_text_status_update() -> None: + resp = StreamResponse( + status_update=new_text_status_update_event( + 't', 'c', TaskState.TASK_STATE_WORKING, 'hello' + ) + ) + assert get_stream_response_text(resp) == 'hello' + + +def test_get_stream_response_text_artifact_update() -> None: + resp = StreamResponse( + artifact_update=new_text_artifact_update_event('t', 'c', 'n', 'hello') + ) + assert get_stream_response_text(resp) == 'hello' + + +def test_get_stream_response_text_empty() -> None: + resp = StreamResponse() + assert get_stream_response_text(resp) == '' diff --git a/tests/integration/cross_version/client_server/client_1_0.py b/tests/integration/cross_version/client_server/client_1_0.py index 5a5e192cf..6630bddad 100644 --- a/tests/integration/cross_version/client_server/client_1_0.py +++ b/tests/integration/cross_version/client_server/client_1_0.py @@ -5,7 +5,7 @@ import sys from uuid import uuid4 -from a2a.client import ClientFactory, ClientConfig +from a2a.client import ClientConfig, create_client from a2a.utils import TransportProtocol from a2a.types import ( Message, @@ -80,7 +80,7 @@ async def test_send_message_sync(url, protocol_enum): config.supported_protocol_bindings = [protocol_enum] config.streaming = False - client = await ClientFactory.connect(url, client_config=config) + client = await create_client(url, client_config=config) msg = Message( role=Role.ROLE_USER, message_id=f'sync-{uuid4()}', @@ -296,7 +296,7 @@ async def run_client(url: str, protocol: str): config.supported_protocol_bindings = [protocol_enum] config.streaming = True - client = await ClientFactory.connect(url, client_config=config) + client = await create_client(url, client_config=config) # 1. Get Extended Agent Card server_name = await test_get_extended_agent_card(client) diff --git a/tests/integration/cross_version/client_server/server_0_3.py b/tests/integration/cross_version/client_server/server_0_3.py index 7bd5f7e75..875cbb1ca 100644 --- a/tests/integration/cross_version/client_server/server_0_3.py +++ b/tests/integration/cross_version/client_server/server_0_3.py @@ -38,7 +38,7 @@ from starlette.requests import Request from starlette.concurrency import iterate_in_threadpool import time - +from a2a.utils.task import new_task from server_common import CustomLoggingMiddleware @@ -48,12 +48,18 @@ def __init__(self): async def execute(self, context: RequestContext, event_queue: EventQueue): print(f'SERVER: execute called for task {context.task_id}') + + task = new_task(context.message) + task.id = context.task_id + task.context_id = context.context_id + task.status.state = TaskState.working + await event_queue.enqueue_event(task) + task_updater = TaskUpdater( event_queue, context.task_id, context.context_id, ) - await task_updater.update_status(TaskState.submitted) await task_updater.update_status(TaskState.working) text = '' diff --git a/tests/integration/cross_version/client_server/server_1_0.py b/tests/integration/cross_version/client_server/server_1_0.py index 74e0bc23b..06f7e5e97 100644 --- a/tests/integration/cross_version/client_server/server_1_0.py +++ b/tests/integration/cross_version/client_server/server_1_0.py @@ -28,6 +28,7 @@ from a2a.utils import TransportProtocol from server_common import CustomLoggingMiddleware from google.protobuf.struct_pb2 import Struct, Value +from a2a.helpers.proto_helpers import new_task_from_user_message class MockAgentExecutor(AgentExecutor): @@ -36,12 +37,17 @@ def __init__(self): async def execute(self, context: RequestContext, event_queue: EventQueue): print(f'SERVER: execute called for task {context.task_id}') + task = new_task_from_user_message(context.message) + task.id = context.task_id + task.context_id = context.context_id + task.status.state = TaskState.TASK_STATE_WORKING + await event_queue.enqueue_event(task) + task_updater = TaskUpdater( event_queue, context.task_id, context.context_id, ) - await task_updater.update_status(TaskState.TASK_STATE_SUBMITTED) await task_updater.update_status(TaskState.TASK_STATE_WORKING) text = '' @@ -158,10 +164,12 @@ async def main_async(http_port: int, grpc_port: int): task_store = InMemoryTaskStore() handler = DefaultRequestHandler( - agent_executor=MockAgentExecutor(), - task_store=task_store, + MockAgentExecutor(), + task_store, + agent_card, queue_manager=InMemoryQueueManager(), push_config_store=InMemoryPushNotificationConfigStore(), + extended_agent_card=agent_card, ) app = FastAPI() @@ -171,9 +179,7 @@ async def main_async(http_port: int, grpc_port: int): agent_card=agent_card, card_url='/.well-known/agent-card.json' ) jsonrpc_routes = create_jsonrpc_routes( - agent_card=agent_card, request_handler=handler, - extended_agent_card=agent_card, rpc_url='/', enable_v0_3_compat=True, ) @@ -183,7 +189,6 @@ async def main_async(http_port: int, grpc_port: int): ) rest_routes = create_rest_routes( - agent_card=agent_card, request_handler=handler, enable_v0_3_compat=True, ) @@ -194,10 +199,10 @@ async def main_async(http_port: int, grpc_port: int): # Start gRPC Server server = grpc.aio.server() - servicer = GrpcHandler(agent_card, handler) + servicer = GrpcHandler(handler) a2a_pb2_grpc.add_A2AServiceServicer_to_server(servicer, server) - compat_servicer = CompatGrpcHandler(agent_card, handler) + compat_servicer = CompatGrpcHandler(handler) a2a_v0_3_pb2_grpc.add_A2AServiceServicer_to_server(compat_servicer, server) server.add_insecure_port(f'127.0.0.1:{grpc_port}') diff --git a/tests/integration/cross_version/test_cross_version_card_validation.py b/tests/integration/cross_version/test_cross_version_card_validation.py index 85379c3a3..25972b075 100644 --- a/tests/integration/cross_version/test_cross_version_card_validation.py +++ b/tests/integration/cross_version/test_cross_version_card_validation.py @@ -18,7 +18,7 @@ SecurityScheme, StringList, ) -from a2a.client.helpers import parse_agent_card +from a2a.client.card_resolver import parse_agent_card from google.protobuf.json_format import MessageToDict, ParseDict diff --git a/tests/integration/test_agent_card.py b/tests/integration/test_agent_card.py index 494fd151c..afa1078f0 100644 --- a/tests/integration/test_agent_card.py +++ b/tests/integration/test_agent_card.py @@ -66,6 +66,7 @@ async def test_agent_card_integration(header_val: str | None) -> None: handler = DefaultRequestHandler( agent_executor=DummyAgentExecutor(), task_store=task_store, + agent_card=agent_card, queue_manager=InMemoryQueueManager(), push_config_store=InMemoryPushNotificationConfigStore(), ) @@ -76,9 +77,7 @@ async def test_agent_card_integration(header_val: str | None) -> None: *create_agent_card_routes( agent_card=agent_card, card_url='/.well-known/agent-card.json' ), - *create_jsonrpc_routes( - agent_card=agent_card, request_handler=handler, rpc_url='/' - ), + *create_jsonrpc_routes(request_handler=handler, rpc_url='/'), ] jsonrpc_app = Starlette(routes=jsonrpc_routes) app.mount('/jsonrpc', jsonrpc_app) @@ -87,7 +86,7 @@ async def test_agent_card_integration(header_val: str | None) -> None: *create_agent_card_routes( agent_card=agent_card, card_url='/.well-known/agent-card.json' ), - *create_rest_routes(agent_card=agent_card, request_handler=handler), + *create_rest_routes(request_handler=handler), ] rest_app = Starlette(routes=rest_routes) app.mount('/rest', rest_app) diff --git a/tests/integration/test_client_server_integration.py b/tests/integration/test_client_server_integration.py index e00b53c02..1711ac810 100644 --- a/tests/integration/test_client_server_integration.py +++ b/tests/integration/test_client_server_integration.py @@ -1,4 +1,5 @@ import asyncio + from collections.abc import AsyncGenerator from typing import Any, NamedTuple from unittest.mock import ANY, AsyncMock, patch @@ -7,9 +8,11 @@ import httpx import pytest import pytest_asyncio + from cryptography.hazmat.primitives.asymmetric import ec from google.protobuf.json_format import MessageToDict from google.protobuf.timestamp_pb2 import Timestamp +from starlette.applications import Starlette from a2a.client import Client, ClientConfig from a2a.client.base_client import BaseClient @@ -21,17 +24,19 @@ with_a2a_extensions, ) from a2a.client.transports import JsonRpcTransport, RestTransport -from starlette.applications import Starlette # Compat v0.3 imports for dedicated tests -from a2a.compat.v0_3 import a2a_v0_3_pb2, a2a_v0_3_pb2_grpc +from a2a.compat.v0_3 import a2a_v0_3_pb2_grpc from a2a.compat.v0_3.grpc_handler import CompatGrpcHandler +from a2a.server.request_handlers import GrpcHandler, RequestHandler from a2a.server.routes import ( create_agent_card_routes, create_jsonrpc_routes, create_rest_routes, ) -from a2a.server.request_handlers import GrpcHandler, RequestHandler +from a2a.server.request_handlers.default_request_handler import ( + LegacyRequestHandler, +) from a2a.types import a2a_pb2_grpc from a2a.types.a2a_pb2 import ( AgentCapabilities, @@ -66,11 +71,7 @@ ContentTypeNotSupportedError, ExtendedAgentCardNotConfiguredError, ExtensionSupportRequiredError, - InternalError, InvalidAgentResponseError, - InvalidParamsError, - InvalidRequestError, - MethodNotFoundError, PushNotificationNotSupportedError, TaskNotCancelableError, TaskNotFoundError, @@ -82,6 +83,7 @@ create_signature_verifier, ) + # --- Test Constants --- TASK_FROM_STREAM = Task( @@ -142,11 +144,12 @@ def key_provider(kid: str | None, jku: str | None): @pytest.fixture -def mock_request_handler() -> AsyncMock: +def mock_request_handler(agent_card) -> AsyncMock: """Provides a mock RequestHandler for the server-side handlers.""" handler = AsyncMock(spec=RequestHandler) # Configure on_message_send for non-streaming calls + handler._agent_card = agent_card handler.on_message_send.return_value = TASK_FROM_BLOCKING # Configure on_message_send_stream for streaming calls @@ -168,6 +171,14 @@ async def stream_side_effect(*args, **kwargs): ) handler.on_delete_task_push_notification_config.return_value = None + # Use async def to ensure it returns an awaitable + async def get_extended_agent_card_mock(*args, **kwargs): + return agent_card + + handler.on_get_extended_agent_card.side_effect = ( + get_extended_agent_card_mock # type: ignore[union-attr] + ) + async def resubscribe_side_effect(*args, **kwargs): yield RESUBSCRIBE_EVENT @@ -220,7 +231,7 @@ def http_base_setup(mock_request_handler: AsyncMock, agent_card: AgentCard): """A base fixture to patch the sse-starlette event loop issue.""" from sse_starlette import sse - sse.AppStatus.should_exit_event = asyncio.Event() # type: ignore[attr-defined] + sse.AppStatus.should_exit_event = asyncio.Event() yield mock_request_handler, agent_card @@ -232,10 +243,7 @@ def jsonrpc_setup(http_base_setup) -> TransportSetup: agent_card=agent_card, card_url='/' ) jsonrpc_routes = create_jsonrpc_routes( - agent_card=agent_card, - request_handler=mock_request_handler, - extended_agent_card=agent_card, - rpc_url='/', + request_handler=mock_request_handler, rpc_url='/' ) app = Starlette(routes=[*agent_card_routes, *jsonrpc_routes]) httpx_client = httpx.AsyncClient(transport=httpx.ASGITransport(app=app)) @@ -253,9 +261,7 @@ def jsonrpc_setup(http_base_setup) -> TransportSetup: def rest_setup(http_base_setup) -> TransportSetup: """Sets up the RestTransport and in-memory server.""" mock_request_handler, agent_card = http_base_setup - rest_routes = create_rest_routes( - agent_card, mock_request_handler, extended_agent_card=agent_card - ) + rest_routes = create_rest_routes(mock_request_handler) agent_card_routes = create_agent_card_routes( agent_card=agent_card, card_url='/' ) @@ -344,10 +350,13 @@ async def grpc_server_and_handler( server = grpc.aio.server() port = server.add_insecure_port('[::]:0') server_address = f'localhost:{port}' - servicer = GrpcHandler(agent_card, mock_request_handler) + servicer = GrpcHandler(request_handler=mock_request_handler) a2a_pb2_grpc.add_A2AServiceServicer_to_server(servicer, server) await server.start() - yield server_address, mock_request_handler + try: + yield server_address, mock_request_handler + finally: + await server.stop(None) @pytest_asyncio.fixture @@ -358,7 +367,9 @@ async def grpc_03_server_and_handler( server = grpc.aio.server() port = server.add_insecure_port('[::]:0') server_address = f'localhost:{port}' - servicer = CompatGrpcHandler(agent_card, mock_request_handler) + servicer = CompatGrpcHandler( + request_handler=mock_request_handler, + ) a2a_v0_3_pb2_grpc.add_A2AServiceServicer_to_server(servicer, server) await server.start() try: @@ -664,9 +675,9 @@ async def test_json_transport_base_client_send_message_with_extensions( call_args[1] if len(call_args) > 1 else call_kwargs.get('context') ) service_params = getattr(called_context, 'service_parameters', {}) - assert 'X-A2A-Extensions' in service_params + assert 'A2A-Extensions' in service_params assert ( - service_params['X-A2A-Extensions'] + service_params['A2A-Extensions'] == 'https://example.com/test-ext/v1,https://example.com/test-ext/v2' ) @@ -698,14 +709,14 @@ async def test_json_transport_get_signed_base_card( }, ) + async def async_signer(card: AgentCard) -> AgentCard: + return signer(card) + agent_card_routes = create_agent_card_routes( - agent_card=agent_card, card_url='/', card_modifier=signer + agent_card=agent_card, card_url='/', card_modifier=async_signer ) jsonrpc_routes = create_jsonrpc_routes( - agent_card=agent_card, - request_handler=mock_request_handler, - extended_agent_card=agent_card, - rpc_url='/', + request_handler=mock_request_handler, rpc_url='/' ) app = Starlette(routes=[*agent_card_routes, *jsonrpc_routes]) httpx_client = httpx.AsyncClient( @@ -762,7 +773,7 @@ async def test_client_get_signed_extended_card( private_key = ec.generate_private_key(ec.SECP256R1()) public_key = private_key.public_key() signer = create_agent_card_signer( - signing_key=private_key, # type: ignore[arg-type] + signing_key=private_key, protected_header={ 'alg': 'ES256', 'kid': 'testkey', @@ -771,15 +782,18 @@ async def test_client_get_signed_extended_card( }, ) + async def get_extended_agent_card_mock_2(*args, **kwargs) -> AgentCard: + return signer(extended_agent_card) + + mock_request_handler.on_get_extended_agent_card.side_effect = ( + get_extended_agent_card_mock_2 # type: ignore[union-attr] + ) + agent_card_routes = create_agent_card_routes( agent_card=agent_card, card_url='/' ) jsonrpc_routes = create_jsonrpc_routes( - agent_card=agent_card, - request_handler=mock_request_handler, - extended_agent_card=extended_agent_card, - extended_card_modifier=lambda card, ctx: signer(card), - rpc_url='/', + request_handler=mock_request_handler, rpc_url='/' ) app = Starlette(routes=[*agent_card_routes, *jsonrpc_routes]) httpx_client = httpx.AsyncClient( @@ -835,7 +849,7 @@ async def test_client_get_signed_base_and_extended_cards( private_key = ec.generate_private_key(ec.SECP256R1()) public_key = private_key.public_key() signer = create_agent_card_signer( - signing_key=private_key, # type: ignore[arg-type] + signing_key=private_key, protected_header={ 'alg': 'ES256', 'kid': 'testkey', @@ -843,16 +857,24 @@ async def test_client_get_signed_base_and_extended_cards( 'typ': 'JOSE', }, ) + signer(extended_agent_card) + + # Use async def to ensure it returns an awaitable + async def get_extended_agent_card_mock_3(*args, **kwargs): + return extended_agent_card + + mock_request_handler.on_get_extended_agent_card.side_effect = ( + get_extended_agent_card_mock_3 # type: ignore[union-attr] + ) + + async def async_signer(card: AgentCard) -> AgentCard: + return signer(card) agent_card_routes = create_agent_card_routes( - agent_card=agent_card, card_url='/', card_modifier=signer + agent_card=agent_card, card_url='/', card_modifier=async_signer ) jsonrpc_routes = create_jsonrpc_routes( - agent_card=agent_card, - request_handler=mock_request_handler, - extended_agent_card=extended_agent_card, - extended_card_modifier=lambda card, ctx: signer(card), - rpc_url='/', + request_handler=mock_request_handler, rpc_url='/' ) app = Starlette(routes=[*agent_card_routes, *jsonrpc_routes]) httpx_client = httpx.AsyncClient( @@ -1004,6 +1026,71 @@ async def mock_generator(*args, **kwargs): await client.close() +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'error_cls,handler_attr,client_method,request_params', + [ + pytest.param( + UnsupportedOperationError, + 'on_subscribe_to_task', + 'subscribe', + SubscribeToTaskRequest(id='some-id'), + id='subscribe', + ), + ], +) +async def test_server_rejects_stream_on_validation_error( + transport_setups, error_cls, handler_attr, client_method, request_params +) -> None: + """Verify that the server returns an error directly and doesn't open a stream on validation error.""" + client = transport_setups.client + handler = transport_setups.handler + + async def mock_generator(*args, **kwargs): + raise error_cls('Validation failed') + yield + + getattr(handler, handler_attr).side_effect = mock_generator + + transport = client._transport + + if isinstance(transport, (RestTransport, JsonRpcTransport)): + # Spy on httpx client to check response headers + original_send = transport.httpx_client.send + response_headers = {} + + async def mock_send(*args, **kwargs): + resp = await original_send(*args, **kwargs) + response_headers['Content-Type'] = resp.headers.get('Content-Type') + return resp + + transport.httpx_client.send = mock_send + + try: + with pytest.raises(error_cls): + async for _ in getattr(client, client_method)( + request=request_params + ): + pass + finally: + transport.httpx_client.send = original_send + + # Verify that the response content type was NOT text/event-stream + assert not response_headers.get('Content-Type', '').startswith( + 'text/event-stream' + ) + else: + # For gRPC, we just verify it raises the error + with pytest.raises(error_cls): + async for _ in getattr(client, client_method)( + request=request_params + ): + pass + + getattr(handler, handler_attr).side_effect = None + await client.close() + + @pytest.mark.asyncio @pytest.mark.parametrize( 'request_kwargs, expected_error_code', @@ -1101,7 +1188,7 @@ async def test_validate_version_unsupported(http_transport_setups) -> None: params = GetTaskRequest(id=GET_TASK_RESPONSE.id) - with pytest.raises(VersionNotSupportedError) as exc_info: + with pytest.raises(VersionNotSupportedError): await client.get_task(request=params, context=context) await client.close() @@ -1114,11 +1201,21 @@ async def test_validate_decorator_push_notifications_disabled( """Integration test for @validate decorator with push notifications disabled.""" client = error_handling_setups.client - agent_card.capabilities.push_notifications = False + real_handler = LegacyRequestHandler( + agent_executor=AsyncMock(), + task_store=AsyncMock(), + agent_card=agent_card, + ) + + error_handling_setups.handler.on_create_task_push_notification_config.side_effect = real_handler.on_create_task_push_notification_config - params = TaskPushNotificationConfig(task_id='123') + params = TaskPushNotificationConfig( + task_id='123', + id='pnc-123', + url='http://example.com', + ) - with pytest.raises(UnsupportedOperationError) as exc_info: + with pytest.raises(PushNotificationNotSupportedError): await client.create_task_push_notification_config(request=params) await client.close() @@ -1134,14 +1231,107 @@ async def test_validate_streaming_disabled( agent_card.capabilities.streaming = False + real_handler = LegacyRequestHandler( + agent_executor=AsyncMock(), + task_store=AsyncMock(), + agent_card=agent_card, + ) + + error_handling_setups.handler.on_message_send_stream.side_effect = ( + real_handler.on_message_send_stream + ) + error_handling_setups.handler.on_subscribe_to_task.side_effect = ( + real_handler.on_subscribe_to_task + ) + params = SendMessageRequest( - message=Message(role=Role.ROLE_USER, parts=[Part(text='hi')]) + message=Message( + role=Role.ROLE_USER, + parts=[Part(text='hi')], + message_id='msg-123', + ) ) stream = transport.send_message_streaming(request=params) - with pytest.raises(UnsupportedOperationError) as exc_info: + with pytest.raises(UnsupportedOperationError): async for _ in stream: pass await transport.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'error_cls', + [ + TaskNotFoundError, + TaskNotCancelableError, + PushNotificationNotSupportedError, + UnsupportedOperationError, + ContentTypeNotSupportedError, + InvalidAgentResponseError, + ExtendedAgentCardNotConfiguredError, + ExtensionSupportRequiredError, + VersionNotSupportedError, + ], +) +@pytest.mark.parametrize( + 'handler_attr, client_method, request_params', + [ + pytest.param( + 'on_message_send_stream', + 'send_message', + SendMessageRequest( + message=Message( + role=Role.ROLE_USER, + message_id='msg-midstream-test', + parts=[Part(text='Hello, mid-stream test!')], + ) + ), + id='stream', + ), + pytest.param( + 'on_subscribe_to_task', + 'subscribe', + SubscribeToTaskRequest(id='some-id'), + id='subscribe', + ), + ], +) +async def test_client_handles_mid_stream_a2a_errors( + transport_setups, + error_cls, + handler_attr, + client_method, + request_params, +) -> None: + """Integration test for mid-stream errors sent as SSE error events. + + The handler yields one event successfully, then raises an A2AError. + The client must receive the first event and then get the error as the + exact error_cls exception. This mirrors test_client_handles_a2a_errors_streaming + but verifies the error occurs *after* the stream has started producing events. + """ + client = transport_setups.client + handler = transport_setups.handler + + async def mock_generator(*args, **kwargs): + yield TASK_FROM_STREAM + raise error_cls('Mid-stream error') + + getattr(handler, handler_attr).side_effect = mock_generator + + received_events = [] + with pytest.raises(error_cls) as exc_info: + async for event in getattr(client, client_method)( + request=request_params + ): + received_events.append(event) # noqa: PERF401 + + assert 'Mid-stream error' in str(exc_info.value) + assert len(received_events) == 1 + + getattr(handler, handler_attr).side_effect = None + + await client.close() diff --git a/tests/integration/test_copying_observability.py b/tests/integration/test_copying_observability.py index a207c9b24..bc23b4696 100644 --- a/tests/integration/test_copying_observability.py +++ b/tests/integration/test_copying_observability.py @@ -25,6 +25,7 @@ SendMessageRequest, TaskState, ) +from a2a.helpers.proto_helpers import new_task_from_user_message from a2a.utils import TransportProtocol @@ -42,6 +43,12 @@ async def execute(self, context: RequestContext, event_queue: EventQueue): if user_input == 'Init task': # Explicitly save status change to ensure task exists with some state + task = new_task_from_user_message(context.message) + task.id = context.task_id + task.context_id = context.context_id + task.status.state = TaskState.TASK_STATE_WORKING + await event_queue.enqueue_event(task) + await task_updater.update_status( TaskState.TASK_STATE_WORKING, message=task_updater.new_agent_message( @@ -94,15 +101,15 @@ def setup_client(agent_card: AgentCard, use_copying: bool) -> ClientSetup: handler = DefaultRequestHandler( agent_executor=MockMutatingAgentExecutor(), task_store=task_store, + agent_card=agent_card, queue_manager=InMemoryQueueManager(), + extended_agent_card=agent_card, ) agent_card_routes = create_agent_card_routes( agent_card=agent_card, card_url='/' ) jsonrpc_routes = create_jsonrpc_routes( - agent_card=agent_card, request_handler=handler, - extended_agent_card=agent_card, rpc_url='/', ) app = Starlette(routes=[*agent_card_routes, *jsonrpc_routes]) @@ -153,6 +160,7 @@ async def test_mutation_observability(agent_card: AgentCard, use_copying: bool): ] event = events[-1] + assert event.HasField('status_update') task_id = event.status_update.task_id # 2. Second message to mutate it @@ -162,7 +170,6 @@ async def test_mutation_observability(agent_card: AgentCard, use_copying: bool): task_id=task_id, parts=[Part(text='Update task without saving it')], ) - _ = [ event async for event in client.send_message( diff --git a/tests/integration/test_end_to_end.py b/tests/integration/test_end_to_end.py index 4987acdb5..dcd016b48 100644 --- a/tests/integration/test_end_to_end.py +++ b/tests/integration/test_end_to_end.py @@ -6,21 +6,27 @@ import pytest import pytest_asyncio +from starlette.applications import Starlette + from a2a.client.base_client import BaseClient -from a2a.client.client import ClientConfig +from a2a.client.client import ClientCallContext, ClientConfig from a2a.client.client_factory import ClientFactory +from a2a.client.service_parameters import ( + ServiceParametersFactory, + with_a2a_extensions, +) from a2a.server.agent_execution import AgentExecutor, RequestContext -from a2a.server.routes.rest_routes import create_rest_routes -from starlette.applications import Starlette -from a2a.server.routes import create_jsonrpc_routes, create_agent_card_routes from a2a.server.events import EventQueue from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager from a2a.server.request_handlers import DefaultRequestHandler, GrpcHandler +from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes +from a2a.server.routes.rest_routes import create_rest_routes from a2a.server.tasks import TaskUpdater from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore from a2a.types import ( AgentCapabilities, AgentCard, + AgentExtension, AgentInterface, CancelTaskRequest, DeleteTaskPushNotificationConfigRequest, @@ -38,9 +44,16 @@ a2a_pb2_grpc, ) from a2a.utils import TransportProtocol +from a2a.helpers.proto_helpers import new_task_from_user_message from a2a.utils.errors import InvalidParamsError +SUPPORTED_EXTENSION_URIS = [ + 'https://example.com/ext/v1', + 'https://example.com/ext/v2', +] + + def assert_message_matches(message, expected_role, expected_text): assert message.role == expected_role assert message.parts[0].text == expected_text @@ -69,7 +82,9 @@ def assert_events_match(events, expected_events): events, expected_events, strict=True ): assert event.HasField(expected_type) - if expected_type == 'status_update': + if expected_type == 'task': + assert event.task.status.state == expected_val + elif expected_type == 'status_update': assert event.status_update.status.state == expected_val elif expected_type == 'artifact_update': if expected_val is not None: @@ -83,26 +98,43 @@ def assert_events_match(events, expected_events): class MockAgentExecutor(AgentExecutor): async def execute(self, context: RequestContext, event_queue: EventQueue): - task_updater = TaskUpdater( - event_queue, - context.task_id, - context.context_id, - ) user_input = context.get_user_input() - is_input_required_resumption = ( - context.current_task is not None - and context.current_task.status.state - == TaskState.TASK_STATE_INPUT_REQUIRED - ) - - if not is_input_required_resumption: - await task_updater.update_status( - TaskState.TASK_STATE_SUBMITTED, - message=task_updater.new_agent_message( - [Part(text='task submitted')] - ), + # Extensions echo: report the requested extensions back to the client + # via the Message.extensions field. + if user_input.startswith('Extensions:'): + await event_queue.enqueue_event( + Message( + role=Role.ROLE_AGENT, + message_id='ext-reply-1', + parts=[Part(text='extensions echoed')], + extensions=sorted(context.requested_extensions), + ) ) + return + + # Direct message response (no task created). + if user_input.startswith('Message:'): + await event_queue.enqueue_event( + Message( + role=Role.ROLE_AGENT, + message_id='direct-reply-1', + parts=[Part(text=f'Direct reply to: {user_input}')], + ) + ) + return + + # Task-based response. + task = context.current_task + if not task: + task = new_task_from_user_message(context.message) + await event_queue.enqueue_event(task) + + task_updater = TaskUpdater( + event_queue, + task.id, + task.context_id, + ) await task_updater.update_status( TaskState.TASK_STATE_WORKING, @@ -136,7 +168,15 @@ def agent_card() -> AgentCard: description='Real in-memory integration testing.', version='1.0.0', capabilities=AgentCapabilities( - streaming=True, push_notifications=False + streaming=True, + push_notifications=False, + extensions=[ + AgentExtension( + uri=uri, + description=f'Test extension {uri}', + ) + for uri in SUPPORTED_EXTENSION_URIS + ], ), skills=[], default_input_modes=['text/plain'], @@ -166,11 +206,12 @@ class ClientSetup(NamedTuple): @pytest.fixture -def base_e2e_setup(): +def base_e2e_setup(agent_card): task_store = InMemoryTaskStore() handler = DefaultRequestHandler( agent_executor=MockAgentExecutor(), task_store=task_store, + agent_card=agent_card, queue_manager=InMemoryQueueManager(), ) return task_store, handler @@ -179,9 +220,7 @@ def base_e2e_setup(): @pytest.fixture def rest_setup(agent_card, base_e2e_setup) -> ClientSetup: task_store, handler = base_e2e_setup - rest_routes = create_rest_routes( - agent_card=agent_card, request_handler=handler - ) + rest_routes = create_rest_routes(request_handler=handler) agent_card_routes = create_agent_card_routes( agent_card=agent_card, card_url='/' ) @@ -209,9 +248,7 @@ def jsonrpc_setup(agent_card, base_e2e_setup) -> ClientSetup: agent_card=agent_card, card_url='/' ) jsonrpc_routes = create_jsonrpc_routes( - agent_card=agent_card, request_handler=handler, - extended_agent_card=agent_card, rpc_url='/', ) app = Starlette(routes=[*agent_card_routes, *jsonrpc_routes]) @@ -250,8 +287,8 @@ async def grpc_setup( break else: raise ValueError('No gRPC interface found in agent card') - - servicer = GrpcHandler(grpc_agent_card, handler) + handler._agent_card = grpc_agent_card + servicer = GrpcHandler(handler) a2a_pb2_grpc.add_A2AServiceServicer_to_server(servicer, server) await server.start() @@ -331,7 +368,6 @@ async def test_end_to_end_send_message_blocking(transport_setups): response.task.history, [ (Role.ROLE_USER, 'Run dummy agent!'), - (Role.ROLE_AGENT, 'task submitted'), (Role.ROLE_AGENT, 'task working'), ], ) @@ -389,20 +425,19 @@ async def test_end_to_end_send_message_streaming(transport_setups): assert_events_match( events, [ - ('status_update', TaskState.TASK_STATE_SUBMITTED), + ('task', TaskState.TASK_STATE_SUBMITTED), ('status_update', TaskState.TASK_STATE_WORKING), ('artifact_update', [('test-artifact', 'artifact content')]), ('status_update', TaskState.TASK_STATE_COMPLETED), ], ) - task_id = events[0].status_update.task_id + task_id = events[0].task.id task = await client.get_task(request=GetTaskRequest(id=task_id)) assert_history_matches( task.history, [ (Role.ROLE_USER, 'Run dummy agent!'), - (Role.ROLE_AGENT, 'task submitted'), (Role.ROLE_AGENT, 'task working'), ], ) @@ -426,7 +461,7 @@ async def test_end_to_end_get_task(transport_setups): ) ] response = events[0] - task_id = response.status_update.task_id + task_id = response.task.id get_request = GetTaskRequest(id=task_id) retrieved_task = await client.get_task(request=get_request) @@ -441,7 +476,6 @@ async def test_end_to_end_get_task(transport_setups): retrieved_task.history, [ (Role.ROLE_USER, 'Test Get Task'), - (Role.ROLE_AGENT, 'task submitted'), (Role.ROLE_AGENT, 'task working'), ], ) @@ -468,7 +502,7 @@ async def test_end_to_end_list_tasks(transport_setups): ) ) ) - expected_task_ids.append(response.status_update.task_id) + expected_task_ids.append(response.task.id) list_request = ListTasksRequest(page_size=page_size) @@ -517,13 +551,13 @@ async def test_end_to_end_input_required(transport_setups): assert_events_match( events, [ - ('status_update', TaskState.TASK_STATE_SUBMITTED), + ('task', TaskState.TASK_STATE_SUBMITTED), ('status_update', TaskState.TASK_STATE_WORKING), ('status_update', TaskState.TASK_STATE_INPUT_REQUIRED), ], ) - task_id = events[0].status_update.task_id + task_id = events[0].task.id task = await client.get_task(request=GetTaskRequest(id=task_id)) assert task.status.state == TaskState.TASK_STATE_INPUT_REQUIRED @@ -531,7 +565,6 @@ async def test_end_to_end_input_required(transport_setups): task.history, [ (Role.ROLE_USER, 'Need input'), - (Role.ROLE_AGENT, 'task submitted'), (Role.ROLE_AGENT, 'task working'), ], ) @@ -575,7 +608,6 @@ async def test_end_to_end_input_required(transport_setups): task.history, [ (Role.ROLE_USER, 'Need input'), - (Role.ROLE_AGENT, 'task submitted'), (Role.ROLE_AGENT, 'task working'), (Role.ROLE_AGENT, 'Please provide input'), (Role.ROLE_USER, 'Here is the input'), @@ -684,3 +716,119 @@ async def test_end_to_end_subscribe_validation_error( assert {e['field'] for e in errors} == {'id'} await client.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'streaming', + [ + pytest.param(False, id='blocking'), + pytest.param(True, id='streaming'), + ], +) +async def test_end_to_end_direct_message(transport_setups, streaming): + """Test that an executor can return a direct Message without creating a Task.""" + client = transport_setups.client + client._config.streaming = streaming + + message_to_send = Message( + role=Role.ROLE_USER, + message_id='msg-direct', + parts=[Part(text='Message: Hello agent')], + ) + + events = [ + event + async for event in client.send_message( + request=SendMessageRequest(message=message_to_send) + ) + ] + + assert len(events) == 1 + response = events[0] + assert response.HasField('message') + assert not response.HasField('task') + assert_message_matches( + response.message, + Role.ROLE_AGENT, + 'Direct reply to: Message: Hello agent', + ) + + +@pytest.mark.asyncio +async def test_end_to_end_direct_message_return_immediately(transport_setups): + """Test that return_immediately still returns the Message for direct replies. + + When the executor responds with a direct Message, the response is + inherently immediate -- there is no async task to defer to. The client + should receive the Message regardless of the return_immediately flag. + """ + client = transport_setups.client + client._config.streaming = False + + message_to_send = Message( + role=Role.ROLE_USER, + message_id='msg-direct-return-immediately', + parts=[Part(text='Message: Quick question')], + ) + configuration = SendMessageConfiguration(return_immediately=True) + + events = [ + event + async for event in client.send_message( + request=SendMessageRequest( + message=message_to_send, configuration=configuration + ) + ) + ] + + assert len(events) == 1 + response = events[0] + assert response.HasField('message') + assert not response.HasField('task') + assert_message_matches( + response.message, + Role.ROLE_AGENT, + 'Direct reply to: Message: Quick question', + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'streaming', + [ + pytest.param(False, id='blocking'), + pytest.param(True, id='streaming'), + ], +) +async def test_end_to_end_extensions_propagation(transport_setups, streaming): + """Test that extensions sent by the client reach the agent executor.""" + client = transport_setups.client + client._config.streaming = streaming + + service_params = ServiceParametersFactory.create( + [with_a2a_extensions(SUPPORTED_EXTENSION_URIS)] + ) + context = ClientCallContext(service_parameters=service_params) + + message_to_send = Message( + role=Role.ROLE_USER, + message_id='msg-ext-propagation', + parts=[Part(text='Extensions: echo')], + ) + + events = [ + event + async for event in client.send_message( + request=SendMessageRequest(message=message_to_send), + context=context, + ) + ] + + assert len(events) == 1 + response = events[0] + assert response.HasField('message') + assert_message_matches( + response.message, Role.ROLE_AGENT, 'extensions echoed' + ) + assert set(response.message.extensions) == set(SUPPORTED_EXTENSION_URIS) diff --git a/tests/integration/test_samples_smoke.py b/tests/integration/test_samples_smoke.py new file mode 100644 index 000000000..fcb49a003 --- /dev/null +++ b/tests/integration/test_samples_smoke.py @@ -0,0 +1,134 @@ +"""End-to-end smoke test for `samples/hello_world_agent.py` and `samples/cli.py`. + +Boots the sample agent as a subprocess on free ports, then runs the sample CLI +against it once per supported transport, asserting the expected greeting reply +flows through. +""" + +from __future__ import annotations + +import asyncio +import socket +import sys + +from pathlib import Path +from typing import TYPE_CHECKING + +import httpx +import pytest +import pytest_asyncio + + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator + + +REPO_ROOT = Path(__file__).resolve().parents[2] +SAMPLES_DIR = REPO_ROOT / 'samples' +AGENT_SCRIPT = SAMPLES_DIR / 'hello_world_agent.py' +CLI_SCRIPT = SAMPLES_DIR / 'cli.py' + +STARTUP_TIMEOUT_S = 30.0 +CLI_TIMEOUT_S = 30.0 +EXPECTED_REPLY = 'Hello World! Nice to meet you!' + + +def _free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(('127.0.0.1', 0)) + return sock.getsockname()[1] + + +async def _wait_for_agent_card(url: str) -> None: + deadline = asyncio.get_running_loop().time() + STARTUP_TIMEOUT_S + async with httpx.AsyncClient(timeout=2.0) as client: + while asyncio.get_running_loop().time() < deadline: + try: + response = await client.get(url) + if response.status_code == 200: + return + except httpx.RequestError: + pass + await asyncio.sleep(0.2) + raise TimeoutError(f'Agent did not become ready at {url}') + + +@pytest_asyncio.fixture +async def running_sample_agent() -> AsyncGenerator[str, None]: + """Start `hello_world_agent.py` as a subprocess on free ports.""" + host = '127.0.0.1' + http_port = _free_port() + grpc_port = _free_port() + compat_grpc_port = _free_port() + base_url = f'http://{host}:{http_port}' + + proc = await asyncio.create_subprocess_exec( + sys.executable, + str(AGENT_SCRIPT), + '--host', + host, + '--port', + str(http_port), + '--grpc-port', + str(grpc_port), + '--compat-grpc-port', + str(compat_grpc_port), + cwd=str(REPO_ROOT), + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.STDOUT, + ) + + try: + await _wait_for_agent_card(f'{base_url}/.well-known/agent-card.json') + yield base_url + finally: + if proc.returncode is None: + proc.terminate() + try: + await asyncio.wait_for(proc.wait(), timeout=10.0) + except asyncio.TimeoutError: + proc.kill() + await proc.wait() + + +async def _run_cli(base_url: str, transport: str) -> str: + """Run `cli.py --transport `, send `hello`, return combined output.""" + proc = await asyncio.create_subprocess_exec( + sys.executable, + str(CLI_SCRIPT), + '--url', + base_url, + '--transport', + transport, + cwd=str(REPO_ROOT), + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.STDOUT, + ) + try: + stdout, _ = await asyncio.wait_for( + proc.communicate(b'hello\n/quit\n'), + timeout=CLI_TIMEOUT_S, + ) + except asyncio.TimeoutError: + proc.kill() + await proc.wait() + raise + output = stdout.decode('utf-8', errors='replace') + assert proc.returncode == 0, ( + f'CLI exited with {proc.returncode} for transport {transport!r}.\n' + f'Output:\n{output}' + ) + return output + + +@pytest.mark.asyncio +@pytest.mark.parametrize('transport', ['JSONRPC', 'HTTP+JSON', 'GRPC']) +async def test_cli_against_sample_agent( + running_sample_agent: str, transport: str +) -> None: + """The CLI should successfully exchange a greeting over each transport.""" + output = await _run_cli(running_sample_agent, transport) + + assert 'TASK_STATE_COMPLETED' in output, output + assert EXPECTED_REPLY in output, output diff --git a/tests/integration/test_scenarios.py b/tests/integration/test_scenarios.py new file mode 100644 index 000000000..6070a672f --- /dev/null +++ b/tests/integration/test_scenarios.py @@ -0,0 +1,2128 @@ +import asyncio +import collections +import contextlib +import logging + +from typing import Any + +import grpc +import pytest +import pytest_asyncio + +from a2a.auth.user import User +from a2a.client.client import ClientConfig +from a2a.client.client_factory import ClientFactory +from a2a.client.errors import A2AClientError +from a2a.server.agent_execution import AgentExecutor, RequestContext +from a2a.server.context import ServerCallContext +from a2a.server.events import EventQueue +from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager +from a2a.server.request_handlers import ( + DefaultRequestHandlerV2, + GrpcHandler, + GrpcServerCallContextBuilder, +) +from a2a.server.request_handlers.default_request_handler import ( + LegacyRequestHandler, +) +from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore +from a2a.types import a2a_pb2_grpc +from a2a.types.a2a_pb2 import ( + AgentCapabilities, + AgentCard, + AgentInterface, + Artifact, + CancelTaskRequest, + GetTaskRequest, + ListTasksRequest, + Message, + Part, + Role, + SendMessageConfiguration, + SendMessageRequest, + SubscribeToTaskRequest, + Task, + TaskArtifactUpdateEvent, + TaskState, + TaskStatus, + TaskStatusUpdateEvent, +) +from a2a.helpers.proto_helpers import new_task_from_user_message +from a2a.utils import TransportProtocol +from a2a.utils.errors import ( + InvalidParamsError, + TaskNotCancelableError, + TaskNotFoundError, + InvalidAgentResponseError, +) + + +logger = logging.getLogger(__name__) + + +async def wait_for_state( + client: Any, + task_id: str, + expected_states: set[TaskState.ValueType], + timeout: float = 1.0, +) -> None: + """Wait for the task to reach one of the expected states.""" + start_time = asyncio.get_event_loop().time() + while True: + task = await client.get_task(GetTaskRequest(id=task_id)) + if task.status.state in expected_states: + return + + if asyncio.get_event_loop().time() - start_time > timeout: + raise TimeoutError( + f'Task {task_id} did not reach expected states {expected_states} within {timeout}s. ' + f'Current state: {task.status.state}' + ) + await asyncio.sleep(0.01) + + +async def get_all_events(stream): + return [event async for event in stream] + + +class MockUser(User): + @property + def is_authenticated(self) -> bool: + return True + + @property + def user_name(self) -> str: + return 'test-user' + + +class MockCallContextBuilder(GrpcServerCallContextBuilder): + def build(self, request: Any) -> ServerCallContext: + return ServerCallContext( + user=MockUser(), state={'headers': {'a2a-version': '1.0'}} + ) + + +def agent_card(): + return AgentCard( + name='Test Agent', + version='1.0.0', + capabilities=AgentCapabilities(streaming=True), + supported_interfaces=[ + AgentInterface( + protocol_binding=TransportProtocol.GRPC, + url='http://testserver', + ) + ], + ) + + +def get_task_id(event): + if event.HasField('task'): + return event.task.id + if event.HasField('status_update'): + return event.status_update.task_id + assert False, f'Event {event} has no task_id' + + +def get_task_context_id(event): + if event.HasField('task'): + return event.task.context_id + if event.HasField('status_update'): + return event.status_update.context_id + assert False, f'Event {event} has no context_id' + + +def get_state(event): + if event.HasField('task'): + return event.task.status.state + return event.status_update.status.state + + +def validate_state(event, expected_state): + assert get_state(event) == expected_state + + +_test_servers = [] + + +@pytest_asyncio.fixture(autouse=True) +async def cleanup_test_servers(): + yield + for server in _test_servers: + await server.stop(None) + _test_servers.clear() + + +# TODO: Test different transport (e.g. HTTP_JSON hangs for some tests). +async def create_client(handler, agent_card, streaming=False): + server = grpc.aio.server() + port = server.add_insecure_port('[::]:0') + server_address = f'localhost:{port}' + + agent_card.supported_interfaces[0].url = server_address + agent_card.supported_interfaces[0].protocol_binding = TransportProtocol.GRPC + + servicer = GrpcHandler( + request_handler=handler, context_builder=MockCallContextBuilder() + ) + a2a_pb2_grpc.add_A2AServiceServicer_to_server(servicer, server) + await server.start() + _test_servers.append(server) + + factory = ClientFactory( + config=ClientConfig( + grpc_channel_factory=grpc.aio.insecure_channel, + supported_protocol_bindings=[TransportProtocol.GRPC], + streaming=streaming, + ) + ) + client = factory.create(agent_card) + client._server = server # Keep reference to prevent garbage collection + return client + + +def create_handler( + agent_executor, use_legacy, task_store=None, queue_manager=None +): + task_store = task_store or InMemoryTaskStore() + queue_manager = queue_manager or InMemoryQueueManager() + return ( + LegacyRequestHandler( + agent_executor, + task_store, + agent_card(), + queue_manager, + ) + if use_legacy + else DefaultRequestHandlerV2( + agent_executor, + task_store, + agent_card(), + queue_manager, + ) + ) + + +# Scenario 1: Cancellation of already terminal task +# This also covers test_scenario_7_cancel_terminal_task from test_handler_comparison +@pytest.mark.timeout(2.0) +@pytest.mark.asyncio +@pytest.mark.parametrize('use_legacy', [False, True], ids=['v2', 'legacy']) +@pytest.mark.parametrize( + 'streaming', [False, True], ids=['blocking', 'streaming'] +) +async def test_scenario_1_cancel_terminal_task(use_legacy, streaming): + class DummyAgentExecutor(AgentExecutor): + async def execute( + self, context: RequestContext, event_queue: EventQueue + ): + pass + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ): + pass + + task_store = InMemoryTaskStore() + handler = create_handler( + DummyAgentExecutor(), use_legacy, task_store=task_store + ) + client = await create_client( + handler, agent_card=agent_card(), streaming=streaming + ) + + task_id = 'terminal-task' + await task_store.save( + Task( + id=task_id, status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED) + ), + ServerCallContext(user=MockUser()), + ) + with pytest.raises(TaskNotCancelableError): + await client.cancel_task(CancelTaskRequest(id=task_id)) + + +@pytest.mark.asyncio +@pytest.mark.parametrize('use_legacy', [False, True], ids=['v2', 'legacy']) +async def test_scenario_4_simple_streaming(use_legacy): + class DummyAgentExecutor(AgentExecutor): + async def execute( + self, context: RequestContext, event_queue: EventQueue + ): + task = new_task_from_user_message(context.message) + task.status.state = TaskState.TASK_STATE_WORKING + await event_queue.enqueue_event(task) + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), + ) + ) + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ): + pass + + handler = create_handler(DummyAgentExecutor(), use_legacy) + client = await create_client( + handler, agent_card=agent_card(), streaming=True + ) + msg = Message( + message_id='test-msg', role=Role.ROLE_USER, parts=[Part(text='hello')] + ) + events = [ + event + async for event in client.send_message(SendMessageRequest(message=msg)) + ] + task, status_update = events + assert task.HasField('task') + assert status_update.HasField('status_update') + + assert [get_state(event) for event in events] == [ + TaskState.TASK_STATE_WORKING, + TaskState.TASK_STATE_COMPLETED, + ] + + +# Scenario 5: Re-subscribing to a finished task +@pytest.mark.asyncio +@pytest.mark.parametrize('use_legacy', [False, True], ids=['v2', 'legacy']) +async def test_scenario_5_resubscribe_to_finished(use_legacy): + class DummyAgentExecutor(AgentExecutor): + async def execute( + self, context: RequestContext, event_queue: EventQueue + ): + task = new_task_from_user_message(context.message) + task.status.state = TaskState.TASK_STATE_WORKING + await event_queue.enqueue_event(task) + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), + ) + ) + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ): + pass + + handler = create_handler(DummyAgentExecutor(), use_legacy) + client = await create_client(handler, agent_card=agent_card()) + msg = Message( + message_id='test-msg', role=Role.ROLE_USER, parts=[Part(text='hello')] + ) + it = client.send_message( + SendMessageRequest( + message=msg, + configuration=SendMessageConfiguration(return_immediately=False), + ) + ) + + (event,) = [event async for event in it] + task_id = event.task.id + + await wait_for_state( + client, task_id, expected_states={TaskState.TASK_STATE_COMPLETED} + ) + # TODO: Use different transport. + with pytest.raises( + NotImplementedError, + match='client and/or server do not support resubscription', + ): + async for _ in client.subscribe(SubscribeToTaskRequest(id=task_id)): + pass + + +# Scenario 6-8: Parity for Error cases +@pytest.mark.timeout(2.0) +@pytest.mark.asyncio +@pytest.mark.parametrize('use_legacy', [False, True], ids=['v2', 'legacy']) +@pytest.mark.parametrize( + 'streaming', [False, True], ids=['blocking', 'streaming'] +) +async def test_scenarios_simple_errors(use_legacy, streaming): + class DummyAgentExecutor(AgentExecutor): + async def execute( + self, context: RequestContext, event_queue: EventQueue + ): + task = new_task_from_user_message(context.message) + task.status.state = TaskState.TASK_STATE_COMPLETED + await event_queue.enqueue_event(task) + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ): + pass + + handler = create_handler(DummyAgentExecutor(), use_legacy) + client = await create_client( + handler, agent_card=agent_card(), streaming=streaming + ) + + with pytest.raises(TaskNotFoundError): + await client.get_task(GetTaskRequest(id='missing')) + + msg1 = Message( + task_id='missing', + message_id='missing-task', + role=Role.ROLE_USER, + parts=[Part(text='h')], + ) + with pytest.raises(TaskNotFoundError): + async for _ in client.send_message(SendMessageRequest(message=msg1)): + pass + + msg = Message( + message_id='test-msg', role=Role.ROLE_USER, parts=[Part(text='hello')] + ) + it = client.send_message( + SendMessageRequest( + message=msg, + configuration=SendMessageConfiguration(return_immediately=False), + ) + ) + (event,) = [event async for event in it] + + if streaming: + assert event.HasField('task') + task_id = event.task.id + validate_state(event, TaskState.TASK_STATE_COMPLETED) + else: + assert event.HasField('task') + task_id = event.task.id + assert event.task.status.state == TaskState.TASK_STATE_COMPLETED + + logger.info('Sending message to completed task %s', task_id) + msg2 = Message( + message_id='test-msg-2', + task_id=task_id, + role=Role.ROLE_USER, + parts=[Part(text='message to completed task')], + ) + # TODO: Is it correct error code ? + with pytest.raises(InvalidParamsError): + async for _ in client.send_message(SendMessageRequest(message=msg2)): + pass + + (task,) = (await client.list_tasks(ListTasksRequest())).tasks + assert task.status.state == TaskState.TASK_STATE_COMPLETED + (message,) = task.history + assert message.role == Role.ROLE_USER + (message_part,) = message.parts + assert message_part.text == 'hello' + + +# Scenario 9: Exception before any event. +@pytest.mark.timeout(2.0) +@pytest.mark.asyncio +@pytest.mark.parametrize('use_legacy', [False, True], ids=['v2', 'legacy']) +@pytest.mark.parametrize( + 'streaming', [False, True], ids=['blocking', 'streaming'] +) +async def test_scenario_9_error_before_blocking(use_legacy, streaming): + class ErrorBeforeAgent(AgentExecutor): + async def execute( + self, context: RequestContext, event_queue: EventQueue + ): + raise ValueError('TEST_ERROR_IN_EXECUTE') + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ): + pass + + handler = create_handler(ErrorBeforeAgent(), use_legacy) + client = await create_client( + handler, agent_card=agent_card(), streaming=streaming + ) + msg = Message( + message_id='test-msg', role=Role.ROLE_USER, parts=[Part(text='hello')] + ) + + # TODO: Is it correct error code ? + with pytest.raises(A2AClientError, match='TEST_ERROR_IN_EXECUTE'): + async for _ in client.send_message( + SendMessageRequest( + message=msg, + configuration=SendMessageConfiguration( + return_immediately=False + ), + ) + ): + pass + + if use_legacy: + # Legacy is not creating tasks for agent failures. + assert len((await client.list_tasks(ListTasksRequest())).tasks) == 0 + else: + (task,) = (await client.list_tasks(ListTasksRequest())).tasks + assert task.status.state == TaskState.TASK_STATE_FAILED + + +# Scenario 12/13: Exception after initial event +@pytest.mark.timeout(2.0) +@pytest.mark.asyncio +@pytest.mark.parametrize('use_legacy', [False, True], ids=['v2', 'legacy']) +@pytest.mark.parametrize( + 'streaming', [False, True], ids=['blocking', 'streaming'] +) +async def test_scenario_12_13_error_after_initial_event(use_legacy, streaming): + started_event = asyncio.Event() + continue_event = asyncio.Event() + + class ErrorAfterAgent(AgentExecutor): + async def execute( + self, context: RequestContext, event_queue: EventQueue + ): + task = new_task_from_user_message(context.message) + task.status.state = TaskState.TASK_STATE_WORKING + await event_queue.enqueue_event(task) + started_event.set() + await continue_event.wait() + raise ValueError('TEST_ERROR_IN_EXECUTE') + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ): + pass + + handler = create_handler(ErrorAfterAgent(), use_legacy) + client = await create_client( + handler, agent_card=agent_card(), streaming=streaming + ) + msg = Message( + message_id='test-msg', role=Role.ROLE_USER, parts=[Part(text='hello')] + ) + + it = client.send_message(SendMessageRequest(message=msg)) + + tasks = [] + + if streaming: + res = await it.__anext__() + validate_state(res, TaskState.TASK_STATE_WORKING) + continue_event.set() + else: + + async def release_agent(): + await started_event.wait() + continue_event.set() + + tasks.append(asyncio.create_task(release_agent())) + + with pytest.raises(A2AClientError, match='TEST_ERROR_IN_EXECUTE'): + async for _ in it: + pass + + await asyncio.gather(*tasks) + + (task,) = (await client.list_tasks(ListTasksRequest())).tasks + if use_legacy: + # Legacy does not update task state on exception. + assert task.status.state == TaskState.TASK_STATE_WORKING + else: + assert task.status.state == TaskState.TASK_STATE_FAILED + + +# Scenario 14: Exception in Cancel +@pytest.mark.timeout(2.0) +@pytest.mark.asyncio +@pytest.mark.parametrize('use_legacy', [False, True], ids=['v2', 'legacy']) +@pytest.mark.parametrize( + 'streaming', [False, True], ids=['blocking', 'streaming'] +) +async def test_scenario_14_error_in_cancel(use_legacy, streaming): + started_event = asyncio.Event() + hang_event = asyncio.Event() + + class ErrorCancelAgent(AgentExecutor): + async def execute( + self, context: RequestContext, event_queue: EventQueue + ): + task = new_task_from_user_message(context.message) + task.status.state = TaskState.TASK_STATE_WORKING + await event_queue.enqueue_event(task) + started_event.set() + await hang_event.wait() + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ): + raise ValueError('TEST_ERROR_IN_CANCEL') + + handler = create_handler(ErrorCancelAgent(), use_legacy) + client = await create_client( + handler, agent_card=agent_card(), streaming=streaming + ) + + msg = Message( + message_id='test-msg', + role=Role.ROLE_USER, + parts=[Part(text='hello')], + ) + + it = client.send_message( + SendMessageRequest( + message=msg, + configuration=SendMessageConfiguration(return_immediately=True), + ) + ) + res = await it.__anext__() + task_id = res.task.id if res.HasField('task') else res.status_update.task_id + + await asyncio.wait_for(started_event.wait(), timeout=1.0) + + with pytest.raises(A2AClientError, match='TEST_ERROR_IN_CANCEL'): + await client.cancel_task(CancelTaskRequest(id=task_id)) + + (task,) = (await client.list_tasks(ListTasksRequest())).tasks + if use_legacy: + # Legacy does not update task state on exception. + assert task.status.state == TaskState.TASK_STATE_WORKING + else: + assert task.status.state == TaskState.TASK_STATE_FAILED + + +# Scenario 15: Subscribe to task that errors out +@pytest.mark.timeout(2.0) +@pytest.mark.asyncio +@pytest.mark.parametrize('use_legacy', [False, True], ids=['v2', 'legacy']) +async def test_scenario_15_subscribe_error(use_legacy): + started_event = asyncio.Event() + continue_event = asyncio.Event() + + class ErrorAfterAgent(AgentExecutor): + async def execute( + self, context: RequestContext, event_queue: EventQueue + ): + task = new_task_from_user_message(context.message) + task.status.state = TaskState.TASK_STATE_WORKING + await event_queue.enqueue_event(task) + started_event.set() + await continue_event.wait() + raise ValueError('TEST_ERROR_IN_EXECUTE') + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ): + pass + + handler = create_handler(ErrorAfterAgent(), use_legacy) + client = await create_client( + handler, agent_card=agent_card(), streaming=True + ) + msg = Message( + message_id='test-msg', role=Role.ROLE_USER, parts=[Part(text='hello')] + ) + + it_start = client.send_message( + SendMessageRequest( + message=msg, + configuration=SendMessageConfiguration(return_immediately=True), + ) + ) + res = await it_start.__anext__() + task_id = res.task.id if res.HasField('task') else res.status_update.task_id + + async def consume_events(): + async for _ in client.subscribe(SubscribeToTaskRequest(id=task_id)): + pass + + consume_task = asyncio.create_task(consume_events()) + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(asyncio.shield(consume_task), timeout=0.1) + + await asyncio.wait_for(started_event.wait(), timeout=1.0) + continue_event.set() + + if use_legacy: + # Legacy client hangs forever. + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(consume_task, timeout=0.1) + else: + with pytest.raises(A2AClientError, match='TEST_ERROR_IN_EXECUTE'): + await consume_task + + (task,) = (await client.list_tasks(ListTasksRequest())).tasks + if use_legacy: + # Legacy does not update task state on exception. + assert task.status.state == TaskState.TASK_STATE_WORKING + else: + assert task.status.state == TaskState.TASK_STATE_FAILED + + +# Scenario 16: Slow execution and return_immediately=True +@pytest.mark.timeout(2.0) +@pytest.mark.asyncio +@pytest.mark.parametrize('use_legacy', [False, True], ids=['v2', 'legacy']) +@pytest.mark.parametrize( + 'streaming', [False, True], ids=['blocking', 'streaming'] +) +async def test_scenario_16_slow_execution(use_legacy, streaming): + started_event = asyncio.Event() + hang_event = asyncio.Event() + + class SlowAgent(AgentExecutor): + async def execute( + self, context: RequestContext, event_queue: EventQueue + ): + started_event.set() + await hang_event.wait() + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ): + pass + + queue_manager = InMemoryQueueManager() + handler = create_handler( + SlowAgent(), use_legacy, queue_manager=queue_manager + ) + client = await create_client( + handler, agent_card=agent_card(), streaming=streaming + ) + + msg = Message( + message_id='test-msg', + role=Role.ROLE_USER, + parts=[Part(text='hello')], + ) + + async def send_message_and_get_first_response(): + it = client.send_message( + SendMessageRequest( + message=msg, + configuration=SendMessageConfiguration(return_immediately=True), + ) + ) + return await asyncio.wait_for(it.__anext__(), timeout=0.1) + + # First response should not be there yet. + with pytest.raises(asyncio.TimeoutError): + await send_message_and_get_first_response() + + tasks = (await client.list_tasks(ListTasksRequest())).tasks + assert len(tasks) == 0 + + +# Scenario 17: Cancellation of a working task. +# @pytest.mark.skip +@pytest.mark.timeout(2.0) +@pytest.mark.asyncio +@pytest.mark.parametrize('use_legacy', [False, True], ids=['v2', 'legacy']) +@pytest.mark.parametrize( + 'streaming', [False, True], ids=['blocking', 'streaming'] +) +async def test_scenario_cancel_working_task_empty_cancel(use_legacy, streaming): + started_event = asyncio.Event() + hang_event = asyncio.Event() + + class DummyCancelAgent(AgentExecutor): + async def execute( + self, context: RequestContext, event_queue: EventQueue + ): + task = new_task_from_user_message(context.message) + task.status.state = TaskState.TASK_STATE_WORKING + await event_queue.enqueue_event(task) + started_event.set() + await hang_event.wait() + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ): + # TODO: this should be done automatically by the framework ? + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + status=TaskStatus(state=TaskState.TASK_STATE_CANCELED), + ) + ) + + handler = create_handler(DummyCancelAgent(), use_legacy) + client = await create_client( + handler, agent_card=agent_card(), streaming=streaming + ) + + msg = Message( + message_id='test-msg', role=Role.ROLE_USER, parts=[Part(text='hello')] + ) + + it = client.send_message( + SendMessageRequest( + message=msg, + configuration=SendMessageConfiguration(return_immediately=True), + ) + ) + res = await it.__anext__() + task_id = res.task.id if res.HasField('task') else res.status_update.task_id + + await asyncio.wait_for(started_event.wait(), timeout=1.0) + + task_before = await client.get_task(GetTaskRequest(id=task_id)) + assert task_before.status.state == TaskState.TASK_STATE_WORKING + + cancel_res = await client.cancel_task(CancelTaskRequest(id=task_id)) + assert cancel_res.status.state == TaskState.TASK_STATE_CANCELED + + task_after = await client.get_task(GetTaskRequest(id=task_id)) + assert task_after.status.state == TaskState.TASK_STATE_CANCELED + + (task_from_list,) = (await client.list_tasks(ListTasksRequest())).tasks + assert task_from_list.status.state == TaskState.TASK_STATE_CANCELED + + +# Scenario 18: Complex streaming with multiple subscribers +@pytest.mark.timeout(2.0) +@pytest.mark.asyncio +@pytest.mark.parametrize('use_legacy', [False, True], ids=['v2', 'legacy']) +async def test_scenario_18_streaming_subscribers(use_legacy): + started_event = asyncio.Event() + working_event = asyncio.Event() + completed_event = asyncio.Event() + + class ComplexAgent(AgentExecutor): + async def execute( + self, context: RequestContext, event_queue: EventQueue + ): + task = new_task_from_user_message(context.message) + task.status.state = TaskState.TASK_STATE_WORKING + await event_queue.enqueue_event(task) + started_event.set() + await working_event.wait() + + await event_queue.enqueue_event( + TaskArtifactUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + artifact=Artifact(artifact_id='test-art'), + ) + ) + await completed_event.wait() + + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), + ) + ) + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ): + pass + + handler = create_handler(ComplexAgent(), use_legacy) + client = await create_client( + handler, agent_card=agent_card(), streaming=True + ) + + msg = Message( + message_id='test-msg', role=Role.ROLE_USER, parts=[Part(text='hello')] + ) + + it = client.send_message( + SendMessageRequest( + message=msg, + configuration=SendMessageConfiguration(return_immediately=True), + ) + ) + res = await it.__anext__() + task_id = res.task.id if res.HasField('task') else res.status_update.task_id + + await asyncio.wait_for(started_event.wait(), timeout=1.0) + + # create first subscriber + sub1 = client.subscribe(SubscribeToTaskRequest(id=task_id)) + + # first subscriber receives current task state (WORKING) + validate_state(await sub1.__anext__(), TaskState.TASK_STATE_WORKING) + + # create second subscriber + sub2 = client.subscribe(SubscribeToTaskRequest(id=task_id)) + + # second subscriber receives current task state (WORKING) + validate_state(await sub2.__anext__(), TaskState.TASK_STATE_WORKING) + + working_event.set() + + # validate what both subscribers observed (artifact) + res1_art = await sub1.__anext__() + assert res1_art.artifact_update.artifact.artifact_id == 'test-art' + + res2_art = await sub2.__anext__() + assert res2_art.artifact_update.artifact.artifact_id == 'test-art' + + completed_event.set() + + # validate what both subscribers observed (completed) + validate_state(await sub1.__anext__(), TaskState.TASK_STATE_COMPLETED) + validate_state(await sub2.__anext__(), TaskState.TASK_STATE_COMPLETED) + + # validate final task state with getTask + final_task = await client.get_task(GetTaskRequest(id=task_id)) + assert final_task.status.state == TaskState.TASK_STATE_COMPLETED + + (artifact,) = final_task.artifacts + assert artifact.artifact_id == 'test-art' + + (message,) = final_task.history + assert message.parts[0].text == 'hello' + + +# Scenario 19: Parallel executions for the same task should not happen simultaneously. +@pytest.mark.timeout(2.0) +@pytest.mark.asyncio +@pytest.mark.parametrize('use_legacy', [False, True], ids=['v2', 'legacy']) +@pytest.mark.parametrize( + 'streaming', [False, True], ids=['blocking', 'streaming'] +) +async def test_scenario_19_no_parallel_executions(use_legacy, streaming): + started_event = asyncio.Event() + continue_event = asyncio.Event() + executions_count = 0 + + class CountingAgent(AgentExecutor): + async def execute( + self, context: RequestContext, event_queue: EventQueue + ): + nonlocal executions_count + executions_count += 1 + + if executions_count > 1: + await event_queue.enqueue_event( + TaskArtifactUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + artifact=Artifact(artifact_id='SECOND_EXECUTION'), + ) + ) + return + + task = new_task_from_user_message(context.message) + task.status.state = TaskState.TASK_STATE_WORKING + await event_queue.enqueue_event(task) + started_event.set() + await continue_event.wait() + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), + ) + ) + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ): + pass + + handler = create_handler(CountingAgent(), use_legacy) + client1 = await create_client( + handler, agent_card=agent_card(), streaming=streaming + ) + client2 = await create_client( + handler, agent_card=agent_card(), streaming=streaming + ) + + msg1 = Message( + message_id='test-msg-1', + role=Role.ROLE_USER, + parts=[Part(text='hello 1')], + ) + + # First client sends initial message + it1 = client1.send_message( + SendMessageRequest( + message=msg1, + configuration=SendMessageConfiguration(return_immediately=False), + ) + ) + task1 = asyncio.create_task(it1.__anext__()) + + # Wait for the first execution to reach the WORKING state + await asyncio.wait_for(started_event.wait(), timeout=1.0) + assert executions_count == 1 + + # Extract task_id from the first call using list_tasks + (task,) = (await client1.list_tasks(ListTasksRequest())).tasks + task_id = task.id + + msg2 = Message( + message_id='test-msg-2', + task_id=task_id, + role=Role.ROLE_USER, + parts=[Part(text='hello 2')], + ) + + # Second client sends a message to the same task + it2 = client2.send_message( + SendMessageRequest( + message=msg2, + configuration=SendMessageConfiguration(return_immediately=False), + ) + ) + + task2 = asyncio.create_task(it2.__anext__()) + + if use_legacy: + # Legacy handler executes the second request in parallel. + await task2 + assert executions_count == 2 + else: + # V2 handler queues the second request. + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(asyncio.shield(task2), timeout=0.1) + assert executions_count == 1 + + # Unblock AgentExecutor + continue_event.set() + + # Verify that both calls for clients finished. + if use_legacy and not streaming: + # Legacy handler fails on first execution. + with pytest.raises(A2AClientError, match='NoTaskQueue'): + await task1 + else: + await task1 + + try: + await task2 + except StopAsyncIteration: + # TODO: Test is flaky. Debug it. + return + + # Consume remaining events if any + async def consume(it): + async for _ in it: + pass + + await asyncio.gather(consume(it1), consume(it2)) + assert executions_count == 2 + + # Validate final task state. + final_task = await client1.get_task(GetTaskRequest(id=task_id)) + + if use_legacy: + # Legacy handler fails to complete the task. + assert final_task.status.state == TaskState.TASK_STATE_WORKING + else: + assert final_task.status.state == TaskState.TASK_STATE_COMPLETED + + # TODO: What is expected state of messages and artifacts? + + +# Scenario: Validate return_immediately flag behaviour. +@pytest.mark.asyncio +@pytest.mark.parametrize('use_legacy', [False, True], ids=['v2', 'legacy']) +@pytest.mark.parametrize( + 'streaming', [False, True], ids=['blocking', 'streaming'] +) +async def test_scenario_return_immediately(use_legacy, streaming): + class ImmediateAgent(AgentExecutor): + async def execute( + self, context: RequestContext, event_queue: EventQueue + ): + task = new_task_from_user_message(context.message) + task.status.state = TaskState.TASK_STATE_WORKING + await event_queue.enqueue_event(task) + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), + ) + ) + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ): + pass + + handler = create_handler(ImmediateAgent(), use_legacy) + client = await create_client( + handler, agent_card=agent_card(), streaming=streaming + ) + + msg = Message( + message_id='test-msg', role=Role.ROLE_USER, parts=[Part(text='hello')] + ) + + # Test non-blocking return. + it = client.send_message( + SendMessageRequest( + message=msg, + configuration=SendMessageConfiguration(return_immediately=True), + ) + ) + states = [get_state(event) async for event in it] + + if streaming: + assert states == [ + TaskState.TASK_STATE_WORKING, + TaskState.TASK_STATE_COMPLETED, + ] + else: + assert states == [TaskState.TASK_STATE_WORKING] + + +# Scenario: Test TASK_STATE_INPUT_REQUIRED. +@pytest.mark.timeout(2.0) +@pytest.mark.asyncio +@pytest.mark.parametrize('use_legacy', [False, True], ids=['v2', 'legacy']) +@pytest.mark.parametrize( + 'streaming', [False, True], ids=['blocking', 'streaming'] +) +async def test_scenario_resumption_from_interrupted(use_legacy, streaming): + class ResumingAgent(AgentExecutor): + async def execute( + self, context: RequestContext, event_queue: EventQueue + ): + message = context.message + if message and message.parts and message.parts[0].text == 'start': + task = new_task_from_user_message(message) + task.status.state = TaskState.TASK_STATE_INPUT_REQUIRED + await event_queue.enqueue_event(task) + elif ( + message + and message.parts + and message.parts[0].text == 'here is input' + ): + task = new_task_from_user_message(message) + task.status.state = TaskState.TASK_STATE_COMPLETED + await event_queue.enqueue_event(task) + else: + raise ValueError('Unexpected message') + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ): + pass + + handler = create_handler(ResumingAgent(), use_legacy) + client = await create_client( + handler, agent_card=agent_card(), streaming=streaming + ) + + # First send message to get it into input required state + msg1 = Message( + message_id='msg-start', role=Role.ROLE_USER, parts=[Part(text='start')] + ) + + it = client.send_message( + SendMessageRequest( + message=msg1, + configuration=SendMessageConfiguration(return_immediately=False), + ) + ) + + events1 = [event async for event in it] + assert [get_state(event) for event in events1] == [ + TaskState.TASK_STATE_INPUT_REQUIRED, + ] + task_id = events1[0].status_update.task_id + context_id = events1[0].status_update.context_id + + # Now send another message to resume + msg2 = Message( + task_id=task_id, + context_id=context_id, + message_id='msg-resume', + role=Role.ROLE_USER, + parts=[Part(text='here is input')], + ) + + it2 = client.send_message( + SendMessageRequest( + message=msg2, + configuration=SendMessageConfiguration(return_immediately=False), + ) + ) + + assert [get_state(event) async for event in it2] == [ + TaskState.TASK_STATE_COMPLETED, + ] + + +# Scenario: Auth required and side channel unblocking +# Migrated from: test_workflow_auth_required_side_channel in test_handler_comparison +@pytest.mark.timeout(2.0) +@pytest.mark.asyncio +@pytest.mark.parametrize('use_legacy', [False, True], ids=['v2', 'legacy']) +@pytest.mark.parametrize( + 'streaming', [False, True], ids=['blocking', 'streaming'] +) +async def test_scenario_auth_required_side_channel(use_legacy, streaming): + side_channel_event = asyncio.Event() + + class AuthAgent(AgentExecutor): + async def execute( + self, context: RequestContext, event_queue: EventQueue + ): + task = new_task_from_user_message(context.message) + task.status.state = TaskState.TASK_STATE_WORKING + await event_queue.enqueue_event(task) + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + status=TaskStatus(state=TaskState.TASK_STATE_AUTH_REQUIRED), + ) + ) + + await side_channel_event.wait() + + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), + ) + ) + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ): + pass + + handler = create_handler(AuthAgent(), use_legacy) + client = await create_client( + handler, agent_card=agent_card(), streaming=streaming + ) + + msg = Message( + message_id='test-msg', role=Role.ROLE_USER, parts=[Part(text='start')] + ) + + it = client.send_message( + SendMessageRequest( + message=msg, + configuration=SendMessageConfiguration(return_immediately=False), + ) + ) + + if streaming: + event1 = await asyncio.wait_for(it.__anext__(), timeout=1.0) + assert get_state(event1) == TaskState.TASK_STATE_WORKING + + event2 = await asyncio.wait_for(it.__anext__(), timeout=1.0) + assert get_state(event2) == TaskState.TASK_STATE_AUTH_REQUIRED + + task_id = event2.status_update.task_id + + side_channel_event.set() + + # Remaining event. + (event3,) = [event async for event in it] + assert get_state(event3) == TaskState.TASK_STATE_COMPLETED + else: + (event,) = [event async for event in it] + assert get_state(event) == TaskState.TASK_STATE_AUTH_REQUIRED + task_id = event.task.id + + side_channel_event.set() + + await wait_for_state( + client, task_id, expected_states={TaskState.TASK_STATE_COMPLETED} + ) + + +# Scenario: Auth required and in channel unblocking +@pytest.mark.timeout(2.0) +@pytest.mark.asyncio +@pytest.mark.parametrize('use_legacy', [False, True], ids=['v2', 'legacy']) +@pytest.mark.parametrize( + 'streaming', [False, True], ids=['blocking', 'streaming'] +) +async def test_scenario_auth_required_in_channel(use_legacy, streaming): + class AuthAgent(AgentExecutor): + async def execute( + self, context: RequestContext, event_queue: EventQueue + ): + message = context.message + if message and message.parts and message.parts[0].text == 'start': + task = new_task_from_user_message(message) + task.status.state = TaskState.TASK_STATE_AUTH_REQUIRED + await event_queue.enqueue_event(task) + elif ( + message + and message.parts + and message.parts[0].text == 'credentials' + ): + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), + ) + ) + + else: + raise ValueError(f'Unexpected message {message}') + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ): + pass + + handler = create_handler(AuthAgent(), use_legacy) + client = await create_client( + handler, agent_card=agent_card(), streaming=streaming + ) + + msg1 = Message( + message_id='msg-start', role=Role.ROLE_USER, parts=[Part(text='start')] + ) + + it = client.send_message( + SendMessageRequest( + message=msg1, + configuration=SendMessageConfiguration(return_immediately=False), + ) + ) + + events1 = [event async for event in it] + assert [get_state(event) for event in events1] == [ + TaskState.TASK_STATE_AUTH_REQUIRED, + ] + task_id = get_task_id(events1[0]) + context_id = get_task_context_id(events1[0]) + + # Now send another message with credentials + msg2 = Message( + task_id=task_id, + context_id=context_id, + message_id='msg-creds', + role=Role.ROLE_USER, + parts=[Part(text='credentials')], + ) + + it2 = client.send_message( + SendMessageRequest( + message=msg2, + configuration=SendMessageConfiguration(return_immediately=False), + ) + ) + + assert [get_state(event) async for event in it2] == [ + TaskState.TASK_STATE_COMPLETED, + ] + + +# Scenario: Parallel subscribe attach detach +# Migrated from: test_parallel_subscribe_attach_detach in test_handler_comparison +@pytest.mark.timeout(5.0) +@pytest.mark.asyncio +@pytest.mark.parametrize('use_legacy', [False, True], ids=['v2', 'legacy']) +async def test_scenario_parallel_subscribe_attach_detach(use_legacy): # noqa: PLR0915 + events = collections.defaultdict(asyncio.Event) + + class EmitAgent(AgentExecutor): + async def execute( + self, context: RequestContext, event_queue: EventQueue + ): + task = new_task_from_user_message(context.message) + task.status.state = TaskState.TASK_STATE_WORKING + await event_queue.enqueue_event(task) + + phases = [ + ('trigger_phase_1', 'artifact_1'), + ('trigger_phase_2', 'artifact_2'), + ('trigger_phase_3', 'artifact_3'), + ('trigger_phase_4', 'artifact_4'), + ] + + for trigger_name, artifact_id in phases: + await events[trigger_name].wait() + await event_queue.enqueue_event( + TaskArtifactUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + artifact=Artifact( + artifact_id=artifact_id, + parts=[Part(text=artifact_id)], + ), + ) + ) + + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), + ) + ) + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ): + pass + + handler = create_handler(EmitAgent(), use_legacy) + client = await create_client( + handler, agent_card=agent_card(), streaming=True + ) + + msg = Message( + message_id='test-msg', role=Role.ROLE_USER, parts=[Part(text='start')] + ) + + it = client.send_message( + SendMessageRequest( + message=msg, + configuration=SendMessageConfiguration(return_immediately=True), + ) + ) + + res = await it.__anext__() + task_id = res.task.id if res.HasField('task') else res.status_update.task_id + + async def monitor_artifacts(): + try: + async for event in client.subscribe( + SubscribeToTaskRequest(id=task_id) + ): + if event.HasField('artifact_update'): + artifact_id = event.artifact_update.artifact.artifact_id + if artifact_id.startswith('artifact_'): + phase_num = artifact_id.split('_')[1] + events[f'emitted_phase_{phase_num}'].set() + except asyncio.CancelledError: + pass + + monitor_task = asyncio.create_task(monitor_artifacts()) + + async def subscribe_and_collect(artifacts_to_collect: int | None = None): + ready_event = asyncio.Event() + + async def collect(): + collected = [] + artifacts_seen = 0 + try: + async for event in client.subscribe( + SubscribeToTaskRequest(id=task_id) + ): + collected.append(event) + ready_event.set() + if event.HasField('artifact_update'): + artifacts_seen += 1 + if ( + artifacts_to_collect is not None + and artifacts_seen >= artifacts_to_collect + ): + break + except asyncio.CancelledError: + pass + return collected + + task = asyncio.create_task(collect()) + await ready_event.wait() + return task + + sub1_task = await subscribe_and_collect() + + events['trigger_phase_1'].set() + await events['emitted_phase_1'].wait() + + sub2_task = await subscribe_and_collect(artifacts_to_collect=1) + sub3_task = await subscribe_and_collect(artifacts_to_collect=2) + + events['trigger_phase_2'].set() + await events['emitted_phase_2'].wait() + + events['trigger_phase_3'].set() + await events['emitted_phase_3'].wait() + + sub4_task = await subscribe_and_collect() + + events['trigger_phase_4'].set() + await events['emitted_phase_4'].wait() + + def get_artifact_updates(evs): + return [ + [p.text for p in sr.artifact_update.artifact.parts] + for sr in evs + if sr.HasField('artifact_update') + ] + + assert get_artifact_updates(await sub1_task) == [ + ['artifact_1'], + ['artifact_2'], + ['artifact_3'], + ['artifact_4'], + ] + + assert get_artifact_updates(await sub2_task) == [ + ['artifact_2'], + ] + assert get_artifact_updates(await sub3_task) == [ + ['artifact_2'], + ['artifact_3'], + ] + assert get_artifact_updates(await sub4_task) == [ + ['artifact_4'], + ] + + monitor_task.cancel() + + +# Return message directly. +@pytest.mark.timeout(2.0) +@pytest.mark.asyncio +@pytest.mark.parametrize('use_legacy', [False, True], ids=['v2', 'legacy']) +@pytest.mark.parametrize( + 'streaming', [False, True], ids=['blocking', 'streaming'] +) +@pytest.mark.parametrize( + 'return_immediately', + [False, True], + ids=['no_return_immediately', 'return_immediately'], +) +async def test_scenario_publish_message( + use_legacy, streaming, return_immediately +): + class MessageAgent(AgentExecutor): + async def execute( + self, context: RequestContext, event_queue: EventQueue + ): + await event_queue.enqueue_event( + Message( + task_id=context.task_id, + context_id=context.context_id, + message_id='msg-1', + role=Role.ROLE_AGENT, + parts=[Part(text='response text')], + ) + ) + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ): + pass + + handler = create_handler(MessageAgent(), use_legacy) + client = await create_client( + handler, agent_card=agent_card(), streaming=streaming + ) + + msg = Message( + message_id='test-msg', role=Role.ROLE_USER, parts=[Part(text='start')] + ) + + it = client.send_message( + SendMessageRequest( + message=msg, + configuration=SendMessageConfiguration( + return_immediately=return_immediately + ), + ) + ) + events = [event async for event in it] + + (event,) = events + assert event.HasField('message') + assert event.message.parts[0].text == 'response text' + + tasks = (await client.list_tasks(ListTasksRequest())).tasks + assert len(tasks) == 0 + + +# Scenario: Publish ArtifactUpdateEvent +@pytest.mark.timeout(2.0) +@pytest.mark.asyncio +@pytest.mark.parametrize('use_legacy', [False, True], ids=['v2', 'legacy']) +@pytest.mark.parametrize( + 'streaming', [False, True], ids=['blocking', 'streaming'] +) +async def test_scenario_publish_artifact(use_legacy, streaming): + class ArtifactAgent(AgentExecutor): + async def execute( + self, context: RequestContext, event_queue: EventQueue + ): + task = new_task_from_user_message(context.message) + task.status.state = TaskState.TASK_STATE_WORKING + await event_queue.enqueue_event(task) + await event_queue.enqueue_event( + TaskArtifactUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + artifact=Artifact( + artifact_id='art-1', parts=[Part(text='artifact data')] + ), + ) + ) + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), + ) + ) + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ): + pass + + handler = create_handler(ArtifactAgent(), use_legacy) + client = await create_client( + handler, agent_card=agent_card(), streaming=streaming + ) + + msg = Message( + message_id='test-msg', role=Role.ROLE_USER, parts=[Part(text='start')] + ) + + it = client.send_message( + SendMessageRequest( + message=msg, + configuration=SendMessageConfiguration(return_immediately=False), + ) + ) + events = [event async for event in it] + + if streaming: + last_event = events[-1] + assert get_state(last_event) == TaskState.TASK_STATE_COMPLETED + + artifact_events = [e for e in events if e.HasField('artifact_update')] + assert len(artifact_events) > 0, ( + 'Bug: Streaming should return the artifact update event' + ) + assert ( + artifact_events[0].artifact_update.artifact.artifact_id == 'art-1' + ) + else: + last_event = events[-1] + assert last_event.HasField('task') + assert last_event.task.status.state == TaskState.TASK_STATE_COMPLETED + + assert len(last_event.task.artifacts) > 0, ( + 'Bug: Task should include the published artifact' + ) + assert last_event.task.artifacts[0].artifact_id == 'art-1' + + +# Scenario: Enqueue Task twice +@pytest.mark.timeout(2.0) +@pytest.mark.asyncio +@pytest.mark.parametrize('use_legacy', [False, True], ids=['v2', 'legacy']) +@pytest.mark.parametrize( + 'streaming', [False, True], ids=['blocking', 'streaming'] +) +async def test_scenario_enqueue_task_twice(caplog, use_legacy, streaming): + class DoubleTaskAgent(AgentExecutor): + async def execute( + self, context: RequestContext, event_queue: EventQueue + ): + task1 = Task( + id=context.task_id, + context_id=context.context_id, + status=TaskStatus( + state=TaskState.TASK_STATE_WORKING, + message=Message(parts=[Part(text='First task')]), + ), + ) + await event_queue.enqueue_event(task1) + + # This is undefined behavior, but it should not crash or hang. + task2 = Task( + id=context.task_id, + context_id=context.context_id, + status=TaskStatus( + state=TaskState.TASK_STATE_WORKING, + message=Message(parts=[Part(text='Second task')]), + ), + ) + await event_queue.enqueue_event(task2) + + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), + ) + ) + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ): + pass + + handler = create_handler(DoubleTaskAgent(), use_legacy) + client = await create_client( + handler, agent_card=agent_card(), streaming=streaming + ) + + msg = Message( + message_id='test-msg', role=Role.ROLE_USER, parts=[Part(text='start')] + ) + + it = client.send_message( + SendMessageRequest( + message=msg, + configuration=SendMessageConfiguration(return_immediately=False), + ) + ) + _ = [event async for event in it] + + (final_task,) = (await client.list_tasks(ListTasksRequest())).tasks + + if use_legacy: + assert [part.text for part in final_task.history[0].parts] == [ + 'Second task' + ] + else: + assert [part.text for part in final_task.history[0].parts] == [ + 'First task' + ] + + # Validate that new version logs with error exactly once 'Ignoring task replacement' + error_logs = [ + record.message + for record in caplog.records + if record.levelname == 'ERROR' + and 'Ignoring task replacement' in record.message + ] + + assert len(error_logs) == 1 + + +# Scenario: Task restoration - terminal state +@pytest.mark.timeout(2.0) +@pytest.mark.asyncio +@pytest.mark.parametrize('use_legacy', [False, True], ids=['v2', 'legacy']) +@pytest.mark.parametrize( + 'streaming', [False, True], ids=['blocking', 'streaming'] +) +@pytest.mark.parametrize( + 'subscribe_first', + [False, True], + ids=['no_subscribe_first', 'subscribe_first'], +) +async def test_restore_task_terminal_state( + use_legacy, streaming, subscribe_first +): + class TerminalAgent(AgentExecutor): + async def execute( + self, context: RequestContext, event_queue: EventQueue + ): + task = new_task_from_user_message(context.message) + task.status.state = TaskState.TASK_STATE_COMPLETED + await event_queue.enqueue_event(task) + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ): + pass + + task_store = InMemoryTaskStore() + handler1 = create_handler( + TerminalAgent(), use_legacy, task_store=task_store + ) + client1 = await create_client( + handler1, agent_card=agent_card(), streaming=streaming + ) + + msg = Message( + message_id='test-msg-1', role=Role.ROLE_USER, parts=[Part(text='start')] + ) + it1 = client1.send_message( + SendMessageRequest( + message=msg, + configuration=SendMessageConfiguration(return_immediately=False), + ) + ) + events1 = [event async for event in it1] + task_id = get_task_id(events1[-1]) + + await wait_for_state( + client1, task_id, expected_states={TaskState.TASK_STATE_COMPLETED} + ) + + # Restore task in a new handler (simulating server restart) + handler2 = create_handler( + TerminalAgent(), use_legacy, task_store=task_store + ) + client2 = await create_client( + handler2, agent_card=agent_card(), streaming=streaming + ) + + restored_task = await client2.get_task(GetTaskRequest(id=task_id)) + assert restored_task.status.state == TaskState.TASK_STATE_COMPLETED + + if subscribe_first and streaming: + with pytest.raises( + Exception, + match=r'terminal state', + ): + async for _ in client2.subscribe( + SubscribeToTaskRequest(id=task_id) + ): + pass + + msg2 = Message( + task_id=task_id, + message_id='test-msg-2', + role=Role.ROLE_USER, + parts=[Part(text='message to completed task')], + ) + + with pytest.raises(Exception, match=r'terminal state'): + async for _ in client2.send_message(SendMessageRequest(message=msg2)): + pass + + if streaming: + with pytest.raises( + Exception, + match=r'terminal state', + ): + async for _ in client2.subscribe( + SubscribeToTaskRequest(id=task_id) + ): + pass + + +# Scenario: Task restoration - user input required state +@pytest.mark.timeout(2.0) +@pytest.mark.asyncio +@pytest.mark.parametrize('use_legacy', [False, True], ids=['v2', 'legacy']) +@pytest.mark.parametrize( + 'streaming', [False, True], ids=['blocking', 'streaming'] +) +@pytest.mark.parametrize( + 'subscribe_mode', + ['none', 'drop', 'listen'], + ids=['no_sub', 'sub_drop', 'sub_listen'], +) +async def test_restore_task_input_required_state( + use_legacy, streaming, subscribe_mode +): + class InputAgent(AgentExecutor): + async def execute( + self, context: RequestContext, event_queue: EventQueue + ): + message = context.message + if message and message.parts and message.parts[0].text == 'start': + task = new_task_from_user_message(message) + task.status.state = TaskState.TASK_STATE_INPUT_REQUIRED + await event_queue.enqueue_event(task) + elif message and message.parts and message.parts[0].text == 'input': + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), + ) + ) + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ): + pass + + task_store = InMemoryTaskStore() + handler1 = create_handler(InputAgent(), use_legacy, task_store=task_store) + client1 = await create_client( + handler1, agent_card=agent_card(), streaming=streaming + ) + + msg1 = Message( + message_id='test-msg-1', role=Role.ROLE_USER, parts=[Part(text='start')] + ) + it1 = client1.send_message( + SendMessageRequest( + message=msg1, + configuration=SendMessageConfiguration(return_immediately=False), + ) + ) + events1 = [event async for event in it1] + + task_id = get_task_id(events1[-1]) + context_id = get_task_context_id(events1[-1]) + + await wait_for_state( + client1, task_id, expected_states={TaskState.TASK_STATE_INPUT_REQUIRED} + ) + + # Restore task in a new handler (simulating server restart) + handler2 = create_handler(InputAgent(), use_legacy, task_store=task_store) + client2 = await create_client( + handler2, agent_card=agent_card(), streaming=streaming + ) + + restored_task = await client2.get_task(GetTaskRequest(id=task_id)) + assert restored_task.status.state == TaskState.TASK_STATE_INPUT_REQUIRED + + # Subscription logic based on mode + listen_task = None + if streaming: + if subscribe_mode == 'drop': + # Subscribing and dropping immediately (cancelling the generator) + async for _ in client2.subscribe( + SubscribeToTaskRequest(id=task_id) + ): + break + elif subscribe_mode == 'listen': + sub_started_event = asyncio.Event() + + async def listen_to_end(): + res = [] + async for ev in client2.subscribe( + SubscribeToTaskRequest(id=task_id) + ): + res.append(ev) + sub_started_event.set() + return res + + listen_task = asyncio.create_task(listen_to_end()) + # Wait for subscription to establish and yield the initial task event + await asyncio.wait_for(sub_started_event.wait(), timeout=1.0) + + msg2 = Message( + task_id=task_id, + context_id=context_id, + message_id='test-msg-2', + role=Role.ROLE_USER, + parts=[Part(text='input')], + ) + + it2 = client2.send_message( + SendMessageRequest( + message=msg2, + configuration=SendMessageConfiguration(return_immediately=False), + ) + ) + events2 = [event async for event in it2] + + if streaming: + assert ( + events2[-1].status_update.status.state + == TaskState.TASK_STATE_COMPLETED + ) + else: + assert events2[-1].task.status.state == TaskState.TASK_STATE_COMPLETED + + if listen_task: + if use_legacy and streaming: + # Error: Legacy handler does not properly manage subscriptions for restored tasks + with pytest.raises(TaskNotFoundError): + await listen_task + else: + listen_events = await listen_task + # The first event is the initial task state (INPUT_REQUIRED), the last should be COMPLETED + assert ( + get_state(listen_events[-1]) == TaskState.TASK_STATE_COMPLETED + ) + + final_task = await client2.get_task(GetTaskRequest(id=task_id)) + assert final_task.status.state == TaskState.TASK_STATE_COMPLETED + + +# Scenario 20: Create initial task with new_task +@pytest.mark.timeout(2.0) +@pytest.mark.asyncio +@pytest.mark.parametrize('use_legacy', [False, True], ids=['v2', 'legacy']) +@pytest.mark.parametrize( + 'streaming', [False, True], ids=['blocking', 'streaming'] +) +@pytest.mark.parametrize('initial_task_type', ['new_task', 'status_update']) +async def test_scenario_initial_task_types( + use_legacy, streaming, initial_task_type +): + started_event = asyncio.Event() + continue_event = asyncio.Event() + + class InitialTaskAgent(AgentExecutor): + async def execute( + self, context: RequestContext, event_queue: EventQueue + ): + if initial_task_type == 'new_task': + # Create with new_task + task = new_task_from_user_message(context.message) + task.status.state = TaskState.TASK_STATE_WORKING + await event_queue.enqueue_event(task) + else: + # Create with status update (illegal in v2) + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + status=TaskStatus(state=TaskState.TASK_STATE_WORKING), + ) + ) + + started_event.set() + await continue_event.wait() + + await event_queue.enqueue_event( + TaskArtifactUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + artifact=Artifact( + artifact_id='art-1', parts=[Part(text='artifact data')] + ), + ) + ) + + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), + ) + ) + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ): + pass + + handler = create_handler(InitialTaskAgent(), use_legacy) + client = await create_client( + handler, agent_card=agent_card(), streaming=streaming + ) + + msg = Message( + message_id='test-msg', role=Role.ROLE_USER, parts=[Part(text='start')] + ) + + it = client.send_message( + SendMessageRequest( + message=msg, + configuration=SendMessageConfiguration( + return_immediately=streaming + ), + ) + ) + + if streaming: + if initial_task_type == 'status_update' and not use_legacy: + with pytest.raises( + InvalidAgentResponseError, + match='Agent should enqueue Task before TaskStatusUpdateEvent event', + ): + await it.__anext__() + + # End of the test. + return + + res = await it.__anext__() + if initial_task_type == 'status_update' and use_legacy: + # First message has to be a Task. + assert res.HasField('status_update') + + # End of the test. + return + + assert res.HasField('task') + task_id = get_task_id(res) + + await asyncio.wait_for(started_event.wait(), timeout=1.0) + + # Start subscription + sub = client.subscribe(SubscribeToTaskRequest(id=task_id)) + + # first subscriber receives current task state (WORKING) + first_event = await sub.__anext__() + assert first_event.HasField('task') + + continue_event.set() + + events = [first_event] + [event async for event in sub] + else: + # blocking + async def release_agent(): + await started_event.wait() + continue_event.set() + + release_task = asyncio.create_task(release_agent()) + if initial_task_type == 'status_update' and not use_legacy: + with pytest.raises( + InvalidAgentResponseError, + match='Agent should enqueue Task before TaskStatusUpdateEvent event', + ): + events = [event async for event in it] + # End of the test + return + else: + events = [event async for event in it] + await release_task + + if streaming: + task, artifact_update, status_update = events + assert task.HasField('task') + validate_state(task, TaskState.TASK_STATE_WORKING) + assert artifact_update.artifact_update.artifact.artifact_id == 'art-1' + assert status_update.HasField('status_update') + validate_state(status_update, TaskState.TASK_STATE_COMPLETED) + else: + (task,) = events + assert task.HasField('task') + validate_state(task, TaskState.TASK_STATE_COMPLETED) + (artifact,) = task.task.artifacts + assert artifact.artifact_id == 'art-1' + task_id = task.task.id + + (final_task_from_list,) = ( + await client.list_tasks(ListTasksRequest(include_artifacts=True)) + ).tasks + assert len(final_task_from_list.artifacts) > 0 + assert final_task_from_list.artifacts[0].artifact_id == 'art-1' + + final_task = await client.get_task(GetTaskRequest(id=task_id)) + assert final_task.status.state == TaskState.TASK_STATE_COMPLETED + assert len(final_task.artifacts) > 0 + assert final_task.artifacts[0].artifact_id == 'art-1' + + +# Scenario 23: Invalid Agent Response - Task followed by Message +@pytest.mark.asyncio +@pytest.mark.parametrize('use_legacy', [False, True], ids=['v2', 'legacy']) +@pytest.mark.parametrize( + 'streaming', [False, True], ids=['blocking', 'streaming'] +) +async def test_scenario_23_invalid_response_task_message(use_legacy, streaming): + class TaskMessageAgent(AgentExecutor): + async def execute( + self, context: RequestContext, event_queue: EventQueue + ): + await event_queue.enqueue_event( + new_task_from_user_message(context.message) + ) + await event_queue.enqueue_event( + Message(message_id='m1', parts=[Part(text='m1')]) + ) + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ): + pass + + handler = create_handler(TaskMessageAgent(), use_legacy) + client = await create_client( + handler, agent_card=agent_card(), streaming=streaming + ) + + msg = Message( + message_id='test-msg', role=Role.ROLE_USER, parts=[Part(text='start')] + ) + + it = client.send_message(SendMessageRequest(message=msg)) + + if use_legacy: + # Legacy: no error. + async for _ in it: + pass + else: + with pytest.raises( + InvalidAgentResponseError, + match='Received Message object in task mode', + ): + async for _ in it: + pass diff --git a/tests/integration/test_stream_generator_cleanup.py b/tests/integration/test_stream_generator_cleanup.py index 47ab5212f..f26f62c6f 100644 --- a/tests/integration/test_stream_generator_cleanup.py +++ b/tests/integration/test_stream_generator_cleanup.py @@ -75,15 +75,14 @@ def client(): handler = DefaultRequestHandler( agent_executor=_MessageExecutor(), task_store=InMemoryTaskStore(), + agent_card=card, queue_manager=InMemoryQueueManager(), ) app = Starlette( routes=[ *create_agent_card_routes(agent_card=card, card_url='/card'), *create_jsonrpc_routes( - agent_card=card, request_handler=handler, - extended_agent_card=card, rpc_url='/', ), ] diff --git a/tests/integration/test_tenant.py b/tests/integration/test_tenant.py index 6ceb1e070..6b489270b 100644 --- a/tests/integration/test_tenant.py +++ b/tests/integration/test_tenant.py @@ -202,9 +202,7 @@ def server_app(self, jsonrpc_agent_card, mock_handler): agent_card=jsonrpc_agent_card, card_url='/' ) jsonrpc_routes = create_jsonrpc_routes( - agent_card=jsonrpc_agent_card, request_handler=mock_handler, - extended_agent_card=jsonrpc_agent_card, rpc_url='/jsonrpc', ) app = Starlette(routes=[*agent_card_routes, *jsonrpc_routes]) diff --git a/tests/integration/test_version_header.py b/tests/integration/test_version_header.py index 683c56833..046f4d4cc 100644 --- a/tests/integration/test_version_header.py +++ b/tests/integration/test_version_header.py @@ -39,6 +39,7 @@ def test_app(): handler = DefaultRequestHandler( agent_executor=DummyAgentExecutor(), task_store=InMemoryTaskStore(), + agent_card=agent_card, queue_manager=InMemoryQueueManager(), push_config_store=InMemoryPushNotificationConfigStore(), ) @@ -61,19 +62,13 @@ async def mock_on_message_send_stream(*args, **kwargs): agent_card=agent_card, card_url='/' ) jsonrpc_routes = create_jsonrpc_routes( - agent_card=agent_card, - request_handler=handler, - rpc_url='/jsonrpc', - enable_v0_3_compat=True, + request_handler=handler, rpc_url='/jsonrpc', enable_v0_3_compat=True ) app.routes.extend(agent_card_routes) app.routes.extend(jsonrpc_routes) rest_routes = create_rest_routes( - agent_card=agent_card, - request_handler=handler, - path_prefix='/rest', - enable_v0_3_compat=True, + request_handler=handler, path_prefix='/rest', enable_v0_3_compat=True ) app.routes.extend(rest_routes) return app @@ -98,7 +93,7 @@ def client(test_app): ('INVALID', 'none'), ], ) -def test_version_header_integration( # noqa: PLR0912, PLR0913, PLR0915 +def test_version_header_integration( client, transport, endpoint_ver, is_streaming, header_val, should_succeed ): headers = {} diff --git a/src/a2a/contrib/__init__.py b/tests/server/agent_execution/__init__.py similarity index 100% rename from src/a2a/contrib/__init__.py rename to tests/server/agent_execution/__init__.py diff --git a/tests/server/agent_execution/test_active_task.py b/tests/server/agent_execution/test_active_task.py new file mode 100644 index 000000000..6e477186b --- /dev/null +++ b/tests/server/agent_execution/test_active_task.py @@ -0,0 +1,893 @@ +import asyncio +import logging + +from unittest.mock import AsyncMock, Mock, patch + +import pytest +import pytest_asyncio + +from a2a.server.agent_execution.active_task import ActiveTask +from a2a.server.agent_execution.agent_executor import AgentExecutor +from a2a.server.agent_execution.context import RequestContext +from a2a.server.context import ServerCallContext +from a2a.server.events.event_queue_v2 import EventQueueSource as EventQueue +from a2a.server.tasks.push_notification_sender import PushNotificationSender +from a2a.server.tasks.task_manager import TaskManager +from a2a.types.a2a_pb2 import ( + Message, + Task, + TaskState, + TaskStatus, + TaskStatusUpdateEvent, + Role, + Part, +) +from a2a.utils.errors import InvalidParamsError + + +logger = logging.getLogger(__name__) + + +class TestActiveTask: + """Tests for the ActiveTask class.""" + + @pytest.fixture + def agent_executor(self) -> Mock: + return Mock(spec=AgentExecutor) + + @pytest.fixture + def task_manager(self) -> Mock: + tm = Mock(spec=TaskManager) + tm.process = AsyncMock(side_effect=lambda x: x) + tm.get_task = AsyncMock(return_value=None) + tm.context_id = 'test-context-id' + tm._init_task_obj = Mock(return_value=Task(id='test-task-id')) + tm.save_task_event = AsyncMock() + return tm + + @pytest_asyncio.fixture + async def event_queue(self) -> EventQueue: + return EventQueue() + + @pytest.fixture + def push_sender(self) -> Mock: + ps = Mock(spec=PushNotificationSender) + ps.send_notification = AsyncMock() + return ps + + @pytest.fixture + def request_context(self) -> Mock: + return Mock(spec=RequestContext) + + @pytest_asyncio.fixture + async def active_task( + self, + agent_executor: Mock, + task_manager: Mock, + push_sender: Mock, + ) -> ActiveTask: + return ActiveTask( + agent_executor=agent_executor, + task_id='test-task-id', + task_manager=task_manager, + push_sender=push_sender, + ) + + @pytest.mark.asyncio + async def test_active_task_already_started( + self, active_task: ActiveTask, request_context: Mock + ) -> None: + """Test starting a task that is already started.""" + await active_task.enqueue_request(request_context) + await active_task.start( + call_context=ServerCallContext(), create_task_if_missing=True + ) + # Enqueuing and starting again should not raise errors + await active_task.enqueue_request(request_context) + await active_task.start( + call_context=ServerCallContext(), create_task_if_missing=True + ) + assert active_task._producer_task is not None + + @pytest.mark.asyncio + async def test_active_task_cancel( + self, + active_task: ActiveTask, + agent_executor: Mock, + request_context: Mock, + task_manager: Mock, + ) -> None: + """Test canceling an ActiveTask.""" + stop_event = asyncio.Event() + + async def execute_mock(req, q): + await stop_event.wait() + + agent_executor.execute = AsyncMock(side_effect=execute_mock) + agent_executor.cancel = AsyncMock() + task_manager.get_task.side_effect = [ + Task( + id='test-task-id', + status=TaskStatus(state=TaskState.TASK_STATE_WORKING), + ) + ] + [ + Task( + id='test-task-id', + status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), + ) + ] * 10 + + await active_task.enqueue_request(request_context) + await active_task.start( + call_context=ServerCallContext(), create_task_if_missing=True + ) + + # Give it a moment to start + await asyncio.sleep(0.1) + + await active_task.cancel(request_context) + + agent_executor.cancel.assert_called_once() + stop_event.set() + + @pytest.mark.asyncio + async def test_active_task_interrupted_auth( + self, + active_task: ActiveTask, + agent_executor: Mock, + request_context: Mock, + task_manager: Mock, + ) -> None: + """Test task interruption due to AUTH_REQUIRED.""" + task_obj = Task( + id='test-task-id', + status=TaskStatus(state=TaskState.TASK_STATE_AUTH_REQUIRED), + ) + + async def execute_mock(req, q): + await q.enqueue_event( + TaskStatusUpdateEvent( + task_id='test-task-id', + status=TaskStatus(state=TaskState.TASK_STATE_AUTH_REQUIRED), + ) + ) + + agent_executor.execute = AsyncMock(side_effect=execute_mock) + task_manager.get_task.side_effect = [ + Task( + id='test-task-id', + status=TaskStatus(state=TaskState.TASK_STATE_WORKING), + ) + ] + [task_obj] * 10 + + await active_task.start( + call_context=ServerCallContext(), create_task_if_missing=True + ) + + events = [ + e async for e in active_task.subscribe(request=request_context) + ] + + result = events[0] if events else None + assert ( + getattr(result, 'id', getattr(result, 'task_id', None)) + == 'test-task-id' + ) + assert result.status.state == TaskState.TASK_STATE_AUTH_REQUIRED + + @pytest.mark.asyncio + async def test_active_task_interrupted_input( + self, + active_task: ActiveTask, + agent_executor: Mock, + request_context: Mock, + task_manager: Mock, + ) -> None: + """Test task interruption due to INPUT_REQUIRED.""" + task_obj = Task( + id='test-task-id', + status=TaskStatus(state=TaskState.TASK_STATE_INPUT_REQUIRED), + ) + + async def execute_mock(req, q): + await q.enqueue_event( + Task( + id='test-task-id', + status=TaskStatus( + state=TaskState.TASK_STATE_INPUT_REQUIRED + ), + ) + ) + + agent_executor.execute = AsyncMock(side_effect=execute_mock) + task_manager.get_task.side_effect = [ + Task( + id='test-task-id', + status=TaskStatus(state=TaskState.TASK_STATE_WORKING), + ) + ] + [task_obj] * 10 + + await active_task.start( + call_context=ServerCallContext(), create_task_if_missing=True + ) + + events = [ + e async for e in active_task.subscribe(request=request_context) + ] + + result = events[-1] if events else None + assert result.id == 'test-task-id' + assert result.status.state == TaskState.TASK_STATE_INPUT_REQUIRED + + @pytest.mark.asyncio + async def test_active_task_producer_failure( + self, + active_task: ActiveTask, + agent_executor: Mock, + request_context: Mock, + ) -> None: + """Test ActiveTask behavior when the producer fails.""" + agent_executor.execute = AsyncMock( + side_effect=ValueError('Producer crashed') + ) + + await active_task.enqueue_request(request_context) + await active_task.start( + call_context=ServerCallContext(), create_task_if_missing=True + ) + + # We need to wait a bit for the producer to fail and set the exception + for _ in range(10): + try: + async for _ in active_task.subscribe(): + pass + except ValueError: + return + await asyncio.sleep(0.05) + + pytest.fail('Producer failure was not raised') + + @pytest.mark.asyncio + async def test_active_task_push_notification( + self, + active_task: ActiveTask, + agent_executor: Mock, + request_context: Mock, + push_sender: Mock, + task_manager: Mock, + ) -> None: + """Test push notification sending.""" + task_obj = Task( + id='test-task-id', + status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), + ) + + async def execute_mock(req, q): + await q.enqueue_event(task_obj) + + agent_executor.execute = AsyncMock(side_effect=execute_mock) + task_manager.get_task.side_effect = [ + Task( + id='test-task-id', + status=TaskStatus(state=TaskState.TASK_STATE_WORKING), + ) + ] + [task_obj] * 10 + + await active_task.start( + call_context=ServerCallContext(), create_task_if_missing=True + ) + + async for _ in active_task.subscribe(request=request_context): + pass + + push_sender.send_notification.assert_called() + + @pytest.mark.asyncio + async def test_active_task_consumer_failure( + self, + active_task: ActiveTask, + agent_executor: Mock, + request_context: Mock, + ) -> None: + """Test behavior when the consumer task fails.""" + # Mock dequeue_event to raise exception + active_task._event_queue_agent.dequeue_event = AsyncMock( + side_effect=RuntimeError('Consumer crash') + ) + + await active_task.enqueue_request(request_context) + await active_task.start( + call_context=ServerCallContext(), create_task_if_missing=True + ) + + # We need to wait for the consumer to fail + for _ in range(10): + try: + async for _ in active_task.subscribe(): + pass + except RuntimeError as e: + if str(e) == 'Consumer crash': + return + await asyncio.sleep(0.05) + + pytest.fail('Consumer failure was not raised') + + @pytest.mark.asyncio + async def test_active_task_subscribe_exception_handling( + self, + active_task: ActiveTask, + agent_executor: Mock, + request_context: Mock, + ) -> None: + """Test exception handling in subscribe.""" + agent_executor.execute = AsyncMock( + side_effect=ValueError('Producer failure') + ) + + await active_task.enqueue_request(request_context) + await active_task.start( + call_context=ServerCallContext(), create_task_if_missing=True + ) + + # Give it a moment to fail + for _ in range(10): + if active_task._exception: + break + await asyncio.sleep(0.05) + + with pytest.raises(ValueError, match='Producer failure'): + async for _ in active_task.subscribe(): + pass + + @pytest.mark.asyncio + async def test_active_task_cancel_not_started( + self, active_task: ActiveTask, request_context: Mock + ) -> None: + """Test canceling a task that was never started.""" + # TODO: Implement this test + + @pytest.mark.asyncio + async def test_active_task_cancel_already_finished( + self, + active_task: ActiveTask, + agent_executor: Mock, + request_context: Mock, + task_manager: Mock, + ) -> None: + """Test canceling a task that is already finished.""" + task_obj = Task( + id='test-task-id', + status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), + ) + + async def execute_mock(req, q): + active_task._request_queue.shutdown(immediate=True) + + agent_executor.execute = AsyncMock(side_effect=execute_mock) + task_manager.get_task.side_effect = [ + Task( + id='test-task-id', + status=TaskStatus(state=TaskState.TASK_STATE_WORKING), + ) + ] + [task_obj] * 10 + + await active_task.start( + call_context=ServerCallContext(), create_task_if_missing=True + ) + + async for _ in active_task.subscribe(request=request_context): + pass + + await active_task._is_finished.wait() + + # Now it is finished + await active_task.cancel(request_context) + + # agent_executor.cancel should NOT be called + agent_executor.cancel.assert_not_called() + + @pytest.mark.asyncio + async def test_active_task_subscribe_cancelled_during_wait( + self, + active_task: ActiveTask, + agent_executor: Mock, + request_context: Mock, + ) -> None: + """Test subscribe when it is cancelled while waiting for events.""" + + async def slow_execute(req, q): + await asyncio.sleep(10) + + agent_executor.execute = AsyncMock(side_effect=slow_execute) + + await active_task.enqueue_request(request_context) + await active_task.start( + call_context=ServerCallContext(), create_task_if_missing=True + ) + + it = active_task.subscribe() + it_obj = it.__aiter__() + + # This task will be waiting inside the loop in subscribe() + task = asyncio.create_task(it_obj.__anext__()) + await asyncio.sleep(0.2) + + task.cancel() + + # In python 3.10+ cancelling an async generator next() might raise StopAsyncIteration + # if the generator handles the cancellation by closing. + with pytest.raises((asyncio.CancelledError, StopAsyncIteration)): + await task + + await it.aclose() + + @pytest.mark.asyncio + async def test_active_task_subscribe_queue_shutdown( + self, + active_task: ActiveTask, + agent_executor: Mock, + request_context: Mock, + ) -> None: + """Test subscribe when the queue is shut down.""" + + async def long_execute(*args, **kwargs): + await asyncio.sleep(10) + + agent_executor.execute = AsyncMock(side_effect=long_execute) + await active_task.enqueue_request(request_context) + await active_task.start( + call_context=ServerCallContext(), create_task_if_missing=True + ) + + tapped = await active_task._event_queue_subscribers.tap() + + with patch.object( + active_task._event_queue_subscribers, 'tap', return_value=tapped + ): + # Close the queue while subscribe is waiting + async def close_later(): + await asyncio.sleep(0.2) + await tapped.close() + + _ = asyncio.create_task(close_later()) + + async for _ in active_task.subscribe(): + pass + + # Should finish normally after QueueShutDown + + @pytest.mark.asyncio + async def test_active_task_subscribe_yield_then_shutdown( + self, + active_task: ActiveTask, + agent_executor: Mock, + request_context: Mock, + ) -> None: + """Test subscribe when an event is yielded and then the queue is shut down.""" + msg = Message(message_id='m1') + + async def execute_mock(req, q): + await q.enqueue_event(msg) + await asyncio.sleep(0.5) + # Finish producer + active_task._request_queue.shutdown(immediate=True) + + agent_executor.execute = AsyncMock(side_effect=execute_mock) + await active_task.enqueue_request(request_context) + await active_task.start( + call_context=ServerCallContext(), create_task_if_missing=True + ) + + events = [event async for event in active_task.subscribe()] + assert len(events) == 1 + assert events[0] == msg + + @pytest.mark.asyncio + async def test_active_task_task_sets_result_first( + self, + active_task: ActiveTask, + agent_executor: Mock, + request_context: Mock, + task_manager: Mock, + ) -> None: + """Test that enqueuing a Task sets result_available when no result yet.""" + task_obj = Task( + id='test-task-id', + status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), + ) + + async def execute_mock(req, q): + # No result available yet + await q.enqueue_event(task_obj) + + agent_executor.execute = AsyncMock(side_effect=execute_mock) + task_manager.get_task.side_effect = [ + Task( + id='test-task-id', + status=TaskStatus(state=TaskState.TASK_STATE_WORKING), + ) + ] + [task_obj] * 10 + + await active_task.start( + call_context=ServerCallContext(), create_task_if_missing=True + ) + + events = [ + e async for e in active_task.subscribe(request=request_context) + ] + + result = events[-1] if events else None + assert result == task_obj + + @pytest.mark.asyncio + async def test_active_task_subscribe_cancelled_during_yield( + self, + active_task: ActiveTask, + agent_executor: Mock, + request_context: Mock, + ) -> None: + """Test subscribe cancellation while yielding (GeneratorExit).""" + msg = Message(message_id='m1') + + async def execute_mock(req, q): + await q.enqueue_event(msg) + await asyncio.sleep(10) + + agent_executor.execute = AsyncMock(side_effect=execute_mock) + await active_task.enqueue_request(request_context) + await active_task.start( + call_context=ServerCallContext(), create_task_if_missing=True + ) + + it = active_task.subscribe() + async for event in it: + assert event == msg + # Cancel while we have the event (inside the loop) + await it.aclose() + break + + @pytest.mark.asyncio + async def test_active_task_cancel_when_already_closed( + self, + active_task: ActiveTask, + agent_executor: Mock, + request_context: Mock, + task_manager: Mock, + ) -> None: + """Test cancel when the event queue is already closed.""" + + async def execute_mock(req, q): + active_task._request_queue.shutdown(immediate=True) + + agent_executor.execute = AsyncMock(side_effect=execute_mock) + task_manager.get_task.return_value = Task(id='test') + await active_task.enqueue_request(request_context) + await active_task.start( + call_context=ServerCallContext(), create_task_if_missing=True + ) + + # Forced queue close. + await active_task._event_queue_agent.close() + await active_task._event_queue_subscribers.close() + + # Now cancel the task itself. + await active_task.cancel(request_context) + # wait() was removed, no need to wait here. + + # Cancel again should not do anything. + await active_task.cancel(request_context) + # wait() was removed, no need to wait here. + + @pytest.mark.asyncio + async def test_active_task_subscribe_dequeue_failure( + self, + active_task: ActiveTask, + agent_executor: Mock, + request_context: Mock, + ) -> None: + """Test subscribe when dequeue_event fails on the tapped queue.""" + + async def slow_execute(req, q): + await asyncio.sleep(10) + + agent_executor.execute = AsyncMock(side_effect=slow_execute) + await active_task.enqueue_request(request_context) + await active_task.start( + call_context=ServerCallContext(), create_task_if_missing=True + ) + + mock_tapped_queue = Mock(spec=EventQueue) + mock_tapped_queue.dequeue_event = AsyncMock( + side_effect=RuntimeError('Tapped queue crash') + ) + mock_tapped_queue.close = AsyncMock() + + with ( + patch.object( + active_task._event_queue_subscribers, + 'tap', + return_value=mock_tapped_queue, + ), + pytest.raises(RuntimeError, match='Tapped queue crash'), + ): + async for _ in active_task.subscribe(): + pass + + mock_tapped_queue.close.assert_called_once() + + @pytest.mark.asyncio + async def test_active_task_consumer_interrupted_multiple_times( + self, + active_task: ActiveTask, + agent_executor: Mock, + request_context: Mock, + task_manager: Mock, + ) -> None: + """Test consumer receiving multiple interrupting events.""" + task_obj = Task( + id='test-task-id', + status=TaskStatus(state=TaskState.TASK_STATE_AUTH_REQUIRED), + ) + + async def execute_mock(req, q): + await q.enqueue_event( + TaskStatusUpdateEvent( + task_id='test-task-id', + status=TaskStatus(state=TaskState.TASK_STATE_AUTH_REQUIRED), + ) + ) + await q.enqueue_event( + TaskStatusUpdateEvent( + task_id='test-task-id', + status=TaskStatus( + state=TaskState.TASK_STATE_INPUT_REQUIRED + ), + ) + ) + + agent_executor.execute = AsyncMock(side_effect=execute_mock) + task_manager.get_task.side_effect = [ + Task( + id='test-task-id', + status=TaskStatus(state=TaskState.TASK_STATE_WORKING), + ) + ] + [task_obj] * 10 + + await active_task.start( + call_context=ServerCallContext(), create_task_if_missing=True + ) + + events = [ + e async for e in active_task.subscribe(request=request_context) + ] + + result = events[0] if events else None + assert result.status.state == TaskState.TASK_STATE_AUTH_REQUIRED + + @pytest.mark.asyncio + async def test_active_task_subscribe_immediate_finish( + self, + active_task: ActiveTask, + agent_executor: Mock, + request_context: Mock, + ) -> None: + """Test subscribe when the task finishes immediately.""" + + async def execute_mock(req, q): + active_task._request_queue.shutdown(immediate=True) + + agent_executor.execute = AsyncMock(side_effect=execute_mock) + + await active_task.enqueue_request(request_context) + await active_task.start( + call_context=ServerCallContext(), create_task_if_missing=True + ) + + # Wait for it to finish + await active_task._is_finished.wait() + + with pytest.raises( + InvalidParamsError, match=r'Task .* is already completed' + ): + async for _ in active_task.subscribe(): + pass + + @pytest.mark.asyncio + async def test_active_task_start_producer_immediate_error( + self, + active_task: ActiveTask, + agent_executor: Mock, + request_context: Mock, + ) -> None: + """Test start when producer fails immediately.""" + agent_executor.execute = AsyncMock( + side_effect=ValueError('Quick failure') + ) + + await active_task.enqueue_request(request_context) + await active_task.start( + call_context=ServerCallContext(), create_task_if_missing=True + ) + + # Consumer should also finish + with pytest.raises(ValueError, match='Quick failure'): + async for _ in active_task.subscribe(): + pass + + @pytest.mark.asyncio + async def test_active_task_subscribe_finished_during_wait( + self, + active_task: ActiveTask, + agent_executor: Mock, + request_context: Mock, + ) -> None: + """Test subscribe when the task finishes while waiting for an event.""" + + async def slow_execute(req, q): + # Do nothing and just finish + await asyncio.sleep(0.5) + active_task._request_queue.shutdown(immediate=True) + + agent_executor.execute = AsyncMock(side_effect=slow_execute) + + await active_task.enqueue_request(request_context) + await active_task.start( + call_context=ServerCallContext(), create_task_if_missing=True + ) + + async def consume(): + async for _ in active_task.subscribe(): + pass + + task = asyncio.create_task(consume()) + await asyncio.sleep(0.2) + + # Task is still running, subscribe is waiting. + # Now it finishes. + await asyncio.sleep(0.5) + await task # Should finish normally + + @pytest.mark.asyncio + async def test_active_task_maybe_cleanup_not_finished( + self, + agent_executor: Mock, + task_manager: Mock, + push_sender: Mock, + ) -> None: + """Test that cleanup is not called if task is not finished.""" + on_cleanup = Mock() + active_task = ActiveTask( + agent_executor=agent_executor, + task_id='test-task-id', + task_manager=task_manager, + push_sender=push_sender, + on_cleanup=on_cleanup, + ) + + # Explicitly call private _maybe_cleanup to verify it respects finished state + await active_task._maybe_cleanup() + on_cleanup.assert_not_called() + + @pytest.mark.asyncio + async def test_active_task_subscribe_exception_already_set( + self, active_task: ActiveTask + ) -> None: + """Test subscribe when exception is already set.""" + active_task._exception = ValueError('Pre-existing error') + with pytest.raises(ValueError, match='Pre-existing error'): + async for _ in active_task.subscribe(): + pass + + @pytest.mark.asyncio + async def test_active_task_subscribe_inner_exception( + self, + active_task: ActiveTask, + agent_executor: Mock, + request_context: Mock, + ) -> None: + """Test the generic exception block in subscribe.""" + + async def slow_execute(req, q): + await asyncio.sleep(10) + + agent_executor.execute = AsyncMock(side_effect=slow_execute) + await active_task.enqueue_request(request_context) + await active_task.start( + call_context=ServerCallContext(), create_task_if_missing=True + ) + + mock_tapped_queue = Mock(spec=EventQueue) + # dequeue_event returns a task that fails + mock_tapped_queue.dequeue_event = AsyncMock( + side_effect=Exception('Inner error') + ) + mock_tapped_queue.close = AsyncMock() + + with ( + patch.object( + active_task._event_queue_subscribers, + 'tap', + return_value=mock_tapped_queue, + ), + pytest.raises(Exception, match='Inner error'), + ): + async for _ in active_task.subscribe(): + pass + + +@pytest.mark.asyncio +async def test_active_task_subscribe_include_initial_task(): + agent_executor = Mock() + task_manager = Mock() + request_context = Mock(spec=RequestContext) + + active_task = ActiveTask( + agent_executor=agent_executor, + task_id='test-task-id', + task_manager=task_manager, + push_sender=Mock(), + ) + + initial_task = Task( + id='test-task-id', status=TaskStatus(state=TaskState.TASK_STATE_WORKING) + ) + + async def execute_mock(req, q): + active_task._request_queue.shutdown(immediate=True) + + agent_executor.execute = AsyncMock(side_effect=execute_mock) + task_manager.get_task = AsyncMock(return_value=initial_task) + task_manager.save_task_event = AsyncMock() + + await active_task.enqueue_request(request_context) + await active_task.start( + call_context=ServerCallContext(), create_task_if_missing=True + ) + + events = [e async for e in active_task.subscribe(include_initial_task=True)] + + # Verify that the first yielded event is the initial task + assert len(events) >= 1 + assert events[0] == initial_task + + +@pytest.mark.timeout(1) +@pytest.mark.asyncio +async def test_active_task_subscribe_request_parameter(): + agent_executor = Mock() + task_manager = Mock() + request_context = Mock(spec=RequestContext) + + active_task = ActiveTask( + agent_executor=agent_executor, + task_id='test-task-id', + task_manager=task_manager, + push_sender=Mock(), + ) + + async def execute_mock(req, q): + # We simulate the task finishing successfully, so it will emit _RequestCompleted + pass + + agent_executor.execute = AsyncMock(side_effect=execute_mock) + agent_executor.cancel = AsyncMock() + task_manager.get_task = AsyncMock( + return_value=Task( + id='test-task-id', + status=TaskStatus(state=TaskState.TASK_STATE_WORKING), + ) + ) + task_manager.save_task_event = AsyncMock() + task_manager.process = AsyncMock(side_effect=lambda x: x) + + await active_task.start( + call_context=ServerCallContext(), create_task_if_missing=True + ) + + # Pass request_context directly to subscribe without enqueuing manually + events = [e async for e in active_task.subscribe(request=request_context)] + + # Should complete without error, and yield no events (just _RequestCompleted which is hidden) + assert len(events) == 0 + + await active_task.cancel(request_context) diff --git a/tests/server/agent_execution/test_context.py b/tests/server/agent_execution/test_context.py index 7ec612986..dce780f58 100644 --- a/tests/server/agent_execution/test_context.py +++ b/tests/server/agent_execution/test_context.py @@ -322,14 +322,8 @@ def test_init_with_context_id_and_existing_context_id_match( assert context.current_task == mock_task def test_extension_handling(self) -> None: - """Test extension handling in RequestContext.""" + """Test that requested_extensions is exposed via RequestContext.""" call_context = ServerCallContext(requested_extensions={'foo', 'bar'}) context = RequestContext(call_context=call_context) assert context.requested_extensions == {'foo', 'bar'} - - context.add_activated_extension('foo') - assert call_context.activated_extensions == {'foo'} - - context.add_activated_extension('baz') - assert call_context.activated_extensions == {'foo', 'baz'} diff --git a/tests/server/events/test_event_consumer.py b/tests/server/events/test_event_consumer.py index cfd315265..d7d20768b 100644 --- a/tests/server/events/test_event_consumer.py +++ b/tests/server/events/test_event_consumer.py @@ -49,11 +49,11 @@ def create_sample_task( @pytest.fixture def mock_event_queue(): - return AsyncMock(spec=EventQueue) + return AsyncMock(spec=EventQueueLegacy) @pytest.fixture -def event_consumer(mock_event_queue: EventQueue): +def event_consumer(mock_event_queue: EventQueueLegacy): return EventConsumer(queue=mock_event_queue) diff --git a/tests/server/events/test_inmemory_queue_manager.py b/tests/server/events/test_inmemory_queue_manager.py index b51334a95..9716b13bf 100644 --- a/tests/server/events/test_inmemory_queue_manager.py +++ b/tests/server/events/test_inmemory_queue_manager.py @@ -5,7 +5,7 @@ import pytest from a2a.server.events import InMemoryQueueManager -from a2a.server.events.event_queue import EventQueue +from a2a.server.events.event_queue import EventQueueLegacy from a2a.server.events.queue_manager import ( NoTaskQueue, TaskQueueExists, @@ -21,7 +21,7 @@ def queue_manager(self) -> InMemoryQueueManager: @pytest.fixture def event_queue(self) -> MagicMock: """Fixture to create a mock EventQueue.""" - queue = MagicMock(spec=EventQueue) + queue = MagicMock(spec=EventQueueLegacy) # Mock the tap method to return itself queue.tap.return_value = queue @@ -119,7 +119,7 @@ async def test_create_or_tap_new_queue( task_id = 'test_task_id' result = await queue_manager.create_or_tap(task_id) - assert isinstance(result, EventQueue) + assert isinstance(result, EventQueueLegacy) assert queue_manager._task_queue[task_id] == result @pytest.mark.asyncio @@ -142,7 +142,7 @@ async def test_concurrency( """Test concurrent access to the queue manager.""" async def add_task(task_id): - queue = EventQueue() + queue = EventQueueLegacy() await queue_manager.add(task_id, queue) return task_id diff --git a/tests/server/request_handlers/test_default_request_handler.py b/tests/server/request_handlers/test_default_request_handler.py index f4ba04996..0138045ae 100644 --- a/tests/server/request_handlers/test_default_request_handler.py +++ b/tests/server/request_handlers/test_default_request_handler.py @@ -14,7 +14,7 @@ import pytest -from a2a.auth.user import UnauthenticatedUser +from a2a.auth.user import UnauthenticatedUser, User from a2a.server.agent_execution import ( AgentExecutor, RequestContext, @@ -22,8 +22,15 @@ SimpleRequestContextBuilder, ) from a2a.server.context import ServerCallContext -from a2a.server.events import EventQueue, InMemoryQueueManager, QueueManager -from a2a.server.request_handlers import DefaultRequestHandler +from a2a.server.events import ( + EventQueue, + EventQueueLegacy, + InMemoryQueueManager, + QueueManager, +) +from a2a.server.request_handlers import ( + LegacyRequestHandler as DefaultRequestHandler, +) from a2a.server.tasks import ( InMemoryPushNotificationConfigStore, InMemoryTaskStore, @@ -34,6 +41,7 @@ TaskUpdater, ) from a2a.types import ( + ExtendedAgentCardNotConfiguredError, InternalError, InvalidParamsError, PushNotificationNotSupportedError, @@ -42,10 +50,13 @@ UnsupportedOperationError, ) from a2a.types.a2a_pb2 import ( + AgentCapabilities, + AgentCard, Artifact, CancelTaskRequest, DeleteTaskPushNotificationConfigRequest, GetTaskPushNotificationConfigRequest, + GetExtendedAgentCardRequest, GetTaskRequest, ListTaskPushNotificationConfigsRequest, ListTasksRequest, @@ -62,7 +73,10 @@ TaskStatus, TaskStatusUpdateEvent, ) -from a2a.utils import new_agent_text_message, new_task +from a2a.helpers.proto_helpers import ( + new_text_message, + new_task_from_user_message, +) class MockAgentExecutor(AgentExecutor): @@ -111,13 +125,25 @@ def create_server_call_context() -> ServerCallContext: return ServerCallContext(user=UnauthenticatedUser()) -def test_init_default_dependencies(): +@pytest.fixture +def agent_card(): + """Provides a standard AgentCard with streaming and push notifications enabled for tests.""" + return AgentCard( + name='test_agent', + version='1.0', + capabilities=AgentCapabilities(streaming=True, push_notifications=True), + ) + + +def test_init_default_dependencies(agent_card): """Test that default dependencies are created if not provided.""" agent_executor = MockAgentExecutor() task_store = InMemoryTaskStore() handler = DefaultRequestHandler( - agent_executor=agent_executor, task_store=task_store + agent_executor=agent_executor, + task_store=task_store, + agent_card=agent_card, ) assert isinstance(handler._queue_manager, InMemoryQueueManager) @@ -134,13 +160,15 @@ def test_init_default_dependencies(): @pytest.mark.asyncio -async def test_on_get_task_not_found(): +async def test_on_get_task_not_found(agent_card): """Test on_get_task when task_store.get returns None.""" mock_task_store = AsyncMock(spec=TaskStore) mock_task_store.get.return_value = None request_handler = DefaultRequestHandler( - agent_executor=MockAgentExecutor(), task_store=mock_task_store + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + agent_card=agent_card, ) params = GetTaskRequest(id='non_existent_task') @@ -153,7 +181,7 @@ async def test_on_get_task_not_found(): @pytest.mark.asyncio -async def test_on_list_tasks_success(): +async def test_on_list_tasks_success(agent_card): """Test on_list_tasks successfully returns a page of tasks .""" mock_task_store = AsyncMock(spec=TaskStore) task2 = create_sample_task(task_id='task2') @@ -175,7 +203,9 @@ async def test_on_list_tasks_success(): ) mock_task_store.list.return_value = mock_page request_handler = DefaultRequestHandler( - agent_executor=AsyncMock(spec=AgentExecutor), task_store=mock_task_store + agent_executor=AsyncMock(spec=AgentExecutor), + task_store=mock_task_store, + agent_card=agent_card, ) params = ListTasksRequest(include_artifacts=True, page_size=10) context = create_server_call_context() @@ -188,7 +218,7 @@ async def test_on_list_tasks_success(): @pytest.mark.asyncio -async def test_on_list_tasks_excludes_artifacts(): +async def test_on_list_tasks_excludes_artifacts(agent_card): """Test on_list_tasks excludes artifacts from returned tasks.""" mock_task_store = AsyncMock(spec=TaskStore) task2 = create_sample_task(task_id='task2') @@ -210,7 +240,9 @@ async def test_on_list_tasks_excludes_artifacts(): ) mock_task_store.list.return_value = mock_page request_handler = DefaultRequestHandler( - agent_executor=AsyncMock(spec=AgentExecutor), task_store=mock_task_store + agent_executor=AsyncMock(spec=AgentExecutor), + task_store=mock_task_store, + agent_card=agent_card, ) params = ListTasksRequest(include_artifacts=False, page_size=10) context = create_server_call_context() @@ -221,12 +253,12 @@ async def test_on_list_tasks_excludes_artifacts(): @pytest.mark.asyncio -async def test_on_list_tasks_applies_history_length(): +async def test_on_list_tasks_applies_history_length(agent_card): """Test on_list_tasks applies history length filter.""" mock_task_store = AsyncMock(spec=TaskStore) history = [ - new_agent_text_message('Hello 1!'), - new_agent_text_message('Hello 2!'), + new_text_message('Hello 1!'), + new_text_message('Hello 2!'), ] task2 = create_sample_task(task_id='task2') task2.history.extend(history) @@ -239,7 +271,9 @@ async def test_on_list_tasks_applies_history_length(): ) mock_task_store.list.return_value = mock_page request_handler = DefaultRequestHandler( - agent_executor=AsyncMock(spec=AgentExecutor), task_store=mock_task_store + agent_executor=AsyncMock(spec=AgentExecutor), + task_store=mock_task_store, + agent_card=agent_card, ) params = ListTasksRequest(history_length=1, page_size=10) context = create_server_call_context() @@ -250,11 +284,13 @@ async def test_on_list_tasks_applies_history_length(): @pytest.mark.asyncio -async def test_on_list_tasks_negative_history_length_error(): +async def test_on_list_tasks_negative_history_length_error(agent_card): """Test on_list_tasks raises error for negative history length.""" mock_task_store = AsyncMock(spec=TaskStore) request_handler = DefaultRequestHandler( - agent_executor=AsyncMock(spec=AgentExecutor), task_store=mock_task_store + agent_executor=AsyncMock(spec=AgentExecutor), + task_store=mock_task_store, + agent_card=agent_card, ) params = ListTasksRequest(history_length=-1, page_size=10) context = create_server_call_context() @@ -272,7 +308,9 @@ async def test_on_cancel_task_task_not_found(): mock_task_store.get.return_value = None request_handler = DefaultRequestHandler( - agent_executor=MockAgentExecutor(), task_store=mock_task_store + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + agent_card=agent_card, ) params = CancelTaskRequest(id='task_not_found_for_cancel') @@ -286,7 +324,7 @@ async def test_on_cancel_task_task_not_found(): @pytest.mark.asyncio -async def test_on_cancel_task_queue_tap_returns_none(): +async def test_on_cancel_task_queue_tap_returns_none(agent_card): """Test on_cancel_task when queue_manager.tap returns None.""" mock_task_store = AsyncMock(spec=TaskStore) sample_task = create_sample_task(task_id='tap_none_task') @@ -314,6 +352,7 @@ async def test_on_cancel_task_queue_tap_returns_none(): agent_executor=mock_agent_executor, task_store=mock_task_store, queue_manager=mock_queue_manager, + agent_card=agent_card, ) context = create_server_call_context() @@ -341,7 +380,7 @@ async def test_on_cancel_task_queue_tap_returns_none(): @pytest.mark.asyncio -async def test_on_cancel_task_cancels_running_agent(): +async def test_on_cancel_task_cancels_running_agent(agent_card): """Test on_cancel_task cancels a running agent task.""" task_id = 'running_agent_task_to_cancel' sample_task = create_sample_task(task_id=task_id) @@ -349,7 +388,7 @@ async def test_on_cancel_task_cancels_running_agent(): mock_task_store.get.return_value = sample_task mock_queue_manager = AsyncMock(spec=QueueManager) - mock_event_queue = AsyncMock(spec=EventQueue) + mock_event_queue = AsyncMock(spec=EventQueueLegacy) mock_queue_manager.tap.return_value = mock_event_queue mock_agent_executor = AsyncMock(spec=AgentExecutor) @@ -366,6 +405,7 @@ async def test_on_cancel_task_cancels_running_agent(): agent_executor=mock_agent_executor, task_store=mock_task_store, queue_manager=mock_queue_manager, + agent_card=agent_card, ) # Simulate a running agent task @@ -385,7 +425,7 @@ async def test_on_cancel_task_cancels_running_agent(): @pytest.mark.asyncio -async def test_on_cancel_task_completes_during_cancellation(): +async def test_on_cancel_task_completes_during_cancellation(agent_card): """Test on_cancel_task fails to cancel a task due to concurrent task completion.""" task_id = 'running_agent_task_to_cancel' sample_task = create_sample_task(task_id=task_id) @@ -393,7 +433,7 @@ async def test_on_cancel_task_completes_during_cancellation(): mock_task_store.get.return_value = sample_task mock_queue_manager = AsyncMock(spec=QueueManager) - mock_event_queue = AsyncMock(spec=EventQueue) + mock_event_queue = AsyncMock(spec=EventQueueLegacy) mock_queue_manager.tap.return_value = mock_event_queue mock_agent_executor = AsyncMock(spec=AgentExecutor) @@ -410,6 +450,7 @@ async def test_on_cancel_task_completes_during_cancellation(): agent_executor=mock_agent_executor, task_store=mock_task_store, queue_manager=mock_queue_manager, + agent_card=agent_card, ) # Simulate a running agent task @@ -431,7 +472,7 @@ async def test_on_cancel_task_completes_during_cancellation(): @pytest.mark.asyncio -async def test_on_cancel_task_invalid_result_type(): +async def test_on_cancel_task_invalid_result_type(agent_card): """Test on_cancel_task when result_aggregator returns a Message instead of a Task.""" task_id = 'cancel_invalid_result_task' sample_task = create_sample_task(task_id=task_id) @@ -439,7 +480,7 @@ async def test_on_cancel_task_invalid_result_type(): mock_task_store.get.return_value = sample_task mock_queue_manager = AsyncMock(spec=QueueManager) - mock_event_queue = AsyncMock(spec=EventQueue) + mock_event_queue = AsyncMock(spec=EventQueueLegacy) mock_queue_manager.tap.return_value = mock_event_queue mock_agent_executor = AsyncMock(spec=AgentExecutor) @@ -456,6 +497,7 @@ async def test_on_cancel_task_invalid_result_type(): agent_executor=mock_agent_executor, task_store=mock_task_store, queue_manager=mock_queue_manager, + agent_card=agent_card, ) with patch( @@ -475,7 +517,7 @@ async def test_on_cancel_task_invalid_result_type(): @pytest.mark.asyncio -async def test_on_message_send_with_push_notification(): +async def test_on_message_send_with_push_notification(agent_card): """Test on_message_send sets push notification info if provided.""" mock_task_store = AsyncMock(spec=TaskStore) mock_push_notification_store = AsyncMock(spec=PushNotificationConfigStore) @@ -511,6 +553,7 @@ async def test_on_message_send_with_push_notification(): task_store=mock_task_store, push_config_store=mock_push_notification_store, request_context_builder=mock_request_context_builder, + agent_card=agent_card, ) push_config = TaskPushNotificationConfig(url='http://callback.com/push') @@ -576,7 +619,9 @@ async def mock_current_result(): @pytest.mark.asyncio -async def test_on_message_send_with_push_notification_in_non_blocking_request(): +async def test_on_message_send_with_push_notification_in_non_blocking_request( + agent_card, +): """Test that push notification callback is called during background event processing for non-blocking requests.""" mock_task_store = AsyncMock(spec=TaskStore) mock_push_notification_store = AsyncMock(spec=PushNotificationConfigStore) @@ -615,6 +660,7 @@ async def test_on_message_send_with_push_notification_in_non_blocking_request(): push_config_store=mock_push_notification_store, request_context_builder=mock_request_context_builder, push_sender=mock_push_sender, + agent_card=agent_card, ) # Configure push notification @@ -715,7 +761,9 @@ async def mock_consume_and_break_on_interrupt( @pytest.mark.asyncio -async def test_on_message_send_with_push_notification_no_existing_Task(): +async def test_on_message_send_with_push_notification_no_existing_Task( + agent_card, +): """Test on_message_send for new task sets push notification info if provided.""" mock_task_store = AsyncMock(spec=TaskStore) mock_push_notification_store = AsyncMock(spec=PushNotificationConfigStore) @@ -740,6 +788,7 @@ async def test_on_message_send_with_push_notification_no_existing_Task(): task_store=mock_task_store, push_config_store=mock_push_notification_store, request_context_builder=mock_request_context_builder, + agent_card=agent_card, ) push_config = TaskPushNotificationConfig(url='http://callback.com/push') @@ -799,8 +848,8 @@ async def mock_current_result(): @pytest.mark.asyncio -async def test_on_message_send_no_result_from_aggregator(): - """Test on_message_send when aggregator returns (None, False).""" +async def test_on_message_send_no_result_from_aggregator(agent_card): + """Test on_message_send when aggregator returns (None, False). Completes unsuccessfully and raises InternalError.""" mock_task_store = AsyncMock(spec=TaskStore) mock_agent_executor = AsyncMock(spec=AgentExecutor) mock_request_context_builder = AsyncMock(spec=RequestContextBuilder) @@ -815,6 +864,7 @@ async def test_on_message_send_no_result_from_aggregator(): agent_executor=mock_agent_executor, task_store=mock_task_store, request_context_builder=mock_request_context_builder, + agent_card=agent_card, ) params = SendMessageRequest( message=Message( @@ -848,7 +898,8 @@ async def test_on_message_send_no_result_from_aggregator(): @pytest.mark.asyncio -async def test_on_message_send_task_id_mismatch(): +async def test_on_message_send_task_id_mismatch(agent_card): + """Test on_message_send returns InternalError if aggregator returns mismatched Task ID.""" """Test on_message_send when result task ID doesn't match request context task ID.""" mock_task_store = AsyncMock(spec=TaskStore) mock_agent_executor = AsyncMock(spec=AgentExecutor) @@ -866,6 +917,7 @@ async def test_on_message_send_task_id_mismatch(): agent_executor=mock_agent_executor, task_store=mock_task_store, request_context_builder=mock_request_context_builder, + agent_card=agent_card, ) params = SendMessageRequest( message=Message( @@ -908,7 +960,7 @@ async def execute(self, context: RequestContext, event_queue: EventQueue): assert context.message is not None, ( 'A message is required to create a new task' ) - task = new_task(context.message) # type: ignore + task = new_task_from_user_message(context.message) # type: ignore await event_queue.enqueue_event(task) updater = TaskUpdater(event_queue, task.id, task.context_id) @@ -933,7 +985,7 @@ async def cancel(self, context: RequestContext, event_queue: EventQueue): @pytest.mark.asyncio -async def test_on_message_send_non_blocking(): +async def test_on_message_send_non_blocking(agent_card): task_store = InMemoryTaskStore() push_store = InMemoryPushNotificationConfigStore() @@ -941,6 +993,7 @@ async def test_on_message_send_non_blocking(): agent_executor=HelloAgentExecutor(), task_store=task_store, push_config_store=push_store, + agent_card=agent_card, ) params = SendMessageRequest( message=Message( @@ -979,7 +1032,7 @@ async def test_on_message_send_non_blocking(): @pytest.mark.asyncio -async def test_on_message_send_limit_history(): +async def test_on_message_send_limit_history(agent_card): task_store = InMemoryTaskStore() push_store = InMemoryPushNotificationConfigStore() @@ -987,6 +1040,7 @@ async def test_on_message_send_limit_history(): agent_executor=HelloAgentExecutor(), task_store=task_store, push_config_store=push_store, + agent_card=agent_card, ) params = SendMessageRequest( message=Message( @@ -1016,7 +1070,7 @@ async def test_on_message_send_limit_history(): @pytest.mark.asyncio -async def test_on_get_task_limit_history(): +async def test_on_get_task_limit_history(agent_card): task_store = InMemoryTaskStore() push_store = InMemoryPushNotificationConfigStore() @@ -1024,6 +1078,7 @@ async def test_on_get_task_limit_history(): agent_executor=HelloAgentExecutor(), task_store=task_store, push_config_store=push_store, + agent_card=agent_card, ) params = SendMessageRequest( message=Message( @@ -1056,7 +1111,7 @@ async def test_on_get_task_limit_history(): @pytest.mark.asyncio -async def test_on_message_send_interrupted_flow(): +async def test_on_message_send_interrupted_flow(agent_card): """Test on_message_send when flow is interrupted (e.g., auth_required).""" mock_task_store = AsyncMock(spec=TaskStore) mock_agent_executor = AsyncMock(spec=AgentExecutor) @@ -1072,6 +1127,7 @@ async def test_on_message_send_interrupted_flow(): agent_executor=mock_agent_executor, task_store=mock_task_store, request_context_builder=mock_request_context_builder, + agent_card=agent_card, ) params = SendMessageRequest( message=Message( @@ -1137,7 +1193,7 @@ def capture_create_task(coro): @pytest.mark.asyncio -async def test_on_message_send_stream_with_push_notification(): +async def test_on_message_send_stream_with_push_notification(agent_card): """Test on_message_send_stream sets and uses push notification info.""" mock_task_store = AsyncMock(spec=TaskStore) mock_push_config_store = AsyncMock(spec=PushNotificationConfigStore) @@ -1175,6 +1231,7 @@ async def test_on_message_send_stream_with_push_notification(): push_config_store=mock_push_config_store, push_sender=mock_push_sender, request_context_builder=mock_request_context_builder, + agent_card=agent_card, ) push_config = TaskPushNotificationConfig( @@ -1284,7 +1341,9 @@ async def to_coro(val): @pytest.mark.asyncio -async def test_stream_disconnect_then_resubscribe_receives_future_events(): +async def test_stream_disconnect_then_resubscribe_receives_future_events( + agent_card, +): """Start streaming, disconnect, then resubscribe and ensure subsequent events are streamed.""" # Arrange mock_task_store = AsyncMock(spec=TaskStore) @@ -1308,6 +1367,7 @@ async def test_stream_disconnect_then_resubscribe_receives_future_events(): agent_executor=mock_agent_executor, task_store=mock_task_store, queue_manager=queue_manager, + agent_card=agent_card, ) params = SendMessageRequest( @@ -1375,7 +1435,9 @@ async def exec_side_effect(_request, queue: EventQueue): @pytest.mark.asyncio -async def test_on_message_send_stream_client_disconnect_triggers_background_cleanup_and_producer_continues(): +async def test_on_message_send_stream_client_disconnect_triggers_background_cleanup_and_producer_continues( + agent_card, +): """Simulate client disconnect: stream stops early, cleanup is scheduled in background, producer keeps running, and cleanup completes after producer finishes.""" # Arrange @@ -1398,7 +1460,7 @@ async def test_on_message_send_stream_client_disconnect_triggers_background_clea mock_request_context_builder.build.return_value = mock_request_context # Queue used by _run_event_stream; must support close() - mock_queue = AsyncMock(spec=EventQueue) + mock_queue = AsyncMock(spec=EventQueueLegacy) mock_queue_manager.create_or_tap.return_value = mock_queue request_handler = DefaultRequestHandler( @@ -1406,6 +1468,7 @@ async def test_on_message_send_stream_client_disconnect_triggers_background_clea task_store=mock_task_store, queue_manager=mock_queue_manager, request_context_builder=mock_request_context_builder, + agent_card=agent_card, ) params = SendMessageRequest( @@ -1514,7 +1577,7 @@ def create_task_spy(coro): @pytest.mark.asyncio -async def test_disconnect_persists_final_task_to_store(): +async def test_disconnect_persists_final_task_to_store(agent_card): """After client disconnect, ensure background consumer persists final Task to store.""" task_store = InMemoryTaskStore() queue_manager = InMemoryQueueManager() @@ -1527,7 +1590,6 @@ def __init__(self): async def execute( self, context: RequestContext, event_queue: EventQueue ): - updater = TaskUpdater( event_queue, cast('str', context.task_id), @@ -1545,7 +1607,10 @@ async def cancel( agent = FinishingAgent() handler = DefaultRequestHandler( - agent_executor=agent, task_store=task_store, queue_manager=queue_manager + agent_executor=agent, + task_store=task_store, + queue_manager=queue_manager, + agent_card=agent_card, ) params = SendMessageRequest( @@ -1604,7 +1669,7 @@ async def wait_until(predicate, timeout: float = 0.2, interval: float = 0.0): @pytest.mark.asyncio -async def test_background_cleanup_task_is_tracked_and_cleared(): +async def test_background_cleanup_task_is_tracked_and_cleared(agent_card): """Ensure background cleanup task is tracked while pending and removed when done.""" # Arrange mock_task_store = AsyncMock(spec=TaskStore) @@ -1625,7 +1690,7 @@ async def test_background_cleanup_task_is_tracked_and_cleared(): mock_request_context.context_id = context_id mock_request_context_builder.build.return_value = mock_request_context - mock_queue = AsyncMock(spec=EventQueue) + mock_queue = AsyncMock(spec=EventQueueLegacy) mock_queue_manager.create_or_tap.return_value = mock_queue request_handler = DefaultRequestHandler( @@ -1633,6 +1698,7 @@ async def test_background_cleanup_task_is_tracked_and_cleared(): task_store=mock_task_store, queue_manager=mock_queue_manager, request_context_builder=mock_request_context_builder, + agent_card=agent_card, ) params = SendMessageRequest( @@ -1722,7 +1788,7 @@ def create_task_spy(coro): @pytest.mark.asyncio -async def test_on_message_send_stream_task_id_mismatch(): +async def test_on_message_send_stream_task_id_mismatch(agent_card): """Test on_message_send_stream raises error if yielded task ID mismatches.""" mock_task_store = AsyncMock(spec=TaskStore) mock_agent_executor = AsyncMock( @@ -1741,6 +1807,7 @@ async def test_on_message_send_stream_task_id_mismatch(): agent_executor=mock_agent_executor, task_store=mock_task_store, request_context_builder=mock_request_context_builder, + agent_card=agent_card, ) params = SendMessageRequest( message=Message( @@ -1782,7 +1849,7 @@ async def event_stream_gen_mismatch(): @pytest.mark.asyncio -async def test_cleanup_producer_task_id_not_in_running_agents(): +async def test_cleanup_producer_task_id_not_in_running_agents(agent_card): """Test _cleanup_producer when task_id is not in _running_agents (e.g., already cleaned up).""" mock_task_store = AsyncMock(spec=TaskStore) mock_queue_manager = AsyncMock(spec=QueueManager) @@ -1790,6 +1857,7 @@ async def test_cleanup_producer_task_id_not_in_running_agents(): agent_executor=MockAgentExecutor(), task_store=mock_task_store, queue_manager=mock_queue_manager, + agent_card=agent_card, ) task_id = 'task_already_cleaned' @@ -1819,12 +1887,13 @@ async def noop_coro_for_task(): @pytest.mark.asyncio -async def test_set_task_push_notification_config_no_notifier(): +async def test_set_task_push_notification_config_no_notifier(agent_card): """Test on_create_task_push_notification_config when _push_config_store is None.""" request_handler = DefaultRequestHandler( agent_executor=MockAgentExecutor(), task_store=AsyncMock(spec=TaskStore), - push_config_store=None, # Explicitly None + push_config_store=None, # Explicitly None, + agent_card=agent_card, ) params = TaskPushNotificationConfig( task_id='task1', @@ -1838,7 +1907,7 @@ async def test_set_task_push_notification_config_no_notifier(): @pytest.mark.asyncio -async def test_set_task_push_notification_config_task_not_found(): +async def test_set_task_push_notification_config_task_not_found(agent_card): """Test on_create_task_push_notification_config when task is not found.""" mock_task_store = AsyncMock(spec=TaskStore) mock_task_store.get.return_value = None # Task not found @@ -1850,6 +1919,7 @@ async def test_set_task_push_notification_config_task_not_found(): task_store=mock_task_store, push_config_store=mock_push_store, push_sender=mock_push_sender, + agent_card=agent_card, ) params = TaskPushNotificationConfig( task_id='non_existent_task', @@ -1866,12 +1936,13 @@ async def test_set_task_push_notification_config_task_not_found(): @pytest.mark.asyncio -async def test_get_task_push_notification_config_no_store(): +async def test_get_task_push_notification_config_no_store(agent_card): """Test on_get_task_push_notification_config when _push_config_store is None.""" request_handler = DefaultRequestHandler( agent_executor=MockAgentExecutor(), task_store=AsyncMock(spec=TaskStore), - push_config_store=None, # Explicitly None + push_config_store=None, # Explicitly None, + agent_card=agent_card, ) params = GetTaskPushNotificationConfigRequest( task_id='task1', @@ -1885,7 +1956,7 @@ async def test_get_task_push_notification_config_no_store(): @pytest.mark.asyncio -async def test_get_task_push_notification_config_task_not_found(): +async def test_get_task_push_notification_config_task_not_found(agent_card): """Test on_get_task_push_notification_config when task is not found.""" mock_task_store = AsyncMock(spec=TaskStore) mock_task_store.get.return_value = None # Task not found @@ -1895,6 +1966,7 @@ async def test_get_task_push_notification_config_task_not_found(): agent_executor=MockAgentExecutor(), task_store=mock_task_store, push_config_store=mock_push_store, + agent_card=agent_card, ) params = GetTaskPushNotificationConfigRequest( task_id='non_existent_task', id='task_push_notification_config' @@ -1910,7 +1982,7 @@ async def test_get_task_push_notification_config_task_not_found(): @pytest.mark.asyncio -async def test_get_task_push_notification_config_info_not_found(): +async def test_get_task_push_notification_config_info_not_found(agent_card): """Test on_get_task_push_notification_config when push_config_store.get_info returns None.""" mock_task_store = AsyncMock(spec=TaskStore) @@ -1924,13 +1996,14 @@ async def test_get_task_push_notification_config_info_not_found(): agent_executor=MockAgentExecutor(), task_store=mock_task_store, push_config_store=mock_push_store, + agent_card=agent_card, ) params = GetTaskPushNotificationConfigRequest( task_id='non_existent_task', id='task_push_notification_config' ) context = create_server_call_context() - with pytest.raises(InternalError): + with pytest.raises(TaskNotFoundError): await request_handler.on_get_task_push_notification_config( params, context ) @@ -1941,7 +2014,7 @@ async def test_get_task_push_notification_config_info_not_found(): @pytest.mark.asyncio -async def test_get_task_push_notification_config_info_with_config(): +async def test_get_task_push_notification_config_info_with_config(agent_card): """Test on_get_task_push_notification_config with valid push config id""" mock_task_store = AsyncMock(spec=TaskStore) mock_task_store.get.return_value = Task(id='task_1', context_id='ctx_1') @@ -1952,6 +2025,7 @@ async def test_get_task_push_notification_config_info_with_config(): agent_executor=MockAgentExecutor(), task_store=mock_task_store, push_config_store=push_store, + agent_card=agent_card, ) set_config_params = TaskPushNotificationConfig( @@ -1979,7 +2053,9 @@ async def test_get_task_push_notification_config_info_with_config(): @pytest.mark.asyncio -async def test_get_task_push_notification_config_info_with_config_no_id(): +async def test_get_task_push_notification_config_info_with_config_no_id( + agent_card, +): """Test on_get_task_push_notification_config with no push config id""" mock_task_store = AsyncMock(spec=TaskStore) mock_task_store.get.return_value = Task(id='task_1', context_id='ctx_1') @@ -1990,6 +2066,7 @@ async def test_get_task_push_notification_config_info_with_config_no_id(): agent_executor=MockAgentExecutor(), task_store=mock_task_store, push_config_store=push_store, + agent_card=agent_card, ) set_config_params = TaskPushNotificationConfig( @@ -2015,13 +2092,15 @@ async def test_get_task_push_notification_config_info_with_config_no_id(): @pytest.mark.asyncio -async def test_on_subscribe_to_task_task_not_found(): +async def test_on_subscribe_to_task_task_not_found(agent_card): """Test on_subscribe_to_task when the task is not found.""" mock_task_store = AsyncMock(spec=TaskStore) mock_task_store.get.return_value = None # Task not found request_handler = DefaultRequestHandler( - agent_executor=MockAgentExecutor(), task_store=mock_task_store + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + agent_card=agent_card, ) params = SubscribeToTaskRequest(id='resub_task_not_found') @@ -2036,7 +2115,7 @@ async def test_on_subscribe_to_task_task_not_found(): @pytest.mark.asyncio -async def test_on_subscribe_to_task_queue_not_found(): +async def test_on_subscribe_to_task_queue_not_found(agent_card): """Test on_subscribe_to_task when the queue is not found by queue_manager.tap.""" mock_task_store = AsyncMock(spec=TaskStore) sample_task = create_sample_task(task_id='resub_queue_not_found') @@ -2049,6 +2128,7 @@ async def test_on_subscribe_to_task_queue_not_found(): agent_executor=MockAgentExecutor(), task_store=mock_task_store, queue_manager=mock_queue_manager, + agent_card=agent_card, ) params = SubscribeToTaskRequest(id='resub_queue_not_found') @@ -2063,9 +2143,11 @@ async def test_on_subscribe_to_task_queue_not_found(): @pytest.mark.asyncio -async def test_on_message_send_stream(): +async def test_on_message_send_stream(agent_card): request_handler = DefaultRequestHandler( - MockAgentExecutor(), InMemoryTaskStore() + MockAgentExecutor(), + InMemoryTaskStore(), + agent_card=agent_card, ) message_params = SendMessageRequest( message=Message( @@ -2100,12 +2182,13 @@ async def consume_stream(): @pytest.mark.asyncio -async def test_list_task_push_notification_config_no_store(): +async def test_list_task_push_notification_config_no_store(agent_card): """Test on_list_task_push_notification_configs when _push_config_store is None.""" request_handler = DefaultRequestHandler( agent_executor=MockAgentExecutor(), task_store=AsyncMock(spec=TaskStore), - push_config_store=None, # Explicitly None + push_config_store=None, # Explicitly None, + agent_card=agent_card, ) params = ListTaskPushNotificationConfigsRequest(task_id='task1') @@ -2116,7 +2199,7 @@ async def test_list_task_push_notification_config_no_store(): @pytest.mark.asyncio -async def test_list_task_push_notification_config_task_not_found(): +async def test_list_task_push_notification_config_task_not_found(agent_card): """Test on_list_task_push_notification_configs when task is not found.""" mock_task_store = AsyncMock(spec=TaskStore) mock_task_store.get.return_value = None # Task not found @@ -2126,6 +2209,7 @@ async def test_list_task_push_notification_config_task_not_found(): agent_executor=MockAgentExecutor(), task_store=mock_task_store, push_config_store=mock_push_store, + agent_card=agent_card, ) params = ListTaskPushNotificationConfigsRequest(task_id='non_existent_task') @@ -2139,7 +2223,7 @@ async def test_list_task_push_notification_config_task_not_found(): @pytest.mark.asyncio -async def test_list_no_task_push_notification_config_info(): +async def test_list_no_task_push_notification_config_info(agent_card): """Test on_get_task_push_notification_config when push_config_store.get_info returns []""" mock_task_store = AsyncMock(spec=TaskStore) @@ -2152,6 +2236,7 @@ async def test_list_no_task_push_notification_config_info(): agent_executor=MockAgentExecutor(), task_store=mock_task_store, push_config_store=push_store, + agent_card=agent_card, ) params = ListTaskPushNotificationConfigsRequest(task_id='non_existent_task') @@ -2162,7 +2247,7 @@ async def test_list_no_task_push_notification_config_info(): @pytest.mark.asyncio -async def test_list_task_push_notification_config_info_with_config(): +async def test_list_task_push_notification_config_info_with_config(agent_card): """Test on_list_task_push_notification_configs with push config+id""" mock_task_store = AsyncMock(spec=TaskStore) @@ -2185,6 +2270,7 @@ async def test_list_task_push_notification_config_info_with_config(): agent_executor=MockAgentExecutor(), task_store=mock_task_store, push_config_store=push_store, + agent_card=agent_card, ) params = ListTaskPushNotificationConfigsRequest(task_id='task_1') @@ -2200,7 +2286,9 @@ async def test_list_task_push_notification_config_info_with_config(): @pytest.mark.asyncio -async def test_list_task_push_notification_config_info_with_config_and_no_id(): +async def test_list_task_push_notification_config_info_with_config_and_no_id( + agent_card, +): """Test on_list_task_push_notification_configs with no push config id""" mock_task_store = AsyncMock(spec=TaskStore) mock_task_store.get.return_value = Task(id='task_1', context_id='ctx_1') @@ -2211,6 +2299,7 @@ async def test_list_task_push_notification_config_info_with_config_and_no_id(): agent_executor=MockAgentExecutor(), task_store=mock_task_store, push_config_store=push_store, + agent_card=agent_card, ) # multiple calls without config id should replace the existing @@ -2243,12 +2332,13 @@ async def test_list_task_push_notification_config_info_with_config_and_no_id(): @pytest.mark.asyncio -async def test_delete_task_push_notification_config_no_store(): +async def test_delete_task_push_notification_config_no_store(agent_card): """Test on_delete_task_push_notification_config when _push_config_store is None.""" request_handler = DefaultRequestHandler( agent_executor=MockAgentExecutor(), task_store=AsyncMock(spec=TaskStore), - push_config_store=None, # Explicitly None + push_config_store=None, # Explicitly None, + agent_card=agent_card, ) params = DeleteTaskPushNotificationConfigRequest( task_id='task1', id='config1' @@ -2261,7 +2351,7 @@ async def test_delete_task_push_notification_config_no_store(): @pytest.mark.asyncio -async def test_delete_task_push_notification_config_task_not_found(): +async def test_delete_task_push_notification_config_task_not_found(agent_card): """Test on_delete_task_push_notification_config when task is not found.""" mock_task_store = AsyncMock(spec=TaskStore) mock_task_store.get.return_value = None # Task not found @@ -2271,6 +2361,7 @@ async def test_delete_task_push_notification_config_task_not_found(): agent_executor=MockAgentExecutor(), task_store=mock_task_store, push_config_store=mock_push_store, + agent_card=agent_card, ) params = DeleteTaskPushNotificationConfigRequest( task_id='non_existent_task', id='config1' @@ -2287,7 +2378,7 @@ async def test_delete_task_push_notification_config_task_not_found(): @pytest.mark.asyncio -async def test_delete_no_task_push_notification_config_info(): +async def test_delete_no_task_push_notification_config_info(agent_card): """Test on_delete_task_push_notification_config without config info""" mock_task_store = AsyncMock(spec=TaskStore) @@ -2305,6 +2396,7 @@ async def test_delete_no_task_push_notification_config_info(): agent_executor=MockAgentExecutor(), task_store=mock_task_store, push_config_store=push_store, + agent_card=agent_card, ) params = DeleteTaskPushNotificationConfigRequest( task_id='task1', id='config_non_existant' @@ -2326,7 +2418,9 @@ async def test_delete_no_task_push_notification_config_info(): @pytest.mark.asyncio -async def test_delete_task_push_notification_config_info_with_config(): +async def test_delete_task_push_notification_config_info_with_config( + agent_card, +): """Test on_list_task_push_notification_configs with push config+id""" mock_task_store = AsyncMock(spec=TaskStore) @@ -2350,6 +2444,7 @@ async def test_delete_task_push_notification_config_info_with_config(): agent_executor=MockAgentExecutor(), task_store=mock_task_store, push_config_store=push_store, + agent_card=agent_card, ) params = DeleteTaskPushNotificationConfigRequest( task_id='task_1', id='config_1' @@ -2372,7 +2467,9 @@ async def test_delete_task_push_notification_config_info_with_config(): @pytest.mark.asyncio -async def test_delete_task_push_notification_config_info_with_config_and_no_id(): +async def test_delete_task_push_notification_config_info_with_config_and_no_id( + agent_card, +): """Test on_list_task_push_notification_configs with no push config id""" mock_task_store = AsyncMock(spec=TaskStore) @@ -2391,6 +2488,7 @@ async def test_delete_task_push_notification_config_info_with_config_and_no_id() agent_executor=MockAgentExecutor(), task_store=mock_task_store, push_config_store=push_store, + agent_card=agent_card, ) params = DeleteTaskPushNotificationConfigRequest( task_id='task_1', id='task_1' @@ -2420,7 +2518,9 @@ async def test_delete_task_push_notification_config_info_with_config_and_no_id() @pytest.mark.asyncio @pytest.mark.parametrize('terminal_state', TERMINAL_TASK_STATES) -async def test_on_message_send_task_in_terminal_state(terminal_state): +async def test_on_message_send_task_in_terminal_state( + terminal_state, agent_card +): """Test on_message_send when task is already in a terminal state.""" state_name = TaskState.Name(terminal_state) task_id = f'terminal_task_{state_name}' @@ -2434,7 +2534,9 @@ async def test_on_message_send_task_in_terminal_state(terminal_state): # So we should patch that instead. request_handler = DefaultRequestHandler( - agent_executor=MockAgentExecutor(), task_store=mock_task_store + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + agent_card=agent_card, ) params = SendMessageRequest( @@ -2464,7 +2566,9 @@ async def test_on_message_send_task_in_terminal_state(terminal_state): @pytest.mark.asyncio @pytest.mark.parametrize('terminal_state', TERMINAL_TASK_STATES) -async def test_on_message_send_stream_task_in_terminal_state(terminal_state): +async def test_on_message_send_stream_task_in_terminal_state( + terminal_state, agent_card +): """Test on_message_send_stream when task is already in a terminal state.""" state_name = TaskState.Name(terminal_state) task_id = f'terminal_stream_task_{state_name}' @@ -2475,7 +2579,9 @@ async def test_on_message_send_stream_task_in_terminal_state(terminal_state): mock_task_store = AsyncMock(spec=TaskStore) request_handler = DefaultRequestHandler( - agent_executor=MockAgentExecutor(), task_store=mock_task_store + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + agent_card=agent_card, ) params = SendMessageRequest( @@ -2505,7 +2611,9 @@ async def test_on_message_send_stream_task_in_terminal_state(terminal_state): @pytest.mark.asyncio @pytest.mark.parametrize('terminal_state', TERMINAL_TASK_STATES) -async def test_on_subscribe_to_task_in_terminal_state(terminal_state): +async def test_on_subscribe_to_task_in_terminal_state( + terminal_state, agent_card +): """Test on_subscribe_to_task when task is in a terminal state.""" state_name = TaskState.Name(terminal_state) task_id = f'resub_terminal_task_{state_name}' @@ -2520,6 +2628,7 @@ async def test_on_subscribe_to_task_in_terminal_state(terminal_state): agent_executor=MockAgentExecutor(), task_store=mock_task_store, queue_manager=AsyncMock(spec=QueueManager), + agent_card=agent_card, ) params = SubscribeToTaskRequest(id=f'{task_id}') @@ -2537,13 +2646,15 @@ async def test_on_subscribe_to_task_in_terminal_state(terminal_state): @pytest.mark.asyncio -async def test_on_message_send_task_id_provided_but_task_not_found(): +async def test_on_message_send_task_id_provided_but_task_not_found(agent_card): """Test on_message_send when task_id is provided but task doesn't exist.""" task_id = 'nonexistent_task' mock_task_store = AsyncMock(spec=TaskStore) request_handler = DefaultRequestHandler( - agent_executor=MockAgentExecutor(), task_store=mock_task_store + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + agent_card=agent_card, ) params = SendMessageRequest( @@ -2573,13 +2684,17 @@ async def test_on_message_send_task_id_provided_but_task_not_found(): @pytest.mark.asyncio -async def test_on_message_send_stream_task_id_provided_but_task_not_found(): +async def test_on_message_send_stream_task_id_provided_but_task_not_found( + agent_card, +): """Test on_message_send_stream when task_id is provided but task doesn't exist.""" task_id = 'nonexistent_stream_task' mock_task_store = AsyncMock(spec=TaskStore) request_handler = DefaultRequestHandler( - agent_executor=MockAgentExecutor(), task_store=mock_task_store + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + agent_card=agent_card, ) params = SendMessageRequest( @@ -2637,14 +2752,16 @@ async def cancel( # we should reconsider the approach. @pytest.mark.asyncio @pytest.mark.timeout(1) -async def test_on_message_send_error_does_not_hang(): +async def test_on_message_send_error_does_not_hang(agent_card): """Test that if the consumer raises an exception during blocking wait, the producer is cancelled and no deadlock occurs.""" agent = HelloWorldAgentExecutor() task_store = AsyncMock(spec=TaskStore) task_store.save.side_effect = RuntimeError('This is an Error!') request_handler = DefaultRequestHandler( - agent_executor=agent, task_store=task_store + agent_executor=agent, + task_store=task_store, + agent_card=agent_card, ) params = SendMessageRequest( @@ -2662,11 +2779,13 @@ async def test_on_message_send_error_does_not_hang(): @pytest.mark.asyncio -async def test_on_get_task_negative_history_length_error(): +async def test_on_get_task_negative_history_length_error(agent_card): """Test on_get_task raises error for negative history length.""" mock_task_store = AsyncMock(spec=TaskStore) request_handler = DefaultRequestHandler( - agent_executor=AsyncMock(spec=AgentExecutor), task_store=mock_task_store + agent_executor=AsyncMock(spec=AgentExecutor), + task_store=mock_task_store, + agent_card=agent_card, ) # GetTaskRequest also has history_length params = GetTaskRequest(id='task1', history_length=-1) @@ -2679,11 +2798,13 @@ async def test_on_get_task_negative_history_length_error(): @pytest.mark.asyncio -async def test_on_list_tasks_page_size_too_small(): +async def test_on_list_tasks_page_size_too_small(agent_card): """Test on_list_tasks raises error for page_size < 1.""" mock_task_store = AsyncMock(spec=TaskStore) request_handler = DefaultRequestHandler( - agent_executor=AsyncMock(spec=AgentExecutor), task_store=mock_task_store + agent_executor=AsyncMock(spec=AgentExecutor), + task_store=mock_task_store, + agent_card=agent_card, ) params = ListTasksRequest(page_size=0) context = create_server_call_context() @@ -2695,11 +2816,13 @@ async def test_on_list_tasks_page_size_too_small(): @pytest.mark.asyncio -async def test_on_list_tasks_page_size_too_large(): +async def test_on_list_tasks_page_size_too_large(agent_card): """Test on_list_tasks raises error for page_size > 100.""" mock_task_store = AsyncMock(spec=TaskStore) request_handler = DefaultRequestHandler( - agent_executor=AsyncMock(spec=AgentExecutor), task_store=mock_task_store + agent_executor=AsyncMock(spec=AgentExecutor), + task_store=mock_task_store, + agent_card=agent_card, ) params = ListTasksRequest(page_size=101) context = create_server_call_context() @@ -2711,12 +2834,14 @@ async def test_on_list_tasks_page_size_too_large(): @pytest.mark.asyncio -async def test_on_message_send_negative_history_length_error(): +async def test_on_message_send_negative_history_length_error(agent_card): """Test on_message_send raises error for negative history length in configuration.""" mock_task_store = AsyncMock(spec=TaskStore) mock_agent_executor = AsyncMock(spec=AgentExecutor) request_handler = DefaultRequestHandler( - agent_executor=mock_agent_executor, task_store=mock_task_store + agent_executor=mock_agent_executor, + task_store=mock_task_store, + agent_card=agent_card, ) message_config = SendMessageConfiguration( @@ -2735,3 +2860,287 @@ async def test_on_message_send_negative_history_length_error(): await request_handler.on_message_send(params, context) assert 'history length must be non-negative' in exc_info.value.message + + +@pytest.mark.asyncio +async def test_on_get_extended_agent_card_success(agent_card): + """Test on_get_extended_agent_card when extended_agent_card is supported.""" + agent_card.capabilities.extended_agent_card = True + + extended_agent_card = AgentCard( + name='Extended Agent', + description='An extended agent', + version='1.0.0', + capabilities=AgentCapabilities( + streaming=True, + push_notifications=True, + extended_agent_card=True, + ), + ) + + request_handler = DefaultRequestHandler( + agent_executor=AsyncMock(spec=AgentExecutor), + task_store=AsyncMock(spec=TaskStore), + agent_card=agent_card, + extended_agent_card=extended_agent_card, + ) + + params = GetExtendedAgentCardRequest() + context = create_server_call_context() + + result = await request_handler.on_get_extended_agent_card(params, context) + + assert result == extended_agent_card + + +@pytest.mark.asyncio +async def test_on_message_send_stream_unsupported(agent_card): + """Test on_message_send_stream when streaming is unsupported.""" + agent_card.capabilities.streaming = False + + request_handler = DefaultRequestHandler( + agent_executor=AsyncMock(spec=AgentExecutor), + task_store=AsyncMock(spec=TaskStore), + agent_card=agent_card, + ) + + params = SendMessageRequest( + message=Message( + role=Role.ROLE_USER, + message_id='msg-unsupported', + parts=[Part(text='hi')], + ) + ) + + context = create_server_call_context() + + with pytest.raises(UnsupportedOperationError): + async for _ in request_handler.on_message_send_stream(params, context): + pass + + +@pytest.mark.asyncio +async def test_on_get_extended_agent_card_unsupported(agent_card): + """Test on_get_extended_agent_card when extended_agent_card is unsupported.""" + agent_card.capabilities.extended_agent_card = False + + request_handler = DefaultRequestHandler( + agent_executor=AsyncMock(spec=AgentExecutor), + task_store=AsyncMock(spec=TaskStore), + agent_card=agent_card, + ) + + params = GetExtendedAgentCardRequest() + context = create_server_call_context() + + with pytest.raises(UnsupportedOperationError): + await request_handler.on_get_extended_agent_card(params, context) + + +@pytest.mark.asyncio +async def test_on_create_task_push_notification_config_unsupported(agent_card): + """Test on_create_task_push_notification_config when push_notifications is unsupported.""" + agent_card.capabilities.push_notifications = False + + request_handler = DefaultRequestHandler( + agent_executor=AsyncMock(spec=AgentExecutor), + task_store=AsyncMock(spec=TaskStore), + agent_card=agent_card, + ) + + params = TaskPushNotificationConfig(url='http://callback.com/push') + + context = create_server_call_context() + + with pytest.raises(PushNotificationNotSupportedError): + await request_handler.on_create_task_push_notification_config( + params, context + ) + + +@pytest.mark.asyncio +async def test_on_subscribe_to_task_unsupported(agent_card): + """Test on_subscribe_to_task when streaming is unsupported.""" + agent_card.capabilities.streaming = False + + request_handler = DefaultRequestHandler( + agent_executor=AsyncMock(spec=AgentExecutor), + task_store=AsyncMock(spec=TaskStore), + agent_card=agent_card, + ) + + params = SubscribeToTaskRequest(id='some_task') + context = create_server_call_context() + + with pytest.raises(UnsupportedOperationError): + # We need to exhaust the generator to trigger the decorator evaluation + async for _ in request_handler.on_subscribe_to_task(params, context): + pass + + +class _NamedUser(User): + """Minimal authenticated test user identified by ``user_name``.""" + + def __init__(self, user_name: str) -> None: + self._user_name = user_name + + @property + def is_authenticated(self) -> bool: + return True + + @property + def user_name(self) -> str: + return self._user_name + + +def _ctx(user_name: str) -> ServerCallContext: + return ServerCallContext(user=_NamedUser(user_name)) + + +@pytest.mark.asyncio +async def test_on_list_task_push_notification_configs_is_owner_scoped( + agent_card, +): + """Bob must not see Alice's configs via tasks/pushNotificationConfig/list. + + Both users have access to the shared task (the mocked TaskStore + returns it for any caller), but listing must only return the + caller's own configs. + """ + mock_task_store = AsyncMock(spec=TaskStore) + mock_task_store.get.return_value = create_sample_task(task_id='shared-task') + + push_store = InMemoryPushNotificationConfigStore() + alice_ctx = _ctx('alice') + bob_ctx = _ctx('bob') + + alice_cfg = TaskPushNotificationConfig( + task_id='shared-task', + id='alice-cfg', + url='http://alice.example.com/cb', + token='alice-secret', + ) + bob_cfg = TaskPushNotificationConfig( + task_id='shared-task', + id='bob-cfg', + url='http://bob.example.com/cb', + token='bob-secret', + ) + await push_store.set_info('shared-task', alice_cfg, alice_ctx) + await push_store.set_info('shared-task', bob_cfg, bob_ctx) + + request_handler = DefaultRequestHandler( + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + push_config_store=push_store, + agent_card=agent_card, + ) + + alice_listing = ( + await request_handler.on_list_task_push_notification_configs( + ListTaskPushNotificationConfigsRequest(task_id='shared-task'), + alice_ctx, + ) + ) + assert {c.id for c in alice_listing.configs} == {'alice-cfg'} + # Sanity: Bob's secret is not in the response. + assert all(c.token != 'bob-secret' for c in alice_listing.configs), ( + 'Listing for Alice must not expose Bob-owned tokens' + ) + + bob_listing = await request_handler.on_list_task_push_notification_configs( + ListTaskPushNotificationConfigsRequest(task_id='shared-task'), + bob_ctx, + ) + assert {c.id for c in bob_listing.configs} == {'bob-cfg'} + assert all(c.token != 'alice-secret' for c in bob_listing.configs), ( + 'Listing for Bob must not expose Alice-owned tokens' + ) + + +@pytest.mark.asyncio +async def test_on_list_task_push_notification_configs_returns_empty_for_third_user( + agent_card, +): + """A third user with task access but no registered configs sees an empty list.""" + mock_task_store = AsyncMock(spec=TaskStore) + mock_task_store.get.return_value = create_sample_task(task_id='shared-task') + + push_store = InMemoryPushNotificationConfigStore() + await push_store.set_info( + 'shared-task', + TaskPushNotificationConfig( + task_id='shared-task', + id='alice-cfg', + url='http://alice.example.com/cb', + ), + _ctx('alice'), + ) + + request_handler = DefaultRequestHandler( + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + push_config_store=push_store, + agent_card=agent_card, + ) + + carol_listing = ( + await request_handler.on_list_task_push_notification_configs( + ListTaskPushNotificationConfigsRequest(task_id='shared-task'), + _ctx('carol'), + ) + ) + assert carol_listing.configs == [] + + +@pytest.mark.asyncio +async def test_on_get_task_push_notification_config_is_owner_scoped( + agent_card, +): + """Bob cannot fetch Alice's config by ID via tasks/pushNotificationConfig/get. + + Even when Bob can read the task and knows (or guesses) the + config_id, the handler must raise TaskNotFoundError because Alice's + config is not in Bob's owner partition. + """ + mock_task_store = AsyncMock(spec=TaskStore) + mock_task_store.get.return_value = create_sample_task(task_id='shared-task') + + push_store = InMemoryPushNotificationConfigStore() + alice_ctx = _ctx('alice') + await push_store.set_info( + 'shared-task', + TaskPushNotificationConfig( + task_id='shared-task', + id='alice-cfg', + url='http://alice.example.com/cb', + token='alice-secret', + ), + alice_ctx, + ) + + request_handler = DefaultRequestHandler( + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + push_config_store=push_store, + agent_card=agent_card, + ) + + # Alice can read her own config. + alice_view = await request_handler.on_get_task_push_notification_config( + GetTaskPushNotificationConfigRequest( + task_id='shared-task', id='alice-cfg' + ), + alice_ctx, + ) + assert alice_view.id == 'alice-cfg' + assert alice_view.token == 'alice-secret' + + # Bob cannot, even guessing the exact config_id. + with pytest.raises(TaskNotFoundError): + await request_handler.on_get_task_push_notification_config( + GetTaskPushNotificationConfigRequest( + task_id='shared-task', id='alice-cfg' + ), + _ctx('bob'), + ) diff --git a/tests/server/request_handlers/test_default_request_handler_v2.py b/tests/server/request_handlers/test_default_request_handler_v2.py new file mode 100644 index 000000000..3f33516d3 --- /dev/null +++ b/tests/server/request_handlers/test_default_request_handler_v2.py @@ -0,0 +1,1529 @@ +import asyncio +import logging +import time +import uuid + +from unittest.mock import AsyncMock, patch, MagicMock + +import pytest + +from a2a.auth.user import UnauthenticatedUser, User +from a2a.server.agent_execution import ( + RequestContextBuilder, + AgentExecutor, + RequestContext, + SimpleRequestContextBuilder, +) +from a2a.server.agent_execution.active_task_registry import ActiveTaskRegistry +from a2a.server.context import ServerCallContext +from a2a.server.events import EventQueue, InMemoryQueueManager, QueueManager +from a2a.server.request_handlers import DefaultRequestHandlerV2 +from a2a.server.tasks import ( + InMemoryPushNotificationConfigStore, + InMemoryTaskStore, + PushNotificationConfigStore, + PushNotificationSender, + TaskStore, + TaskUpdater, +) +from a2a.types import ( + InternalError, + InvalidAgentResponseError, + InvalidParamsError, + TaskNotFoundError, + PushNotificationNotSupportedError, +) +from a2a.types.a2a_pb2 import ( + AgentCapabilities, + AgentCard, + Artifact, + CancelTaskRequest, + DeleteTaskPushNotificationConfigRequest, + GetTaskPushNotificationConfigRequest, + GetTaskRequest, + ListTaskPushNotificationConfigsRequest, + ListTasksRequest, + ListTasksResponse, + Message, + Part, + Role, + SendMessageConfiguration, + SendMessageRequest, + SubscribeToTaskRequest, + Task, + TaskPushNotificationConfig, + TaskState, + TaskStatus, + TaskStatusUpdateEvent, +) +from a2a.helpers.proto_helpers import ( + new_text_message, + new_task_from_user_message, +) + + +def create_default_agent_card(): + """Provides a standard AgentCard with streaming and push notifications enabled for tests.""" + return AgentCard( + name='test_agent', + version='1.0', + capabilities=AgentCapabilities(streaming=True, push_notifications=True), + ) + + +class MockAgentExecutor(AgentExecutor): + async def execute(self, context: RequestContext, event_queue: EventQueue): + if context.message: + await event_queue.enqueue_event( + new_task_from_user_message(context.message) + ) + + task_updater = TaskUpdater( + event_queue, + str(context.task_id or ''), + str(context.context_id or ''), + ) + + async for i in self._run(): + parts = [Part(text=f'Event {i}')] + try: + await task_updater.update_status( + TaskState.TASK_STATE_WORKING, + message=task_updater.new_agent_message(parts), + ) + except RuntimeError: + break + + async def _run(self): + for i in range(1000000): + yield i + + async def cancel(self, context: RequestContext, event_queue: EventQueue): + pass + + +def create_sample_task( + task_id='task1', + status_state=TaskState.TASK_STATE_SUBMITTED, + context_id='ctx1', +) -> Task: + return Task( + id=task_id, context_id=context_id, status=TaskStatus(state=status_state) + ) + + +def create_server_call_context() -> ServerCallContext: + return ServerCallContext(user=UnauthenticatedUser()) + + +def test_init_default_dependencies(): + """Test that default dependencies are created if not provided.""" + agent_executor = MockAgentExecutor() + task_store = InMemoryTaskStore() + handler = DefaultRequestHandlerV2( + agent_executor=agent_executor, + task_store=task_store, + agent_card=create_default_agent_card(), + ) + assert isinstance(handler._active_task_registry, ActiveTaskRegistry) + assert isinstance( + handler._request_context_builder, SimpleRequestContextBuilder + ) + assert handler._push_config_store is None + assert handler._push_sender is None + assert ( + handler._request_context_builder._should_populate_referred_tasks + is False + ) + assert handler._request_context_builder._task_store == task_store + + +@pytest.mark.asyncio +async def test_on_get_task_not_found(): + """Test on_get_task when task_store.get returns None.""" + mock_task_store = AsyncMock(spec=TaskStore) + mock_task_store.get.return_value = None + request_handler = DefaultRequestHandlerV2( + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + agent_card=create_default_agent_card(), + ) + params = GetTaskRequest(id='non_existent_task') + context = create_server_call_context() + with pytest.raises(TaskNotFoundError): + await request_handler.on_get_task(params, context) + mock_task_store.get.assert_awaited_once_with('non_existent_task', context) + + +@pytest.mark.asyncio +async def test_on_list_tasks_success(): + """Test on_list_tasks successfully returns a page of tasks .""" + mock_task_store = AsyncMock(spec=TaskStore) + task2 = create_sample_task(task_id='task2') + task2.artifacts.extend( + [ + Artifact( + artifact_id='artifact1', + parts=[Part(text='Hello world!')], + name='conversion_result', + ) + ] + ) + mock_page = ListTasksResponse( + tasks=[create_sample_task(task_id='task1'), task2], + next_page_token='123', # noqa: S106 + ) + mock_task_store.list.return_value = mock_page + request_handler = DefaultRequestHandlerV2( + agent_executor=AsyncMock(spec=AgentExecutor), + task_store=mock_task_store, + agent_card=create_default_agent_card(), + ) + params = ListTasksRequest(include_artifacts=True, page_size=10) + context = create_server_call_context() + result = await request_handler.on_list_tasks(params, context) + mock_task_store.list.assert_awaited_once_with(params, context) + assert result.tasks == mock_page.tasks + assert result.next_page_token == mock_page.next_page_token + + +@pytest.mark.asyncio +async def test_on_list_tasks_excludes_artifacts(): + """Test on_list_tasks excludes artifacts from returned tasks.""" + mock_task_store = AsyncMock(spec=TaskStore) + task2 = create_sample_task(task_id='task2') + task2.artifacts.extend( + [ + Artifact( + artifact_id='artifact1', + parts=[Part(text='Hello world!')], + name='conversion_result', + ) + ] + ) + mock_page = ListTasksResponse( + tasks=[create_sample_task(task_id='task1'), task2], + next_page_token='123', # noqa: S106 + ) + mock_task_store.list.return_value = mock_page + request_handler = DefaultRequestHandlerV2( + agent_executor=AsyncMock(spec=AgentExecutor), + task_store=mock_task_store, + agent_card=create_default_agent_card(), + ) + params = ListTasksRequest(include_artifacts=False, page_size=10) + context = create_server_call_context() + result = await request_handler.on_list_tasks(params, context) + assert not result.tasks[1].artifacts + + +@pytest.mark.asyncio +async def test_on_list_tasks_applies_history_length(): + """Test on_list_tasks applies history length filter.""" + mock_task_store = AsyncMock(spec=TaskStore) + history = [ + new_text_message('Hello 1!'), + new_text_message('Hello 2!'), + ] + task2 = create_sample_task(task_id='task2') + task2.history.extend(history) + mock_page = ListTasksResponse( + tasks=[create_sample_task(task_id='task1'), task2], + next_page_token='123', # noqa: S106 + ) + mock_task_store.list.return_value = mock_page + request_handler = DefaultRequestHandlerV2( + agent_executor=AsyncMock(spec=AgentExecutor), + task_store=mock_task_store, + agent_card=create_default_agent_card(), + ) + params = ListTasksRequest(history_length=1, page_size=10) + context = create_server_call_context() + result = await request_handler.on_list_tasks(params, context) + assert result.tasks[1].history == [history[1]] + + +@pytest.mark.asyncio +async def test_on_list_tasks_negative_history_length_error(): + """Test on_list_tasks raises error for negative history length.""" + mock_task_store = AsyncMock(spec=TaskStore) + request_handler = DefaultRequestHandlerV2( + agent_executor=AsyncMock(spec=AgentExecutor), + task_store=mock_task_store, + agent_card=create_default_agent_card(), + ) + params = ListTasksRequest(history_length=-1, page_size=10) + context = create_server_call_context() + with pytest.raises(InvalidParamsError) as exc_info: + await request_handler.on_list_tasks(params, context) + assert 'history length must be non-negative' in exc_info.value.message + + +@pytest.mark.asyncio +async def test_on_cancel_task_task_not_found(): + """Test on_cancel_task when the task is not found.""" + mock_task_store = AsyncMock(spec=TaskStore) + mock_task_store.get.return_value = None + request_handler = DefaultRequestHandlerV2( + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + agent_card=create_default_agent_card(), + ) + params = CancelTaskRequest(id='task_not_found_for_cancel') + context = create_server_call_context() + with pytest.raises(TaskNotFoundError): + await request_handler.on_cancel_task(params, context) + mock_task_store.get.assert_awaited_once_with( + 'task_not_found_for_cancel', context + ) + + +class HelloAgentExecutor(AgentExecutor): + async def execute(self, context: RequestContext, event_queue: EventQueue): + task = context.current_task + if not task: + assert context.message is not None, ( + 'A message is required to create a new task' + ) + task = new_task_from_user_message(context.message) + await event_queue.enqueue_event(task) + updater = TaskUpdater(event_queue, task.id, task.context_id) + try: + parts = [Part(text='I am working')] + await updater.update_status( + TaskState.TASK_STATE_WORKING, + message=updater.new_agent_message(parts), + ) + except Exception as e: # noqa: BLE001 + logging.warning('Error: %s', e) + return + await updater.add_artifact( + [Part(text='Hello world!')], name='conversion_result' + ) + await updater.complete() + + async def cancel(self, context: RequestContext, event_queue: EventQueue): + pass + + +@pytest.mark.asyncio +async def test_on_get_task_limit_history(): + task_store = InMemoryTaskStore() + push_store = InMemoryPushNotificationConfigStore() + request_handler = DefaultRequestHandlerV2( + agent_executor=HelloAgentExecutor(), + task_store=task_store, + push_config_store=push_store, + agent_card=create_default_agent_card(), + ) + params = SendMessageRequest( + message=Message( + role=Role.ROLE_USER, message_id='msg_push', parts=[Part(text='Hi')] + ), + configuration=SendMessageConfiguration( + accepted_output_modes=['text/plain'] + ), + ) + result = await request_handler.on_message_send( + params, create_server_call_context() + ) + assert result is not None + assert isinstance(result, Task) + get_task_result = await request_handler.on_get_task( + GetTaskRequest(id=result.id, history_length=1), + create_server_call_context(), + ) + assert get_task_result is not None + assert isinstance(get_task_result, Task) + assert ( + get_task_result.history is not None + and len(get_task_result.history) == 1 + ) + + +async def wait_until(predicate, timeout: float = 0.2, interval: float = 0.0): + """Await until predicate() is True or timeout elapses.""" + loop = asyncio.get_running_loop() + end = loop.time() + timeout + while True: + if predicate(): + return + if loop.time() >= end: + raise AssertionError('condition not met within timeout') + await asyncio.sleep(interval) + + +@pytest.mark.asyncio +async def test_set_task_push_notification_config_no_notifier(): + """Test on_create_task_push_notification_config when _push_config_store is None.""" + request_handler = DefaultRequestHandlerV2( + agent_executor=MockAgentExecutor(), + task_store=AsyncMock(spec=TaskStore), + push_config_store=None, + agent_card=create_default_agent_card(), + ) + params = TaskPushNotificationConfig( + task_id='task1', url='http://example.com' + ) + with pytest.raises(PushNotificationNotSupportedError): + await request_handler.on_create_task_push_notification_config( + params, create_server_call_context() + ) + + +@pytest.mark.asyncio +async def test_set_task_push_notification_config_task_not_found(): + """Test on_create_task_push_notification_config when task is not found.""" + mock_task_store = AsyncMock(spec=TaskStore) + mock_task_store.get.return_value = None + mock_push_store = AsyncMock(spec=PushNotificationConfigStore) + mock_push_sender = AsyncMock(spec=PushNotificationSender) + request_handler = DefaultRequestHandlerV2( + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + push_config_store=mock_push_store, + push_sender=mock_push_sender, + agent_card=create_default_agent_card(), + ) + params = TaskPushNotificationConfig( + task_id='non_existent_task', url='http://example.com' + ) + context = create_server_call_context() + with pytest.raises(TaskNotFoundError): + await request_handler.on_create_task_push_notification_config( + params, context + ) + mock_task_store.get.assert_awaited_once_with('non_existent_task', context) + mock_push_store.set_info.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_get_task_push_notification_config_no_store(): + """Test on_get_task_push_notification_config when _push_config_store is None.""" + request_handler = DefaultRequestHandlerV2( + agent_executor=MockAgentExecutor(), + task_store=AsyncMock(spec=TaskStore), + push_config_store=None, + agent_card=create_default_agent_card(), + ) + params = GetTaskPushNotificationConfigRequest( + task_id='task1', id='task_push_notification_config' + ) + with pytest.raises(PushNotificationNotSupportedError): + await request_handler.on_get_task_push_notification_config( + params, create_server_call_context() + ) + + +@pytest.mark.asyncio +async def test_get_task_push_notification_config_task_not_found(): + """Test on_get_task_push_notification_config when task is not found.""" + mock_task_store = AsyncMock(spec=TaskStore) + mock_task_store.get.return_value = None + mock_push_store = AsyncMock(spec=PushNotificationConfigStore) + request_handler = DefaultRequestHandlerV2( + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + push_config_store=mock_push_store, + agent_card=create_default_agent_card(), + ) + params = GetTaskPushNotificationConfigRequest( + task_id='non_existent_task', id='task_push_notification_config' + ) + context = create_server_call_context() + with pytest.raises(TaskNotFoundError): + await request_handler.on_get_task_push_notification_config( + params, context + ) + mock_task_store.get.assert_awaited_once_with('non_existent_task', context) + mock_push_store.get_info.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_get_task_push_notification_config_info_not_found(): + """Test on_get_task_push_notification_config when push_config_store.get_info returns None.""" + mock_task_store = AsyncMock(spec=TaskStore) + sample_task = create_sample_task(task_id='non_existent_task') + mock_task_store.get.return_value = sample_task + mock_push_store = AsyncMock(spec=PushNotificationConfigStore) + mock_push_store.get_info.return_value = None + request_handler = DefaultRequestHandlerV2( + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + push_config_store=mock_push_store, + agent_card=create_default_agent_card(), + ) + params = GetTaskPushNotificationConfigRequest( + task_id='non_existent_task', id='task_push_notification_config' + ) + context = create_server_call_context() + with pytest.raises(TaskNotFoundError): + await request_handler.on_get_task_push_notification_config( + params, context + ) + mock_task_store.get.assert_awaited_once_with('non_existent_task', context) + mock_push_store.get_info.assert_awaited_once_with( + 'non_existent_task', context + ) + + +@pytest.mark.asyncio +async def test_get_task_push_notification_config_info_with_config(): + """Test on_get_task_push_notification_config with valid push config id""" + mock_task_store = AsyncMock(spec=TaskStore) + mock_task_store.get.return_value = Task(id='task_1', context_id='ctx_1') + push_store = InMemoryPushNotificationConfigStore() + request_handler = DefaultRequestHandlerV2( + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + push_config_store=push_store, + agent_card=create_default_agent_card(), + ) + set_config_params = TaskPushNotificationConfig( + task_id='task_1', id='config_id', url='http://1.example.com' + ) + context = create_server_call_context() + await request_handler.on_create_task_push_notification_config( + set_config_params, context + ) + params = GetTaskPushNotificationConfigRequest( + task_id='task_1', id='config_id' + ) + result: TaskPushNotificationConfig = ( + await request_handler.on_get_task_push_notification_config( + params, context + ) + ) + assert result is not None + assert result.task_id == 'task_1' + assert result.url == set_config_params.url + assert result.id == 'config_id' + + +@pytest.mark.asyncio +async def test_get_task_push_notification_config_info_with_config_no_id(): + """Test on_get_task_push_notification_config with no push config id""" + mock_task_store = AsyncMock(spec=TaskStore) + mock_task_store.get.return_value = Task(id='task_1', context_id='ctx_1') + push_store = InMemoryPushNotificationConfigStore() + request_handler = DefaultRequestHandlerV2( + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + push_config_store=push_store, + agent_card=create_default_agent_card(), + ) + set_config_params = TaskPushNotificationConfig( + task_id='task_1', url='http://1.example.com' + ) + await request_handler.on_create_task_push_notification_config( + set_config_params, create_server_call_context() + ) + params = GetTaskPushNotificationConfigRequest(task_id='task_1', id='task_1') + result: TaskPushNotificationConfig = ( + await request_handler.on_get_task_push_notification_config( + params, create_server_call_context() + ) + ) + assert result is not None + assert result.task_id == 'task_1' + assert result.url == set_config_params.url + assert result.id == 'task_1' + + +@pytest.mark.asyncio +async def test_on_subscribe_to_task_task_not_found(): + """Test on_subscribe_to_task when the task is not found.""" + mock_task_store = AsyncMock(spec=TaskStore) + mock_task_store.get.return_value = None + request_handler = DefaultRequestHandlerV2( + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + agent_card=create_default_agent_card(), + ) + params = SubscribeToTaskRequest(id='resub_task_not_found') + context = create_server_call_context() + with pytest.raises(TaskNotFoundError): + async for _ in request_handler.on_subscribe_to_task(params, context): + pass + mock_task_store.get.assert_awaited_once_with( + 'resub_task_not_found', context + ) + + +@pytest.mark.asyncio +async def test_on_message_send_stream(): + request_handler = DefaultRequestHandlerV2( + MockAgentExecutor(), + InMemoryTaskStore(), + create_default_agent_card(), + ) + message_params = SendMessageRequest( + message=Message( + role=Role.ROLE_USER, + message_id='msg-123', + parts=[Part(text='How are you?')], + ) + ) + + async def consume_stream(): + events = [] + async for event in request_handler.on_message_send_stream( + message_params, create_server_call_context() + ): + events.append(event) + if len(events) >= 3: + break + return events + + start = time.perf_counter() + events = await consume_stream() + elapsed = time.perf_counter() - start + assert len(events) == 3 + assert elapsed < 0.5 + task, event0, event1 = events + assert isinstance(task, Task) + assert task.history[0].parts[0].text == 'How are you?' + + assert isinstance(event0, TaskStatusUpdateEvent) + assert event0.status.message.parts[0].text == 'Event 0' + + assert isinstance(event1, TaskStatusUpdateEvent) + assert event1.status.message.parts[0].text == 'Event 1' + + +@pytest.mark.asyncio +async def test_list_task_push_notification_config_no_store(): + """Test on_list_task_push_notification_configs when _push_config_store is None.""" + request_handler = DefaultRequestHandlerV2( + agent_executor=MockAgentExecutor(), + task_store=AsyncMock(spec=TaskStore), + push_config_store=None, + agent_card=create_default_agent_card(), + ) + params = ListTaskPushNotificationConfigsRequest(task_id='task1') + with pytest.raises(PushNotificationNotSupportedError): + await request_handler.on_list_task_push_notification_configs( + params, create_server_call_context() + ) + + +@pytest.mark.asyncio +async def test_list_task_push_notification_config_task_not_found(): + """Test on_list_task_push_notification_configs when task is not found.""" + mock_task_store = AsyncMock(spec=TaskStore) + mock_task_store.get.return_value = None + mock_push_store = AsyncMock(spec=PushNotificationConfigStore) + request_handler = DefaultRequestHandlerV2( + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + push_config_store=mock_push_store, + agent_card=create_default_agent_card(), + ) + params = ListTaskPushNotificationConfigsRequest(task_id='non_existent_task') + context = create_server_call_context() + with pytest.raises(TaskNotFoundError): + await request_handler.on_list_task_push_notification_configs( + params, context + ) + mock_task_store.get.assert_awaited_once_with('non_existent_task', context) + mock_push_store.get_info.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_list_no_task_push_notification_config_info(): + """Test on_get_task_push_notification_config when push_config_store.get_info returns []""" + mock_task_store = AsyncMock(spec=TaskStore) + sample_task = create_sample_task(task_id='non_existent_task') + mock_task_store.get.return_value = sample_task + push_store = InMemoryPushNotificationConfigStore() + request_handler = DefaultRequestHandlerV2( + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + push_config_store=push_store, + agent_card=create_default_agent_card(), + ) + params = ListTaskPushNotificationConfigsRequest(task_id='non_existent_task') + result = await request_handler.on_list_task_push_notification_configs( + params, create_server_call_context() + ) + assert result.configs == [] + + +@pytest.mark.asyncio +async def test_list_task_push_notification_config_info_with_config(): + """Test on_list_task_push_notification_configs with push config+id""" + mock_task_store = AsyncMock(spec=TaskStore) + sample_task = create_sample_task(task_id='non_existent_task') + mock_task_store.get.return_value = sample_task + push_config1 = TaskPushNotificationConfig( + task_id='task_1', id='config_1', url='http://example.com' + ) + push_config2 = TaskPushNotificationConfig( + task_id='task_1', id='config_2', url='http://example.com' + ) + push_store = InMemoryPushNotificationConfigStore() + context = create_server_call_context() + await push_store.set_info('task_1', push_config1, context) + await push_store.set_info('task_1', push_config2, context) + await push_store.set_info('task_2', push_config1, context) + request_handler = DefaultRequestHandlerV2( + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + push_config_store=push_store, + agent_card=create_default_agent_card(), + ) + params = ListTaskPushNotificationConfigsRequest(task_id='task_1') + result = await request_handler.on_list_task_push_notification_configs( + params, create_server_call_context() + ) + assert len(result.configs) == 2 + assert result.configs[0].task_id == 'task_1' + assert result.configs[0] == push_config1 + assert result.configs[1].task_id == 'task_1' + assert result.configs[1] == push_config2 + + +@pytest.mark.asyncio +async def test_list_task_push_notification_config_info_with_config_and_no_id(): + """Test on_list_task_push_notification_configs with no push config id""" + mock_task_store = AsyncMock(spec=TaskStore) + mock_task_store.get.return_value = Task(id='task_1', context_id='ctx_1') + push_store = InMemoryPushNotificationConfigStore() + request_handler = DefaultRequestHandlerV2( + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + push_config_store=push_store, + agent_card=create_default_agent_card(), + ) + set_config_params1 = TaskPushNotificationConfig( + task_id='task_1', url='http://1.example.com' + ) + await request_handler.on_create_task_push_notification_config( + set_config_params1, create_server_call_context() + ) + set_config_params2 = TaskPushNotificationConfig( + task_id='task_1', url='http://2.example.com' + ) + await request_handler.on_create_task_push_notification_config( + set_config_params2, create_server_call_context() + ) + params = ListTaskPushNotificationConfigsRequest(task_id='task_1') + result = await request_handler.on_list_task_push_notification_configs( + params, create_server_call_context() + ) + assert len(result.configs) == 1 + assert result.configs[0].task_id == 'task_1' + assert result.configs[0].url == set_config_params2.url + assert result.configs[0].id == 'task_1' + + +@pytest.mark.asyncio +async def test_delete_task_push_notification_config_no_store(): + """Test on_delete_task_push_notification_config when _push_config_store is None.""" + request_handler = DefaultRequestHandlerV2( + agent_executor=MockAgentExecutor(), + task_store=AsyncMock(spec=TaskStore), + push_config_store=None, + agent_card=create_default_agent_card(), + ) + params = DeleteTaskPushNotificationConfigRequest( + task_id='task1', id='config1' + ) + with pytest.raises(PushNotificationNotSupportedError) as exc_info: + await request_handler.on_delete_task_push_notification_config( + params, create_server_call_context() + ) + assert isinstance(exc_info.value, PushNotificationNotSupportedError) + + +@pytest.mark.asyncio +async def test_delete_task_push_notification_config_task_not_found(): + """Test on_delete_task_push_notification_config when task is not found.""" + mock_task_store = AsyncMock(spec=TaskStore) + mock_task_store.get.return_value = None + mock_push_store = AsyncMock(spec=PushNotificationConfigStore) + request_handler = DefaultRequestHandlerV2( + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + push_config_store=mock_push_store, + agent_card=create_default_agent_card(), + ) + params = DeleteTaskPushNotificationConfigRequest( + task_id='non_existent_task', id='config1' + ) + context = create_server_call_context() + with pytest.raises(TaskNotFoundError): + await request_handler.on_delete_task_push_notification_config( + params, context + ) + mock_task_store.get.assert_awaited_once_with('non_existent_task', context) + mock_push_store.get_info.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_delete_no_task_push_notification_config_info(): + """Test on_delete_task_push_notification_config without config info""" + mock_task_store = AsyncMock(spec=TaskStore) + sample_task = create_sample_task(task_id='task_1') + mock_task_store.get.return_value = sample_task + push_store = InMemoryPushNotificationConfigStore() + await push_store.set_info( + 'task_2', + TaskPushNotificationConfig(id='config_1', url='http://example.com'), + create_server_call_context(), + ) + request_handler = DefaultRequestHandlerV2( + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + push_config_store=push_store, + agent_card=create_default_agent_card(), + ) + params = DeleteTaskPushNotificationConfigRequest( + task_id='task1', id='config_non_existant' + ) + result = await request_handler.on_delete_task_push_notification_config( + params, create_server_call_context() + ) + assert result is None + params = DeleteTaskPushNotificationConfigRequest( + task_id='task2', id='config_non_existant' + ) + result = await request_handler.on_delete_task_push_notification_config( + params, create_server_call_context() + ) + assert result is None + + +@pytest.mark.asyncio +async def test_delete_task_push_notification_config_info_with_config(): + """Test on_list_task_push_notification_configs with push config+id""" + mock_task_store = AsyncMock(spec=TaskStore) + sample_task = create_sample_task(task_id='non_existent_task') + mock_task_store.get.return_value = sample_task + push_config1 = TaskPushNotificationConfig( + task_id='task_1', id='config_1', url='http://example.com' + ) + push_config2 = TaskPushNotificationConfig( + task_id='task_1', id='config_2', url='http://example.com' + ) + push_store = InMemoryPushNotificationConfigStore() + context = create_server_call_context() + await push_store.set_info('task_1', push_config1, context) + await push_store.set_info('task_1', push_config2, context) + await push_store.set_info('task_2', push_config1, context) + request_handler = DefaultRequestHandlerV2( + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + push_config_store=push_store, + agent_card=create_default_agent_card(), + ) + params = DeleteTaskPushNotificationConfigRequest( + task_id='task_1', id='config_1' + ) + result1 = await request_handler.on_delete_task_push_notification_config( + params, create_server_call_context() + ) + assert result1 is None + result2 = await request_handler.on_list_task_push_notification_configs( + ListTaskPushNotificationConfigsRequest(task_id='task_1'), + create_server_call_context(), + ) + assert len(result2.configs) == 1 + assert result2.configs[0].task_id == 'task_1' + assert result2.configs[0] == push_config2 + + +@pytest.mark.asyncio +async def test_delete_task_push_notification_config_info_with_config_and_no_id(): + """Test on_list_task_push_notification_configs with no push config id""" + mock_task_store = AsyncMock(spec=TaskStore) + sample_task = create_sample_task(task_id='non_existent_task') + mock_task_store.get.return_value = sample_task + push_config = TaskPushNotificationConfig(url='http://example.com') + push_store = InMemoryPushNotificationConfigStore() + context = create_server_call_context() + await push_store.set_info('task_1', push_config, context) + await push_store.set_info('task_1', push_config, context) + request_handler = DefaultRequestHandlerV2( + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + push_config_store=push_store, + agent_card=create_default_agent_card(), + ) + params = DeleteTaskPushNotificationConfigRequest( + task_id='task_1', id='task_1' + ) + result = await request_handler.on_delete_task_push_notification_config( + params, create_server_call_context() + ) + assert result is None + result2 = await request_handler.on_list_task_push_notification_configs( + ListTaskPushNotificationConfigsRequest(task_id='task_1'), + create_server_call_context(), + ) + assert len(result2.configs) == 0 + + +TERMINAL_TASK_STATES = { + TaskState.TASK_STATE_COMPLETED, + TaskState.TASK_STATE_CANCELED, + TaskState.TASK_STATE_FAILED, + TaskState.TASK_STATE_REJECTED, +} + + +@pytest.mark.asyncio +@pytest.mark.parametrize('terminal_state', TERMINAL_TASK_STATES) +async def test_on_message_send_task_in_terminal_state(terminal_state): + """Test on_message_send when task is already in a terminal state.""" + state_name = TaskState.Name(terminal_state) + task_id = f'terminal_task_{state_name}' + terminal_task = create_sample_task( + task_id=task_id, status_state=terminal_state + ) + mock_task_store = AsyncMock(spec=TaskStore) + request_handler = DefaultRequestHandlerV2( + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + agent_card=create_default_agent_card(), + ) + params = SendMessageRequest( + message=Message( + role=Role.ROLE_USER, + message_id='msg_terminal', + parts=[Part(text='hello')], + task_id=task_id, + ) + ) + with ( + patch( + 'a2a.server.request_handlers.default_request_handler.TaskManager.get_task', + return_value=terminal_task, + ), + pytest.raises(InvalidParamsError) as exc_info, + ): + await request_handler.on_message_send( + params, create_server_call_context() + ) + assert ( + f'Task {task_id} is in terminal state: {terminal_state}' + in exc_info.value.message + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize('terminal_state', TERMINAL_TASK_STATES) +async def test_on_message_send_stream_task_in_terminal_state(terminal_state): + """Test on_message_send_stream when task is already in a terminal state.""" + state_name = TaskState.Name(terminal_state) + task_id = f'terminal_stream_task_{state_name}' + terminal_task = create_sample_task( + task_id=task_id, status_state=terminal_state + ) + mock_task_store = AsyncMock(spec=TaskStore) + request_handler = DefaultRequestHandlerV2( + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + agent_card=create_default_agent_card(), + ) + params = SendMessageRequest( + message=Message( + role=Role.ROLE_USER, + message_id='msg_terminal_stream', + parts=[Part(text='hello')], + task_id=task_id, + ) + ) + with ( + patch( + 'a2a.server.request_handlers.default_request_handler.TaskManager.get_task', + return_value=terminal_task, + ), + pytest.raises(InvalidParamsError) as exc_info, + ): + async for _ in request_handler.on_message_send_stream( + params, create_server_call_context() + ): + pass + assert ( + f'Task {task_id} is in terminal state: {terminal_state}' + in exc_info.value.message + ) + + +@pytest.mark.asyncio +async def test_on_message_send_task_id_provided_but_task_not_found(): + """Test on_message_send when task_id is provided but task doesn't exist.""" + pass + + +@pytest.mark.asyncio +async def test_on_message_send_stream_task_id_provided_but_task_not_found(): + """Test on_message_send_stream when task_id is provided but task doesn't exist.""" + pass + + +class HelloWorldAgentExecutor(AgentExecutor): + """Test Agent Implementation.""" + + async def execute( + self, context: RequestContext, event_queue: EventQueue + ) -> None: + if context.message: + await event_queue.enqueue_event( + new_task_from_user_message(context.message) + ) + updater = TaskUpdater( + event_queue, + task_id=context.task_id or str(uuid.uuid4()), + context_id=context.context_id or str(uuid.uuid4()), + ) + await updater.update_status(TaskState.TASK_STATE_WORKING) + await updater.complete() + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ) -> None: + raise NotImplementedError('cancel not supported') + + +@pytest.mark.asyncio +@pytest.mark.timeout(1) +async def test_on_message_send_error_does_not_hang(): + """Test that if the consumer raises an exception during blocking wait, the producer is cancelled and no deadlock occurs.""" + agent = HelloWorldAgentExecutor() + task_store = AsyncMock(spec=TaskStore) + task_store.get.return_value = None + task_store.save.side_effect = RuntimeError('This is an Error!') + + request_handler = DefaultRequestHandlerV2( + agent_executor=agent, + task_store=task_store, + agent_card=create_default_agent_card(), + ) + + params = SendMessageRequest( + message=Message( + role=Role.ROLE_USER, + message_id='msg_error_blocking', + parts=[Part(text='Test message')], + ) + ) + with pytest.raises(RuntimeError, match='This is an Error!'): + await request_handler.on_message_send( + params, create_server_call_context() + ) + + +@pytest.mark.asyncio +async def test_on_get_task_negative_history_length_error(): + """Test on_get_task raises error for negative history length.""" + mock_task_store = AsyncMock(spec=TaskStore) + request_handler = DefaultRequestHandlerV2( + agent_executor=AsyncMock(spec=AgentExecutor), + task_store=mock_task_store, + agent_card=create_default_agent_card(), + ) + params = GetTaskRequest(id='task1', history_length=-1) + context = create_server_call_context() + with pytest.raises(InvalidParamsError) as exc_info: + await request_handler.on_get_task(params, context) + assert 'history length must be non-negative' in exc_info.value.message + + +@pytest.mark.asyncio +async def test_on_list_tasks_page_size_too_small(): + """Test on_list_tasks raises error for page_size < 1.""" + mock_task_store = AsyncMock(spec=TaskStore) + request_handler = DefaultRequestHandlerV2( + agent_executor=AsyncMock(spec=AgentExecutor), + task_store=mock_task_store, + agent_card=create_default_agent_card(), + ) + params = ListTasksRequest(page_size=0) + context = create_server_call_context() + with pytest.raises(InvalidParamsError) as exc_info: + await request_handler.on_list_tasks(params, context) + assert 'minimum page size is 1' in exc_info.value.message + + +@pytest.mark.asyncio +async def test_on_list_tasks_page_size_too_large(): + """Test on_list_tasks raises error for page_size > 100.""" + mock_task_store = AsyncMock(spec=TaskStore) + request_handler = DefaultRequestHandlerV2( + agent_executor=AsyncMock(spec=AgentExecutor), + task_store=mock_task_store, + agent_card=create_default_agent_card(), + ) + params = ListTasksRequest(page_size=101) + context = create_server_call_context() + with pytest.raises(InvalidParamsError) as exc_info: + await request_handler.on_list_tasks(params, context) + assert 'maximum page size is 100' in exc_info.value.message + + +@pytest.mark.asyncio +async def test_on_message_send_negative_history_length_error(): + """Test on_message_send raises error for negative history length in configuration.""" + mock_task_store = AsyncMock(spec=TaskStore) + mock_agent_executor = AsyncMock(spec=AgentExecutor) + request_handler = DefaultRequestHandlerV2( + agent_executor=mock_agent_executor, + task_store=mock_task_store, + agent_card=create_default_agent_card(), + ) + message_config = SendMessageConfiguration( + history_length=-1, accepted_output_modes=['text/plain'] + ) + params = SendMessageRequest( + message=Message( + role=Role.ROLE_USER, message_id='msg1', parts=[Part(text='hello')] + ), + configuration=message_config, + ) + context = create_server_call_context() + with pytest.raises(InvalidParamsError) as exc_info: + await request_handler.on_message_send(params, context) + assert 'history length must be non-negative' in exc_info.value.message + + +@pytest.mark.asyncio +async def test_on_message_send_limit_history(): + task_store = InMemoryTaskStore() + push_store = InMemoryPushNotificationConfigStore() + + request_handler = DefaultRequestHandlerV2( + agent_executor=HelloAgentExecutor(), + task_store=task_store, + push_config_store=push_store, + agent_card=create_default_agent_card(), + ) + params = SendMessageRequest( + message=Message( + role=Role.ROLE_USER, + message_id='msg_push', + parts=[Part(text='Hi')], + ), + configuration=SendMessageConfiguration( + accepted_output_modes=['text/plain'], + history_length=1, + ), + ) + + context = create_server_call_context() + result = await request_handler.on_message_send(params, context) + + # verify that history_length is honored + assert result is not None + assert isinstance(result, Task) + assert result.history is not None and len(result.history) == 1 + assert result.status.state == TaskState.TASK_STATE_COMPLETED + + # verify that history is still persisted to the store + task = await task_store.get(result.id, context) + assert task is not None + assert task.history is not None and len(task.history) > 1 + + +@pytest.mark.asyncio +async def test_on_message_send_stream_task_id_mismatch(): + mock_task_store = AsyncMock(spec=TaskStore) + mock_agent_executor = AsyncMock(spec=AgentExecutor) + mock_request_context_builder = AsyncMock(spec=RequestContextBuilder) + + context_task_id = 'context_task_id_stream_1' + result_task_id = 'DIFFERENT_task_id_stream_1' + + mock_request_context = MagicMock() + mock_request_context.task_id = context_task_id + mock_request_context_builder.build.return_value = mock_request_context + + request_handler = DefaultRequestHandlerV2( + agent_executor=mock_agent_executor, + task_store=mock_task_store, + request_context_builder=mock_request_context_builder, + agent_card=create_default_agent_card(), + ) + params = SendMessageRequest( + message=Message( + role=Role.ROLE_USER, + message_id='msg_id_mismatch_stream', + parts=[Part(text='hello')], + ) + ) + + mismatched_task = create_sample_task(task_id=result_task_id) + + async def mock_subscribe(request=None, include_initial_task=False): + yield mismatched_task + + mock_active_task = MagicMock() + mock_active_task.subscribe.side_effect = mock_subscribe + mock_active_task.start = AsyncMock() + mock_active_task.enqueue_request = AsyncMock() + + with ( + patch.object( + request_handler._active_task_registry, + 'get_or_create', + return_value=mock_active_task, + ), + patch( + 'a2a.server.request_handlers.default_request_handler.TaskManager.get_task', + return_value=None, + ), + ): + stream = request_handler.on_message_send_stream( + params, context=MagicMock() + ) + with pytest.raises(InternalError) as exc_info: + async for _ in stream: + pass + assert 'Task ID mismatch' in exc_info.value.message + + +@pytest.mark.asyncio +async def test_on_message_send_non_blocking(): + task_store = InMemoryTaskStore() + push_store = InMemoryPushNotificationConfigStore() + + request_handler = DefaultRequestHandlerV2( + agent_executor=HelloAgentExecutor(), + task_store=task_store, + push_config_store=push_store, + agent_card=create_default_agent_card(), + ) + params = SendMessageRequest( + message=Message( + role=Role.ROLE_USER, + message_id='msg_push_non_blocking', + parts=[Part(text='Hi')], + ), + configuration=SendMessageConfiguration( + return_immediately=True, + ), + ) + + context = create_server_call_context() + result = await request_handler.on_message_send(params, context) + + # non-blocking should return the task immediately + assert result is not None + assert isinstance(result, Task) + assert result.status.state == TaskState.TASK_STATE_SUBMITTED + + +@pytest.mark.asyncio +async def test_on_message_send_with_push_notification(): + task_store = InMemoryTaskStore() + push_store = AsyncMock(spec=PushNotificationConfigStore) + + request_handler = DefaultRequestHandlerV2( + agent_executor=HelloAgentExecutor(), + task_store=task_store, + push_config_store=push_store, + agent_card=create_default_agent_card(), + ) + push_config = TaskPushNotificationConfig(url='http://example.com/webhook') + params = SendMessageRequest( + message=Message( + role=Role.ROLE_USER, + message_id='msg_push_1', + parts=[Part(text='Hi')], + ), + configuration=SendMessageConfiguration( + task_push_notification_config=push_config + ), + ) + + context = create_server_call_context() + result = await request_handler.on_message_send(params, context) + + assert result is not None + assert isinstance(result, Task) + push_store.set_info.assert_awaited_once_with( + result.id, push_config, context + ) + + +class MultipleMessagesAgentExecutor(AgentExecutor): + """Misbehaving agent that yields more than one Message.""" + + async def execute(self, context: RequestContext, event_queue: EventQueue): + await event_queue.enqueue_event( + new_text_message('first', role=Role.ROLE_AGENT) + ) + await event_queue.enqueue_event( + new_text_message('second', role=Role.ROLE_AGENT) + ) + + async def cancel(self, context: RequestContext, event_queue: EventQueue): + pass + + +class MessageAfterTaskEventAgentExecutor(AgentExecutor): + """Misbehaving agent that yields a task-mode event then a Message.""" + + async def execute(self, context: RequestContext, event_queue: EventQueue): + task = new_task_from_user_message(context.message) + await event_queue.enqueue_event(task) + updater = TaskUpdater(event_queue, task.id, task.context_id) + await updater.update_status(TaskState.TASK_STATE_WORKING) + await event_queue.enqueue_event( + new_text_message('stray message', role=Role.ROLE_AGENT) + ) + + async def cancel(self, context: RequestContext, event_queue: EventQueue): + pass + + +class TaskEventAfterMessageAgentExecutor(AgentExecutor): + """Misbehaving agent that yields a Message and then a task-mode event.""" + + async def execute(self, context: RequestContext, event_queue: EventQueue): + await event_queue.enqueue_event( + new_text_message('only message', role=Role.ROLE_AGENT) + ) + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=str(context.task_id or ''), + context_id=str(context.context_id or ''), + status=TaskStatus(state=TaskState.TASK_STATE_WORKING), + ) + ) + + async def cancel(self, context: RequestContext, event_queue: EventQueue): + pass + + +class EventAfterTerminalStateAgentExecutor(AgentExecutor): + """Misbehaving agent that yields an event after reaching a terminal state.""" + + async def execute(self, context: RequestContext, event_queue: EventQueue): + task = new_task_from_user_message(context.message) + await event_queue.enqueue_event(task) + updater = TaskUpdater(event_queue, task.id, task.context_id) + await updater.complete() + await event_queue.enqueue_event( + new_text_message('after terminal', role=Role.ROLE_AGENT) + ) + + async def cancel(self, context: RequestContext, event_queue: EventQueue): + pass + + +@pytest.mark.asyncio +@pytest.mark.timeout(1) +async def test_on_message_send_stream_rejects_multiple_messages(): + """Stream surfaces InvalidAgentResponseError when the agent yields a + second Message after the first one (see comment in on_message_send_stream).""" + request_handler = DefaultRequestHandlerV2( + agent_executor=MultipleMessagesAgentExecutor(), + task_store=InMemoryTaskStore(), + agent_card=create_default_agent_card(), + ) + params = SendMessageRequest( + message=Message( + role=Role.ROLE_USER, + message_id='msg_multi_stream', + parts=[Part(text='Hi')], + ) + ) + with pytest.raises(InvalidAgentResponseError, match='Multiple Message'): + async for _ in request_handler.on_message_send_stream( + params, create_server_call_context() + ): + pass + + +@pytest.mark.asyncio +@pytest.mark.timeout(1) +async def test_on_message_send_stream_rejects_message_after_task_event(): + """Stream surfaces InvalidAgentResponseError when the agent yields a + Message after entering task mode (see comment in on_message_send_stream).""" + request_handler = DefaultRequestHandlerV2( + agent_executor=MessageAfterTaskEventAgentExecutor(), + task_store=InMemoryTaskStore(), + agent_card=create_default_agent_card(), + ) + params = SendMessageRequest( + message=Message( + role=Role.ROLE_USER, + message_id='msg_after_task_stream', + parts=[Part(text='Hi')], + ) + ) + with pytest.raises( + InvalidAgentResponseError, match='Message object in task mode' + ): + async for _ in request_handler.on_message_send_stream( + params, create_server_call_context() + ): + pass + + +@pytest.mark.asyncio +@pytest.mark.timeout(1) +async def test_on_message_send_stream_rejects_task_event_after_message(): + """Stream surfaces InvalidAgentResponseError when the agent yields a + task-mode event after a Message (see comment in on_message_send_stream).""" + request_handler = DefaultRequestHandlerV2( + agent_executor=TaskEventAfterMessageAgentExecutor(), + task_store=InMemoryTaskStore(), + agent_card=create_default_agent_card(), + ) + params = SendMessageRequest( + message=Message( + role=Role.ROLE_USER, + message_id='msg_then_task_stream', + parts=[Part(text='Hi')], + ) + ) + with pytest.raises(InvalidAgentResponseError, match='in message mode'): + async for _ in request_handler.on_message_send_stream( + params, create_server_call_context() + ): + pass + + +@pytest.mark.asyncio +@pytest.mark.timeout(1) +async def test_on_message_send_stream_rejects_event_after_terminal_state(): + """Stream surfaces InvalidAgentResponseError when the agent yields an event + after reaching a terminal state (see comment in on_message_send_stream).""" + request_handler = DefaultRequestHandlerV2( + agent_executor=EventAfterTerminalStateAgentExecutor(), + task_store=InMemoryTaskStore(), + agent_card=create_default_agent_card(), + ) + params = SendMessageRequest( + message=Message( + role=Role.ROLE_USER, + message_id='msg_after_terminal_stream', + parts=[Part(text='Hi')], + ) + ) + with pytest.raises( + InvalidAgentResponseError, match='Message object in task mode' + ): + async for _ in request_handler.on_message_send_stream( + params, create_server_call_context() + ): + pass + + +class _NamedUser(User): + """Minimal authenticated test user identified by ``user_name``.""" + + def __init__(self, user_name: str) -> None: + self._user_name = user_name + + @property + def is_authenticated(self) -> bool: + return True + + @property + def user_name(self) -> str: + return self._user_name + + +def _ctx(user_name: str) -> ServerCallContext: + return ServerCallContext(user=_NamedUser(user_name)) + + +@pytest.mark.asyncio +async def test_on_list_task_push_notification_configs_is_owner_scoped(): + """v2 handler: Bob must not see Alice's configs via .../list.""" + mock_task_store = AsyncMock(spec=TaskStore) + mock_task_store.get.return_value = Task( + id='shared-task', context_id='ctx_1' + ) + + push_store = InMemoryPushNotificationConfigStore() + alice_ctx = _ctx('alice') + bob_ctx = _ctx('bob') + + alice_cfg = TaskPushNotificationConfig( + task_id='shared-task', + id='alice-cfg', + url='http://alice.example.com/cb', + token='alice-secret', + ) + bob_cfg = TaskPushNotificationConfig( + task_id='shared-task', + id='bob-cfg', + url='http://bob.example.com/cb', + token='bob-secret', + ) + await push_store.set_info('shared-task', alice_cfg, alice_ctx) + await push_store.set_info('shared-task', bob_cfg, bob_ctx) + + request_handler = DefaultRequestHandlerV2( + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + push_config_store=push_store, + agent_card=create_default_agent_card(), + ) + + alice_listing = ( + await request_handler.on_list_task_push_notification_configs( + ListTaskPushNotificationConfigsRequest(task_id='shared-task'), + alice_ctx, + ) + ) + assert {c.id for c in alice_listing.configs} == {'alice-cfg'} + assert all(c.token != 'bob-secret' for c in alice_listing.configs) + + bob_listing = await request_handler.on_list_task_push_notification_configs( + ListTaskPushNotificationConfigsRequest(task_id='shared-task'), + bob_ctx, + ) + assert {c.id for c in bob_listing.configs} == {'bob-cfg'} + assert all(c.token != 'alice-secret' for c in bob_listing.configs) + + +@pytest.mark.asyncio +async def test_on_get_task_push_notification_config_is_owner_scoped(): + """v2 handler: Bob cannot fetch Alice's config by ID via .../get.""" + mock_task_store = AsyncMock(spec=TaskStore) + mock_task_store.get.return_value = Task( + id='shared-task', context_id='ctx_1' + ) + + push_store = InMemoryPushNotificationConfigStore() + alice_ctx = _ctx('alice') + await push_store.set_info( + 'shared-task', + TaskPushNotificationConfig( + task_id='shared-task', + id='alice-cfg', + url='http://alice.example.com/cb', + token='alice-secret', + ), + alice_ctx, + ) + + request_handler = DefaultRequestHandlerV2( + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + push_config_store=push_store, + agent_card=create_default_agent_card(), + ) + + alice_view = await request_handler.on_get_task_push_notification_config( + GetTaskPushNotificationConfigRequest( + task_id='shared-task', id='alice-cfg' + ), + alice_ctx, + ) + assert alice_view.id == 'alice-cfg' + assert alice_view.token == 'alice-secret' + + with pytest.raises(TaskNotFoundError): + await request_handler.on_get_task_push_notification_config( + GetTaskPushNotificationConfigRequest( + task_id='shared-task', id='alice-cfg' + ), + _ctx('bob'), + ) diff --git a/tests/server/request_handlers/test_grpc_handler.py b/tests/server/request_handlers/test_grpc_handler.py index 11ceaf7bb..d140d3d7b 100644 --- a/tests/server/request_handlers/test_grpc_handler.py +++ b/tests/server/request_handlers/test_grpc_handler.py @@ -53,9 +53,8 @@ def sample_agent_card() -> types.AgentCard: def grpc_handler( mock_request_handler: AsyncMock, sample_agent_card: types.AgentCard ) -> GrpcHandler: - return GrpcHandler( - agent_card=sample_agent_card, request_handler=mock_request_handler - ) + mock_request_handler._agent_card = sample_agent_card + return GrpcHandler(request_handler=mock_request_handler) # --- Test Cases --- @@ -182,13 +181,19 @@ async def test_get_extended_agent_card( grpc_handler: GrpcHandler, sample_agent_card: types.AgentCard, mock_grpc_context: AsyncMock, + mock_request_handler: AsyncMock, ) -> None: """Test GetExtendedAgentCard call.""" + + async def to_coro(*args, **kwargs): + return sample_agent_card + + mock_request_handler.on_get_extended_agent_card.side_effect = to_coro request_proto = a2a_pb2.GetExtendedAgentCardRequest() response = await grpc_handler.GetExtendedAgentCard( request_proto, mock_grpc_context ) - + mock_request_handler.on_get_extended_agent_card.assert_awaited_once() assert response.name == sample_agent_card.name assert response.version == sample_agent_card.version @@ -207,17 +212,20 @@ async def modifier(card: types.AgentCard) -> types.AgentCard: modified_card.name = 'Modified gRPC Agent' return modified_card - grpc_handler_modified = GrpcHandler( - agent_card=sample_agent_card, - request_handler=mock_request_handler, - card_modifier=modifier, - ) + # Use side_effect to ensure it returns an awaitable + async def side_effect_func(*_args, **_kwargs): + return await modifier(sample_agent_card) + mock_request_handler.on_get_extended_agent_card.side_effect = ( + side_effect_func + ) + mock_request_handler._agent_card = sample_agent_card + grpc_handler_modified = GrpcHandler(request_handler=mock_request_handler) request_proto = a2a_pb2.GetExtendedAgentCardRequest() response = await grpc_handler_modified.GetExtendedAgentCard( request_proto, mock_grpc_context ) - + mock_request_handler.on_get_extended_agent_card.assert_awaited_once() assert response.name == 'Modified gRPC Agent' assert response.version == sample_agent_card.version @@ -237,17 +245,17 @@ def modifier(card: types.AgentCard) -> types.AgentCard: modified_card.name = 'Modified gRPC Agent' return modified_card - grpc_handler_modified = GrpcHandler( - agent_card=sample_agent_card, - request_handler=mock_request_handler, - card_modifier=modifier, - ) + async def async_modifier(*args, **kwargs): + return modifier(sample_agent_card) + mock_request_handler.on_get_extended_agent_card.side_effect = async_modifier + mock_request_handler._agent_card = sample_agent_card + grpc_handler_modified = GrpcHandler(request_handler=mock_request_handler) request_proto = a2a_pb2.GetExtendedAgentCardRequest() response = await grpc_handler_modified.GetExtendedAgentCard( request_proto, mock_grpc_context ) - + mock_request_handler.on_get_extended_agent_card.assert_awaited_once() assert response.name == 'Modified gRPC Agent' assert response.version == sample_agent_card.version @@ -346,7 +354,7 @@ async def test_list_tasks_success( ), ], ) -async def test_abort_context_error_mapping( # noqa: PLR0913 +async def test_abort_context_error_mapping( grpc_handler: GrpcHandler, mock_request_handler: AsyncMock, mock_grpc_context: AsyncMock, @@ -413,19 +421,11 @@ async def test_send_message_with_extensions( (HTTP_EXTENSION_HEADER.lower(), 'foo'), (HTTP_EXTENSION_HEADER.lower(), 'bar'), ) - - def side_effect(request, context: ServerCallContext): - context.activated_extensions.add('foo') - context.activated_extensions.add('baz') - return types.Task( - id='task-1', - context_id='ctx-1', - status=types.TaskStatus( - state=types.TaskState.TASK_STATE_COMPLETED - ), - ) - - mock_request_handler.on_message_send.side_effect = side_effect + mock_request_handler.on_message_send.return_value = types.Task( + id='task-1', + context_id='ctx-1', + status=types.TaskStatus(state=types.TaskState.TASK_STATE_COMPLETED), + ) await grpc_handler.SendMessage( a2a_pb2.SendMessageRequest(), mock_grpc_context @@ -436,15 +436,6 @@ def side_effect(request, context: ServerCallContext): assert isinstance(call_context, ServerCallContext) assert call_context.requested_extensions == {'foo', 'bar'} - mock_grpc_context.set_trailing_metadata.assert_called_once() - called_metadata = ( - mock_grpc_context.set_trailing_metadata.call_args.args[0] - ) - assert set(called_metadata) == { - (HTTP_EXTENSION_HEADER.lower(), 'foo'), - (HTTP_EXTENSION_HEADER.lower(), 'baz'), - } - async def test_send_message_with_comma_separated_extensions( self, grpc_handler: GrpcHandler, @@ -482,8 +473,6 @@ async def test_send_streaming_message_with_extensions( ) async def side_effect(request, context: ServerCallContext): - context.activated_extensions.add('foo') - context.activated_extensions.add('baz') yield types.Task( id='task-1', context_id='ctx-1', @@ -509,15 +498,6 @@ async def side_effect(request, context: ServerCallContext): assert isinstance(call_context, ServerCallContext) assert call_context.requested_extensions == {'foo', 'bar'} - mock_grpc_context.set_trailing_metadata.assert_called_once() - called_metadata = ( - mock_grpc_context.set_trailing_metadata.call_args.args[0] - ) - assert set(called_metadata) == { - (HTTP_EXTENSION_HEADER.lower(), 'foo'), - (HTTP_EXTENSION_HEADER.lower(), 'baz'), - } - @pytest.mark.asyncio class TestTenantExtraction: diff --git a/tests/server/routes/test_jsonrpc_dispatcher.py b/tests/server/routes/test_jsonrpc_dispatcher.py index f884bb38e..7ce73eb2e 100644 --- a/tests/server/routes/test_jsonrpc_dispatcher.py +++ b/tests/server/routes/test_jsonrpc_dispatcher.py @@ -61,7 +61,7 @@ def test_app(mock_handler): mock_agent_card.capabilities.streaming = False jsonrpc_routes = create_jsonrpc_routes( - agent_card=mock_agent_card, request_handler=mock_handler, rpc_url='/' + request_handler=mock_handler, rpc_url='/' ) from starlette.applications import Starlette @@ -101,7 +101,8 @@ def mock_app_params(self) -> dict: mock_handler = MagicMock(spec=RequestHandler) mock_agent_card = MagicMock(spec=AgentCard) mock_agent_card.url = 'http://example.com' - return {'agent_card': mock_agent_card, 'request_handler': mock_handler} + mock_handler._agent_card = mock_agent_card + return {'request_handler': mock_handler} @pytest.fixture(scope='class') def mark_pkg_starlette_not_installed(self): @@ -168,31 +169,6 @@ def test_method_added_to_call_context_state(self, client, mock_handler): call_context = mock_handler.on_message_send.call_args[0][1] assert call_context.state['method'] == 'SendMessage' - def test_response_with_activated_extensions(self, client, mock_handler): - def side_effect(request, context: ServerCallContext): - context.activated_extensions.add('foo') - context.activated_extensions.add('baz') - return Message( - message_id='test', - role=Role.ROLE_AGENT, - parts=[Part(text='response message')], - ) - - mock_handler.on_message_send.side_effect = side_effect - - response = client.post( - '/', - json=_make_send_message_request(), - ) - response.raise_for_status() - - assert response.status_code == 200 - assert HTTP_EXTENSION_HEADER in response.headers - assert set(response.headers[HTTP_EXTENSION_HEADER].split(', ')) == { - 'foo', - 'baz', - } - class TestJsonRpcDispatcherTenant: def test_tenant_extraction_from_params(self, client, mock_handler): @@ -228,13 +204,12 @@ def test_v0_3_compat_flag_routes_to_adapter(self, mock_handler): mock_agent_card.capabilities = MagicMock() mock_agent_card.capabilities.streaming = False + mock_handler._agent_card = mock_agent_card + from starlette.applications import Starlette jsonrpc_routes = create_jsonrpc_routes( - agent_card=mock_agent_card, - request_handler=mock_handler, - enable_v0_3_compat=True, - rpc_url='/', + request_handler=mock_handler, enable_v0_3_compat=True, rpc_url='/' ) app = Starlette(routes=jsonrpc_routes) client = TestClient(app) @@ -328,9 +303,7 @@ def agent_card(self): @pytest.fixture def client(self, handler, agent_card): jsonrpc_routes = create_jsonrpc_routes( - agent_card=agent_card, request_handler=handler, - extended_agent_card=agent_card, rpc_url='/', ) from starlette.applications import Starlette @@ -480,11 +453,9 @@ async def capture_modifier(card, context): captured['method'] = context.state.get('method') return card + handler.on_get_extended_agent_card.return_value = agent_card jsonrpc_routes = create_jsonrpc_routes( - agent_card=agent_card, request_handler=handler, - extended_agent_card=agent_card, - extended_card_modifier=capture_modifier, rpc_url='/', ) from starlette.applications import Starlette @@ -500,7 +471,7 @@ async def capture_modifier(card, context): data = response.json() assert 'result' in data assert data['result']['name'] == 'TestAgent' - assert captured['method'] == 'GetExtendedAgentCard' + handler.on_get_extended_agent_card.assert_called_once() # --- Streaming method routing tests --- @@ -526,7 +497,6 @@ async def stream_generator(): ) jsonrpc_routes = create_jsonrpc_routes( - agent_card=agent_card, request_handler=handler, rpc_url='/', ) @@ -588,7 +558,6 @@ async def stream_generator(): ) jsonrpc_routes = create_jsonrpc_routes( - agent_card=agent_card, request_handler=handler, rpc_url='/', ) diff --git a/tests/server/routes/test_jsonrpc_routes.py b/tests/server/routes/test_jsonrpc_routes.py index 3330d14c8..ff1b81f3f 100644 --- a/tests/server/routes/test_jsonrpc_routes.py +++ b/tests/server/routes/test_jsonrpc_routes.py @@ -23,9 +23,7 @@ def mock_handler(): def test_routes_creation(agent_card, mock_handler): """Tests that create_jsonrpc_routes creates Route objects list.""" routes = create_jsonrpc_routes( - agent_card=agent_card, - request_handler=mock_handler, - rpc_url='/a2a/jsonrpc', + request_handler=mock_handler, rpc_url='/a2a/jsonrpc' ) assert isinstance(routes, list) @@ -41,7 +39,7 @@ def test_jsonrpc_custom_url(agent_card, mock_handler): """Tests that custom rpc_url is respected for routing.""" custom_url = '/custom/api/jsonrpc' routes = create_jsonrpc_routes( - agent_card=agent_card, request_handler=mock_handler, rpc_url=custom_url + request_handler=mock_handler, rpc_url=custom_url ) app = Starlette(routes=routes) diff --git a/tests/server/routes/test_rest_dispatcher.py b/tests/server/routes/test_rest_dispatcher.py index be5870cc4..a1d2c27cd 100644 --- a/tests/server/routes/test_rest_dispatcher.py +++ b/tests/server/routes/test_rest_dispatcher.py @@ -31,12 +31,25 @@ @pytest.fixture -def mock_handler(): +def agent_card(): + card = MagicMock(spec=AgentCard) + card.capabilities = AgentCapabilities( + streaming=True, + push_notifications=True, + extended_agent_card=True, + ) + return card + + +@pytest.fixture +def mock_handler(agent_card): handler = AsyncMock(spec=RequestHandler) # Default success cases + handler._agent_card = agent_card handler.on_message_send.return_value = Message(message_id='test_msg') handler.on_cancel_task.return_value = Task(id='test_task') handler.on_get_task.return_value = Task(id='test_task') + handler.on_get_extended_agent_card.return_value = agent_card() handler.on_list_tasks.return_value = ListTasksResponse() handler.on_get_task_push_notification_config.return_value = ( TaskPushNotificationConfig(url='http://test') @@ -59,19 +72,8 @@ async def mock_stream(*args, **kwargs) -> AsyncIterator[Task]: @pytest.fixture -def agent_card(): - card = MagicMock(spec=AgentCard) - card.capabilities = AgentCapabilities( - streaming=True, - push_notifications=True, - extended_agent_card=True, - ) - return card - - -@pytest.fixture -def rest_dispatcher_instance(agent_card, mock_handler): - return RestDispatcher(agent_card=agent_card, request_handler=mock_handler) +def rest_dispatcher_instance(mock_handler): + return RestDispatcher(request_handler=mock_handler) from starlette.datastructures import Headers @@ -117,13 +119,13 @@ def mark_pkg_starlette_not_installed(self): ) def test_missing_starlette_raises_importerror( - self, mark_pkg_starlette_not_installed, agent_card, mock_handler + self, mark_pkg_starlette_not_installed, mock_handler ): with pytest.raises( ImportError, match='Packages `starlette` and `sse-starlette` are required', ): - RestDispatcher(agent_card=agent_card, request_handler=mock_handler) + RestDispatcher(request_handler=mock_handler) @pytest.mark.asyncio @@ -237,18 +239,6 @@ async def test_delete_push_notification( response = await rest_dispatcher_instance.delete_push_notification(req) assert response.status_code == 200 - async def test_set_push_notification_disabled_raises( - self, agent_card, mock_handler - ): - agent_card.capabilities.push_notifications = False - dispatcher = RestDispatcher( - agent_card=agent_card, request_handler=mock_handler - ) - req = make_mock_request(method='POST', path_params={'id': 'task1'}) - - response = await dispatcher.set_push_notification(req) - assert response.status_code == 400 # UnsupportedOperation maps to 400 - async def test_handle_authenticated_agent_card( self, rest_dispatcher_instance ): @@ -258,45 +248,9 @@ async def test_handle_authenticated_agent_card( ) assert response.status_code == 200 - async def test_handle_authenticated_agent_card_unsupported( - self, agent_card, mock_handler - ): - agent_card.capabilities.extended_agent_card = False - dispatcher = RestDispatcher( - agent_card=agent_card, request_handler=mock_handler - ) - req = make_mock_request() - - response = await dispatcher.handle_authenticated_agent_card(req) - assert response.status_code == 400 - @pytest.mark.asyncio class TestRestDispatcherStreaming: - async def test_on_message_send_stream_unsupported( - self, agent_card, mock_handler - ): - agent_card.capabilities.streaming = False - dispatcher = RestDispatcher( - agent_card=agent_card, request_handler=mock_handler - ) - req = make_mock_request(method='POST') - - response = await dispatcher.on_message_send_stream(req) - assert response.status_code == 400 - - async def test_on_subscribe_to_task_unsupported( - self, agent_card, mock_handler - ): - agent_card.capabilities.streaming = False - dispatcher = RestDispatcher( - agent_card=agent_card, request_handler=mock_handler - ) - req = make_mock_request(method='GET', path_params={'id': 't1'}) - - response = await dispatcher.on_subscribe_to_task(req) - assert response.status_code == 400 - async def test_on_message_send_stream_success( self, rest_dispatcher_instance ): @@ -310,9 +264,8 @@ async def test_on_message_send_stream_success( chunks.append(chunk) assert len(chunks) == 2 - # sse-starlette yields strings or bytes formatted as Server-Sent Events - assert 'chunk1' in str(chunks[0]) - assert 'chunk2' in str(chunks[1]) + assert 'chunk1' in chunks[0].data + assert 'chunk2' in chunks[1].data async def test_on_subscribe_to_task_success(self, rest_dispatcher_instance): req = make_mock_request(method='GET', path_params={'id': 'test_task'}) @@ -325,5 +278,18 @@ async def test_on_subscribe_to_task_success(self, rest_dispatcher_instance): chunks.append(chunk) assert len(chunks) == 2 - assert 'chunk1' in str(chunks[0]) - assert 'chunk2' in str(chunks[1]) + assert 'chunk1' in chunks[0].data + assert 'chunk2' in chunks[1].data + + async def test_on_message_send_stream_handler_error(self, mock_handler): + from a2a.utils.errors import UnsupportedOperationError + + mock_handler.on_message_send_stream.side_effect = ( + UnsupportedOperationError('Mocked error') + ) + + dispatcher = RestDispatcher(request_handler=mock_handler) + req = make_mock_request(method='POST') + + response = await dispatcher.on_message_send_stream(req) + assert response.status_code == 400 diff --git a/tests/server/routes/test_rest_routes.py b/tests/server/routes/test_rest_routes.py index 98bf4130d..2b3477c6b 100644 --- a/tests/server/routes/test_rest_routes.py +++ b/tests/server/routes/test_rest_routes.py @@ -22,26 +22,21 @@ def mock_handler(): def test_routes_creation(agent_card, mock_handler): """Tests that create_rest_routes creates Route objects list.""" - routes = create_rest_routes( - agent_card=agent_card, request_handler=mock_handler - ) + routes = create_rest_routes(request_handler=mock_handler) assert isinstance(routes, list) assert len(routes) > 0 - assert all(isinstance(r, BaseRoute) for r in routes) + assert all((isinstance(r, BaseRoute) for r in routes)) def test_routes_creation_v03_compat(agent_card, mock_handler): """Tests that create_rest_routes creates more routes with enable_v0_3_compat.""" + mock_handler._agent_card = agent_card routes_without_compat = create_rest_routes( - agent_card=agent_card, - request_handler=mock_handler, - enable_v0_3_compat=False, + request_handler=mock_handler, enable_v0_3_compat=False ) routes_with_compat = create_rest_routes( - agent_card=agent_card, - request_handler=mock_handler, - enable_v0_3_compat=True, + request_handler=mock_handler, enable_v0_3_compat=True ) assert len(routes_with_compat) > len(routes_without_compat) @@ -51,9 +46,7 @@ def test_rest_endpoints_routing(agent_card, mock_handler): """Tests that mounted routes route to the handler endpoints.""" mock_handler.on_message_send.return_value = Task(id='123') - routes = create_rest_routes( - agent_card=agent_card, request_handler=mock_handler - ) + routes = create_rest_routes(request_handler=mock_handler) app = Starlette(routes=routes) client = TestClient(app) @@ -70,9 +63,7 @@ def test_rest_endpoints_routing_tenant(agent_card, mock_handler): """Tests that mounted routes with {tenant} route to the handler endpoints.""" mock_handler.on_message_send.return_value = Task(id='123') - routes = create_rest_routes( - agent_card=agent_card, request_handler=mock_handler - ) + routes = create_rest_routes(request_handler=mock_handler) app = Starlette(routes=routes) client = TestClient(app) @@ -94,9 +85,7 @@ def test_rest_list_tasks(agent_card, mock_handler): """Tests that list tasks endpoint is routed to the handler.""" mock_handler.on_list_tasks.return_value = ListTasksResponse() - routes = create_rest_routes( - agent_card=agent_card, request_handler=mock_handler - ) + routes = create_rest_routes(request_handler=mock_handler) app = Starlette(routes=routes) client = TestClient(app) diff --git a/tests/server/tasks/test_database_push_notification_config_store.py b/tests/server/tasks/test_database_push_notification_config_store.py index b13a5cf55..6608d49bf 100644 --- a/tests/server/tasks/test_database_push_notification_config_store.py +++ b/tests/server/tasks/test_database_push_notification_config_store.py @@ -727,6 +727,57 @@ async def test_owner_resource_scoping( await config_store.delete_info('task1', context=context_user2) +@pytest.mark.asyncio +async def test_get_info_for_dispatch_returns_all_owners( + db_store_parameterized: DatabasePushNotificationConfigStore, +) -> None: + """get_info_for_dispatch MUST return configs across all owners. + + The dispatch path has no caller identity (the originating request + has completed by the time notifications fire). Authorization + happened at registration time. The DB query must therefore filter + on task_id only, with no owner predicate. + """ + config_store = db_store_parameterized + + alice_ctx = ServerCallContext(user=SampleUser(user_name='alice')) + bob_ctx = ServerCallContext(user=SampleUser(user_name='bob')) + + alice_cfg = TaskPushNotificationConfig( + id='alice-cfg', url='http://alice.example.com/cb' + ) + bob_cfg = TaskPushNotificationConfig( + id='bob-cfg', url='http://bob.example.com/cb' + ) + other_task_cfg = TaskPushNotificationConfig( + id='alice-other', url='http://alice.example.com/other' + ) + + await config_store.set_info('shared-task', alice_cfg, alice_ctx) + await config_store.set_info('shared-task', bob_cfg, bob_ctx) + # An unrelated config on a different task -- must NOT leak through. + await config_store.set_info('other-task', other_task_cfg, alice_ctx) + + dispatched = await config_store.get_info_for_dispatch('shared-task') + + assert {c.id for c in dispatched} == {'alice-cfg', 'bob-cfg'} + assert {c.url for c in dispatched} == { + 'http://alice.example.com/cb', + 'http://bob.example.com/cb', + } + + # Sanity: user-callable get_info remains owner-scoped on the same data. + alice_view = await config_store.get_info('shared-task', alice_ctx) + assert {c.id for c in alice_view} == {'alice-cfg'} + bob_view = await config_store.get_info('shared-task', bob_ctx) + assert {c.id for c in bob_view} == {'bob-cfg'} + + # Cleanup + await config_store.delete_info('shared-task', context=alice_ctx) + await config_store.delete_info('shared-task', context=bob_ctx) + await config_store.delete_info('other-task', context=alice_ctx) + + @pytest.mark.asyncio async def test_get_0_3_push_notification_config_detailed( db_store_parameterized: DatabasePushNotificationConfigStore, diff --git a/tests/server/tasks/test_inmemory_push_notifications.py b/tests/server/tasks/test_inmemory_push_notifications.py index d8b560aae..d23bcee05 100644 --- a/tests/server/tasks/test_inmemory_push_notifications.py +++ b/tests/server/tasks/test_inmemory_push_notifications.py @@ -3,6 +3,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import httpx + from google.protobuf.json_format import MessageToDict from a2a.auth.user import User @@ -14,9 +15,9 @@ InMemoryPushNotificationConfigStore, ) from a2a.types.a2a_pb2 import ( - TaskPushNotificationConfig, StreamResponse, Task, + TaskPushNotificationConfig, TaskState, TaskStatus, ) @@ -70,8 +71,7 @@ def setUp(self) -> None: self.notifier = BasePushNotificationSender( httpx_client=self.mock_httpx_client, config_store=self.config_store, - context=MINIMAL_CALL_CONTEXT, - ) # Corrected argument name + ) def test_constructor_stores_client(self) -> None: self.assertEqual(self.notifier._client, self.mock_httpx_client) @@ -428,5 +428,148 @@ async def test_owner_resource_scoping(self) -> None: await self.config_store.delete_info('task1', context=context_user2) +class TestPushNotificationDispatchAcrossOwners( + unittest.IsolatedAsyncioTestCase +): + """Dispatch-correctness tests for the registrar/dispatcher asymmetry. + + Push notifications must fire for any event on the task, regardless of + which user's action triggered the event. The dispatch path therefore + reads configs via get_info_for_dispatch (cross-owner), not + get_info (owner-scoped). + """ + + def setUp(self) -> None: + self.mock_httpx_client = AsyncMock(spec=httpx.AsyncClient) + mock_response = AsyncMock(spec=httpx.Response) + mock_response.status_code = 200 + self.mock_httpx_client.post.return_value = mock_response + + self.config_store = InMemoryPushNotificationConfigStore() + + self.sender = BasePushNotificationSender( + httpx_client=self.mock_httpx_client, + config_store=self.config_store, + ) + + async def test_multi_registrar_fan_out(self) -> None: + """Three users registering distinct webhooks for the same task all fire.""" + users_and_urls = [ + ('alice', 'http://alice.example.com/cb', 'tok-alice'), + ('bob', 'http://bob.example.com/cb', 'tok-bob'), + ('carol', 'http://carol.example.com/cb', 'tok-carol'), + ] + for user_name, url, token in users_and_urls: + ctx = ServerCallContext(user=SampleUser(user_name=user_name)) + cfg = TaskPushNotificationConfig( + id=f'cfg-{user_name}', url=url, token=token + ) + await self.config_store.set_info('shared-task', cfg, ctx) + + await self.sender.send_notification( + 'shared-task', _create_sample_task(task_id='shared-task') + ) + + self.assertEqual(self.mock_httpx_client.post.await_count, 3) + called_urls = { + call.args[0] for call in self.mock_httpx_client.post.call_args_list + } + self.assertEqual( + called_urls, + {url for _, url, _ in users_and_urls}, + ) + called_tokens = { + call.kwargs['headers']['X-A2A-Notification-Token'] + for call in self.mock_httpx_client.post.call_args_list + } + self.assertEqual( + called_tokens, + {token for _, _, token in users_and_urls}, + ) + + async def test_write_side_owner_isolation_preserved(self) -> None: + """Bob's ``delete_info`` against Alice's config is a no-op. + + After the no-op, Alice's config must still be: + (a) retrievable via the user-callable ``get_info`` for Alice, and + (b) returned by ``get_info_for_dispatch`` so that the + notification will still fire. + + Guards the write-side scoping that the design preserves + (see §9.3). + """ + alice_ctx = ServerCallContext(user=SampleUser(user_name='alice')) + bob_ctx = ServerCallContext(user=SampleUser(user_name='bob')) + + config = TaskPushNotificationConfig( + id='alice-cfg', + url='http://alice.example.com/cb', + token='alice-token', + ) + await self.config_store.set_info('shared-task', config, alice_ctx) + + # Bob attempts to delete Alice's config -- must be a no-op. + await self.config_store.delete_info( + 'shared-task', context=bob_ctx, config_id='alice-cfg' + ) + + # (a) Alice's user-callable view is unchanged. + alice_view = await self.config_store.get_info('shared-task', alice_ctx) + self.assertEqual(len(alice_view), 1) + self.assertEqual(alice_view[0].id, 'alice-cfg') + + # (b) Dispatch path still sees the config (notifications fire). + dispatched = await self.config_store.get_info_for_dispatch( + 'shared-task' + ) + self.assertEqual(len(dispatched), 1) + self.assertEqual(dispatched[0].id, 'alice-cfg') + self.assertEqual(dispatched[0].token, 'alice-token') + + # And end-to-end: the sender actually dispatches to Alice's URL. + await self.sender.send_notification( + 'shared-task', _create_sample_task(task_id='shared-task') + ) + self.mock_httpx_client.post.assert_awaited_once_with( + 'http://alice.example.com/cb', + json=MessageToDict( + StreamResponse(task=_create_sample_task(task_id='shared-task')) + ), + headers={'X-A2A-Notification-Token': 'alice-token'}, + ) + + async def test_cross_user_dispatch_alice_registers_bob_triggers( + self, + ) -> None: + """Alice registers; Bob triggers; Alice's webhook receives the POST. + + The send_notification carries no identity, so there is no notion of + "who triggered this event" at the store layer. get_info_for_dispatch + returns Alice's config because Alice registered it. The fact that the + event was caused by Bob is not visible to (and not relevant for) the + dispatch path. + """ + alice_context = ServerCallContext(user=SampleUser(user_name='alice')) + config = _create_sample_push_config( + url='http://alice.example.com/cb', token='alice-token' + ) + await self.config_store.set_info('collab-task', config, alice_context) + + # No bob_context is passed anywhere -- the dispatch path never + # sees it. This is precisely the point: identity is not the + # dispatch path's concern. + await self.sender.send_notification( + 'collab-task', _create_sample_task(task_id='collab-task') + ) + + self.mock_httpx_client.post.assert_awaited_once_with( + 'http://alice.example.com/cb', + json=MessageToDict( + StreamResponse(task=_create_sample_task(task_id='collab-task')) + ), + headers={'X-A2A-Notification-Token': 'alice-token'}, + ) + + if __name__ == '__main__': unittest.main() diff --git a/tests/server/tasks/test_push_notification_sender.py b/tests/server/tasks/test_push_notification_sender.py index 783e1f413..22f904a2a 100644 --- a/tests/server/tasks/test_push_notification_sender.py +++ b/tests/server/tasks/test_push_notification_sender.py @@ -6,40 +6,20 @@ from google.protobuf.json_format import MessageToDict -from a2a.auth.user import User -from a2a.server.context import ServerCallContext from a2a.server.tasks.base_push_notification_sender import ( BasePushNotificationSender, ) from a2a.types.a2a_pb2 import ( - TaskPushNotificationConfig, StreamResponse, Task, TaskArtifactUpdateEvent, + TaskPushNotificationConfig, TaskState, TaskStatus, TaskStatusUpdateEvent, ) -class SampleUser(User): - """A test implementation of the User interface.""" - - def __init__(self, user_name: str): - self._user_name = user_name - - @property - def is_authenticated(self) -> bool: - return True - - @property - def user_name(self) -> str: - return self._user_name - - -MINIMAL_CALL_CONTEXT = ServerCallContext(user=SampleUser(user_name='user')) - - def _create_sample_task( task_id: str = 'task123', status_state: TaskState = TaskState.TASK_STATE_COMPLETED, @@ -66,7 +46,6 @@ def setUp(self) -> None: self.sender = BasePushNotificationSender( httpx_client=self.mock_httpx_client, config_store=self.mock_config_store, - context=MINIMAL_CALL_CONTEXT, ) def test_constructor_stores_client_and_config_store(self) -> None: @@ -77,7 +56,7 @@ async def test_send_notification_success(self) -> None: task_id = 'task_send_success' task_data = _create_sample_task(task_id=task_id) config = _create_sample_push_config(url='http://notify.me/here') - self.mock_config_store.get_info.return_value = [config] + self.mock_config_store.get_info_for_dispatch.return_value = [config] mock_response = AsyncMock(spec=httpx.Response) mock_response.status_code = 200 @@ -85,8 +64,8 @@ async def test_send_notification_success(self) -> None: await self.sender.send_notification(task_id, task_data) - self.mock_config_store.get_info.assert_awaited_once_with( - task_data.id, MINIMAL_CALL_CONTEXT + self.mock_config_store.get_info_for_dispatch.assert_awaited_once_with( + task_data.id ) # assert httpx_client post method got invoked with right parameters @@ -103,7 +82,7 @@ async def test_send_notification_with_token_success(self) -> None: config = _create_sample_push_config( url='http://notify.me/here', token='unique_token' ) - self.mock_config_store.get_info.return_value = [config] + self.mock_config_store.get_info_for_dispatch.return_value = [config] mock_response = AsyncMock(spec=httpx.Response) mock_response.status_code = 200 @@ -111,8 +90,8 @@ async def test_send_notification_with_token_success(self) -> None: await self.sender.send_notification(task_id, task_data) - self.mock_config_store.get_info.assert_awaited_once_with( - task_data.id, MINIMAL_CALL_CONTEXT + self.mock_config_store.get_info_for_dispatch.assert_awaited_once_with( + task_data.id ) # assert httpx_client post method got invoked with right parameters @@ -126,12 +105,12 @@ async def test_send_notification_with_token_success(self) -> None: async def test_send_notification_no_config(self) -> None: task_id = 'task_send_no_config' task_data = _create_sample_task(task_id=task_id) - self.mock_config_store.get_info.return_value = [] + self.mock_config_store.get_info_for_dispatch.return_value = [] await self.sender.send_notification(task_id, task_data) - self.mock_config_store.get_info.assert_awaited_once_with( - task_id, MINIMAL_CALL_CONTEXT + self.mock_config_store.get_info_for_dispatch.assert_awaited_once_with( + task_id ) self.mock_httpx_client.post.assert_not_called() @@ -142,7 +121,7 @@ async def test_send_notification_http_status_error( task_id = 'task_send_http_err' task_data = _create_sample_task(task_id=task_id) config = _create_sample_push_config(url='http://notify.me/http_error') - self.mock_config_store.get_info.return_value = [config] + self.mock_config_store.get_info_for_dispatch.return_value = [config] mock_response = MagicMock(spec=httpx.Response) mock_response.status_code = 404 @@ -154,8 +133,8 @@ async def test_send_notification_http_status_error( await self.sender.send_notification(task_id, task_data) - self.mock_config_store.get_info.assert_awaited_once_with( - task_id, MINIMAL_CALL_CONTEXT + self.mock_config_store.get_info_for_dispatch.assert_awaited_once_with( + task_id ) self.mock_httpx_client.post.assert_awaited_once_with( config.url, @@ -173,7 +152,10 @@ async def test_send_notification_multiple_configs(self) -> None: config2 = _create_sample_push_config( url='http://notify.me/cfg2', config_id='cfg2' ) - self.mock_config_store.get_info.return_value = [config1, config2] + self.mock_config_store.get_info_for_dispatch.return_value = [ + config1, + config2, + ] mock_response = AsyncMock(spec=httpx.Response) mock_response.status_code = 200 @@ -181,8 +163,8 @@ async def test_send_notification_multiple_configs(self) -> None: await self.sender.send_notification(task_id, task_data) - self.mock_config_store.get_info.assert_awaited_once_with( - task_id, MINIMAL_CALL_CONTEXT + self.mock_config_store.get_info_for_dispatch.assert_awaited_once_with( + task_id ) self.assertEqual(self.mock_httpx_client.post.call_count, 2) @@ -207,7 +189,7 @@ async def test_send_notification_status_update_event(self) -> None: status=TaskStatus(state=TaskState.TASK_STATE_WORKING), ) config = _create_sample_push_config(url='http://notify.me/status') - self.mock_config_store.get_info.return_value = [config] + self.mock_config_store.get_info_for_dispatch.return_value = [config] mock_response = AsyncMock(spec=httpx.Response) mock_response.status_code = 200 @@ -215,8 +197,8 @@ async def test_send_notification_status_update_event(self) -> None: await self.sender.send_notification(task_id, event) - self.mock_config_store.get_info.assert_awaited_once_with( - task_id, MINIMAL_CALL_CONTEXT + self.mock_config_store.get_info_for_dispatch.assert_awaited_once_with( + task_id ) self.mock_httpx_client.post.assert_awaited_once_with( config.url, @@ -231,7 +213,7 @@ async def test_send_notification_artifact_update_event(self) -> None: append=True, ) config = _create_sample_push_config(url='http://notify.me/artifact') - self.mock_config_store.get_info.return_value = [config] + self.mock_config_store.get_info_for_dispatch.return_value = [config] mock_response = AsyncMock(spec=httpx.Response) mock_response.status_code = 200 @@ -239,8 +221,8 @@ async def test_send_notification_artifact_update_event(self) -> None: await self.sender.send_notification(task_id, event) - self.mock_config_store.get_info.assert_awaited_once_with( - task_id, MINIMAL_CALL_CONTEXT + self.mock_config_store.get_info_for_dispatch.assert_awaited_once_with( + task_id ) self.mock_httpx_client.post.assert_awaited_once_with( config.url, diff --git a/tests/server/tasks/test_task_manager.py b/tests/server/tasks/test_task_manager.py index bdfbf525c..eba8d2f14 100644 --- a/tests/server/tasks/test_task_manager.py +++ b/tests/server/tasks/test_task_manager.py @@ -6,6 +6,7 @@ from a2a.auth.user import User from a2a.server.context import ServerCallContext from a2a.server.tasks import TaskManager +from a2a.server.tasks.task_manager import append_artifact_to_task from a2a.types.a2a_pb2 import ( Artifact, Message, @@ -345,3 +346,99 @@ async def test_save_task_event_no_task_existing( assert saved_task.status.state == TaskState.TASK_STATE_COMPLETED assert task_manager_without_id.task_id == 'event-task-id' assert task_manager_without_id.context_id == 'some-context' + + +def test_append_artifact_to_task(): + # Prepare base task + task = create_minimal_task() + assert task.id == 'task-abc' + assert task.context_id == 'session-xyz' + assert task.status.state == TaskState.TASK_STATE_SUBMITTED + assert len(task.history) == 0 # proto repeated fields are empty, not None + assert len(task.artifacts) == 0 + + # Prepare appending artifact and event + artifact_1 = Artifact( + artifact_id='artifact-123', parts=[Part(text='Hello')] + ) + append_event_1 = TaskArtifactUpdateEvent( + artifact=artifact_1, append=False, task_id='123', context_id='123' + ) + + # Test adding a new artifact (not appending) + append_artifact_to_task(task, append_event_1) + assert len(task.artifacts) == 1 + assert task.artifacts[0].artifact_id == 'artifact-123' + assert task.artifacts[0].name == '' # proto default for string + assert len(task.artifacts[0].parts) == 1 + assert task.artifacts[0].parts[0].text == 'Hello' + + # Test replacing the artifact + artifact_2 = Artifact( + artifact_id='artifact-123', + name='updated name', + parts=[Part(text='Updated')], + metadata={'existing_key': 'existing_value'}, + ) + append_event_2 = TaskArtifactUpdateEvent( + artifact=artifact_2, append=False, task_id='123', context_id='123' + ) + append_artifact_to_task(task, append_event_2) + assert len(task.artifacts) == 1 # Should still have one artifact + assert task.artifacts[0].artifact_id == 'artifact-123' + assert task.artifacts[0].name == 'updated name' + assert len(task.artifacts[0].parts) == 1 + assert task.artifacts[0].parts[0].text == 'Updated' + assert task.artifacts[0].metadata['existing_key'] == 'existing_value' + + # Test appending parts to an existing artifact + artifact_with_parts = Artifact( + artifact_id='artifact-123', + parts=[Part(text='Part 2')], + metadata={'new_key': 'new_value'}, + ) + append_event_3 = TaskArtifactUpdateEvent( + artifact=artifact_with_parts, + append=True, + task_id='123', + context_id='123', + ) + append_artifact_to_task(task, append_event_3) + assert len(task.artifacts[0].parts) == 2 + assert task.artifacts[0].parts[0].text == 'Updated' + assert task.artifacts[0].parts[1].text == 'Part 2' + assert task.artifacts[0].metadata['existing_key'] == 'existing_value' + assert task.artifacts[0].metadata['new_key'] == 'new_value' + + # Test adding another new artifact + another_artifact_with_parts = Artifact( + artifact_id='new_artifact', + parts=[Part(text='new artifact Part 1')], + ) + append_event_4 = TaskArtifactUpdateEvent( + artifact=another_artifact_with_parts, + append=False, + task_id='123', + context_id='123', + ) + append_artifact_to_task(task, append_event_4) + assert len(task.artifacts) == 2 + assert task.artifacts[0].artifact_id == 'artifact-123' + assert task.artifacts[1].artifact_id == 'new_artifact' + assert len(task.artifacts[0].parts) == 2 + assert len(task.artifacts[1].parts) == 1 + + # Test appending part to a task that does not have a matching artifact + non_existing_artifact_with_parts = Artifact( + artifact_id='artifact-456', parts=[Part(text='Part 1')] + ) + append_event_5 = TaskArtifactUpdateEvent( + artifact=non_existing_artifact_with_parts, + append=True, + task_id='123', + context_id='123', + ) + append_artifact_to_task(task, append_event_5) + assert len(task.artifacts) == 2 + assert len(task.artifacts[0].parts) == 2 + assert len(task.artifacts[1].parts) == 1 diff --git a/tests/server/test_integration.py b/tests/server/test_integration.py index f879e8078..56663e7e9 100644 --- a/tests/server/test_integration.py +++ b/tests/server/test_integration.py @@ -165,9 +165,7 @@ def build( app_instance.routes.extend(card_routes) # JSON-RPC router - rpc_routes = create_jsonrpc_routes( - self.agent_card, self.handler, rpc_url=rpc_url - ) + rpc_routes = create_jsonrpc_routes(self.handler, rpc_url=rpc_url) app_instance.routes.extend(rpc_routes) return app_instance @@ -777,7 +775,7 @@ def test_dynamic_agent_card_modifier_sync( ): """Test that a synchronous card_modifier dynamically alters the public agent card.""" - def modifier(card: AgentCard) -> AgentCard: + async def modifier(card: AgentCard) -> AgentCard: modified_card = AgentCard() modified_card.CopyFrom(card) modified_card.name = 'Dynamically Modified Agent' @@ -820,7 +818,7 @@ def test_fastapi_dynamic_agent_card_modifier_sync( ): """Test that a synchronous card_modifier dynamically alters the public agent card for FastAPI.""" - def modifier(card: AgentCard) -> AgentCard: + async def modifier(card: AgentCard) -> AgentCard: modified_card = AgentCard() modified_card.CopyFrom(card) modified_card.name = 'Dynamically Modified Agent' diff --git a/tests/utils/test_artifact.py b/tests/utils/test_artifact.py deleted file mode 100644 index cbe8e9c91..000000000 --- a/tests/utils/test_artifact.py +++ /dev/null @@ -1,161 +0,0 @@ -import unittest -import uuid - -from unittest.mock import patch - -from google.protobuf.struct_pb2 import Struct - -from a2a.types.a2a_pb2 import ( - Artifact, - Part, -) -from a2a.utils.artifact import ( - get_artifact_text, - new_artifact, - new_data_artifact, - new_text_artifact, -) - - -class TestArtifact(unittest.TestCase): - @patch('uuid.uuid4') - def test_new_artifact_generates_id(self, mock_uuid4): - mock_uuid = uuid.UUID('abcdef12-1234-5678-1234-567812345678') - mock_uuid4.return_value = mock_uuid - artifact = new_artifact(parts=[], name='test_artifact') - self.assertEqual(artifact.artifact_id, str(mock_uuid)) - - def test_new_artifact_assigns_parts_name_description(self): - parts = [Part(text='Sample text')] - name = 'My Artifact' - description = 'This is a test artifact.' - artifact = new_artifact(parts=parts, name=name, description=description) - assert len(artifact.parts) == len(parts) - self.assertEqual(artifact.name, name) - self.assertEqual(artifact.description, description) - - def test_new_artifact_empty_description_if_not_provided(self): - parts = [Part(text='Another sample')] - name = 'Artifact_No_Desc' - artifact = new_artifact(parts=parts, name=name) - self.assertEqual(artifact.description, '') - - def test_new_text_artifact_creates_single_text_part(self): - text = 'This is a text artifact.' - name = 'Text_Artifact' - artifact = new_text_artifact(text=text, name=name) - self.assertEqual(len(artifact.parts), 1) - self.assertTrue(artifact.parts[0].HasField('text')) - - def test_new_text_artifact_part_contains_provided_text(self): - text = 'Hello, world!' - name = 'Greeting_Artifact' - artifact = new_text_artifact(text=text, name=name) - self.assertEqual(artifact.parts[0].text, text) - - def test_new_text_artifact_assigns_name_description(self): - text = 'Some content.' - name = 'Named_Text_Artifact' - description = 'Description for text artifact.' - artifact = new_text_artifact( - text=text, name=name, description=description - ) - self.assertEqual(artifact.name, name) - self.assertEqual(artifact.description, description) - - def test_new_data_artifact_creates_single_data_part(self): - sample_data = {'key': 'value', 'number': 123} - name = 'Data_Artifact' - artifact = new_data_artifact(data=sample_data, name=name) - self.assertEqual(len(artifact.parts), 1) - self.assertTrue(artifact.parts[0].HasField('data')) - - def test_new_data_artifact_part_contains_provided_data(self): - sample_data = {'content': 'test_data', 'is_valid': True} - name = 'Structured_Data_Artifact' - artifact = new_data_artifact(data=sample_data, name=name) - self.assertTrue(artifact.parts[0].HasField('data')) - # Compare via MessageToDict for proto Struct - from google.protobuf.json_format import MessageToDict - - self.assertEqual(MessageToDict(artifact.parts[0].data), sample_data) - - def test_new_data_artifact_assigns_name_description(self): - sample_data = {'info': 'some details'} - name = 'Named_Data_Artifact' - description = 'Description for data artifact.' - artifact = new_data_artifact( - data=sample_data, name=name, description=description - ) - self.assertEqual(artifact.name, name) - self.assertEqual(artifact.description, description) - - -class TestGetArtifactText(unittest.TestCase): - def test_get_artifact_text_single_part(self): - # Setup - artifact = Artifact( - name='test-artifact', - parts=[Part(text='Hello world')], - artifact_id='test-artifact-id', - ) - - # Exercise - result = get_artifact_text(artifact) - - # Verify - assert result == 'Hello world' - - def test_get_artifact_text_multiple_parts(self): - # Setup - artifact = Artifact( - name='test-artifact', - parts=[ - Part(text='First line'), - Part(text='Second line'), - Part(text='Third line'), - ], - artifact_id='test-artifact-id', - ) - - # Exercise - result = get_artifact_text(artifact) - - # Verify - default delimiter is newline - assert result == 'First line\nSecond line\nThird line' - - def test_get_artifact_text_custom_delimiter(self): - # Setup - artifact = Artifact( - name='test-artifact', - parts=[ - Part(text='First part'), - Part(text='Second part'), - Part(text='Third part'), - ], - artifact_id='test-artifact-id', - ) - - # Exercise - result = get_artifact_text(artifact, delimiter=' | ') - - # Verify - assert result == 'First part | Second part | Third part' - - def test_get_artifact_text_empty_parts(self): - # Setup - artifact = Artifact( - name='test-artifact', - parts=[], - artifact_id='test-artifact-id', - ) - - # Exercise - result = get_artifact_text(artifact) - - # Verify - assert result == '' - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/utils/test_helpers.py b/tests/utils/test_helpers.py deleted file mode 100644 index c157bb986..000000000 --- a/tests/utils/test_helpers.py +++ /dev/null @@ -1,490 +0,0 @@ -import uuid - -from typing import Any -from unittest.mock import patch - -import pytest - -from a2a.types import ( - AgentCapabilities, - AgentCard, - AgentCardSignature, - AgentInterface, - AgentSkill, - Artifact, - Message, - Part, - Role, - SendMessageRequest, - Task, - TaskArtifactUpdateEvent, - TaskState, - TaskStatus, -) -from a2a.utils.errors import UnsupportedOperationError -from a2a.utils.helpers import ( - _clean_empty, - append_artifact_to_task, - are_modalities_compatible, - build_text_artifact, - canonicalize_agent_card, - create_task_obj, - validate, -) - - -# --- Helper Functions --- -def create_test_message( - role: Role = Role.ROLE_USER, - text: str = 'Hello', - message_id: str = 'msg-123', -) -> Message: - return Message( - role=role, - parts=[Part(text=text)], - message_id=message_id, - ) - - -def create_test_task( - task_id: str = 'task-abc', - context_id: str = 'session-xyz', -) -> Task: - return Task( - id=task_id, - context_id=context_id, - status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED), - ) - - -SAMPLE_AGENT_CARD: dict[str, Any] = { - 'name': 'Test Agent', - 'description': 'A test agent', - 'supported_interfaces': [ - AgentInterface( - url='http://localhost', - protocol_binding='HTTP+JSON', - ) - ], - 'version': '1.0.0', - 'capabilities': AgentCapabilities( - streaming=None, - push_notifications=True, - ), - 'default_input_modes': ['text/plain'], - 'default_output_modes': ['text/plain'], - 'documentation_url': None, - 'icon_url': '', - 'skills': [ - AgentSkill( - id='skill1', - name='Test Skill', - description='A test skill', - tags=['test'], - ) - ], - 'signatures': [ - AgentCardSignature( - protected='protected_header', signature='test_signature' - ) - ], -} - - -# Test create_task_obj -def test_create_task_obj(): - message = create_test_message() - message.context_id = 'test-context' # Set context_id to test it's preserved - send_params = SendMessageRequest(message=message) - - task = create_task_obj(send_params) - assert task.id is not None - assert task.context_id == message.context_id - assert task.status.state == TaskState.TASK_STATE_SUBMITTED - assert len(task.history) == 1 - assert task.history[0] == message - - -def test_create_task_obj_generates_context_id(): - """Test that create_task_obj generates context_id if not present and uses it for the task.""" - # Message without context_id - message_no_context_id = Message( - role=Role.ROLE_USER, - parts=[Part(text='test')], - message_id='msg-no-ctx', - task_id='task-from-msg', # Provide a task_id to differentiate from generated task.id - ) - send_params = SendMessageRequest(message=message_no_context_id) - - # Ensure message.context_id is empty initially (proto default is empty string) - assert send_params.message.context_id == '' - - known_task_uuid = uuid.UUID('11111111-1111-1111-1111-111111111111') - known_context_uuid = uuid.UUID('22222222-2222-2222-2222-222222222222') - - # Patch uuid.uuid4 to return specific UUIDs in sequence - # The first call will be for message.context_id (if empty), the second for task.id. - with patch( - 'a2a.utils.helpers.uuid4', - side_effect=[known_context_uuid, known_task_uuid], - ) as mock_uuid4: - task = create_task_obj(send_params) - - # Assert that uuid4 was called twice (once for context_id, once for task.id) - assert mock_uuid4.call_count == 2 - - # Assert that message.context_id was set to the first generated UUID - assert send_params.message.context_id == str(known_context_uuid) - - # Assert that task.context_id is the same generated UUID - assert task.context_id == str(known_context_uuid) - - # Assert that task.id is the second generated UUID - assert task.id == str(known_task_uuid) - - # Ensure the original message in history also has the updated context_id - assert len(task.history) == 1 - assert task.history[0].context_id == str(known_context_uuid) - - -# Test append_artifact_to_task -def test_append_artifact_to_task(): - # Prepare base task - task = create_test_task() - assert task.id == 'task-abc' - assert task.context_id == 'session-xyz' - assert task.status.state == TaskState.TASK_STATE_SUBMITTED - assert len(task.history) == 0 # proto repeated fields are empty, not None - assert len(task.artifacts) == 0 - - # Prepare appending artifact and event - artifact_1 = Artifact( - artifact_id='artifact-123', parts=[Part(text='Hello')] - ) - append_event_1 = TaskArtifactUpdateEvent( - artifact=artifact_1, append=False, task_id='123', context_id='123' - ) - - # Test adding a new artifact (not appending) - append_artifact_to_task(task, append_event_1) - assert len(task.artifacts) == 1 - assert task.artifacts[0].artifact_id == 'artifact-123' - assert task.artifacts[0].name == '' # proto default for string - assert len(task.artifacts[0].parts) == 1 - assert task.artifacts[0].parts[0].text == 'Hello' - - # Test replacing the artifact - artifact_2 = Artifact( - artifact_id='artifact-123', - name='updated name', - parts=[Part(text='Updated')], - ) - append_event_2 = TaskArtifactUpdateEvent( - artifact=artifact_2, append=False, task_id='123', context_id='123' - ) - append_artifact_to_task(task, append_event_2) - assert len(task.artifacts) == 1 # Should still have one artifact - assert task.artifacts[0].artifact_id == 'artifact-123' - assert task.artifacts[0].name == 'updated name' - assert len(task.artifacts[0].parts) == 1 - assert task.artifacts[0].parts[0].text == 'Updated' - - # Test appending parts to an existing artifact - artifact_with_parts = Artifact( - artifact_id='artifact-123', parts=[Part(text='Part 2')] - ) - append_event_3 = TaskArtifactUpdateEvent( - artifact=artifact_with_parts, - append=True, - task_id='123', - context_id='123', - ) - append_artifact_to_task(task, append_event_3) - assert len(task.artifacts[0].parts) == 2 - assert task.artifacts[0].parts[0].text == 'Updated' - assert task.artifacts[0].parts[1].text == 'Part 2' - - # Test adding another new artifact - another_artifact_with_parts = Artifact( - artifact_id='new_artifact', - parts=[Part(text='new artifact Part 1')], - ) - append_event_4 = TaskArtifactUpdateEvent( - artifact=another_artifact_with_parts, - append=False, - task_id='123', - context_id='123', - ) - append_artifact_to_task(task, append_event_4) - assert len(task.artifacts) == 2 - assert task.artifacts[0].artifact_id == 'artifact-123' - assert task.artifacts[1].artifact_id == 'new_artifact' - assert len(task.artifacts[0].parts) == 2 - assert len(task.artifacts[1].parts) == 1 - - # Test appending part to a task that does not have a matching artifact - non_existing_artifact_with_parts = Artifact( - artifact_id='artifact-456', parts=[Part(text='Part 1')] - ) - append_event_5 = TaskArtifactUpdateEvent( - artifact=non_existing_artifact_with_parts, - append=True, - task_id='123', - context_id='123', - ) - append_artifact_to_task(task, append_event_5) - assert len(task.artifacts) == 2 - assert len(task.artifacts[0].parts) == 2 - assert len(task.artifacts[1].parts) == 1 - - -# Test build_text_artifact -def test_build_text_artifact(): - artifact_id = 'text_artifact' - text = 'This is a sample text' - artifact = build_text_artifact(text, artifact_id) - - assert artifact.artifact_id == artifact_id - assert len(artifact.parts) == 1 - assert artifact.parts[0].text == text - - -# Test validate decorator -def test_validate_decorator(): - class TestClass: - condition = True - - @validate(lambda self: self.condition, 'Condition not met') - def test_method(self) -> str: - return 'Success' - - obj = TestClass() - - # Test passing condition - assert obj.test_method() == 'Success' - - # Test failing condition - obj.condition = False - with pytest.raises(UnsupportedOperationError) as exc_info: - obj.test_method() - assert 'Condition not met' in str(exc_info.value) - - -# Tests for are_modalities_compatible -def test_are_modalities_compatible_client_none(): - assert ( - are_modalities_compatible( - client_output_modes=None, server_output_modes=['text/plain'] - ) - is True - ) - - -def test_are_modalities_compatible_client_empty(): - assert ( - are_modalities_compatible( - client_output_modes=[], server_output_modes=['text/plain'] - ) - is True - ) - - -def test_are_modalities_compatible_server_none(): - assert ( - are_modalities_compatible( - server_output_modes=None, client_output_modes=['text/plain'] - ) - is True - ) - - -def test_are_modalities_compatible_server_empty(): - assert ( - are_modalities_compatible( - server_output_modes=[], client_output_modes=['text/plain'] - ) - is True - ) - - -def test_are_modalities_compatible_common_mode(): - assert ( - are_modalities_compatible( - server_output_modes=['text/plain', 'application/json'], - client_output_modes=['application/json', 'image/png'], - ) - is True - ) - - -def test_are_modalities_compatible_no_common_modes(): - assert ( - are_modalities_compatible( - server_output_modes=['text/plain'], - client_output_modes=['application/json'], - ) - is False - ) - - -def test_are_modalities_compatible_exact_match(): - assert ( - are_modalities_compatible( - server_output_modes=['text/plain'], - client_output_modes=['text/plain'], - ) - is True - ) - - -def test_are_modalities_compatible_server_more_but_common(): - assert ( - are_modalities_compatible( - server_output_modes=['text/plain', 'image/jpeg'], - client_output_modes=['text/plain'], - ) - is True - ) - - -def test_are_modalities_compatible_client_more_but_common(): - assert ( - are_modalities_compatible( - server_output_modes=['text/plain'], - client_output_modes=['text/plain', 'image/jpeg'], - ) - is True - ) - - -def test_are_modalities_compatible_both_none(): - assert ( - are_modalities_compatible( - server_output_modes=None, client_output_modes=None - ) - is True - ) - - -def test_are_modalities_compatible_both_empty(): - assert ( - are_modalities_compatible( - server_output_modes=[], client_output_modes=[] - ) - is True - ) - - -def test_canonicalize_agent_card(): - """Test canonicalize_agent_card with defaults, optionals, and exceptions. - - - extensions is omitted as it's not set and optional. - - protocolVersion is included because it's always added by canonicalize_agent_card. - - signatures should be omitted. - """ - agent_card = AgentCard(**SAMPLE_AGENT_CARD) - expected_jcs = ( - '{"capabilities":{"pushNotifications":true},' - '"defaultInputModes":["text/plain"],"defaultOutputModes":["text/plain"],' - '"description":"A test agent","name":"Test Agent",' - '"skills":[{"description":"A test skill","id":"skill1","name":"Test Skill","tags":["test"]}],' - '"supportedInterfaces":[{"protocolBinding":"HTTP+JSON","url":"http://localhost"}],' - '"version":"1.0.0"}' - ) - result = canonicalize_agent_card(agent_card) - assert result == expected_jcs - - -def test_canonicalize_agent_card_preserves_false_capability(): - """Regression #692: streaming=False must not be stripped from canonical JSON.""" - card = AgentCard( - **{ - **SAMPLE_AGENT_CARD, - 'capabilities': AgentCapabilities( - streaming=False, - push_notifications=True, - ), - } - ) - result = canonicalize_agent_card(card) - assert '"streaming":false' in result - - -@pytest.mark.parametrize( - 'input_val', - [ - pytest.param({'a': ''}, id='empty-string'), - pytest.param({'a': []}, id='empty-list'), - pytest.param({'a': {}}, id='empty-dict'), - pytest.param({'a': {'b': []}}, id='nested-empty'), - pytest.param({'a': '', 'b': [], 'c': {}}, id='all-empties'), - pytest.param({'a': {'b': {'c': ''}}}, id='deeply-nested'), - ], -) -def test_clean_empty_removes_empties(input_val): - """_clean_empty removes empty strings, lists, and dicts recursively.""" - assert _clean_empty(input_val) is None - - -def test_clean_empty_top_level_list_becomes_none(): - """Top-level list that becomes empty after cleaning should return None.""" - assert _clean_empty(['', {}, []]) is None - - -@pytest.mark.parametrize( - 'input_val,expected', - [ - pytest.param({'retries': 0}, {'retries': 0}, id='int-zero'), - pytest.param({'enabled': False}, {'enabled': False}, id='bool-false'), - pytest.param({'score': 0.0}, {'score': 0.0}, id='float-zero'), - pytest.param([0, 1, 2], [0, 1, 2], id='zero-in-list'), - pytest.param([False, True], [False, True], id='false-in-list'), - pytest.param( - {'config': {'max_retries': 0, 'name': 'agent'}}, - {'config': {'max_retries': 0, 'name': 'agent'}}, - id='nested-zero', - ), - ], -) -def test_clean_empty_preserves_falsy_values(input_val, expected): - """_clean_empty preserves legitimate falsy values (0, False, 0.0).""" - assert _clean_empty(input_val) == expected - - -@pytest.mark.parametrize( - 'input_val,expected', - [ - pytest.param( - {'count': 0, 'label': '', 'items': []}, - {'count': 0}, - id='falsy-with-empties', - ), - pytest.param( - {'a': 0, 'b': 'hello', 'c': False, 'd': ''}, - {'a': 0, 'b': 'hello', 'c': False}, - id='mixed-types', - ), - pytest.param( - {'name': 'agent', 'retries': 0, 'tags': [], 'desc': ''}, - {'name': 'agent', 'retries': 0}, - id='realistic-mixed', - ), - ], -) -def test_clean_empty_mixed(input_val, expected): - """_clean_empty handles mixed empty and falsy values correctly.""" - assert _clean_empty(input_val) == expected - - -def test_clean_empty_does_not_mutate_input(): - """_clean_empty should not mutate the original input object.""" - original = {'a': '', 'b': 1, 'c': {'d': ''}} - original_copy = { - 'a': '', - 'b': 1, - 'c': {'d': ''}, - } - - _clean_empty(original) - - assert original == original_copy diff --git a/tests/utils/test_message.py b/tests/utils/test_message.py deleted file mode 100644 index c90d422aa..000000000 --- a/tests/utils/test_message.py +++ /dev/null @@ -1,209 +0,0 @@ -import uuid - -from unittest.mock import patch - -from google.protobuf.struct_pb2 import Struct, Value - -from a2a.types.a2a_pb2 import ( - Message, - Part, - Role, -) -from a2a.utils.message import ( - get_message_text, - new_agent_parts_message, - new_agent_text_message, -) - - -class TestNewAgentTextMessage: - def test_new_agent_text_message_basic(self): - # Setup - text = "Hello, I'm an agent" - - # Exercise - with a fixed uuid for testing - with patch( - 'uuid.uuid4', - return_value=uuid.UUID('12345678-1234-5678-1234-567812345678'), - ): - message = new_agent_text_message(text) - - # Verify - assert message.role == Role.ROLE_AGENT - assert len(message.parts) == 1 - assert message.parts[0].text == text - assert message.message_id == '12345678-1234-5678-1234-567812345678' - assert message.task_id == '' - assert message.context_id == '' - - def test_new_agent_text_message_with_context_id(self): - # Setup - text = 'Message with context' - context_id = 'test-context-id' - - # Exercise - with patch( - 'uuid.uuid4', - return_value=uuid.UUID('12345678-1234-5678-1234-567812345678'), - ): - message = new_agent_text_message(text, context_id=context_id) - - # Verify - assert message.role == Role.ROLE_AGENT - assert message.parts[0].text == text - assert message.message_id == '12345678-1234-5678-1234-567812345678' - assert message.context_id == context_id - assert message.task_id == '' - - def test_new_agent_text_message_with_task_id(self): - # Setup - text = 'Message with task id' - task_id = 'test-task-id' - - # Exercise - with patch( - 'uuid.uuid4', - return_value=uuid.UUID('12345678-1234-5678-1234-567812345678'), - ): - message = new_agent_text_message(text, task_id=task_id) - - # Verify - assert message.role == Role.ROLE_AGENT - assert message.parts[0].text == text - assert message.message_id == '12345678-1234-5678-1234-567812345678' - assert message.task_id == task_id - assert message.context_id == '' - - def test_new_agent_text_message_with_both_ids(self): - # Setup - text = 'Message with both ids' - context_id = 'test-context-id' - task_id = 'test-task-id' - - # Exercise - with patch( - 'uuid.uuid4', - return_value=uuid.UUID('12345678-1234-5678-1234-567812345678'), - ): - message = new_agent_text_message( - text, context_id=context_id, task_id=task_id - ) - - # Verify - assert message.role == Role.ROLE_AGENT - assert message.parts[0].text == text - assert message.message_id == '12345678-1234-5678-1234-567812345678' - assert message.context_id == context_id - assert message.task_id == task_id - - def test_new_agent_text_message_empty_text(self): - # Setup - text = '' - - # Exercise - with patch( - 'uuid.uuid4', - return_value=uuid.UUID('12345678-1234-5678-1234-567812345678'), - ): - message = new_agent_text_message(text) - - # Verify - assert message.role == Role.ROLE_AGENT - assert message.parts[0].text == '' - assert message.message_id == '12345678-1234-5678-1234-567812345678' - - -class TestNewAgentPartsMessage: - def test_new_agent_parts_message(self): - """Test creating an agent message with multiple, mixed parts.""" - # Setup - data = Struct() - data.update({'product_id': 123, 'quantity': 2}) - parts = [ - Part(text='Here is some text.'), - Part(data=Value(struct_value=data)), - ] - context_id = 'ctx-multi-part' - task_id = 'task-multi-part' - - # Exercise - with patch( - 'uuid.uuid4', - return_value=uuid.UUID('abcdefab-cdef-abcd-efab-cdefabcdefab'), - ): - message = new_agent_parts_message( - parts, context_id=context_id, task_id=task_id - ) - - # Verify - assert message.role == Role.ROLE_AGENT - assert len(message.parts) == len(parts) - assert message.context_id == context_id - assert message.task_id == task_id - assert message.message_id == 'abcdefab-cdef-abcd-efab-cdefabcdefab' - - -class TestGetMessageText: - def test_get_message_text_single_part(self): - # Setup - message = Message( - role=Role.ROLE_AGENT, - parts=[Part(text='Hello world')], - message_id='test-message-id', - ) - - # Exercise - result = get_message_text(message) - - # Verify - assert result == 'Hello world' - - def test_get_message_text_multiple_parts(self): - # Setup - message = Message( - role=Role.ROLE_AGENT, - parts=[ - Part(text='First line'), - Part(text='Second line'), - Part(text='Third line'), - ], - message_id='test-message-id', - ) - - # Exercise - result = get_message_text(message) - - # Verify - default delimiter is newline - assert result == 'First line\nSecond line\nThird line' - - def test_get_message_text_custom_delimiter(self): - # Setup - message = Message( - role=Role.ROLE_AGENT, - parts=[ - Part(text='First part'), - Part(text='Second part'), - Part(text='Third part'), - ], - message_id='test-message-id', - ) - - # Exercise - result = get_message_text(message, delimiter=' | ') - - # Verify - assert result == 'First part | Second part | Third part' - - def test_get_message_text_empty_parts(self): - # Setup - message = Message( - role=Role.ROLE_AGENT, - parts=[], - message_id='test-message-id', - ) - - # Exercise - result = get_message_text(message) - - # Verify - assert result == '' diff --git a/tests/utils/test_parts.py b/tests/utils/test_parts.py deleted file mode 100644 index a7a24e225..000000000 --- a/tests/utils/test_parts.py +++ /dev/null @@ -1,184 +0,0 @@ -from google.protobuf.struct_pb2 import Struct, Value -from a2a.types.a2a_pb2 import ( - Part, -) -from a2a.utils.parts import ( - get_data_parts, - get_file_parts, - get_text_parts, -) - - -class TestGetTextParts: - def test_get_text_parts_single_text_part(self): - # Setup - parts = [Part(text='Hello world')] - - # Exercise - result = get_text_parts(parts) - - # Verify - assert result == ['Hello world'] - - def test_get_text_parts_multiple_text_parts(self): - # Setup - parts = [ - Part(text='First part'), - Part(text='Second part'), - Part(text='Third part'), - ] - - # Exercise - result = get_text_parts(parts) - - # Verify - assert result == ['First part', 'Second part', 'Third part'] - - def test_get_text_parts_empty_list(self): - # Setup - parts = [] - - # Exercise - result = get_text_parts(parts) - - # Verify - assert result == [] - - -class TestGetDataParts: - def test_get_data_parts_single_data_part(self): - # Setup - data = Struct() - data.update({'key': 'value'}) - parts = [Part(data=Value(struct_value=data))] - - # Exercise - result = get_data_parts(parts) - - # Verify - assert result == [{'key': 'value'}] - - def test_get_data_parts_multiple_data_parts(self): - # Setup - data1 = Struct() - data1.update({'key1': 'value1'}) - data2 = Struct() - data2.update({'key2': 'value2'}) - parts = [ - Part(data=Value(struct_value=data1)), - Part(data=Value(struct_value=data2)), - ] - - # Exercise - result = get_data_parts(parts) - - # Verify - assert result == [{'key1': 'value1'}, {'key2': 'value2'}] - - def test_get_data_parts_mixed_parts(self): - # Setup - data1 = Struct() - data1.update({'key1': 'value1'}) - data2 = Struct() - data2.update({'key2': 'value2'}) - parts = [ - Part(text='some text'), - Part(data=Value(struct_value=data1)), - Part(data=Value(struct_value=data2)), - ] - - # Exercise - result = get_data_parts(parts) - - # Verify - assert result == [{'key1': 'value1'}, {'key2': 'value2'}] - - def test_get_data_parts_no_data_parts(self): - # Setup - parts = [ - Part(text='some text'), - ] - - # Exercise - result = get_data_parts(parts) - - # Verify - assert result == [] - - def test_get_data_parts_empty_list(self): - # Setup - parts = [] - - # Exercise - result = get_data_parts(parts) - - # Verify - assert result == [] - - -class TestGetFileParts: - def test_get_file_parts_single_file_part(self): - # Setup - parts = [Part(url='file://path/to/file', media_type='text/plain')] - - # Exercise - result = get_file_parts(parts) - - # Verify - assert len(result) == 1 - assert result[0].url == 'file://path/to/file' - assert result[0].media_type == 'text/plain' - - def test_get_file_parts_multiple_file_parts(self): - # Setup - parts = [ - Part(url='file://path/to/file1', media_type='text/plain'), - Part(raw=b'file content', media_type='application/octet-stream'), - ] - - # Exercise - result = get_file_parts(parts) - - # Verify - assert len(result) == 2 - assert result[0].url == 'file://path/to/file1' - assert result[1].raw == b'file content' - - def test_get_file_parts_mixed_parts(self): - # Setup - parts = [ - Part(text='some text'), - Part(url='file://path/to/file', media_type='text/plain'), - ] - - # Exercise - result = get_file_parts(parts) - - # Verify - assert len(result) == 1 - assert result[0].url == 'file://path/to/file' - - def test_get_file_parts_no_file_parts(self): - # Setup - data = Struct() - data.update({'key': 'value'}) - parts = [ - Part(text='some text'), - Part(data=Value(struct_value=data)), - ] - - # Exercise - result = get_file_parts(parts) - - # Verify - assert result == [] - - def test_get_file_parts_empty_list(self): - # Setup - parts = [] - - # Exercise - result = get_file_parts(parts) - - # Verify - assert result == [] diff --git a/tests/utils/test_signing.py b/tests/utils/test_signing.py index 162f28e28..2a09943fe 100644 --- a/tests/utils/test_signing.py +++ b/tests/utils/test_signing.py @@ -178,3 +178,111 @@ def test_signer_and_verifier_asymmetric(sample_agent_card: AgentCard): ) with pytest.raises(signing.InvalidSignaturesError): verifier_wrong_key(signed_card) + + +def test_canonicalize_agent_card(sample_agent_card: AgentCard): + """Test canonicalize_agent_card with defaults, optionals, and exceptions. + + - extensions is omitted as it's not set and optional. + - protocolVersion is included because it's always added by canonicalize_agent_card. + - signatures should be omitted. + """ + expected_jcs = ( + '{"capabilities":{"pushNotifications":true},' + '"defaultInputModes":["text/plain"],"defaultOutputModes":["text/plain"],' + '"description":"A test agent","name":"Test Agent",' + '"skills":[{"description":"A test skill","id":"skill1","name":"Test Skill","tags":["test"]}],' + '"supportedInterfaces":[{"protocolBinding":"HTTP+JSON","url":"http://localhost"}],' + '"version":"1.0.0"}' + ) + result = signing._canonicalize_agent_card(sample_agent_card) + assert result == expected_jcs + + +def test_canonicalize_agent_card_preserves_false_capability( + sample_agent_card: AgentCard, +): + """Regression #692: streaming=False must not be stripped from canonical JSON.""" + sample_agent_card.capabilities.streaming = False + result = signing._canonicalize_agent_card(sample_agent_card) + assert '"streaming":false' in result + + +@pytest.mark.parametrize( + 'input_val', + [ + pytest.param({'a': ''}, id='empty-string'), + pytest.param({'a': []}, id='empty-list'), + pytest.param({'a': {}}, id='empty-dict'), + pytest.param({'a': {'b': []}}, id='nested-empty'), + pytest.param({'a': '', 'b': [], 'c': {}}, id='all-empties'), + pytest.param({'a': {'b': {'c': ''}}}, id='deeply-nested'), + ], +) +def test_clean_empty_removes_empties(input_val): + """_clean_empty removes empty strings, lists, and dicts recursively.""" + assert signing._clean_empty(input_val) is None + + +def test_clean_empty_top_level_list_becomes_none(): + """Top-level list that becomes empty after cleaning should return None.""" + assert signing._clean_empty(['', {}, []]) is None + + +@pytest.mark.parametrize( + 'input_val,expected', + [ + pytest.param({'retries': 0}, {'retries': 0}, id='int-zero'), + pytest.param({'enabled': False}, {'enabled': False}, id='bool-false'), + pytest.param({'score': 0.0}, {'score': 0.0}, id='float-zero'), + pytest.param([0, 1, 2], [0, 1, 2], id='zero-in-list'), + pytest.param([False, True], [False, True], id='false-in-list'), + pytest.param( + {'config': {'max_retries': 0, 'name': 'agent'}}, + {'config': {'max_retries': 0, 'name': 'agent'}}, + id='nested-zero', + ), + ], +) +def test_clean_empty_preserves_falsy_values(input_val, expected): + """_clean_empty preserves legitimate falsy values (0, False, 0.0).""" + assert signing._clean_empty(input_val) == expected + + +@pytest.mark.parametrize( + 'input_val,expected', + [ + pytest.param( + {'count': 0, 'label': '', 'items': []}, + {'count': 0}, + id='falsy-with-empties', + ), + pytest.param( + {'a': 0, 'b': 'hello', 'c': False, 'd': ''}, + {'a': 0, 'b': 'hello', 'c': False}, + id='mixed-types', + ), + pytest.param( + {'name': 'agent', 'retries': 0, 'tags': [], 'desc': ''}, + {'name': 'agent', 'retries': 0}, + id='realistic-mixed', + ), + ], +) +def test_clean_empty_mixed(input_val, expected): + """_clean_empty handles mixed empty and falsy values correctly.""" + assert signing._clean_empty(input_val) == expected + + +def test_clean_empty_does_not_mutate_input(): + """_clean_empty should not mutate the original input object.""" + original = {'a': '', 'b': 1, 'c': {'d': ''}} + original_copy = { + 'a': '', + 'b': 1, + 'c': {'d': ''}, + } + + signing._clean_empty(original) + + assert original == original_copy diff --git a/tests/utils/test_task.py b/tests/utils/test_task.py index 3e1f3c058..55dc8ed4f 100644 --- a/tests/utils/test_task.py +++ b/tests/utils/test_task.py @@ -14,197 +14,16 @@ GetTaskRequest, SendMessageConfiguration, ) +from a2a.helpers.proto_helpers import new_task from a2a.utils.task import ( apply_history_length, - completed_task, decode_page_token, encode_page_token, - new_task, ) from a2a.utils.errors import InvalidParamsError class TestTask(unittest.TestCase): - def test_new_task_status(self): - message = Message( - role=Role.ROLE_USER, - parts=[Part(text='test message')], - message_id=str(uuid.uuid4()), - ) - task = new_task(message) - self.assertEqual(task.status.state, TaskState.TASK_STATE_SUBMITTED) - - @patch('uuid.uuid4') - def test_new_task_generates_ids(self, mock_uuid4): - mock_uuid = uuid.UUID('12345678-1234-5678-1234-567812345678') - mock_uuid4.return_value = mock_uuid - message = Message( - role=Role.ROLE_USER, - parts=[Part(text='test message')], - message_id=str(uuid.uuid4()), - ) - task = new_task(message) - self.assertEqual(task.id, str(mock_uuid)) - self.assertEqual(task.context_id, str(mock_uuid)) - - def test_new_task_uses_provided_ids(self): - task_id = str(uuid.uuid4()) - context_id = str(uuid.uuid4()) - message = Message( - role=Role.ROLE_USER, - parts=[Part(text='test message')], - message_id=str(uuid.uuid4()), - task_id=task_id, - context_id=context_id, - ) - task = new_task(message) - self.assertEqual(task.id, task_id) - self.assertEqual(task.context_id, context_id) - - def test_new_task_initial_message_in_history(self): - message = Message( - role=Role.ROLE_USER, - parts=[Part(text='test message')], - message_id=str(uuid.uuid4()), - ) - task = new_task(message) - self.assertEqual(len(task.history), 1) - self.assertEqual(task.history[0], message) - - def test_completed_task_status(self): - task_id = str(uuid.uuid4()) - context_id = str(uuid.uuid4()) - artifacts = [ - Artifact( - artifact_id='artifact_1', - parts=[Part(text='some content')], - ) - ] - task = completed_task( - task_id=task_id, - context_id=context_id, - artifacts=artifacts, - history=[], - ) - self.assertEqual(task.status.state, TaskState.TASK_STATE_COMPLETED) - - def test_completed_task_assigns_ids_and_artifacts(self): - task_id = str(uuid.uuid4()) - context_id = str(uuid.uuid4()) - artifacts = [ - Artifact( - artifact_id='artifact_1', - parts=[Part(text='some content')], - ) - ] - task = completed_task( - task_id=task_id, - context_id=context_id, - artifacts=artifacts, - history=[], - ) - self.assertEqual(task.id, task_id) - self.assertEqual(task.context_id, context_id) - self.assertEqual(len(task.artifacts), len(artifacts)) - - def test_completed_task_empty_history_if_not_provided(self): - task_id = str(uuid.uuid4()) - context_id = str(uuid.uuid4()) - artifacts = [ - Artifact( - artifact_id='artifact_1', - parts=[Part(text='some content')], - ) - ] - task = completed_task( - task_id=task_id, context_id=context_id, artifacts=artifacts - ) - self.assertEqual(len(task.history), 0) - - def test_completed_task_uses_provided_history(self): - task_id = str(uuid.uuid4()) - context_id = str(uuid.uuid4()) - artifacts = [ - Artifact( - artifact_id='artifact_1', - parts=[Part(text='some content')], - ) - ] - history = [ - Message( - role=Role.ROLE_USER, - parts=[Part(text='Hello')], - message_id=str(uuid.uuid4()), - ), - Message( - role=Role.ROLE_AGENT, - parts=[Part(text='Hi there')], - message_id=str(uuid.uuid4()), - ), - ] - task = completed_task( - task_id=task_id, - context_id=context_id, - artifacts=artifacts, - history=history, - ) - self.assertEqual(len(task.history), len(history)) - - def test_new_task_invalid_message_empty_parts(self): - with self.assertRaises(ValueError): - new_task( - Message( - role=Role.ROLE_USER, - parts=[], - message_id=str(uuid.uuid4()), - ) - ) - - def test_new_task_invalid_message_empty_content(self): - with self.assertRaises(ValueError): - new_task( - Message( - role=Role.ROLE_USER, - parts=[Part(text='')], - message_id=str(uuid.uuid4()), - ) - ) - - def test_new_task_invalid_message_none_role(self): - # Proto messages always have a default role (ROLE_UNSPECIFIED = 0) - # Testing with unspecified role - msg = Message( - role=Role.ROLE_UNSPECIFIED, - parts=[Part(text='test message')], - message_id=str(uuid.uuid4()), - ) - with self.assertRaises((TypeError, ValueError)): - new_task(msg) - - def test_completed_task_empty_artifacts(self): - with pytest.raises( - ValueError, - match='artifacts must be a non-empty list of Artifact objects', - ): - completed_task( - task_id='task-123', - context_id='ctx-456', - artifacts=[], - history=[], - ) - - def test_completed_task_invalid_artifact_type(self): - with pytest.raises( - ValueError, - match='artifacts must be a non-empty list of Artifact objects', - ): - completed_task( - task_id='task-123', - context_id='ctx-456', - artifacts=['not an artifact'], # type: ignore[arg-type] - history=[], - ) - page_token = 'd47a95ba-0f39-4459-965b-3923cdd2ff58' encoded_page_token = 'ZDQ3YTk1YmEtMGYzOS00NDU5LTk2NWItMzkyM2NkZDJmZjU4' # base64 for 'd47a95ba-0f39-4459-965b-3923cdd2ff58' @@ -234,9 +53,10 @@ def setUp(self): for i in range(5) ] artifacts = [Artifact(artifact_id='a1', parts=[Part(text='a')])] - self.task = completed_task( + self.task = new_task( task_id='t1', context_id='c1', + state=TaskState.TASK_STATE_COMPLETED, artifacts=artifacts, history=self.history, ) diff --git a/tests/utils/test_helpers_validation.py b/tests/utils/test_version_validation.py similarity index 98% rename from tests/utils/test_helpers_validation.py rename to tests/utils/test_version_validation.py index 571f8ae9b..b2ae0594e 100644 --- a/tests/utils/test_helpers_validation.py +++ b/tests/utils/test_version_validation.py @@ -6,7 +6,7 @@ from a2a.server.context import ServerCallContext from a2a.utils import constants from a2a.utils.errors import VersionNotSupportedError -from a2a.utils.helpers import validate_version +from a2a.utils.version_validator import validate_version class TestHandler: diff --git a/uv.lock b/uv.lock index dc87a7b6d..3bf3ba333 100644 --- a/uv.lock +++ b/uv.lock @@ -27,7 +27,6 @@ dependencies = [ all = [ { name = "alembic" }, { name = "cryptography" }, - { name = "google-cloud-aiplatform" }, { name = "grpcio" }, { name = "grpcio-reflection" }, { name = "grpcio-status" }, @@ -74,9 +73,6 @@ telemetry = [ { name = "opentelemetry-api" }, { name = "opentelemetry-sdk" }, ] -vertex = [ - { name = "google-cloud-aiplatform" }, -] [package.dev-dependencies] dev = [ @@ -109,8 +105,6 @@ requires-dist = [ { name = "cryptography", marker = "extra == 'encryption'", specifier = ">=43.0.0" }, { name = "culsans", marker = "python_full_version < '3.13'", specifier = ">=0.11.0" }, { name = "google-api-core", specifier = ">=1.26.0" }, - { name = "google-cloud-aiplatform", marker = "extra == 'all'", specifier = ">=1.140.0" }, - { name = "google-cloud-aiplatform", marker = "extra == 'vertex'", specifier = ">=1.140.0" }, { name = "googleapis-common-protos", specifier = ">=1.70.0" }, { name = "grpcio", marker = "extra == 'all'", specifier = ">=1.60" }, { name = "grpcio", marker = "extra == 'grpc'", specifier = ">=1.60" }, @@ -128,7 +122,7 @@ requires-dist = [ { name = "opentelemetry-sdk", marker = "extra == 'all'", specifier = ">=1.33.0" }, { name = "opentelemetry-sdk", marker = "extra == 'telemetry'", specifier = ">=1.33.0" }, { name = "packaging", specifier = ">=24.0" }, - { name = "protobuf", specifier = ">=5.29.5" }, + { name = "protobuf", specifier = ">=5.29.5,<7" }, { name = "pydantic", specifier = ">=2.11.3" }, { name = "pyjwt", marker = "extra == 'all'", specifier = ">=2.0.0" }, { name = "pyjwt", marker = "extra == 'signing'", specifier = ">=2.0.0" }, @@ -146,7 +140,7 @@ requires-dist = [ { name = "starlette", marker = "extra == 'all'" }, { name = "starlette", marker = "extra == 'http-server'" }, ] -provides-extras = ["all", "db-cli", "encryption", "grpc", "http-server", "mysql", "postgresql", "signing", "sql", "sqlite", "telemetry", "vertex"] +provides-extras = ["all", "db-cli", "encryption", "grpc", "http-server", "mysql", "postgresql", "signing", "sql", "sqlite", "telemetry"] [package.metadata.requires-dev] dev = [ @@ -685,62 +679,62 @@ toml = [ [[package]] name = "cryptography" -version = "46.0.5" +version = "46.0.7" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cffi", marker = "platform_python_implementation != 'PyPy'" }, { name = "typing-extensions", marker = "python_full_version < '3.11'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/60/04/ee2a9e8542e4fa2773b81771ff8349ff19cdd56b7258a0cc442639052edb/cryptography-46.0.5.tar.gz", hash = "sha256:abace499247268e3757271b2f1e244b36b06f8515cf27c4d49468fc9eb16e93d", size = 750064, upload-time = "2026-02-10T19:18:38.255Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/f7/81/b0bb27f2ba931a65409c6b8a8b358a7f03c0e46eceacddff55f7c84b1f3b/cryptography-46.0.5-cp311-abi3-macosx_10_9_universal2.whl", hash = "sha256:351695ada9ea9618b3500b490ad54c739860883df6c1f555e088eaf25b1bbaad", size = 7176289, upload-time = "2026-02-10T19:17:08.274Z" }, - { url = "https://files.pythonhosted.org/packages/ff/9e/6b4397a3e3d15123de3b1806ef342522393d50736c13b20ec4c9ea6693a6/cryptography-46.0.5-cp311-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:c18ff11e86df2e28854939acde2d003f7984f721eba450b56a200ad90eeb0e6b", size = 4275637, upload-time = "2026-02-10T19:17:10.53Z" }, - { url = "https://files.pythonhosted.org/packages/63/e7/471ab61099a3920b0c77852ea3f0ea611c9702f651600397ac567848b897/cryptography-46.0.5-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d7e3d356b8cd4ea5aff04f129d5f66ebdc7b6f8eae802b93739ed520c47c79b", size = 4424742, upload-time = "2026-02-10T19:17:12.388Z" }, - { url = "https://files.pythonhosted.org/packages/37/53/a18500f270342d66bf7e4d9f091114e31e5ee9e7375a5aba2e85a91e0044/cryptography-46.0.5-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:50bfb6925eff619c9c023b967d5b77a54e04256c4281b0e21336a130cd7fc263", size = 4277528, upload-time = "2026-02-10T19:17:13.853Z" }, - { url = "https://files.pythonhosted.org/packages/22/29/c2e812ebc38c57b40e7c583895e73c8c5adb4d1e4a0cc4c5a4fdab2b1acc/cryptography-46.0.5-cp311-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:803812e111e75d1aa73690d2facc295eaefd4439be1023fefc4995eaea2af90d", size = 4947993, upload-time = "2026-02-10T19:17:15.618Z" }, - { url = "https://files.pythonhosted.org/packages/6b/e7/237155ae19a9023de7e30ec64e5d99a9431a567407ac21170a046d22a5a3/cryptography-46.0.5-cp311-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:3ee190460e2fbe447175cda91b88b84ae8322a104fc27766ad09428754a618ed", size = 4456855, upload-time = "2026-02-10T19:17:17.221Z" }, - { url = "https://files.pythonhosted.org/packages/2d/87/fc628a7ad85b81206738abbd213b07702bcbdada1dd43f72236ef3cffbb5/cryptography-46.0.5-cp311-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:f145bba11b878005c496e93e257c1e88f154d278d2638e6450d17e0f31e558d2", size = 3984635, upload-time = "2026-02-10T19:17:18.792Z" }, - { url = "https://files.pythonhosted.org/packages/84/29/65b55622bde135aedf4565dc509d99b560ee4095e56989e815f8fd2aa910/cryptography-46.0.5-cp311-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:e9251e3be159d1020c4030bd2e5f84d6a43fe54b6c19c12f51cde9542a2817b2", size = 4277038, upload-time = "2026-02-10T19:17:20.256Z" }, - { url = "https://files.pythonhosted.org/packages/bc/36/45e76c68d7311432741faf1fbf7fac8a196a0a735ca21f504c75d37e2558/cryptography-46.0.5-cp311-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:47fb8a66058b80e509c47118ef8a75d14c455e81ac369050f20ba0d23e77fee0", size = 4912181, upload-time = "2026-02-10T19:17:21.825Z" }, - { url = "https://files.pythonhosted.org/packages/6d/1a/c1ba8fead184d6e3d5afcf03d569acac5ad063f3ac9fb7258af158f7e378/cryptography-46.0.5-cp311-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:4c3341037c136030cb46e4b1e17b7418ea4cbd9dd207e4a6f3b2b24e0d4ac731", size = 4456482, upload-time = "2026-02-10T19:17:25.133Z" }, - { url = "https://files.pythonhosted.org/packages/f9/e5/3fb22e37f66827ced3b902cf895e6a6bc1d095b5b26be26bd13c441fdf19/cryptography-46.0.5-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:890bcb4abd5a2d3f852196437129eb3667d62630333aacc13dfd470fad3aaa82", size = 4405497, upload-time = "2026-02-10T19:17:26.66Z" }, - { url = "https://files.pythonhosted.org/packages/1a/df/9d58bb32b1121a8a2f27383fabae4d63080c7ca60b9b5c88be742be04ee7/cryptography-46.0.5-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:80a8d7bfdf38f87ca30a5391c0c9ce4ed2926918e017c29ddf643d0ed2778ea1", size = 4667819, upload-time = "2026-02-10T19:17:28.569Z" }, - { url = "https://files.pythonhosted.org/packages/ea/ed/325d2a490c5e94038cdb0117da9397ece1f11201f425c4e9c57fe5b9f08b/cryptography-46.0.5-cp311-abi3-win32.whl", hash = "sha256:60ee7e19e95104d4c03871d7d7dfb3d22ef8a9b9c6778c94e1c8fcc8365afd48", size = 3028230, upload-time = "2026-02-10T19:17:30.518Z" }, - { url = "https://files.pythonhosted.org/packages/e9/5a/ac0f49e48063ab4255d9e3b79f5def51697fce1a95ea1370f03dc9db76f6/cryptography-46.0.5-cp311-abi3-win_amd64.whl", hash = "sha256:38946c54b16c885c72c4f59846be9743d699eee2b69b6988e0a00a01f46a61a4", size = 3480909, upload-time = "2026-02-10T19:17:32.083Z" }, - { url = "https://files.pythonhosted.org/packages/00/13/3d278bfa7a15a96b9dc22db5a12ad1e48a9eb3d40e1827ef66a5df75d0d0/cryptography-46.0.5-cp314-cp314t-macosx_10_9_universal2.whl", hash = "sha256:94a76daa32eb78d61339aff7952ea819b1734b46f73646a07decb40e5b3448e2", size = 7119287, upload-time = "2026-02-10T19:17:33.801Z" }, - { url = "https://files.pythonhosted.org/packages/67/c8/581a6702e14f0898a0848105cbefd20c058099e2c2d22ef4e476dfec75d7/cryptography-46.0.5-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5be7bf2fb40769e05739dd0046e7b26f9d4670badc7b032d6ce4db64dddc0678", size = 4265728, upload-time = "2026-02-10T19:17:35.569Z" }, - { url = "https://files.pythonhosted.org/packages/dd/4a/ba1a65ce8fc65435e5a849558379896c957870dd64fecea97b1ad5f46a37/cryptography-46.0.5-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:fe346b143ff9685e40192a4960938545c699054ba11d4f9029f94751e3f71d87", size = 4408287, upload-time = "2026-02-10T19:17:36.938Z" }, - { url = "https://files.pythonhosted.org/packages/f8/67/8ffdbf7b65ed1ac224d1c2df3943553766914a8ca718747ee3871da6107e/cryptography-46.0.5-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:c69fd885df7d089548a42d5ec05be26050ebcd2283d89b3d30676eb32ff87dee", size = 4270291, upload-time = "2026-02-10T19:17:38.748Z" }, - { url = "https://files.pythonhosted.org/packages/f8/e5/f52377ee93bc2f2bba55a41a886fd208c15276ffbd2569f2ddc89d50e2c5/cryptography-46.0.5-cp314-cp314t-manylinux_2_28_ppc64le.whl", hash = "sha256:8293f3dea7fc929ef7240796ba231413afa7b68ce38fd21da2995549f5961981", size = 4927539, upload-time = "2026-02-10T19:17:40.241Z" }, - { url = "https://files.pythonhosted.org/packages/3b/02/cfe39181b02419bbbbcf3abdd16c1c5c8541f03ca8bda240debc467d5a12/cryptography-46.0.5-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:1abfdb89b41c3be0365328a410baa9df3ff8a9110fb75e7b52e66803ddabc9a9", size = 4442199, upload-time = "2026-02-10T19:17:41.789Z" }, - { url = "https://files.pythonhosted.org/packages/c0/96/2fcaeb4873e536cf71421a388a6c11b5bc846e986b2b069c79363dc1648e/cryptography-46.0.5-cp314-cp314t-manylinux_2_31_armv7l.whl", hash = "sha256:d66e421495fdb797610a08f43b05269e0a5ea7f5e652a89bfd5a7d3c1dee3648", size = 3960131, upload-time = "2026-02-10T19:17:43.379Z" }, - { url = "https://files.pythonhosted.org/packages/d8/d2/b27631f401ddd644e94c5cf33c9a4069f72011821cf3dc7309546b0642a0/cryptography-46.0.5-cp314-cp314t-manylinux_2_34_aarch64.whl", hash = "sha256:4e817a8920bfbcff8940ecfd60f23d01836408242b30f1a708d93198393a80b4", size = 4270072, upload-time = "2026-02-10T19:17:45.481Z" }, - { url = "https://files.pythonhosted.org/packages/f4/a7/60d32b0370dae0b4ebe55ffa10e8599a2a59935b5ece1b9f06edb73abdeb/cryptography-46.0.5-cp314-cp314t-manylinux_2_34_ppc64le.whl", hash = "sha256:68f68d13f2e1cb95163fa3b4db4bf9a159a418f5f6e7242564fc75fcae667fd0", size = 4892170, upload-time = "2026-02-10T19:17:46.997Z" }, - { url = "https://files.pythonhosted.org/packages/d2/b9/cf73ddf8ef1164330eb0b199a589103c363afa0cf794218c24d524a58eab/cryptography-46.0.5-cp314-cp314t-manylinux_2_34_x86_64.whl", hash = "sha256:a3d1fae9863299076f05cb8a778c467578262fae09f9dc0ee9b12eb4268ce663", size = 4441741, upload-time = "2026-02-10T19:17:48.661Z" }, - { url = "https://files.pythonhosted.org/packages/5f/eb/eee00b28c84c726fe8fa0158c65afe312d9c3b78d9d01daf700f1f6e37ff/cryptography-46.0.5-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:c4143987a42a2397f2fc3b4d7e3a7d313fbe684f67ff443999e803dd75a76826", size = 4396728, upload-time = "2026-02-10T19:17:50.058Z" }, - { url = "https://files.pythonhosted.org/packages/65/f4/6bc1a9ed5aef7145045114b75b77c2a8261b4d38717bd8dea111a63c3442/cryptography-46.0.5-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:7d731d4b107030987fd61a7f8ab512b25b53cef8f233a97379ede116f30eb67d", size = 4652001, upload-time = "2026-02-10T19:17:51.54Z" }, - { url = "https://files.pythonhosted.org/packages/86/ef/5d00ef966ddd71ac2e6951d278884a84a40ffbd88948ef0e294b214ae9e4/cryptography-46.0.5-cp314-cp314t-win32.whl", hash = "sha256:c3bcce8521d785d510b2aad26ae2c966092b7daa8f45dd8f44734a104dc0bc1a", size = 3003637, upload-time = "2026-02-10T19:17:52.997Z" }, - { url = "https://files.pythonhosted.org/packages/b7/57/f3f4160123da6d098db78350fdfd9705057aad21de7388eacb2401dceab9/cryptography-46.0.5-cp314-cp314t-win_amd64.whl", hash = "sha256:4d8ae8659ab18c65ced284993c2265910f6c9e650189d4e3f68445ef82a810e4", size = 3469487, upload-time = "2026-02-10T19:17:54.549Z" }, - { url = "https://files.pythonhosted.org/packages/e2/fa/a66aa722105ad6a458bebd64086ca2b72cdd361fed31763d20390f6f1389/cryptography-46.0.5-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:4108d4c09fbbf2789d0c926eb4152ae1760d5a2d97612b92d508d96c861e4d31", size = 7170514, upload-time = "2026-02-10T19:17:56.267Z" }, - { url = "https://files.pythonhosted.org/packages/0f/04/c85bdeab78c8bc77b701bf0d9bdcf514c044e18a46dcff330df5448631b0/cryptography-46.0.5-cp38-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:7d1f30a86d2757199cb2d56e48cce14deddf1f9c95f1ef1b64ee91ea43fe2e18", size = 4275349, upload-time = "2026-02-10T19:17:58.419Z" }, - { url = "https://files.pythonhosted.org/packages/5c/32/9b87132a2f91ee7f5223b091dc963055503e9b442c98fc0b8a5ca765fab0/cryptography-46.0.5-cp38-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:039917b0dc418bb9f6edce8a906572d69e74bd330b0b3fea4f79dab7f8ddd235", size = 4420667, upload-time = "2026-02-10T19:18:00.619Z" }, - { url = "https://files.pythonhosted.org/packages/a1/a6/a7cb7010bec4b7c5692ca6f024150371b295ee1c108bdc1c400e4c44562b/cryptography-46.0.5-cp38-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:ba2a27ff02f48193fc4daeadf8ad2590516fa3d0adeeb34336b96f7fa64c1e3a", size = 4276980, upload-time = "2026-02-10T19:18:02.379Z" }, - { url = "https://files.pythonhosted.org/packages/8e/7c/c4f45e0eeff9b91e3f12dbd0e165fcf2a38847288fcfd889deea99fb7b6d/cryptography-46.0.5-cp38-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:61aa400dce22cb001a98014f647dc21cda08f7915ceb95df0c9eaf84b4b6af76", size = 4939143, upload-time = "2026-02-10T19:18:03.964Z" }, - { url = "https://files.pythonhosted.org/packages/37/19/e1b8f964a834eddb44fa1b9a9976f4e414cbb7aa62809b6760c8803d22d1/cryptography-46.0.5-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:3ce58ba46e1bc2aac4f7d9290223cead56743fa6ab94a5d53292ffaac6a91614", size = 4453674, upload-time = "2026-02-10T19:18:05.588Z" }, - { url = "https://files.pythonhosted.org/packages/db/ed/db15d3956f65264ca204625597c410d420e26530c4e2943e05a0d2f24d51/cryptography-46.0.5-cp38-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:420d0e909050490d04359e7fdb5ed7e667ca5c3c402b809ae2563d7e66a92229", size = 3978801, upload-time = "2026-02-10T19:18:07.167Z" }, - { url = "https://files.pythonhosted.org/packages/41/e2/df40a31d82df0a70a0daf69791f91dbb70e47644c58581d654879b382d11/cryptography-46.0.5-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:582f5fcd2afa31622f317f80426a027f30dc792e9c80ffee87b993200ea115f1", size = 4276755, upload-time = "2026-02-10T19:18:09.813Z" }, - { url = "https://files.pythonhosted.org/packages/33/45/726809d1176959f4a896b86907b98ff4391a8aa29c0aaaf9450a8a10630e/cryptography-46.0.5-cp38-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:bfd56bb4b37ed4f330b82402f6f435845a5f5648edf1ad497da51a8452d5d62d", size = 4901539, upload-time = "2026-02-10T19:18:11.263Z" }, - { url = "https://files.pythonhosted.org/packages/99/0f/a3076874e9c88ecb2ecc31382f6e7c21b428ede6f55aafa1aa272613e3cd/cryptography-46.0.5-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:a3d507bb6a513ca96ba84443226af944b0f7f47dcc9a399d110cd6146481d24c", size = 4452794, upload-time = "2026-02-10T19:18:12.914Z" }, - { url = "https://files.pythonhosted.org/packages/02/ef/ffeb542d3683d24194a38f66ca17c0a4b8bf10631feef44a7ef64e631b1a/cryptography-46.0.5-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:9f16fbdf4da055efb21c22d81b89f155f02ba420558db21288b3d0035bafd5f4", size = 4404160, upload-time = "2026-02-10T19:18:14.375Z" }, - { url = "https://files.pythonhosted.org/packages/96/93/682d2b43c1d5f1406ed048f377c0fc9fc8f7b0447a478d5c65ab3d3a66eb/cryptography-46.0.5-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:ced80795227d70549a411a4ab66e8ce307899fad2220ce5ab2f296e687eacde9", size = 4667123, upload-time = "2026-02-10T19:18:15.886Z" }, - { url = "https://files.pythonhosted.org/packages/45/2d/9c5f2926cb5300a8eefc3f4f0b3f3df39db7f7ce40c8365444c49363cbda/cryptography-46.0.5-cp38-abi3-win32.whl", hash = "sha256:02f547fce831f5096c9a567fd41bc12ca8f11df260959ecc7c3202555cc47a72", size = 3010220, upload-time = "2026-02-10T19:18:17.361Z" }, - { url = "https://files.pythonhosted.org/packages/48/ef/0c2f4a8e31018a986949d34a01115dd057bf536905dca38897bacd21fac3/cryptography-46.0.5-cp38-abi3-win_amd64.whl", hash = "sha256:556e106ee01aa13484ce9b0239bca667be5004efb0aabbed28d353df86445595", size = 3467050, upload-time = "2026-02-10T19:18:18.899Z" }, - { url = "https://files.pythonhosted.org/packages/eb/dd/2d9fdb07cebdf3d51179730afb7d5e576153c6744c3ff8fded23030c204e/cryptography-46.0.5-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:3b4995dc971c9fb83c25aa44cf45f02ba86f71ee600d81091c2f0cbae116b06c", size = 3476964, upload-time = "2026-02-10T19:18:20.687Z" }, - { url = "https://files.pythonhosted.org/packages/e9/6f/6cc6cc9955caa6eaf83660b0da2b077c7fe8ff9950a3c5e45d605038d439/cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:bc84e875994c3b445871ea7181d424588171efec3e185dced958dad9e001950a", size = 4218321, upload-time = "2026-02-10T19:18:22.349Z" }, - { url = "https://files.pythonhosted.org/packages/3e/5d/c4da701939eeee699566a6c1367427ab91a8b7088cc2328c09dbee940415/cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:2ae6971afd6246710480e3f15824ed3029a60fc16991db250034efd0b9fb4356", size = 4381786, upload-time = "2026-02-10T19:18:24.529Z" }, - { url = "https://files.pythonhosted.org/packages/ac/97/a538654732974a94ff96c1db621fa464f455c02d4bb7d2652f4edc21d600/cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:d861ee9e76ace6cf36a6a89b959ec08e7bc2493ee39d07ffe5acb23ef46d27da", size = 4217990, upload-time = "2026-02-10T19:18:25.957Z" }, - { url = "https://files.pythonhosted.org/packages/ae/11/7e500d2dd3ba891197b9efd2da5454b74336d64a7cc419aa7327ab74e5f6/cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:2b7a67c9cd56372f3249b39699f2ad479f6991e62ea15800973b956f4b73e257", size = 4381252, upload-time = "2026-02-10T19:18:27.496Z" }, - { url = "https://files.pythonhosted.org/packages/bc/58/6b3d24e6b9bc474a2dcdee65dfd1f008867015408a271562e4b690561a4d/cryptography-46.0.5-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:8456928655f856c6e1533ff59d5be76578a7157224dbd9ce6872f25055ab9ab7", size = 3407605, upload-time = "2026-02-10T19:18:29.233Z" }, +sdist = { url = "https://files.pythonhosted.org/packages/47/93/ac8f3d5ff04d54bc814e961a43ae5b0b146154c89c61b47bb07557679b18/cryptography-46.0.7.tar.gz", hash = "sha256:e4cfd68c5f3e0bfdad0d38e023239b96a2fe84146481852dffbcca442c245aa5", size = 750652, upload-time = "2026-04-08T01:57:54.692Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0b/5d/4a8f770695d73be252331e60e526291e3df0c9b27556a90a6b47bccca4c2/cryptography-46.0.7-cp311-abi3-macosx_10_9_universal2.whl", hash = "sha256:ea42cbe97209df307fdc3b155f1b6fa2577c0defa8f1f7d3be7d31d189108ad4", size = 7179869, upload-time = "2026-04-08T01:56:17.157Z" }, + { url = "https://files.pythonhosted.org/packages/5f/45/6d80dc379b0bbc1f9d1e429f42e4cb9e1d319c7a8201beffd967c516ea01/cryptography-46.0.7-cp311-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:b36a4695e29fe69215d75960b22577197aca3f7a25b9cf9d165dcfe9d80bc325", size = 4275492, upload-time = "2026-04-08T01:56:19.36Z" }, + { url = "https://files.pythonhosted.org/packages/4a/9a/1765afe9f572e239c3469f2cb429f3ba7b31878c893b246b4b2994ffe2fe/cryptography-46.0.7-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5ad9ef796328c5e3c4ceed237a183f5d41d21150f972455a9d926593a1dcb308", size = 4426670, upload-time = "2026-04-08T01:56:21.415Z" }, + { url = "https://files.pythonhosted.org/packages/8f/3e/af9246aaf23cd4ee060699adab1e47ced3f5f7e7a8ffdd339f817b446462/cryptography-46.0.7-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:73510b83623e080a2c35c62c15298096e2a5dc8d51c3b4e1740211839d0dea77", size = 4280275, upload-time = "2026-04-08T01:56:23.539Z" }, + { url = "https://files.pythonhosted.org/packages/0f/54/6bbbfc5efe86f9d71041827b793c24811a017c6ac0fd12883e4caa86b8ed/cryptography-46.0.7-cp311-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:cbd5fb06b62bd0721e1170273d3f4d5a277044c47ca27ee257025146c34cbdd1", size = 4928402, upload-time = "2026-04-08T01:56:25.624Z" }, + { url = "https://files.pythonhosted.org/packages/2d/cf/054b9d8220f81509939599c8bdbc0c408dbd2bdd41688616a20731371fe0/cryptography-46.0.7-cp311-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:420b1e4109cc95f0e5700eed79908cef9268265c773d3a66f7af1eef53d409ef", size = 4459985, upload-time = "2026-04-08T01:56:27.309Z" }, + { url = "https://files.pythonhosted.org/packages/f9/46/4e4e9c6040fb01c7467d47217d2f882daddeb8828f7df800cb806d8a2288/cryptography-46.0.7-cp311-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:24402210aa54baae71d99441d15bb5a1919c195398a87b563df84468160a65de", size = 3990652, upload-time = "2026-04-08T01:56:29.095Z" }, + { url = "https://files.pythonhosted.org/packages/36/5f/313586c3be5a2fbe87e4c9a254207b860155a8e1f3cca99f9910008e7d08/cryptography-46.0.7-cp311-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:8a469028a86f12eb7d2fe97162d0634026d92a21f3ae0ac87ed1c4a447886c83", size = 4279805, upload-time = "2026-04-08T01:56:30.928Z" }, + { url = "https://files.pythonhosted.org/packages/69/33/60dfc4595f334a2082749673386a4d05e4f0cf4df8248e63b2c3437585f2/cryptography-46.0.7-cp311-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:9694078c5d44c157ef3162e3bf3946510b857df5a3955458381d1c7cfc143ddb", size = 4892883, upload-time = "2026-04-08T01:56:32.614Z" }, + { url = "https://files.pythonhosted.org/packages/c7/0b/333ddab4270c4f5b972f980adef4faa66951a4aaf646ca067af597f15563/cryptography-46.0.7-cp311-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:42a1e5f98abb6391717978baf9f90dc28a743b7d9be7f0751a6f56a75d14065b", size = 4459756, upload-time = "2026-04-08T01:56:34.306Z" }, + { url = "https://files.pythonhosted.org/packages/d2/14/633913398b43b75f1234834170947957c6b623d1701ffc7a9600da907e89/cryptography-46.0.7-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:91bbcb08347344f810cbe49065914fe048949648f6bd5c2519f34619142bbe85", size = 4410244, upload-time = "2026-04-08T01:56:35.977Z" }, + { url = "https://files.pythonhosted.org/packages/10/f2/19ceb3b3dc14009373432af0c13f46aa08e3ce334ec6eff13492e1812ccd/cryptography-46.0.7-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:5d1c02a14ceb9148cc7816249f64f623fbfee39e8c03b3650d842ad3f34d637e", size = 4674868, upload-time = "2026-04-08T01:56:38.034Z" }, + { url = "https://files.pythonhosted.org/packages/1a/bb/a5c213c19ee94b15dfccc48f363738633a493812687f5567addbcbba9f6f/cryptography-46.0.7-cp311-abi3-win32.whl", hash = "sha256:d23c8ca48e44ee015cd0a54aeccdf9f09004eba9fc96f38c911011d9ff1bd457", size = 3026504, upload-time = "2026-04-08T01:56:39.666Z" }, + { url = "https://files.pythonhosted.org/packages/2b/02/7788f9fefa1d060ca68717c3901ae7fffa21ee087a90b7f23c7a603c32ae/cryptography-46.0.7-cp311-abi3-win_amd64.whl", hash = "sha256:397655da831414d165029da9bc483bed2fe0e75dde6a1523ec2fe63f3c46046b", size = 3488363, upload-time = "2026-04-08T01:56:41.893Z" }, + { url = "https://files.pythonhosted.org/packages/7b/56/15619b210e689c5403bb0540e4cb7dbf11a6bf42e483b7644e471a2812b3/cryptography-46.0.7-cp314-cp314t-macosx_10_9_universal2.whl", hash = "sha256:d151173275e1728cf7839aaa80c34fe550c04ddb27b34f48c232193df8db5842", size = 7119671, upload-time = "2026-04-08T01:56:44Z" }, + { url = "https://files.pythonhosted.org/packages/74/66/e3ce040721b0b5599e175ba91ab08884c75928fbeb74597dd10ef13505d2/cryptography-46.0.7-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:db0f493b9181c7820c8134437eb8b0b4792085d37dbb24da050476ccb664e59c", size = 4268551, upload-time = "2026-04-08T01:56:46.071Z" }, + { url = "https://files.pythonhosted.org/packages/03/11/5e395f961d6868269835dee1bafec6a1ac176505a167f68b7d8818431068/cryptography-46.0.7-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ebd6daf519b9f189f85c479427bbd6e9c9037862cf8fe89ee35503bd209ed902", size = 4408887, upload-time = "2026-04-08T01:56:47.718Z" }, + { url = "https://files.pythonhosted.org/packages/40/53/8ed1cf4c3b9c8e611e7122fb56f1c32d09e1fff0f1d77e78d9ff7c82653e/cryptography-46.0.7-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:b7b412817be92117ec5ed95f880defe9cf18a832e8cafacf0a22337dc1981b4d", size = 4271354, upload-time = "2026-04-08T01:56:49.312Z" }, + { url = "https://files.pythonhosted.org/packages/50/46/cf71e26025c2e767c5609162c866a78e8a2915bbcfa408b7ca495c6140c4/cryptography-46.0.7-cp314-cp314t-manylinux_2_28_ppc64le.whl", hash = "sha256:fbfd0e5f273877695cb93baf14b185f4878128b250cc9f8e617ea0c025dfb022", size = 4905845, upload-time = "2026-04-08T01:56:50.916Z" }, + { url = "https://files.pythonhosted.org/packages/c0/ea/01276740375bac6249d0a971ebdf6b4dc9ead0ee0a34ef3b5a88c1a9b0d4/cryptography-46.0.7-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:ffca7aa1d00cf7d6469b988c581598f2259e46215e0140af408966a24cf086ce", size = 4444641, upload-time = "2026-04-08T01:56:52.882Z" }, + { url = "https://files.pythonhosted.org/packages/3d/4c/7d258f169ae71230f25d9f3d06caabcff8c3baf0978e2b7d65e0acac3827/cryptography-46.0.7-cp314-cp314t-manylinux_2_31_armv7l.whl", hash = "sha256:60627cf07e0d9274338521205899337c5d18249db56865f943cbe753aa96f40f", size = 3967749, upload-time = "2026-04-08T01:56:54.597Z" }, + { url = "https://files.pythonhosted.org/packages/b5/2a/2ea0767cad19e71b3530e4cad9605d0b5e338b6a1e72c37c9c1ceb86c333/cryptography-46.0.7-cp314-cp314t-manylinux_2_34_aarch64.whl", hash = "sha256:80406c3065e2c55d7f49a9550fe0c49b3f12e5bfff5dedb727e319e1afb9bf99", size = 4270942, upload-time = "2026-04-08T01:56:56.416Z" }, + { url = "https://files.pythonhosted.org/packages/41/3d/fe14df95a83319af25717677e956567a105bb6ab25641acaa093db79975d/cryptography-46.0.7-cp314-cp314t-manylinux_2_34_ppc64le.whl", hash = "sha256:c5b1ccd1239f48b7151a65bc6dd54bcfcc15e028c8ac126d3fada09db0e07ef1", size = 4871079, upload-time = "2026-04-08T01:56:58.31Z" }, + { url = "https://files.pythonhosted.org/packages/9c/59/4a479e0f36f8f378d397f4eab4c850b4ffb79a2f0d58704b8fa0703ddc11/cryptography-46.0.7-cp314-cp314t-manylinux_2_34_x86_64.whl", hash = "sha256:d5f7520159cd9c2154eb61eb67548ca05c5774d39e9c2c4339fd793fe7d097b2", size = 4443999, upload-time = "2026-04-08T01:57:00.508Z" }, + { url = "https://files.pythonhosted.org/packages/28/17/b59a741645822ec6d04732b43c5d35e4ef58be7bfa84a81e5ae6f05a1d33/cryptography-46.0.7-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:fcd8eac50d9138c1d7fc53a653ba60a2bee81a505f9f8850b6b2888555a45d0e", size = 4399191, upload-time = "2026-04-08T01:57:02.654Z" }, + { url = "https://files.pythonhosted.org/packages/59/6a/bb2e166d6d0e0955f1e9ff70f10ec4b2824c9cfcdb4da772c7dd69cc7d80/cryptography-46.0.7-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:65814c60f8cc400c63131584e3e1fad01235edba2614b61fbfbfa954082db0ee", size = 4655782, upload-time = "2026-04-08T01:57:04.592Z" }, + { url = "https://files.pythonhosted.org/packages/95/b6/3da51d48415bcb63b00dc17c2eff3a651b7c4fed484308d0f19b30e8cb2c/cryptography-46.0.7-cp314-cp314t-win32.whl", hash = "sha256:fdd1736fed309b4300346f88f74cd120c27c56852c3838cab416e7a166f67298", size = 3002227, upload-time = "2026-04-08T01:57:06.91Z" }, + { url = "https://files.pythonhosted.org/packages/32/a8/9f0e4ed57ec9cebe506e58db11ae472972ecb0c659e4d52bbaee80ca340a/cryptography-46.0.7-cp314-cp314t-win_amd64.whl", hash = "sha256:e06acf3c99be55aa3b516397fe42f5855597f430add9c17fa46bf2e0fb34c9bb", size = 3475332, upload-time = "2026-04-08T01:57:08.807Z" }, + { url = "https://files.pythonhosted.org/packages/a7/7f/cd42fc3614386bc0c12f0cb3c4ae1fc2bbca5c9662dfed031514911d513d/cryptography-46.0.7-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:462ad5cb1c148a22b2e3bcc5ad52504dff325d17daf5df8d88c17dda1f75f2a4", size = 7165618, upload-time = "2026-04-08T01:57:10.645Z" }, + { url = "https://files.pythonhosted.org/packages/a5/d0/36a49f0262d2319139d2829f773f1b97ef8aef7f97e6e5bd21455e5a8fb5/cryptography-46.0.7-cp38-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:84d4cced91f0f159a7ddacad249cc077e63195c36aac40b4150e7a57e84fffe7", size = 4270628, upload-time = "2026-04-08T01:57:12.885Z" }, + { url = "https://files.pythonhosted.org/packages/8a/6c/1a42450f464dda6ffbe578a911f773e54dd48c10f9895a23a7e88b3e7db5/cryptography-46.0.7-cp38-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:128c5edfe5e5938b86b03941e94fac9ee793a94452ad1365c9fc3f4f62216832", size = 4415405, upload-time = "2026-04-08T01:57:14.923Z" }, + { url = "https://files.pythonhosted.org/packages/9a/92/4ed714dbe93a066dc1f4b4581a464d2d7dbec9046f7c8b7016f5286329e2/cryptography-46.0.7-cp38-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:5e51be372b26ef4ba3de3c167cd3d1022934bc838ae9eaad7e644986d2a3d163", size = 4272715, upload-time = "2026-04-08T01:57:16.638Z" }, + { url = "https://files.pythonhosted.org/packages/b7/e6/a26b84096eddd51494bba19111f8fffe976f6a09f132706f8f1bf03f51f7/cryptography-46.0.7-cp38-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:cdf1a610ef82abb396451862739e3fc93b071c844399e15b90726ef7470eeaf2", size = 4918400, upload-time = "2026-04-08T01:57:19.021Z" }, + { url = "https://files.pythonhosted.org/packages/c7/08/ffd537b605568a148543ac3c2b239708ae0bd635064bab41359252ef88ed/cryptography-46.0.7-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:1d25aee46d0c6f1a501adcddb2d2fee4b979381346a78558ed13e50aa8a59067", size = 4450634, upload-time = "2026-04-08T01:57:21.185Z" }, + { url = "https://files.pythonhosted.org/packages/16/01/0cd51dd86ab5b9befe0d031e276510491976c3a80e9f6e31810cce46c4ad/cryptography-46.0.7-cp38-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:cdfbe22376065ffcf8be74dc9a909f032df19bc58a699456a21712d6e5eabfd0", size = 3985233, upload-time = "2026-04-08T01:57:22.862Z" }, + { url = "https://files.pythonhosted.org/packages/92/49/819d6ed3a7d9349c2939f81b500a738cb733ab62fbecdbc1e38e83d45e12/cryptography-46.0.7-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:abad9dac36cbf55de6eb49badd4016806b3165d396f64925bf2999bcb67837ba", size = 4271955, upload-time = "2026-04-08T01:57:24.814Z" }, + { url = "https://files.pythonhosted.org/packages/80/07/ad9b3c56ebb95ed2473d46df0847357e01583f4c52a85754d1a55e29e4d0/cryptography-46.0.7-cp38-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:935ce7e3cfdb53e3536119a542b839bb94ec1ad081013e9ab9b7cfd478b05006", size = 4879888, upload-time = "2026-04-08T01:57:26.88Z" }, + { url = "https://files.pythonhosted.org/packages/b8/c7/201d3d58f30c4c2bdbe9b03844c291feb77c20511cc3586daf7edc12a47b/cryptography-46.0.7-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:35719dc79d4730d30f1c2b6474bd6acda36ae2dfae1e3c16f2051f215df33ce0", size = 4449961, upload-time = "2026-04-08T01:57:29.068Z" }, + { url = "https://files.pythonhosted.org/packages/a5/ef/649750cbf96f3033c3c976e112265c33906f8e462291a33d77f90356548c/cryptography-46.0.7-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:7bbc6ccf49d05ac8f7d7b5e2e2c33830d4fe2061def88210a126d130d7f71a85", size = 4401696, upload-time = "2026-04-08T01:57:31.029Z" }, + { url = "https://files.pythonhosted.org/packages/41/52/a8908dcb1a389a459a29008c29966c1d552588d4ae6d43f3a1a4512e0ebe/cryptography-46.0.7-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a1529d614f44b863a7b480c6d000fe93b59acee9c82ffa027cfadc77521a9f5e", size = 4664256, upload-time = "2026-04-08T01:57:33.144Z" }, + { url = "https://files.pythonhosted.org/packages/4b/fa/f0ab06238e899cc3fb332623f337a7364f36f4bb3f2534c2bb95a35b132c/cryptography-46.0.7-cp38-abi3-win32.whl", hash = "sha256:f247c8c1a1fb45e12586afbb436ef21ff1e80670b2861a90353d9b025583d246", size = 3013001, upload-time = "2026-04-08T01:57:34.933Z" }, + { url = "https://files.pythonhosted.org/packages/d2/f1/00ce3bde3ca542d1acd8f8cfa38e446840945aa6363f9b74746394b14127/cryptography-46.0.7-cp38-abi3-win_amd64.whl", hash = "sha256:506c4ff91eff4f82bdac7633318a526b1d1309fc07ca76a3ad182cb5b686d6d3", size = 3472985, upload-time = "2026-04-08T01:57:36.714Z" }, + { url = "https://files.pythonhosted.org/packages/63/0c/dca8abb64e7ca4f6b2978769f6fea5ad06686a190cec381f0a796fdcaaba/cryptography-46.0.7-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:fc9ab8856ae6cf7c9358430e49b368f3108f050031442eaeb6b9d87e4dcf4e4f", size = 3476879, upload-time = "2026-04-08T01:57:38.664Z" }, + { url = "https://files.pythonhosted.org/packages/3a/ea/075aac6a84b7c271578d81a2f9968acb6e273002408729f2ddff517fed4a/cryptography-46.0.7-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:d3b99c535a9de0adced13d159c5a9cf65c325601aa30f4be08afd680643e9c15", size = 4219700, upload-time = "2026-04-08T01:57:40.625Z" }, + { url = "https://files.pythonhosted.org/packages/6c/7b/1c55db7242b5e5612b29fc7a630e91ee7a6e3c8e7bf5406d22e206875fbd/cryptography-46.0.7-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:d02c738dacda7dc2a74d1b2b3177042009d5cab7c7079db74afc19e56ca1b455", size = 4385982, upload-time = "2026-04-08T01:57:42.725Z" }, + { url = "https://files.pythonhosted.org/packages/cb/da/9870eec4b69c63ef5925bf7d8342b7e13bc2ee3d47791461c4e49ca212f4/cryptography-46.0.7-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:04959522f938493042d595a736e7dbdff6eb6cc2339c11465b3ff89343b65f65", size = 4219115, upload-time = "2026-04-08T01:57:44.939Z" }, + { url = "https://files.pythonhosted.org/packages/f4/72/05aa5832b82dd341969e9a734d1812a6aadb088d9eb6f0430fc337cc5a8f/cryptography-46.0.7-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:3986ac1dee6def53797289999eabe84798ad7817f3e97779b5061a95b0ee4968", size = 4385479, upload-time = "2026-04-08T01:57:46.86Z" }, + { url = "https://files.pythonhosted.org/packages/20/2a/1b016902351a523aa2bd446b50a5bc1175d7a7d1cf90fe2ef904f9b84ebc/cryptography-46.0.7-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:258514877e15963bd43b558917bc9f54cf7cf866c38aa576ebf47a77ddbc43a4", size = 3412829, upload-time = "2026-04-08T01:57:48.874Z" }, ] [[package]] @@ -765,24 +759,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/33/6b/e0547afaf41bf2c42e52430072fa5658766e3d65bd4b03a563d1b6336f57/distlib-0.4.0-py2.py3-none-any.whl", hash = "sha256:9659f7d87e46584a30b5780e43ac7a2143098441670ff0a49d5f9034c54a6c16", size = 469047, upload-time = "2025-07-17T16:51:58.613Z" }, ] -[[package]] -name = "distro" -version = "1.9.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/fc/f8/98eea607f65de6527f8a2e8885fc8015d3e6f5775df186e443e0964a11c3/distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed", size = 60722, upload-time = "2023-12-24T09:54:32.31Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/12/b3/231ffd4ab1fc9d679809f356cebee130ac7daa00d6d6f3206dd4fd137e9e/distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2", size = 20277, upload-time = "2023-12-24T09:54:30.421Z" }, -] - -[[package]] -name = "docstring-parser" -version = "0.17.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/b2/9d/c3b43da9515bd270df0f80548d9944e389870713cc1fe2b8fb35fe2bcefd/docstring_parser-0.17.0.tar.gz", hash = "sha256:583de4a309722b3315439bb31d64ba3eebada841f2e2cee23b99df001434c912", size = 27442, upload-time = "2025-07-21T07:35:01.868Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/55/e2/2537ebcff11c1ee1ff17d8d0b6f4db75873e3b0fb32c2d4a2ee31ecb310a/docstring_parser-0.17.0-py3-none-any.whl", hash = "sha256:cf2569abd23dce8099b300f9b4fa8191e9582dda731fd533daf54c4551658708", size = 36896, upload-time = "2025-07-21T07:35:00.684Z" }, -] - [[package]] name = "dunamai" version = "1.26.0" @@ -857,12 +833,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/45/27/09c33d67f7e0dcf06d7ac17d196594e66989299374bfb0d4331d1038e76b/google_api_core-2.30.0-py3-none-any.whl", hash = "sha256:80be49ee937ff9aba0fd79a6eddfde35fe658b9953ab9b79c57dd7061afa8df5", size = 173288, upload-time = "2026-02-18T20:28:10.367Z" }, ] -[package.optional-dependencies] -grpc = [ - { name = "grpcio" }, - { name = "grpcio-status" }, -] - [[package]] name = "google-auth" version = "2.49.1" @@ -876,167 +846,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e9/eb/c6c2478d8a8d633460be40e2a8a6f8f429171997a35a96f81d3b680dec83/google_auth-2.49.1-py3-none-any.whl", hash = "sha256:195ebe3dca18eddd1b3db5edc5189b76c13e96f29e73043b923ebcf3f1a860f7", size = 240737, upload-time = "2026-03-12T19:30:53.159Z" }, ] -[package.optional-dependencies] -requests = [ - { name = "requests" }, -] - -[[package]] -name = "google-cloud-aiplatform" -version = "1.141.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "docstring-parser" }, - { name = "google-api-core", extra = ["grpc"] }, - { name = "google-auth" }, - { name = "google-cloud-bigquery" }, - { name = "google-cloud-resource-manager" }, - { name = "google-cloud-storage" }, - { name = "google-genai" }, - { name = "packaging" }, - { name = "proto-plus" }, - { name = "protobuf" }, - { name = "pydantic" }, - { name = "typing-extensions" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/ac/dc/1209c7aab43bd7233cf631165a3b1b4284d22fc7fe7387c66228d07868ab/google_cloud_aiplatform-1.141.0.tar.gz", hash = "sha256:e3b1cdb28865dd862aac9c685dfc5ac076488705aba0a5354016efadcddd59c6", size = 10152688, upload-time = "2026-03-10T22:20:08.692Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/6a/fc/428af69a69ff2e477e7f5e12d227b31fe5790f1a8234aacd54297f49c836/google_cloud_aiplatform-1.141.0-py2.py3-none-any.whl", hash = "sha256:6bd25b4d514c40b8181ca703e1b313ad6d0454ab8006fc9907fb3e9f672f31d1", size = 8358409, upload-time = "2026-03-10T22:20:04.871Z" }, -] - -[[package]] -name = "google-cloud-bigquery" -version = "3.40.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "google-api-core", extra = ["grpc"] }, - { name = "google-auth" }, - { name = "google-cloud-core" }, - { name = "google-resumable-media" }, - { name = "packaging" }, - { name = "python-dateutil" }, - { name = "requests" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/11/0c/153ee546c288949fcc6794d58811ab5420f3ecad5fa7f9e73f78d9512a6e/google_cloud_bigquery-3.40.1.tar.gz", hash = "sha256:75afcfb6e007238fe1deefb2182105249321145ff921784fe7b1de2b4ba24506", size = 511761, upload-time = "2026-02-12T18:44:18.958Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/7c/f5/081cf5b90adfe524ae0d671781b0d497a75a0f2601d075af518828e22d8f/google_cloud_bigquery-3.40.1-py3-none-any.whl", hash = "sha256:9082a6b8193aba87bed6a2c79cf1152b524c99bb7e7ac33a785e333c09eac868", size = 262018, upload-time = "2026-02-12T18:44:16.913Z" }, -] - -[[package]] -name = "google-cloud-core" -version = "2.5.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "google-api-core" }, - { name = "google-auth" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/a6/03/ef0bc99d0e0faf4fdbe67ac445e18cdaa74824fd93cd069e7bb6548cb52d/google_cloud_core-2.5.0.tar.gz", hash = "sha256:7c1b7ef5c92311717bd05301aa1a91ffbc565673d3b0b4163a52d8413a186963", size = 36027, upload-time = "2025-10-29T23:17:39.513Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/89/20/bfa472e327c8edee00f04beecc80baeddd2ab33ee0e86fd7654da49d45e9/google_cloud_core-2.5.0-py3-none-any.whl", hash = "sha256:67d977b41ae6c7211ee830c7912e41003ea8194bff15ae7d72fd6f51e57acabc", size = 29469, upload-time = "2025-10-29T23:17:38.548Z" }, -] - -[[package]] -name = "google-cloud-resource-manager" -version = "1.16.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "google-api-core", extra = ["grpc"] }, - { name = "google-auth" }, - { name = "grpc-google-iam-v1" }, - { name = "grpcio" }, - { name = "proto-plus" }, - { name = "protobuf" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/4e/7f/db00b2820475793a52958dc55fe9ec2eb8e863546e05fcece9b921f86ebe/google_cloud_resource_manager-1.16.0.tar.gz", hash = "sha256:cc938f87cc36c2672f062b1e541650629e0d954c405a4dac35ceedee70c267c3", size = 459840, upload-time = "2026-01-15T13:04:07.726Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/94/ff/4b28bcc791d9d7e4ac8fea00fbd90ccb236afda56746a3b4564d2ae45df3/google_cloud_resource_manager-1.16.0-py3-none-any.whl", hash = "sha256:fb9a2ad2b5053c508e1c407ac31abfd1a22e91c32876c1892830724195819a28", size = 400218, upload-time = "2026-01-15T13:02:47.378Z" }, -] - -[[package]] -name = "google-cloud-storage" -version = "3.10.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "google-api-core" }, - { name = "google-auth" }, - { name = "google-cloud-core" }, - { name = "google-crc32c" }, - { name = "google-resumable-media" }, - { name = "requests" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/7a/e3/747759eebc72e420c25903d6bc231d0ceb110b66ac7e6ee3f350417152cd/google_cloud_storage-3.10.0.tar.gz", hash = "sha256:1aeebf097c27d718d84077059a28d7e87f136f3700212215f1ceeae1d1c5d504", size = 17309829, upload-time = "2026-03-18T15:54:11.875Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/29/e2/d58442f4daee5babd9255cf492a1f3d114357164072f8339a22a3ad460a2/google_cloud_storage-3.10.0-py3-none-any.whl", hash = "sha256:0072e7783b201e45af78fd9779894cdb6bec2bf922ee932f3fcc16f8bce9b9a3", size = 324382, upload-time = "2026-03-18T15:54:10.091Z" }, -] - -[[package]] -name = "google-crc32c" -version = "1.8.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/03/41/4b9c02f99e4c5fb477122cd5437403b552873f014616ac1d19ac8221a58d/google_crc32c-1.8.0.tar.gz", hash = "sha256:a428e25fb7691024de47fecfbff7ff957214da51eddded0da0ae0e0f03a2cf79", size = 14192, upload-time = "2025-12-16T00:35:25.142Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/95/ac/6f7bc93886a823ab545948c2dd48143027b2355ad1944c7cf852b338dc91/google_crc32c-1.8.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:0470b8c3d73b5f4e3300165498e4cf25221c7eb37f1159e221d1825b6df8a7ff", size = 31296, upload-time = "2025-12-16T00:19:07.261Z" }, - { url = "https://files.pythonhosted.org/packages/f7/97/a5accde175dee985311d949cfcb1249dcbb290f5ec83c994ea733311948f/google_crc32c-1.8.0-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:119fcd90c57c89f30040b47c211acee231b25a45d225e3225294386f5d258288", size = 30870, upload-time = "2025-12-16T00:29:17.669Z" }, - { url = "https://files.pythonhosted.org/packages/3d/63/bec827e70b7a0d4094e7476f863c0dbd6b5f0f1f91d9c9b32b76dcdfeb4e/google_crc32c-1.8.0-cp310-cp310-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:6f35aaffc8ccd81ba3162443fabb920e65b1f20ab1952a31b13173a67811467d", size = 33214, upload-time = "2025-12-16T00:40:19.618Z" }, - { url = "https://files.pythonhosted.org/packages/63/bc/11b70614df04c289128d782efc084b9035ef8466b3d0a8757c1b6f5cf7ac/google_crc32c-1.8.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:864abafe7d6e2c4c66395c1eb0fe12dc891879769b52a3d56499612ca93b6092", size = 33589, upload-time = "2025-12-16T00:40:20.7Z" }, - { url = "https://files.pythonhosted.org/packages/3e/00/a08a4bc24f1261cc5b0f47312d8aebfbe4b53c2e6307f1b595605eed246b/google_crc32c-1.8.0-cp310-cp310-win_amd64.whl", hash = "sha256:db3fe8eaf0612fc8b20fa21a5f25bd785bc3cd5be69f8f3412b0ac2ffd49e733", size = 34437, upload-time = "2025-12-16T00:35:19.437Z" }, - { url = "https://files.pythonhosted.org/packages/5d/ef/21ccfaab3d5078d41efe8612e0ed0bfc9ce22475de074162a91a25f7980d/google_crc32c-1.8.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:014a7e68d623e9a4222d663931febc3033c5c7c9730785727de2a81f87d5bab8", size = 31298, upload-time = "2025-12-16T00:20:32.241Z" }, - { url = "https://files.pythonhosted.org/packages/c5/b8/f8413d3f4b676136e965e764ceedec904fe38ae8de0cdc52a12d8eb1096e/google_crc32c-1.8.0-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:86cfc00fe45a0ac7359e5214a1704e51a99e757d0272554874f419f79838c5f7", size = 30872, upload-time = "2025-12-16T00:33:58.785Z" }, - { url = "https://files.pythonhosted.org/packages/f6/fd/33aa4ec62b290477181c55bb1c9302c9698c58c0ce9a6ab4874abc8b0d60/google_crc32c-1.8.0-cp311-cp311-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:19b40d637a54cb71e0829179f6cb41835f0fbd9e8eb60552152a8b52c36cbe15", size = 33243, upload-time = "2025-12-16T00:40:21.46Z" }, - { url = "https://files.pythonhosted.org/packages/71/03/4820b3bd99c9653d1a5210cb32f9ba4da9681619b4d35b6a052432df4773/google_crc32c-1.8.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:17446feb05abddc187e5441a45971b8394ea4c1b6efd88ab0af393fd9e0a156a", size = 33608, upload-time = "2025-12-16T00:40:22.204Z" }, - { url = "https://files.pythonhosted.org/packages/7c/43/acf61476a11437bf9733fb2f70599b1ced11ec7ed9ea760fdd9a77d0c619/google_crc32c-1.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:71734788a88f551fbd6a97be9668a0020698e07b2bf5b3aa26a36c10cdfb27b2", size = 34439, upload-time = "2025-12-16T00:35:20.458Z" }, - { url = "https://files.pythonhosted.org/packages/e9/5f/7307325b1198b59324c0fa9807cafb551afb65e831699f2ce211ad5c8240/google_crc32c-1.8.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:4b8286b659c1335172e39563ab0a768b8015e88e08329fa5321f774275fc3113", size = 31300, upload-time = "2025-12-16T00:21:56.723Z" }, - { url = "https://files.pythonhosted.org/packages/21/8e/58c0d5d86e2220e6a37befe7e6a94dd2f6006044b1a33edf1ff6d9f7e319/google_crc32c-1.8.0-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:2a3dc3318507de089c5384cc74d54318401410f82aa65b2d9cdde9d297aca7cb", size = 30867, upload-time = "2025-12-16T00:38:31.302Z" }, - { url = "https://files.pythonhosted.org/packages/ce/a9/a780cc66f86335a6019f557a8aaca8fbb970728f0efd2430d15ff1beae0e/google_crc32c-1.8.0-cp312-cp312-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:14f87e04d613dfa218d6135e81b78272c3b904e2a7053b841481b38a7d901411", size = 33364, upload-time = "2025-12-16T00:40:22.96Z" }, - { url = "https://files.pythonhosted.org/packages/21/3f/3457ea803db0198c9aaca2dd373750972ce28a26f00544b6b85088811939/google_crc32c-1.8.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cb5c869c2923d56cb0c8e6bcdd73c009c36ae39b652dbe46a05eb4ef0ad01454", size = 33740, upload-time = "2025-12-16T00:40:23.96Z" }, - { url = "https://files.pythonhosted.org/packages/df/c0/87c2073e0c72515bb8733d4eef7b21548e8d189f094b5dad20b0ecaf64f6/google_crc32c-1.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:3cc0c8912038065eafa603b238abf252e204accab2a704c63b9e14837a854962", size = 34437, upload-time = "2025-12-16T00:35:21.395Z" }, - { url = "https://files.pythonhosted.org/packages/d1/db/000f15b41724589b0e7bc24bc7a8967898d8d3bc8caf64c513d91ef1f6c0/google_crc32c-1.8.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:3ebb04528e83b2634857f43f9bb8ef5b2bbe7f10f140daeb01b58f972d04736b", size = 31297, upload-time = "2025-12-16T00:23:20.709Z" }, - { url = "https://files.pythonhosted.org/packages/d7/0d/8ebed0c39c53a7e838e2a486da8abb0e52de135f1b376ae2f0b160eb4c1a/google_crc32c-1.8.0-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:450dc98429d3e33ed2926fc99ee81001928d63460f8538f21a5d6060912a8e27", size = 30867, upload-time = "2025-12-16T00:43:14.628Z" }, - { url = "https://files.pythonhosted.org/packages/ce/42/b468aec74a0354b34c8cbf748db20d6e350a68a2b0912e128cabee49806c/google_crc32c-1.8.0-cp313-cp313-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:3b9776774b24ba76831609ffbabce8cdf6fa2bd5e9df37b594221c7e333a81fa", size = 33344, upload-time = "2025-12-16T00:40:24.742Z" }, - { url = "https://files.pythonhosted.org/packages/1c/e8/b33784d6fc77fb5062a8a7854e43e1e618b87d5ddf610a88025e4de6226e/google_crc32c-1.8.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:89c17d53d75562edfff86679244830599ee0a48efc216200691de8b02ab6b2b8", size = 33694, upload-time = "2025-12-16T00:40:25.505Z" }, - { url = "https://files.pythonhosted.org/packages/92/b1/d3cbd4d988afb3d8e4db94ca953df429ed6db7282ed0e700d25e6c7bfc8d/google_crc32c-1.8.0-cp313-cp313-win_amd64.whl", hash = "sha256:57a50a9035b75643996fbf224d6661e386c7162d1dfdab9bc4ca790947d1007f", size = 34435, upload-time = "2025-12-16T00:35:22.107Z" }, - { url = "https://files.pythonhosted.org/packages/21/88/8ecf3c2b864a490b9e7010c84fd203ec8cf3b280651106a3a74dd1b0ca72/google_crc32c-1.8.0-cp314-cp314-macosx_12_0_arm64.whl", hash = "sha256:e6584b12cb06796d285d09e33f63309a09368b9d806a551d8036a4207ea43697", size = 31301, upload-time = "2025-12-16T00:24:48.527Z" }, - { url = "https://files.pythonhosted.org/packages/36/c6/f7ff6c11f5ca215d9f43d3629163727a272eabc356e5c9b2853df2bfe965/google_crc32c-1.8.0-cp314-cp314-macosx_12_0_x86_64.whl", hash = "sha256:f4b51844ef67d6cf2e9425983274da75f18b1597bb2c998e1c0a0e8d46f8f651", size = 30868, upload-time = "2025-12-16T00:48:12.163Z" }, - { url = "https://files.pythonhosted.org/packages/56/15/c25671c7aad70f8179d858c55a6ae8404902abe0cdcf32a29d581792b491/google_crc32c-1.8.0-cp314-cp314-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:b0d1a7afc6e8e4635564ba8aa5c0548e3173e41b6384d7711a9123165f582de2", size = 33381, upload-time = "2025-12-16T00:40:26.268Z" }, - { url = "https://files.pythonhosted.org/packages/42/fa/f50f51260d7b0ef5d4898af122d8a7ec5a84e2984f676f746445f783705f/google_crc32c-1.8.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8b3f68782f3cbd1bce027e48768293072813469af6a61a86f6bb4977a4380f21", size = 33734, upload-time = "2025-12-16T00:40:27.028Z" }, - { url = "https://files.pythonhosted.org/packages/08/a5/7b059810934a09fb3ccb657e0843813c1fee1183d3bc2c8041800374aa2c/google_crc32c-1.8.0-cp314-cp314-win_amd64.whl", hash = "sha256:d511b3153e7011a27ab6ee6bb3a5404a55b994dc1a7322c0b87b29606d9790e2", size = 34878, upload-time = "2025-12-16T00:35:23.142Z" }, - { url = "https://files.pythonhosted.org/packages/52/c5/c171e4d8c44fec1422d801a6d2e5d7ddabd733eeda505c79730ee9607f07/google_crc32c-1.8.0-pp311-pypy311_pp73-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:87fa445064e7db928226b2e6f0d5304ab4cd0339e664a4e9a25029f384d9bb93", size = 28615, upload-time = "2025-12-16T00:40:29.298Z" }, - { url = "https://files.pythonhosted.org/packages/9c/97/7d75fe37a7a6ed171a2cf17117177e7aab7e6e0d115858741b41e9dd4254/google_crc32c-1.8.0-pp311-pypy311_pp73-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:f639065ea2042d5c034bf258a9f085eaa7af0cd250667c0635a3118e8f92c69c", size = 28800, upload-time = "2025-12-16T00:40:30.322Z" }, -] - -[[package]] -name = "google-genai" -version = "1.68.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "anyio" }, - { name = "distro" }, - { name = "google-auth", extra = ["requests"] }, - { name = "httpx" }, - { name = "pydantic" }, - { name = "requests" }, - { name = "sniffio" }, - { name = "tenacity" }, - { name = "typing-extensions" }, - { name = "websockets" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/9c/2c/f059982dbcb658cc535c81bbcbe7e2c040d675f4b563b03cdb01018a4bc3/google_genai-1.68.0.tar.gz", hash = "sha256:ac30c0b8bc630f9372993a97e4a11dae0e36f2e10d7c55eacdca95a9fa14ca96", size = 511285, upload-time = "2026-03-18T01:03:18.243Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/84/de/7d3ee9c94b74c3578ea4f88d45e8de9405902f857932334d81e89bce3dfa/google_genai-1.68.0-py3-none-any.whl", hash = "sha256:a1bc9919c0e2ea2907d1e319b65471d3d6d58c54822039a249fe1323e4178d15", size = 750912, upload-time = "2026-03-18T01:03:15.983Z" }, -] - -[[package]] -name = "google-resumable-media" -version = "2.8.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "google-crc32c" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/64/d7/520b62a35b23038ff005e334dba3ffc75fcf583bee26723f1fd8fd4b6919/google_resumable_media-2.8.0.tar.gz", hash = "sha256:f1157ed8b46994d60a1bc432544db62352043113684d4e030ee02e77ebe9a1ae", size = 2163265, upload-time = "2025-11-17T15:38:06.659Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/1f/0b/93afde9cfe012260e9fe1522f35c9b72d6ee222f316586b1f23ecf44d518/google_resumable_media-2.8.0-py3-none-any.whl", hash = "sha256:dd14a116af303845a8d932ddae161a26e86cc229645bc98b39f026f9b1717582", size = 81340, upload-time = "2025-11-17T15:38:05.594Z" }, -] - [[package]] name = "googleapis-common-protos" version = "1.73.0" @@ -1049,11 +858,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/69/28/23eea8acd65972bbfe295ce3666b28ac510dfcb115fac089d3edb0feb00a/googleapis_common_protos-1.73.0-py3-none-any.whl", hash = "sha256:dfdaaa2e860f242046be561e6d6cb5c5f1541ae02cfbcb034371aadb2942b4e8", size = 297578, upload-time = "2026-03-06T21:52:33.933Z" }, ] -[package.optional-dependencies] -grpc = [ - { name = "grpcio" }, -] - [[package]] name = "greenlet" version = "3.3.2" @@ -1114,20 +918,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/29/4b/45d90626aef8e65336bed690106d1382f7a43665e2249017e9527df8823b/greenlet-3.3.2-cp314-cp314t-win_amd64.whl", hash = "sha256:c04c5e06ec3e022cbfe2cd4a846e1d4e50087444f875ff6d2c2ad8445495cf1a", size = 237086, upload-time = "2026-02-20T20:20:45.786Z" }, ] -[[package]] -name = "grpc-google-iam-v1" -version = "0.14.3" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "googleapis-common-protos", extra = ["grpc"] }, - { name = "grpcio" }, - { name = "protobuf" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/76/1e/1011451679a983f2f5c6771a1682542ecb027776762ad031fd0d7129164b/grpc_google_iam_v1-0.14.3.tar.gz", hash = "sha256:879ac4ef33136c5491a6300e27575a9ec760f6cdf9a2518798c1b8977a5dc389", size = 23745, upload-time = "2025-10-15T21:14:53.318Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/4a/bd/330a1bbdb1afe0b96311249e699b6dc9cfc17916394fd4503ac5aca2514b/grpc_google_iam_v1-0.14.3-py3-none-any.whl", hash = "sha256:7a7f697e017a067206a3dfef44e4c634a34d3dee135fe7d7a4613fe3e59217e6", size = 32690, upload-time = "2025-10-15T21:14:51.72Z" }, -] - [[package]] name = "grpcio" version = "1.78.0" @@ -1986,7 +1776,7 @@ wheels = [ [[package]] name = "pytest" -version = "9.0.2" +version = "9.0.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "colorama", marker = "sys_platform == 'win32'" }, @@ -1997,9 +1787,9 @@ dependencies = [ { name = "pygments" }, { name = "tomli", marker = "python_full_version < '3.11'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/d1/db/7ef3487e0fb0049ddb5ce41d3a49c235bf9ad299b6a25d5780a89f19230f/pytest-9.0.2.tar.gz", hash = "sha256:75186651a92bd89611d1d9fc20f0b4345fd827c41ccd5c299a868a05d70edf11", size = 1568901, upload-time = "2025-12-06T21:30:51.014Z" } +sdist = { url = "https://files.pythonhosted.org/packages/7d/0d/549bd94f1a0a402dc8cf64563a117c0f3765662e2e668477624baeec44d5/pytest-9.0.3.tar.gz", hash = "sha256:b86ada508af81d19edeb213c681b1d48246c1a91d304c6c81a427674c17eb91c", size = 1572165, upload-time = "2026-04-07T17:16:18.027Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b", size = 374801, upload-time = "2025-12-06T21:30:49.154Z" }, + { url = "https://files.pythonhosted.org/packages/d4/24/a372aaf5c9b7208e7112038812994107bc65a84cd00e0354a88c2c77a617/pytest-9.0.3-py3-none-any.whl", hash = "sha256:2c5efc453d45394fdd706ade797c0a81091eccd1d6e4bccfcd476e2b8e0ab5d9", size = 375249, upload-time = "2026-04-07T17:16:16.13Z" }, ] [[package]] @@ -2067,18 +1857,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ca/31/d4e37e9e550c2b92a9cbc2e4d0b7420a27224968580b5a447f420847c975/pytest_xdist-3.8.0-py3-none-any.whl", hash = "sha256:202ca578cfeb7370784a8c33d6d05bc6e13b4f25b5053c30a152269fd10f0b88", size = 46396, upload-time = "2025-07-01T13:30:56.632Z" }, ] -[[package]] -name = "python-dateutil" -version = "2.9.0.post0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "six" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/66/c0/0c8b6ad9f17a802ee498c46e004a0eb49bc148f2fd230864601a86dcf6db/python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3", size = 342432, upload-time = "2024-03-01T18:36:20.211Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427", size = 229892, upload-time = "2024-03-01T18:36:18.57Z" }, -] - [[package]] name = "python-discovery" version = "1.2.0" @@ -2158,7 +1936,7 @@ wheels = [ [[package]] name = "requests" -version = "2.32.5" +version = "2.33.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "certifi" }, @@ -2166,9 +1944,9 @@ dependencies = [ { name = "idna" }, { name = "urllib3" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c9/74/b3ff8e6c8446842c3f5c837e9c3dfcfe2018ea6ecef224c710c85ef728f4/requests-2.32.5.tar.gz", hash = "sha256:dbba0bac56e100853db0ea71b82b4dfd5fe2bf6d3754a8893c3af500cec7d7cf", size = 134517, upload-time = "2025-08-18T20:46:02.573Z" } +sdist = { url = "https://files.pythonhosted.org/packages/34/64/8860370b167a9721e8956ae116825caff829224fbca0ca6e7bf8ddef8430/requests-2.33.0.tar.gz", hash = "sha256:c7ebc5e8b0f21837386ad0e1c8fe8b829fa5f544d8df3b2253bff14ef29d7652", size = 134232, upload-time = "2026-03-25T15:10:41.586Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738, upload-time = "2025-08-18T20:46:00.542Z" }, + { url = "https://files.pythonhosted.org/packages/56/5d/c814546c2333ceea4ba42262d8c4d55763003e767fa169adc693bd524478/requests-2.33.0-py3-none-any.whl", hash = "sha256:3324635456fa185245e24865e810cecec7b4caf933d7eb133dcde67d48cee69b", size = 65017, upload-time = "2026-03-25T15:10:40.382Z" }, ] [[package]] @@ -2217,15 +1995,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9d/76/f789f7a86709c6b087c5a2f52f911838cad707cc613162401badc665acfe/setuptools-82.0.1-py3-none-any.whl", hash = "sha256:a59e362652f08dcd477c78bb6e7bd9d80a7995bc73ce773050228a348ce2e5bb", size = 1006223, upload-time = "2026-03-09T12:47:15.026Z" }, ] -[[package]] -name = "six" -version = "1.17.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/94/e7/b2c673351809dca68a0e064b6af791aa332cf192da575fd474ed7d6f16a2/six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81", size = 34031, upload-time = "2024-12-04T17:35:28.174Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274", size = 11050, upload-time = "2024-12-04T17:35:26.475Z" }, -] - [[package]] name = "sniffio" version = "1.3.1" @@ -2348,15 +2117,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/81/0d/13d1d239a25cbfb19e740db83143e95c772a1fe10202dda4b76792b114dd/starlette-0.52.1-py3-none-any.whl", hash = "sha256:0029d43eb3d273bc4f83a08720b4912ea4b071087a3b48db01b7c839f7954d74", size = 74272, upload-time = "2026-01-18T13:34:09.188Z" }, ] -[[package]] -name = "tenacity" -version = "9.1.4" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/47/c6/ee486fd809e357697ee8a44d3d69222b344920433d3b6666ccd9b374630c/tenacity-9.1.4.tar.gz", hash = "sha256:adb31d4c263f2bd041081ab33b498309a57c77f9acf2db65aadf0898179cf93a", size = 49413, upload-time = "2026-02-07T10:45:33.841Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/d7/c1/eb8f9debc45d3b7918a32ab756658a0904732f75e555402972246b0b8e71/tenacity-9.1.4-py3-none-any.whl", hash = "sha256:6095a360c919085f28c6527de529e76a06ad89b23659fa881ae0649b867a9d55", size = 28926, upload-time = "2026-02-07T10:45:32.24Z" }, -] - [[package]] name = "tomli" version = "2.4.0" @@ -2543,74 +2303,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c6/59/7d02447a55b2e55755011a647479041bc92a82e143f96a8195cb33bd0a1c/virtualenv-21.2.0-py3-none-any.whl", hash = "sha256:1bd755b504931164a5a496d217c014d098426cddc79363ad66ac78125f9d908f", size = 5825084, upload-time = "2026-03-09T17:24:35.378Z" }, ] -[[package]] -name = "websockets" -version = "16.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/04/24/4b2031d72e840ce4c1ccb255f693b15c334757fc50023e4db9537080b8c4/websockets-16.0.tar.gz", hash = "sha256:5f6261a5e56e8d5c42a4497b364ea24d94d9563e8fbd44e78ac40879c60179b5", size = 179346, upload-time = "2026-01-10T09:23:47.181Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/20/74/221f58decd852f4b59cc3354cccaf87e8ef695fede361d03dc9a7396573b/websockets-16.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:04cdd5d2d1dacbad0a7bf36ccbcd3ccd5a30ee188f2560b7a62a30d14107b31a", size = 177343, upload-time = "2026-01-10T09:22:21.28Z" }, - { url = "https://files.pythonhosted.org/packages/19/0f/22ef6107ee52ab7f0b710d55d36f5a5d3ef19e8a205541a6d7ffa7994e5a/websockets-16.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8ff32bb86522a9e5e31439a58addbb0166f0204d64066fb955265c4e214160f0", size = 175021, upload-time = "2026-01-10T09:22:22.696Z" }, - { url = "https://files.pythonhosted.org/packages/10/40/904a4cb30d9b61c0e278899bf36342e9b0208eb3c470324a9ecbaac2a30f/websockets-16.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:583b7c42688636f930688d712885cf1531326ee05effd982028212ccc13e5957", size = 175320, upload-time = "2026-01-10T09:22:23.94Z" }, - { url = "https://files.pythonhosted.org/packages/9d/2f/4b3ca7e106bc608744b1cdae041e005e446124bebb037b18799c2d356864/websockets-16.0-cp310-cp310-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:7d837379b647c0c4c2355c2499723f82f1635fd2c26510e1f587d89bc2199e72", size = 183815, upload-time = "2026-01-10T09:22:25.469Z" }, - { url = "https://files.pythonhosted.org/packages/86/26/d40eaa2a46d4302becec8d15b0fc5e45bdde05191e7628405a19cf491ccd/websockets-16.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:df57afc692e517a85e65b72e165356ed1df12386ecb879ad5693be08fac65dde", size = 185054, upload-time = "2026-01-10T09:22:27.101Z" }, - { url = "https://files.pythonhosted.org/packages/b0/ba/6500a0efc94f7373ee8fefa8c271acdfd4dca8bd49a90d4be7ccabfc397e/websockets-16.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:2b9f1e0d69bc60a4a87349d50c09a037a2607918746f07de04df9e43252c77a3", size = 184565, upload-time = "2026-01-10T09:22:28.293Z" }, - { url = "https://files.pythonhosted.org/packages/04/b4/96bf2cee7c8d8102389374a2616200574f5f01128d1082f44102140344cc/websockets-16.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:335c23addf3d5e6a8633f9f8eda77efad001671e80b95c491dd0924587ece0b3", size = 183848, upload-time = "2026-01-10T09:22:30.394Z" }, - { url = "https://files.pythonhosted.org/packages/02/8e/81f40fb00fd125357814e8c3025738fc4ffc3da4b6b4a4472a82ba304b41/websockets-16.0-cp310-cp310-win32.whl", hash = "sha256:37b31c1623c6605e4c00d466c9d633f9b812ea430c11c8a278774a1fde1acfa9", size = 178249, upload-time = "2026-01-10T09:22:32.083Z" }, - { url = "https://files.pythonhosted.org/packages/b4/5f/7e40efe8df57db9b91c88a43690ac66f7b7aa73a11aa6a66b927e44f26fa/websockets-16.0-cp310-cp310-win_amd64.whl", hash = "sha256:8e1dab317b6e77424356e11e99a432b7cb2f3ec8c5ab4dabbcee6add48f72b35", size = 178685, upload-time = "2026-01-10T09:22:33.345Z" }, - { url = "https://files.pythonhosted.org/packages/f2/db/de907251b4ff46ae804ad0409809504153b3f30984daf82a1d84a9875830/websockets-16.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:31a52addea25187bde0797a97d6fc3d2f92b6f72a9370792d65a6e84615ac8a8", size = 177340, upload-time = "2026-01-10T09:22:34.539Z" }, - { url = "https://files.pythonhosted.org/packages/f3/fa/abe89019d8d8815c8781e90d697dec52523fb8ebe308bf11664e8de1877e/websockets-16.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:417b28978cdccab24f46400586d128366313e8a96312e4b9362a4af504f3bbad", size = 175022, upload-time = "2026-01-10T09:22:36.332Z" }, - { url = "https://files.pythonhosted.org/packages/58/5d/88ea17ed1ded2079358b40d31d48abe90a73c9e5819dbcde1606e991e2ad/websockets-16.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:af80d74d4edfa3cb9ed973a0a5ba2b2a549371f8a741e0800cb07becdd20f23d", size = 175319, upload-time = "2026-01-10T09:22:37.602Z" }, - { url = "https://files.pythonhosted.org/packages/d2/ae/0ee92b33087a33632f37a635e11e1d99d429d3d323329675a6022312aac2/websockets-16.0-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:08d7af67b64d29823fed316505a89b86705f2b7981c07848fb5e3ea3020c1abe", size = 184631, upload-time = "2026-01-10T09:22:38.789Z" }, - { url = "https://files.pythonhosted.org/packages/c8/c5/27178df583b6c5b31b29f526ba2da5e2f864ecc79c99dae630a85d68c304/websockets-16.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7be95cfb0a4dae143eaed2bcba8ac23f4892d8971311f1b06f3c6b78952ee70b", size = 185870, upload-time = "2026-01-10T09:22:39.893Z" }, - { url = "https://files.pythonhosted.org/packages/87/05/536652aa84ddc1c018dbb7e2c4cbcd0db884580bf8e95aece7593fde526f/websockets-16.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d6297ce39ce5c2e6feb13c1a996a2ded3b6832155fcfc920265c76f24c7cceb5", size = 185361, upload-time = "2026-01-10T09:22:41.016Z" }, - { url = "https://files.pythonhosted.org/packages/6d/e2/d5332c90da12b1e01f06fb1b85c50cfc489783076547415bf9f0a659ec19/websockets-16.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:1c1b30e4f497b0b354057f3467f56244c603a79c0d1dafce1d16c283c25f6e64", size = 184615, upload-time = "2026-01-10T09:22:42.442Z" }, - { url = "https://files.pythonhosted.org/packages/77/fb/d3f9576691cae9253b51555f841bc6600bf0a983a461c79500ace5a5b364/websockets-16.0-cp311-cp311-win32.whl", hash = "sha256:5f451484aeb5cafee1ccf789b1b66f535409d038c56966d6101740c1614b86c6", size = 178246, upload-time = "2026-01-10T09:22:43.654Z" }, - { url = "https://files.pythonhosted.org/packages/54/67/eaff76b3dbaf18dcddabc3b8c1dba50b483761cccff67793897945b37408/websockets-16.0-cp311-cp311-win_amd64.whl", hash = "sha256:8d7f0659570eefb578dacde98e24fb60af35350193e4f56e11190787bee77dac", size = 178684, upload-time = "2026-01-10T09:22:44.941Z" }, - { url = "https://files.pythonhosted.org/packages/84/7b/bac442e6b96c9d25092695578dda82403c77936104b5682307bd4deb1ad4/websockets-16.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:71c989cbf3254fbd5e84d3bff31e4da39c43f884e64f2551d14bb3c186230f00", size = 177365, upload-time = "2026-01-10T09:22:46.787Z" }, - { url = "https://files.pythonhosted.org/packages/b0/fe/136ccece61bd690d9c1f715baaeefd953bb2360134de73519d5df19d29ca/websockets-16.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:8b6e209ffee39ff1b6d0fa7bfef6de950c60dfb91b8fcead17da4ee539121a79", size = 175038, upload-time = "2026-01-10T09:22:47.999Z" }, - { url = "https://files.pythonhosted.org/packages/40/1e/9771421ac2286eaab95b8575b0cb701ae3663abf8b5e1f64f1fd90d0a673/websockets-16.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:86890e837d61574c92a97496d590968b23c2ef0aeb8a9bc9421d174cd378ae39", size = 175328, upload-time = "2026-01-10T09:22:49.809Z" }, - { url = "https://files.pythonhosted.org/packages/18/29/71729b4671f21e1eaa5d6573031ab810ad2936c8175f03f97f3ff164c802/websockets-16.0-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:9b5aca38b67492ef518a8ab76851862488a478602229112c4b0d58d63a7a4d5c", size = 184915, upload-time = "2026-01-10T09:22:51.071Z" }, - { url = "https://files.pythonhosted.org/packages/97/bb/21c36b7dbbafc85d2d480cd65df02a1dc93bf76d97147605a8e27ff9409d/websockets-16.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e0334872c0a37b606418ac52f6ab9cfd17317ac26365f7f65e203e2d0d0d359f", size = 186152, upload-time = "2026-01-10T09:22:52.224Z" }, - { url = "https://files.pythonhosted.org/packages/4a/34/9bf8df0c0cf88fa7bfe36678dc7b02970c9a7d5e065a3099292db87b1be2/websockets-16.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:a0b31e0b424cc6b5a04b8838bbaec1688834b2383256688cf47eb97412531da1", size = 185583, upload-time = "2026-01-10T09:22:53.443Z" }, - { url = "https://files.pythonhosted.org/packages/47/88/4dd516068e1a3d6ab3c7c183288404cd424a9a02d585efbac226cb61ff2d/websockets-16.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:485c49116d0af10ac698623c513c1cc01c9446c058a4e61e3bf6c19dff7335a2", size = 184880, upload-time = "2026-01-10T09:22:55.033Z" }, - { url = "https://files.pythonhosted.org/packages/91/d6/7d4553ad4bf1c0421e1ebd4b18de5d9098383b5caa1d937b63df8d04b565/websockets-16.0-cp312-cp312-win32.whl", hash = "sha256:eaded469f5e5b7294e2bdca0ab06becb6756ea86894a47806456089298813c89", size = 178261, upload-time = "2026-01-10T09:22:56.251Z" }, - { url = "https://files.pythonhosted.org/packages/c3/f0/f3a17365441ed1c27f850a80b2bc680a0fa9505d733fe152fdf5e98c1c0b/websockets-16.0-cp312-cp312-win_amd64.whl", hash = "sha256:5569417dc80977fc8c2d43a86f78e0a5a22fee17565d78621b6bb264a115d4ea", size = 178693, upload-time = "2026-01-10T09:22:57.478Z" }, - { url = "https://files.pythonhosted.org/packages/cc/9c/baa8456050d1c1b08dd0ec7346026668cbc6f145ab4e314d707bb845bf0d/websockets-16.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:878b336ac47938b474c8f982ac2f7266a540adc3fa4ad74ae96fea9823a02cc9", size = 177364, upload-time = "2026-01-10T09:22:59.333Z" }, - { url = "https://files.pythonhosted.org/packages/7e/0c/8811fc53e9bcff68fe7de2bcbe75116a8d959ac699a3200f4847a8925210/websockets-16.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:52a0fec0e6c8d9a784c2c78276a48a2bdf099e4ccc2a4cad53b27718dbfd0230", size = 175039, upload-time = "2026-01-10T09:23:01.171Z" }, - { url = "https://files.pythonhosted.org/packages/aa/82/39a5f910cb99ec0b59e482971238c845af9220d3ab9fa76dd9162cda9d62/websockets-16.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:e6578ed5b6981005df1860a56e3617f14a6c307e6a71b4fff8c48fdc50f3ed2c", size = 175323, upload-time = "2026-01-10T09:23:02.341Z" }, - { url = "https://files.pythonhosted.org/packages/bd/28/0a25ee5342eb5d5f297d992a77e56892ecb65e7854c7898fb7d35e9b33bd/websockets-16.0-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:95724e638f0f9c350bb1c2b0a7ad0e83d9cc0c9259f3ea94e40d7b02a2179ae5", size = 184975, upload-time = "2026-01-10T09:23:03.756Z" }, - { url = "https://files.pythonhosted.org/packages/f9/66/27ea52741752f5107c2e41fda05e8395a682a1e11c4e592a809a90c6a506/websockets-16.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c0204dc62a89dc9d50d682412c10b3542d748260d743500a85c13cd1ee4bde82", size = 186203, upload-time = "2026-01-10T09:23:05.01Z" }, - { url = "https://files.pythonhosted.org/packages/37/e5/8e32857371406a757816a2b471939d51c463509be73fa538216ea52b792a/websockets-16.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:52ac480f44d32970d66763115edea932f1c5b1312de36df06d6b219f6741eed8", size = 185653, upload-time = "2026-01-10T09:23:06.301Z" }, - { url = "https://files.pythonhosted.org/packages/9b/67/f926bac29882894669368dc73f4da900fcdf47955d0a0185d60103df5737/websockets-16.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:6e5a82b677f8f6f59e8dfc34ec06ca6b5b48bc4fcda346acd093694cc2c24d8f", size = 184920, upload-time = "2026-01-10T09:23:07.492Z" }, - { url = "https://files.pythonhosted.org/packages/3c/a1/3d6ccdcd125b0a42a311bcd15a7f705d688f73b2a22d8cf1c0875d35d34a/websockets-16.0-cp313-cp313-win32.whl", hash = "sha256:abf050a199613f64c886ea10f38b47770a65154dc37181bfaff70c160f45315a", size = 178255, upload-time = "2026-01-10T09:23:09.245Z" }, - { url = "https://files.pythonhosted.org/packages/6b/ae/90366304d7c2ce80f9b826096a9e9048b4bb760e44d3b873bb272cba696b/websockets-16.0-cp313-cp313-win_amd64.whl", hash = "sha256:3425ac5cf448801335d6fdc7ae1eb22072055417a96cc6b31b3861f455fbc156", size = 178689, upload-time = "2026-01-10T09:23:10.483Z" }, - { url = "https://files.pythonhosted.org/packages/f3/1d/e88022630271f5bd349ed82417136281931e558d628dd52c4d8621b4a0b2/websockets-16.0-cp314-cp314-macosx_10_15_universal2.whl", hash = "sha256:8cc451a50f2aee53042ac52d2d053d08bf89bcb31ae799cb4487587661c038a0", size = 177406, upload-time = "2026-01-10T09:23:12.178Z" }, - { url = "https://files.pythonhosted.org/packages/f2/78/e63be1bf0724eeb4616efb1ae1c9044f7c3953b7957799abb5915bffd38e/websockets-16.0-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:daa3b6ff70a9241cf6c7fc9e949d41232d9d7d26fd3522b1ad2b4d62487e9904", size = 175085, upload-time = "2026-01-10T09:23:13.511Z" }, - { url = "https://files.pythonhosted.org/packages/bb/f4/d3c9220d818ee955ae390cf319a7c7a467beceb24f05ee7aaaa2414345ba/websockets-16.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:fd3cb4adb94a2a6e2b7c0d8d05cb94e6f1c81a0cf9dc2694fb65c7e8d94c42e4", size = 175328, upload-time = "2026-01-10T09:23:14.727Z" }, - { url = "https://files.pythonhosted.org/packages/63/bc/d3e208028de777087e6fb2b122051a6ff7bbcca0d6df9d9c2bf1dd869ae9/websockets-16.0-cp314-cp314-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:781caf5e8eee67f663126490c2f96f40906594cb86b408a703630f95550a8c3e", size = 185044, upload-time = "2026-01-10T09:23:15.939Z" }, - { url = "https://files.pythonhosted.org/packages/ad/6e/9a0927ac24bd33a0a9af834d89e0abc7cfd8e13bed17a86407a66773cc0e/websockets-16.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:caab51a72c51973ca21fa8a18bd8165e1a0183f1ac7066a182ff27107b71e1a4", size = 186279, upload-time = "2026-01-10T09:23:17.148Z" }, - { url = "https://files.pythonhosted.org/packages/b9/ca/bf1c68440d7a868180e11be653c85959502efd3a709323230314fda6e0b3/websockets-16.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:19c4dc84098e523fd63711e563077d39e90ec6702aff4b5d9e344a60cb3c0cb1", size = 185711, upload-time = "2026-01-10T09:23:18.372Z" }, - { url = "https://files.pythonhosted.org/packages/c4/f8/fdc34643a989561f217bb477cbc47a3a07212cbda91c0e4389c43c296ebf/websockets-16.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:a5e18a238a2b2249c9a9235466b90e96ae4795672598a58772dd806edc7ac6d3", size = 184982, upload-time = "2026-01-10T09:23:19.652Z" }, - { url = "https://files.pythonhosted.org/packages/dd/d1/574fa27e233764dbac9c52730d63fcf2823b16f0856b3329fc6268d6ae4f/websockets-16.0-cp314-cp314-win32.whl", hash = "sha256:a069d734c4a043182729edd3e9f247c3b2a4035415a9172fd0f1b71658a320a8", size = 177915, upload-time = "2026-01-10T09:23:21.458Z" }, - { url = "https://files.pythonhosted.org/packages/8a/f1/ae6b937bf3126b5134ce1f482365fde31a357c784ac51852978768b5eff4/websockets-16.0-cp314-cp314-win_amd64.whl", hash = "sha256:c0ee0e63f23914732c6d7e0cce24915c48f3f1512ec1d079ed01fc629dab269d", size = 178381, upload-time = "2026-01-10T09:23:22.715Z" }, - { url = "https://files.pythonhosted.org/packages/06/9b/f791d1db48403e1f0a27577a6beb37afae94254a8c6f08be4a23e4930bc0/websockets-16.0-cp314-cp314t-macosx_10_15_universal2.whl", hash = "sha256:a35539cacc3febb22b8f4d4a99cc79b104226a756aa7400adc722e83b0d03244", size = 177737, upload-time = "2026-01-10T09:23:24.523Z" }, - { url = "https://files.pythonhosted.org/packages/bd/40/53ad02341fa33b3ce489023f635367a4ac98b73570102ad2cdd770dacc9a/websockets-16.0-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:b784ca5de850f4ce93ec85d3269d24d4c82f22b7212023c974c401d4980ebc5e", size = 175268, upload-time = "2026-01-10T09:23:25.781Z" }, - { url = "https://files.pythonhosted.org/packages/74/9b/6158d4e459b984f949dcbbb0c5d270154c7618e11c01029b9bbd1bb4c4f9/websockets-16.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:569d01a4e7fba956c5ae4fc988f0d4e187900f5497ce46339c996dbf24f17641", size = 175486, upload-time = "2026-01-10T09:23:27.033Z" }, - { url = "https://files.pythonhosted.org/packages/e5/2d/7583b30208b639c8090206f95073646c2c9ffd66f44df967981a64f849ad/websockets-16.0-cp314-cp314t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:50f23cdd8343b984957e4077839841146f67a3d31ab0d00e6b824e74c5b2f6e8", size = 185331, upload-time = "2026-01-10T09:23:28.259Z" }, - { url = "https://files.pythonhosted.org/packages/45/b0/cce3784eb519b7b5ad680d14b9673a31ab8dcb7aad8b64d81709d2430aa8/websockets-16.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:152284a83a00c59b759697b7f9e9cddf4e3c7861dd0d964b472b70f78f89e80e", size = 186501, upload-time = "2026-01-10T09:23:29.449Z" }, - { url = "https://files.pythonhosted.org/packages/19/60/b8ebe4c7e89fb5f6cdf080623c9d92789a53636950f7abacfc33fe2b3135/websockets-16.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:bc59589ab64b0022385f429b94697348a6a234e8ce22544e3681b2e9331b5944", size = 186062, upload-time = "2026-01-10T09:23:31.368Z" }, - { url = "https://files.pythonhosted.org/packages/88/a8/a080593f89b0138b6cba1b28f8df5673b5506f72879322288b031337c0b8/websockets-16.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:32da954ffa2814258030e5a57bc73a3635463238e797c7375dc8091327434206", size = 185356, upload-time = "2026-01-10T09:23:32.627Z" }, - { url = "https://files.pythonhosted.org/packages/c2/b6/b9afed2afadddaf5ebb2afa801abf4b0868f42f8539bfe4b071b5266c9fe/websockets-16.0-cp314-cp314t-win32.whl", hash = "sha256:5a4b4cc550cb665dd8a47f868c8d04c8230f857363ad3c9caf7a0c3bf8c61ca6", size = 178085, upload-time = "2026-01-10T09:23:33.816Z" }, - { url = "https://files.pythonhosted.org/packages/9f/3e/28135a24e384493fa804216b79a6a6759a38cc4ff59118787b9fb693df93/websockets-16.0-cp314-cp314t-win_amd64.whl", hash = "sha256:b14dc141ed6d2dde437cddb216004bcac6a1df0935d79656387bd41632ba0bbd", size = 178531, upload-time = "2026-01-10T09:23:35.016Z" }, - { url = "https://files.pythonhosted.org/packages/72/07/c98a68571dcf256e74f1f816b8cc5eae6eb2d3d5cfa44d37f801619d9166/websockets-16.0-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:349f83cd6c9a415428ee1005cadb5c2c56f4389bc06a9af16103c3bc3dcc8b7d", size = 174947, upload-time = "2026-01-10T09:23:36.166Z" }, - { url = "https://files.pythonhosted.org/packages/7e/52/93e166a81e0305b33fe416338be92ae863563fe7bce446b0f687b9df5aea/websockets-16.0-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:4a1aba3340a8dca8db6eb5a7986157f52eb9e436b74813764241981ca4888f03", size = 175260, upload-time = "2026-01-10T09:23:37.409Z" }, - { url = "https://files.pythonhosted.org/packages/56/0c/2dbf513bafd24889d33de2ff0368190a0e69f37bcfa19009ef819fe4d507/websockets-16.0-pp311-pypy311_pp73-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:f4a32d1bd841d4bcbffdcb3d2ce50c09c3909fbead375ab28d0181af89fd04da", size = 176071, upload-time = "2026-01-10T09:23:39.158Z" }, - { url = "https://files.pythonhosted.org/packages/a5/8f/aea9c71cc92bf9b6cc0f7f70df8f0b420636b6c96ef4feee1e16f80f75dd/websockets-16.0-pp311-pypy311_pp73-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0298d07ee155e2e9fda5be8a9042200dd2e3bb0b8a38482156576f863a9d457c", size = 176968, upload-time = "2026-01-10T09:23:41.031Z" }, - { url = "https://files.pythonhosted.org/packages/9a/3f/f70e03f40ffc9a30d817eef7da1be72ee4956ba8d7255c399a01b135902a/websockets-16.0-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:a653aea902e0324b52f1613332ddf50b00c06fdaf7e92624fbf8c77c78fa5767", size = 178735, upload-time = "2026-01-10T09:23:42.259Z" }, - { url = "https://files.pythonhosted.org/packages/6f/28/258ebab549c2bf3e64d2b0217b973467394a9cea8c42f70418ca2c5d0d2e/websockets-16.0-py3-none-any.whl", hash = "sha256:1637db62fad1dc833276dded54215f2c7fa46912301a24bd94d45d46a011ceec", size = 171598, upload-time = "2026-01-10T09:23:45.395Z" }, -] - [[package]] name = "wrapt" version = "2.1.2"