diff --git a/.github/scripts/upload-integ-test-metrics.py b/.github/scripts/upload-integ-test-metrics.py new file mode 100644 index 000000000..28595d647 --- /dev/null +++ b/.github/scripts/upload-integ-test-metrics.py @@ -0,0 +1,147 @@ +#!/usr/bin/env python3 +import sys +import xml.etree.ElementTree as ET +from datetime import datetime +from dataclasses import dataclass +from typing import Any, Literal, TypedDict +import os +import boto3 + +STRANDS_METRIC_NAMESPACE = 'Strands/Tests' + + + +class Dimension(TypedDict): + Name: str + Value: str + + +class MetricDatum(TypedDict): + MetricName: str + Dimensions: list[Dimension] + Value: float + Unit: str + Timestamp: datetime + + +@dataclass +class TestResult: + name: str + classname: str + duration: float + outcome: Literal['failed', 'skipped', 'passed'] + + +def parse_junit_xml(xml_file_path: str) -> list[TestResult]: + try: + tree = ET.parse(xml_file_path) + except FileNotFoundError: + print(f"Warning: XML file not found: {xml_file_path}") + return [] + except ET.ParseError as e: + print(f"Warning: Failed to parse XML: {e}") + return [] + + results = [] + root = tree.getroot() + + for testcase in root.iter('testcase'): + name = testcase.get('name') + classname = testcase.get('classname') + duration = float(testcase.get('time', 0.0)) + + if not name or not classname: + continue + + if testcase.find('failure') is not None or testcase.find('error') is not None: + outcome = 'failed' + elif testcase.find('skipped') is not None: + outcome = 'skipped' + else: + outcome = 'passed' + + results.append(TestResult(name, classname, duration, outcome)) + + return results + + +def build_metric_data(test_results: list[TestResult], repository: str) -> list[MetricDatum]: + metrics: list[MetricDatum] = [] + timestamp = datetime.utcnow() + + for test in test_results: + test_name = f"{test.classname}.{test.name}" + dimensions: list[Dimension] = [ + Dimension(Name='TestName', Value=test_name), + Dimension(Name='Repository', Value=repository) + ] + + metrics.append(MetricDatum( + MetricName='TestPassed', + Dimensions=dimensions, + Value=1.0 if test.outcome == 'passed' else 0.0, + Unit='Count', + Timestamp=timestamp + )) + + metrics.append(MetricDatum( + MetricName='TestFailed', + Dimensions=dimensions, + Value=1.0 if test.outcome == 'failed' else 0.0, + Unit='Count', + Timestamp=timestamp + )) + + metrics.append(MetricDatum( + MetricName='TestSkipped', + Dimensions=dimensions, + Value=1.0 if test.outcome == 'skipped' else 0.0, + Unit='Count', + Timestamp=timestamp + )) + + metrics.append(MetricDatum( + MetricName='TestDuration', + Dimensions=dimensions, + Value=test.duration, + Unit='Seconds', + Timestamp=timestamp + )) + + return metrics + + +def publish_metrics(metric_data: list[dict[str, Any]], region: str): + cloudwatch = boto3.client('cloudwatch', region_name=region) + + batch_size = 1000 + for i in range(0, len(metric_data), batch_size): + batch = metric_data[i:i + batch_size] + try: + cloudwatch.put_metric_data(Namespace=STRANDS_METRIC_NAMESPACE, MetricData=batch) + print(f"Published {len(batch)} metrics to CloudWatch") + except Exception as e: + print(f"Warning: Failed to publish metrics batch: {e}") + + +def main(): + if len(sys.argv) != 3: + print("Usage: python upload-integ-test-metrics.py ") + sys.exit(0) + + xml_file = sys.argv[1] + repository = sys.argv[2] + region = os.environ.get('AWS_REGION', 'us-east-1') + + test_results = parse_junit_xml(xml_file) + if not test_results: + print("No test results found") + sys.exit(1) + + print(f"Found {len(test_results)} test results") + metric_data = build_metric_data(test_results, repository) + publish_metrics(metric_data, region) + + +if __name__ == '__main__': + main() diff --git a/.github/workflows/LAMDBA_LAYERS_SOP.md b/.github/workflows/LAMDBA_LAYERS_SOP.md new file mode 100644 index 000000000..1cf58a614 --- /dev/null +++ b/.github/workflows/LAMDBA_LAYERS_SOP.md @@ -0,0 +1,31 @@ +# Lambda Layers Standard Operating Procedures (SOP) + +## Overview + +This document defines the standard operating procedures for managing Strands Agents Lambda layers across all AWS regions, Python versions, and architectures. + +**Total: 136 individual Lambda layers** (17 regions × 2 architectures × 4 Python versions). All variants must maintain the same layer version number for each PyPI package version, with only one row per PyPI version appearing in documentation. + +## Deployment Process + +### 1. Initial Deployment +1. Run workflow with ALL options selected (default) +2. Specify PyPI package version +3. Type "Create Lambda Layer {package_version}" to confirm +4. All 136 individual layers deploy in parallel (4 Python × 2 arch × 17 regions) +5. Each layer gets its own unique name: `strands-agents-py{PYTHON_VERSION}-{ARCH}` + +### 2. Version Buffering for New Variants +When adding new variants (new Python version, architecture, or region): + +1. **Determine target layer version**: Check existing variants to find the highest layer version +2. **Buffer deployment**: Deploy new variants multiple times until layer version matches existing variants +3. **Example**: If existing variants are at layer version 5, deploy new variant 5 times to reach version 5 + +### 3. Handling Transient Failures +When some regions fail during deployment: + +1. **Identify failed regions**: Check which combinations didn't complete successfully +2. **Targeted redeployment**: Use specific region/arch/Python inputs to redeploy failed combinations +3. **Version alignment**: Continue deploying until all variants reach the same layer version +4. **Verification**: Confirm all combinations have identical layer versions before updating docs \ No newline at end of file diff --git a/.github/workflows/auto-close.yml b/.github/workflows/auto-close.yml index dc9b577a0..be31606d9 100644 --- a/.github/workflows/auto-close.yml +++ b/.github/workflows/auto-close.yml @@ -24,13 +24,13 @@ jobs: include: - label: 'autoclose in 3 days' days: 3 - issue_types: 'issues' #issues/pulls/both + issue_types: 'both' #issues/pulls/both replacement_label: '' closure_message: 'This issue has been automatically closed as it was marked for auto-closure by the team and no additional responses was received within 3 days.' dry_run: 'false' - label: 'autoclose in 7 days' days: 7 - issue_types: 'issues' # issues/pulls/both + issue_types: 'both' # issues/pulls/both replacement_label: '' closure_message: 'This issue has been automatically closed as it was marked for auto-closure by the team and no additional responses was received within 7 days.' dry_run: 'false' diff --git a/.github/workflows/auto-strands-review.yml b/.github/workflows/auto-strands-review.yml new file mode 100644 index 000000000..ebcbc1870 --- /dev/null +++ b/.github/workflows/auto-strands-review.yml @@ -0,0 +1,49 @@ +name: Auto Strands Review + +on: + pull_request_target: + branches: [main] + types: [opened, synchronize, reopened, ready_for_review] + +jobs: + authorization-check: + name: Check access + permissions: read-all + runs-on: ubuntu-latest + outputs: + approval-env: ${{ steps.auth.outputs.approval-env }} + steps: + - name: Check Authorization + id: auth + uses: strands-agents/devtools/authorization-check@main + with: + skip-check: false + username: ${{ github.event.pull_request.user.login || 'invalid' }} + allowed-roles: 'triage,write,admin' + + trigger-review: + name: Trigger Strands Review + needs: authorization-check + environment: ${{ needs.authorization-check.outputs.approval-env }} + permissions: + actions: write + contents: read + runs-on: ubuntu-latest + steps: + - name: Trigger Strands Command Workflow + uses: actions/github-script@v8 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + script: | + await github.rest.actions.createWorkflowDispatch({ + owner: context.repo.owner, + repo: context.repo.repo, + workflow_id: 'strands-command.yml', + ref: 'main', + inputs: { + issue_id: String(context.payload.pull_request.number), + command: 'review', + session_id: '' + } + }); + console.log(`Triggered /strands review for PR #${context.payload.pull_request.number}`); diff --git a/.github/workflows/check-markdown-links.yml b/.github/workflows/check-markdown-links.yml new file mode 100644 index 000000000..2ac596190 --- /dev/null +++ b/.github/workflows/check-markdown-links.yml @@ -0,0 +1,51 @@ +name: Check Markdown Links + +on: + schedule: + - cron: '0 9 * * 1' # Every Monday at 9am UTC + workflow_dispatch: # Allow manual trigger + +jobs: + check-links: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: gaurav-nelson/github-action-markdown-link-check@3c3b66f1f7d0900e37b71eca45b63ea9eedfce31 # v1.0.17 + id: link-check + with: + use-quiet-mode: 'yes' + use-verbose-mode: 'yes' + config-file: '.markdown-link-check.json' + continue-on-error: true + + - name: Create issue if links are broken + if: steps.link-check.outcome == 'failure' + uses: actions/github-script@v7 + with: + script: | + const title = '🔗 Broken markdown links detected'; + const label = 'broken-links'; + + // Check for existing open issue to avoid duplicates + const existing = await github.rest.issues.listForRepo({ + owner: context.repo.owner, + repo: context.repo.repo, + state: 'open', + labels: label, + }); + + if (existing.data.length > 0) { + console.log(`Issue already exists: #${existing.data[0].number}`); + return; + } + + const runUrl = `${context.serverUrl}/${context.repo.owner}/${context.repo.repo}/actions/runs/${context.runId}`; + + await github.rest.issues.create({ + owner: context.repo.owner, + repo: context.repo.repo, + title, + body: `The weekly markdown link check found broken links.\n\nSee the [workflow run](${runUrl}) for details.`, + labels: [label], + }); diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml index 7496e45ef..5f7dd20d9 100644 --- a/.github/workflows/integration-test.yml +++ b/.github/workflows/integration-test.yml @@ -2,40 +2,26 @@ name: Secure Integration test on: pull_request_target: - branches: main + branches: [main] + merge_group: # Run tests in merge queue + types: [checks_requested] jobs: authorization-check: + name: Check access permissions: read-all runs-on: ubuntu-latest outputs: - approval-env: ${{ steps.collab-check.outputs.result }} + approval-env: ${{ steps.auth.outputs.approval-env }} steps: - - name: Collaborator Check - uses: actions/github-script@v8 - id: collab-check + - name: Check Authorization + id: auth + uses: strands-agents/devtools/authorization-check@main with: - result-encoding: string - script: | - try { - const permissionResponse = await github.rest.repos.getCollaboratorPermissionLevel({ - owner: context.repo.owner, - repo: context.repo.repo, - username: context.payload.pull_request.user.login, - }); - const permission = permissionResponse.data.permission; - const hasWriteAccess = ['write', 'admin'].includes(permission); - if (!hasWriteAccess) { - console.log(`User ${context.payload.pull_request.user.login} does not have write access to the repository (permission: ${permission})`); - return "manual-approval" - } else { - console.log(`Verifed ${context.payload.pull_request.user.login} has write access. Auto Approving PR Checks.`) - return "auto-approve" - } - } catch (error) { - console.log(`${context.payload.pull_request.user.login} does not have write access. Requiring Manual Approval to run PR Checks.`) - return "manual-approval" - } + skip-check: ${{ github.event_name == 'merge_group' }} + username: ${{ github.event.pull_request.user.login || 'invalid' }} + allowed-roles: 'maintain,triage,write,admin' + check-access-and-checkout: runs-on: ubuntu-latest needs: authorization-check @@ -46,13 +32,14 @@ jobs: contents: read steps: - name: Configure Credentials - uses: aws-actions/configure-aws-credentials@v5 + uses: aws-actions/configure-aws-credentials@v6 with: role-to-assume: ${{ secrets.STRANDS_INTEG_TEST_ROLE }} aws-region: us-east-1 mask-aws-account-id: true + - name: Checkout head commit - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: ref: ${{ github.event.pull_request.head.sha }} # Pull the commit from the forked repo persist-credentials: false # Don't persist credentials for subsequent actions @@ -61,8 +48,10 @@ jobs: with: python-version: '3.10' - name: Install dependencies + # Pin virtualenv until hatch is fixed. + # See https://github.com/pypa/hatch/issues/2193 run: | - pip install --no-cache-dir hatch + pip install --no-cache-dir hatch 'virtualenv<21' - name: Run integration tests env: AWS_REGION: us-east-1 @@ -71,3 +60,43 @@ jobs: id: tests run: | hatch test tests_integ + + - name: Upload test results + if: always() + uses: actions/upload-artifact@v7 + with: + name: test-results + path: ./build/test-results.xml + + upload-metrics: + runs-on: ubuntu-latest + needs: check-access-and-checkout + if: always() + permissions: + id-token: write + contents: read + steps: + - name: Configure Credentials + uses: aws-actions/configure-aws-credentials@v6 + with: + role-to-assume: ${{ secrets.STRANDS_INTEG_TEST_ROLE }} + aws-region: us-east-1 + mask-aws-account-id: true + + - name: Checkout main + uses: actions/checkout@v6 + with: + ref: main + sparse-checkout: | + .github/scripts + persist-credentials: false + + - name: Download test results + uses: actions/download-artifact@v8 + with: + name: test-results + + - name: Publish test metrics to CloudWatch + run: | + pip install --no-cache-dir boto3 + python .github/scripts/upload-integ-test-metrics.py test-results.xml ${{ github.event.repository.name }} diff --git a/.github/workflows/issue-responder.yml b/.github/workflows/issue-responder.yml new file mode 100644 index 000000000..5b3ad7305 --- /dev/null +++ b/.github/workflows/issue-responder.yml @@ -0,0 +1,66 @@ +name: Issue Responder + +on: + issues: + types: [opened] + +permissions: + id-token: write + contents: read + +jobs: + respond-to-issue: + runs-on: ubuntu-latest + + steps: + - name: Configure AWS credentials + uses: aws-actions/configure-aws-credentials@v6 + with: + role-to-assume: ${{ secrets.STRANDS_AGENTCORE_ACTIONS_ROLE }} + aws-region: us-east-1 + - name: Invoke AgentCore with issue details + env: + GH_ISSUE_AGENTCORE_RUNTIME_ARN: ${{ secrets.GH_ISSUE_AGENTCORE_RUNTIME_ARN }} + ISSUE_NUMBER: ${{ github.event.issue.number }} + ISSUE_TITLE: ${{ github.event.issue.title }} + ISSUE_BODY: ${{ github.event.issue.body }} + ISSUE_URL: ${{ github.event.issue.html_url }} + ISSUE_AUTHOR: ${{ github.event.issue.user.login }} + REPO: ${{ github.repository }} + run: | + npm install @aws-sdk/client-bedrock-agentcore + node - <<'JSEOF' + const { BedrockAgentCoreClient, InvokeAgentRuntimeCommand } = require("@aws-sdk/client-bedrock-agentcore"); + + const payload = JSON.stringify({ + source: "github", + action: "issue_opened", + issue: { + number: parseInt(process.env.ISSUE_NUMBER), + title: process.env.ISSUE_TITLE, + body: process.env.ISSUE_BODY, + url: process.env.ISSUE_URL, + author: process.env.ISSUE_AUTHOR, + repo: process.env.REPO + } + }); + + console.log("Invoking AgentCore with payload:"); + console.log(JSON.stringify(JSON.parse(payload), null, 2)); + + const client = new BedrockAgentCoreClient({ region: "us-east-1" }); + + const sessionId = `github-issue-${process.env.ISSUE_NUMBER}-${Date.now()}-${Math.random().toString(36).slice(2)}`; + + const command = new InvokeAgentRuntimeCommand({ + agentRuntimeArn: process.env.GH_ISSUE_AGENTCORE_RUNTIME_ARN, + runtimeSessionId: sessionId, + payload: Buffer.from(payload) + }); + + (async () => { + const response = await client.send(command); + const textResponse = await response.response.transformToString(); + console.log("Response:", textResponse); + })(); + JSEOF diff --git a/.github/workflows/pr-and-push.yml b/.github/workflows/pr-and-push.yml index b558943dd..d2af9f956 100644 --- a/.github/workflows/pr-and-push.yml +++ b/.github/workflows/pr-and-push.yml @@ -17,3 +17,24 @@ jobs: contents: read with: ref: ${{ github.event.pull_request.head.sha }} + secrets: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} + + check-api: + runs-on: ubuntu-latest + permissions: + contents: read + steps: + - name: Checkout code + uses: actions/checkout@v6 + with: + fetch-depth: 0 # We the need the full Git history. + - name: Setup uv + uses: astral-sh/setup-uv@v7 + - name: Check API breaking changes + run: | + if ! uvx griffe check --search src --format github strands --against "main"; then + echo "Potential API changes detected (review if actually breaking)" + exit 1 + fi + diff --git a/.github/workflows/pr-title.yml b/.github/workflows/pr-title.yml new file mode 100644 index 000000000..ada75b746 --- /dev/null +++ b/.github/workflows/pr-title.yml @@ -0,0 +1,37 @@ +name: PR Title Conventional Commits + +on: + pull_request: + branches: [main] + types: [opened, edited, synchronize, reopened] + +jobs: + validate-pr-title: + runs-on: ubuntu-latest + permissions: + pull-requests: read + steps: + - name: Check PR title follows conventional commits + uses: amannn/action-semantic-pull-request@v6 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + types: | + feat + fix + docs + style + refactor + perf + test + build + ci + chore + revert + requireScope: false + subjectPattern: ^[a-z].+$ + subjectPatternError: | + The subject "{subject}" must start with a lowercase letter. + ignoreLabels: | + bot + dependencies diff --git a/.github/workflows/publish-lambda-layer.yml b/.github/workflows/publish-lambda-layer.yml new file mode 100644 index 000000000..73252f0ff --- /dev/null +++ b/.github/workflows/publish-lambda-layer.yml @@ -0,0 +1,168 @@ +name: Publish PyPI Package to Lambda Layer + +on: + workflow_dispatch: + inputs: + package_version: + description: 'Package version to download' + required: true + type: string + layer_version: + description: 'Layer version' + required: true + type: string + python_version: + description: 'Python version' + required: true + default: 'ALL' + type: choice + options: ['ALL', '3.10', '3.11', '3.12', '3.13'] + architecture: + description: 'Architecture' + required: true + default: 'ALL' + type: choice + options: ['ALL', 'x86_64', 'aarch64'] + region: + description: 'AWS region' + required: true + default: 'ALL' + type: choice + # Only non opt-in regions included for now + options: ['ALL', 'us-east-1', 'us-east-2', 'us-west-1', 'us-west-2', 'ap-south-1', 'ap-northeast-1', 'ap-northeast-2', 'ap-northeast-3', 'ap-southeast-1', 'ap-southeast-2', 'ca-central-1', 'eu-central-1', 'eu-west-1', 'eu-west-2', 'eu-west-3', 'eu-north-1', 'sa-east-1'] + confirm: + description: 'Type "Create Lambda Layer {PyPI version}-layer{layer version}" to confirm publishing the layer' + required: true + type: string + +jobs: + validate: + runs-on: ubuntu-latest + permissions: {} + steps: + - name: Validate confirmation + run: | + CONFIRM="${{ inputs.confirm }}" + EXPECTED="Create Lambda Layer ${{ inputs.package_version }}-layer${{ inputs.layer_version }}" + if [ "$CONFIRM" != "$EXPECTED" ]; then + echo "Confirmation failed. You must type exactly '$EXPECTED' to proceed." + exit 1 + fi + echo "Confirmation validated" + + package-and-upload: + needs: validate + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ${{ inputs.python_version == 'ALL' && fromJson('["3.10", "3.11", "3.12", "3.13"]') || fromJson(format('["{0}"]', inputs.python_version)) }} + architecture: ${{ inputs.architecture == 'ALL' && fromJson('["x86_64", "aarch64"]') || fromJson(format('["{0}"]', inputs.architecture)) }} + region: ${{ inputs.region == 'ALL' && fromJson('["us-east-1", "us-east-2", "us-west-1", "us-west-2", "ap-south-1", "ap-northeast-1", "ap-northeast-2", "ap-northeast-3", "ap-southeast-1", "ap-southeast-2", "ca-central-1", "eu-central-1", "eu-west-1", "eu-west-2", "eu-west-3", "eu-north-1", "sa-east-1"]') || fromJson(format('["{0}"]', inputs.region)) }} + + permissions: + id-token: write + + steps: + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version: ${{ matrix.python-version }} + + - name: Configure AWS credentials + uses: aws-actions/configure-aws-credentials@v6 + with: + role-to-assume: ${{ secrets.STRANDS_LAMBDA_LAYER_PUBLISHER_ROLE }} + aws-region: ${{ matrix.region }} + + - name: Create layer directory structure + run: | + mkdir -p layer/python + + - name: Download and install package + run: | + pip install strands-agents==${{ inputs.package_version }} \ + --python-version ${{ matrix.python-version }} \ + --platform manylinux2014_${{ matrix.architecture }} \ + -t layer/python/ \ + --only-binary=:all: + + - name: Create layer zip + run: | + cd layer + zip -r ../lambda-layer.zip . + + - name: Upload to S3 + run: | + PYTHON_VERSION="${{ matrix.python-version }}" + ARCH="${{ matrix.architecture }}" + REGION="${{ matrix.region }}" + LAYER_NAME="strands-agents-py${PYTHON_VERSION//./_}-${ARCH}" + ACCOUNT_ID=$(aws sts get-caller-identity --query Account --output text) + BUCKET_NAME="strands-layer-${ACCOUNT_ID}-${{ secrets.STRANDS_LAMBDA_LAYER_BUCKET_SALT }}-${REGION}" + LAYER_KEY="$LAYER_NAME/${{ inputs.package_version }}/layer${{ inputs.layer_version }}/lambda-layer.zip" + + aws s3 cp lambda-layer.zip "s3://$BUCKET_NAME/$LAYER_KEY" --region "$REGION" + echo "Uploaded layer to s3://$BUCKET_NAME/$LAYER_KEY" + + publish-layer: + needs: package-and-upload + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ${{ inputs.python_version == 'ALL' && fromJson('["3.10", "3.11", "3.12", "3.13"]') || fromJson(format('["{0}"]', inputs.python_version)) }} + architecture: ${{ inputs.architecture == 'ALL' && fromJson('["x86_64", "aarch64"]') || fromJson(format('["{0}"]', inputs.architecture)) }} + region: ${{ inputs.region == 'ALL' && fromJson('["us-east-1", "us-east-2", "us-west-1", "us-west-2", "ap-south-1", "ap-northeast-1", "ap-northeast-2", "ap-northeast-3", "ap-southeast-1", "ap-southeast-2", "ca-central-1", "eu-central-1", "eu-west-1", "eu-west-2", "eu-west-3", "eu-north-1", "sa-east-1"]') || fromJson(format('["{0}"]', inputs.region)) }} + + permissions: + id-token: write + + steps: + - name: Configure AWS credentials + uses: aws-actions/configure-aws-credentials@v6 + with: + role-to-assume: ${{ secrets.STRANDS_LAMBDA_LAYER_PUBLISHER_ROLE }} + aws-region: ${{ matrix.region }} + + - name: Publish layer + run: | + PYTHON_VERSION="${{ matrix.python-version }}" + ARCH="${{ matrix.architecture }}" + REGION="${{ matrix.region }}" + LAYER_NAME="strands-agents-py${PYTHON_VERSION//./_}-${ARCH}" + ACCOUNT_ID=$(aws sts get-caller-identity --query Account --output text) + REGION_BUCKET="strands-layer-${ACCOUNT_ID}-${{ secrets.STRANDS_LAMBDA_LAYER_BUCKET_SALT }}-${REGION}" + LAYER_KEY="$LAYER_NAME/${{ inputs.package_version }}/layer${{ inputs.layer_version }}/lambda-layer.zip" + + DESCRIPTION="PyPI package: strands-agents v${{ inputs.package_version }} (Python $PYTHON_VERSION, $ARCH)" + + # Set compatible architecture based on matrix architecture + if [ "$ARCH" = "x86_64" ]; then + COMPATIBLE_ARCH="x86_64" + else + COMPATIBLE_ARCH="arm64" + fi + + LAYER_OUTPUT=$(aws lambda publish-layer-version \ + --layer-name $LAYER_NAME \ + --description "$DESCRIPTION" \ + --content S3Bucket=$REGION_BUCKET,S3Key=$LAYER_KEY \ + --compatible-runtimes python${{ matrix.python-version }} \ + --compatible-architectures $COMPATIBLE_ARCH \ + --region "$REGION" \ + --license-info Apache-2.0 \ + --output json) + + LAYER_ARN=$(echo "$LAYER_OUTPUT" | jq -r '.LayerArn') + LAYER_VERSION=$(echo "$LAYER_OUTPUT" | jq -r '.Version') + + echo "Published layer version $LAYER_VERSION with ARN: $LAYER_ARN in region $REGION" + + aws lambda add-layer-version-permission \ + --layer-name $LAYER_NAME \ + --version-number $LAYER_VERSION \ + --statement-id public \ + --action lambda:GetLayerVersion \ + --principal '*' \ + --region "$REGION" + + echo "Successfully published layer version $LAYER_VERSION in region $REGION" \ No newline at end of file diff --git a/.github/workflows/pypi-publish-on-release.yml b/.github/workflows/pypi-publish-on-release.yml index ff19e46b1..4601d4069 100644 --- a/.github/workflows/pypi-publish-on-release.yml +++ b/.github/workflows/pypi-publish-on-release.yml @@ -12,6 +12,8 @@ jobs: contents: read with: ref: ${{ github.event.release.target_commitish }} + secrets: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} build: name: Build distribution 📦 @@ -22,7 +24,7 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 with: persist-credentials: false @@ -32,9 +34,11 @@ jobs: python-version: '3.10' - name: Install dependencies + # Pin virtualenv until hatch is fixed. + # See https://github.com/pypa/hatch/issues/2193 run: | python -m pip install --upgrade pip - pip install hatch twine + pip install hatch twine 'virtualenv<21' - name: Validate version run: | @@ -52,7 +56,7 @@ jobs: hatch build - name: Store the distribution packages - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v7 with: name: python-package-distributions path: dist/ @@ -74,7 +78,7 @@ jobs: steps: - name: Download all the dists - uses: actions/download-artifact@v5 + uses: actions/download-artifact@v8 with: name: python-package-distributions path: dist/ diff --git a/.github/workflows/strands-command.yml b/.github/workflows/strands-command.yml new file mode 100644 index 000000000..496ce025b --- /dev/null +++ b/.github/workflows/strands-command.yml @@ -0,0 +1,92 @@ +name: Strands Command Handler + +on: + issue_comment: + types: [created] + workflow_dispatch: + inputs: + issue_id: + description: 'Issue ID to process (can be issue or PR number)' + required: true + type: string + command: + description: 'Strands command to execute' + required: false + type: string + default: '' + session_id: + description: 'Optional session ID to use' + required: false + type: string + default: '' + +jobs: + authorization-check: + if: startsWith(github.event.comment.body, '/strands') || github.event_name == 'workflow_dispatch' + name: Check access + permissions: read-all + runs-on: ubuntu-latest + outputs: + approval-env: ${{ steps.auth.outputs.approval-env }} + steps: + - name: Check Authorization + id: auth + uses: strands-agents/devtools/authorization-check@main + with: + skip-check: ${{ github.event_name == 'workflow_dispatch' }} + username: ${{ github.event.comment.user.login || 'invalid' }} + allowed-roles: 'maintain,triage,write,admin' + + setup-and-process: + needs: [authorization-check] + environment: ${{ needs.authorization-check.outputs.approval-env }} + permissions: + # Needed to create a branch for the Implementer Agent + contents: write + # These both are needed to add the `strands-running` label to issues and prs + issues: write + pull-requests: write + runs-on: ubuntu-latest + steps: + - name: Parse input + id: parse + uses: strands-agents/devtools/strands-command/actions/strands-input-parser@main + with: + issue_id: ${{ inputs.issue_id }} + command: ${{ inputs.command }} + session_id: ${{ inputs.session_id }} + + execute-readonly-agent: + needs: [setup-and-process] + permissions: + contents: read + issues: read + pull-requests: read + id-token: write # Required for OIDC + runs-on: ubuntu-latest + timeout-minutes: 60 + steps: + + # Add any steps here to set up the environment for the Agent in your repo + # setup node, setup python, or any other dependencies + + - name: Run Strands Agent + id: agent-runner + uses: strands-agents/devtools/strands-command/actions/strands-agent-runner@main + with: + aws_role_arn: ${{ secrets.AWS_ROLE_ARN }} + sessions_bucket: ${{ secrets.AGENT_SESSIONS_BUCKET }} + write_permission: 'false' + + finalize: + if: always() && (startsWith(github.event.comment.body, '/strands') || github.event_name == 'workflow_dispatch') + needs: [setup-and-process, execute-readonly-agent] + permissions: + contents: write + issues: write + pull-requests: write + runs-on: ubuntu-latest + timeout-minutes: 30 + steps: + - name: Execute write operations + uses: strands-agents/devtools/strands-command/actions/strands-finalize@main diff --git a/.github/workflows/test-lint.yml b/.github/workflows/test-lint.yml index 4986acf1f..5f5aa6fcd 100644 --- a/.github/workflows/test-lint.yml +++ b/.github/workflows/test-lint.yml @@ -6,6 +6,9 @@ on: ref: required: true type: string + secrets: + CODECOV_TOKEN: + required: false jobs: unit-test: @@ -28,6 +31,9 @@ jobs: - os: ubuntu-latest os-name: 'linux' python-version: "3.13" + - os: ubuntu-latest + os-name: 'linux' + python-version: "3.14" # Windows - os: windows-latest os-name: 'windows' @@ -41,17 +47,20 @@ jobs: - os: windows-latest os-name: 'windows' python-version: "3.13" + - os: windows-latest + os-name: 'windows' + python-version: "3.14" # MacOS - latest only; not enough runners for macOS - os: macos-latest os-name: 'macOS' - python-version: "3.13" + python-version: "3.14" fail-fast: true runs-on: ${{ matrix.os }} env: LOG_LEVEL: DEBUG steps: - name: Checkout code - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: ref: ${{ inputs.ref }} # Explicitly define which commit to check out persist-credentials: false # Don't persist credentials for subsequent actions @@ -74,8 +83,10 @@ jobs: # Windows typically has audio libraries available by default echo "Windows audio dependencies handled by PyAudio wheels" - name: Install dependencies + # Pin virtualenv until hatch is fixed. + # See https://github.com/pypa/hatch/issues/2193 run: | - pip install --no-cache-dir hatch + pip install --no-cache-dir hatch 'virtualenv<21' - name: Run Unit tests id: tests run: hatch test tests --cover @@ -92,7 +103,7 @@ jobs: contents: read steps: - name: Checkout code - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: ref: ${{ inputs.ref }} persist-credentials: false @@ -109,8 +120,10 @@ jobs: sudo apt-get install -y portaudio19-dev libasound2-dev - name: Install dependencies + # Pin virtualenv until hatch is fixed. + # See https://github.com/pypa/hatch/issues/2193 run: | - pip install --no-cache-dir hatch + pip install --no-cache-dir hatch 'virtualenv<21' - name: Run lint id: lint diff --git a/.gitignore b/.gitignore index 8b0fd989c..0b1375b50 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,4 @@ repl_state .kiro uv.lock .audio_cache +CLAUDE.md diff --git a/.markdown-link-check.json b/.markdown-link-check.json new file mode 100644 index 000000000..a03e7e0a9 --- /dev/null +++ b/.markdown-link-check.json @@ -0,0 +1,6 @@ +{ + "retryOn429": true, + "retryCount": 3, + "fallbackRetryDelay": "30s", + "aliveStatusCodes": [200, 206] +} diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 000000000..daddbbb2d --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,570 @@ +# AGENTS.md + +This document provides context, patterns, and guidelines for AI coding assistants working in this repository. For human contributors, see [CONTRIBUTING.md](./CONTRIBUTING.md). + +## Product Overview + +Strands Agents is an open-source Python SDK for building AI agents with a model-driven approach. It provides a lightweight, flexible framework that scales from simple conversational assistants to complex autonomous workflows. + +**Core Features:** +- Model Agnostic: Multiple model providers (Amazon Bedrock, Anthropic, OpenAI, Gemini, Ollama, etc.) +- Python-Based Tools: Simple `@tool` decorator with hot reloading +- MCP Integration: Native Model Context Protocol support +- Multi-Agent Systems: Agent-to-agent, swarms, and graph patterns +- Streaming Support: Real-time response streaming +- Hooks: Event-driven extensibility for agent lifecycle +- Session Management: Pluggable session managers (file, S3, custom) +- Observability: OpenTelemetry tracing and metrics + +## Directory Structure + +``` +strands-agents/ +│ +├── src/strands/ # Main package source code +│ ├── agent/ # Core agent implementation +│ │ ├── agent.py # Main Agent class +│ │ ├── agent_result.py # Agent execution results +│ │ ├── base.py # AgentBase protocol (agent interface) +│ │ ├── a2a_agent.py # A2AAgent client for remote A2A agents +│ │ ├── state.py # Agent state management +│ │ └── conversation_manager/ # Message history strategies +│ │ ├── conversation_manager.py # Base conversation manager +│ │ ├── null_conversation_manager.py # No-op manager +│ │ ├── sliding_window_conversation_manager.py # Window-based +│ │ └── summarizing_conversation_manager.py # Summarization-based +│ │ +│ ├── event_loop/ # Agent execution loop +│ │ ├── event_loop.py # Main loop logic +│ │ ├── streaming.py # Streaming response handling +│ │ └── _recover_message_on_max_tokens_reached.py +│ │ +│ ├── models/ # Model provider implementations +│ │ ├── model.py # Base model interface +│ │ ├── bedrock.py # Amazon Bedrock +│ │ ├── anthropic.py # Anthropic Claude +│ │ ├── openai.py # OpenAI +│ │ ├── gemini.py # Google Gemini +│ │ ├── ollama.py # Ollama local models +│ │ ├── litellm.py # LiteLLM unified interface +│ │ ├── mistral.py # Mistral AI +│ │ ├── llamaapi.py # LlamaAPI +│ │ ├── llamacpp.py # llama.cpp local +│ │ ├── sagemaker.py # AWS SageMaker +│ │ ├── writer.py # Writer AI +│ │ └── _validation.py # Validation utilities +│ │ +│ ├── tools/ # Tool system +│ │ ├── decorator.py # @tool decorator +│ │ ├── tools.py # Tool base classes +│ │ ├── tool_provider.py # ToolProvider interface +│ │ ├── registry.py # Tool registration +│ │ ├── loader.py # Dynamic tool loading +│ │ ├── watcher.py # Hot reload +│ │ ├── _caller.py # Tool invocation +│ │ ├── _validator.py # Tool validation +│ │ ├── _tool_helpers.py # Helper utilities +│ │ ├── executors/ # Tool execution environments +│ │ │ ├── _executor.py # Base executor +│ │ │ ├── concurrent.py # Thread/process pool +│ │ │ └── sequential.py # Sequential execution +│ │ ├── mcp/ # Model Context Protocol +│ │ │ ├── mcp_client.py # MCP client implementation +│ │ │ ├── mcp_agent_tool.py # MCP tool wrapper +│ │ │ ├── mcp_types.py # MCP type definitions +│ │ │ ├── mcp_tasks.py # Task-augmented execution config +│ │ │ └── mcp_instrumentation.py # MCP telemetry +│ │ └── structured_output/ # Structured output handling +│ │ ├── structured_output_tool.py +│ │ ├── structured_output_utils.py +│ │ └── _structured_output_context.py +│ │ +│ ├── multiagent/ # Multi-agent patterns +│ │ ├── base.py # Base multi-agent classes +│ │ ├── graph.py # Graph-based orchestration +│ │ ├── swarm.py # Swarm pattern +│ │ ├── a2a/ # Agent-to-agent protocol +│ │ │ ├── executor.py # A2A executor +│ │ │ ├── server.py # A2A server +│ │ │ └── converters.py # Strands/A2A type converters +│ │ └── nodes/ # Graph node implementations +│ │ +│ ├── types/ # Type definitions +│ │ ├── content.py # Content types (text, images, etc.) +│ │ ├── tools.py # Tool-related types +│ │ ├── streaming.py # Streaming event types +│ │ ├── exceptions.py # Custom exceptions +│ │ ├── agent.py # Agent types +│ │ ├── session.py # Session types +│ │ ├── multiagent.py # Multi-agent types +│ │ ├── guardrails.py # Guardrail types +│ │ ├── interrupt.py # Interrupt types +│ │ ├── media.py # Media types +│ │ ├── citations.py # Citation types +│ │ ├── traces.py # Trace types +│ │ ├── event_loop.py # Event loop types +│ │ ├── json_dict.py # JSON dict utilities +│ │ ├── collections.py # Collection types +│ │ ├── _snapshot.py # Snapshot types and helpers +│ │ ├── _events.py # Internal event types +│ │ ├── a2a.py # A2A protocol types +│ │ └── models/ # Model-specific types +│ │ +│ ├── session/ # Session management +│ │ ├── session_manager.py # Base interface +│ │ ├── file_session_manager.py # File-based storage +│ │ ├── s3_session_manager.py # S3 storage +│ │ ├── repository_session_manager.py # Repository pattern +│ │ └── session_repository.py # Storage interface +│ │ +│ ├── telemetry/ # Observability (OpenTelemetry) +│ │ ├── tracer.py # Tracing +│ │ ├── metrics.py # Metrics collection +│ │ ├── metrics_constants.py # Metric definitions +│ │ └── config.py # Configuration +│ │ +│ ├── hooks/ # Event hooks system +│ │ ├── events.py # Hook event definitions +│ │ ├── registry.py # Hook registration +│ │ └── _type_inference.py # Event type inference from type hints +│ │ +│ ├── plugins/ # Plugin system +│ │ ├── plugin.py # Plugin base class +│ │ ├── multiagent_plugin.py # MultiAgentPlugin base class +│ │ ├── decorator.py # @hook decorator +│ │ ├── registry.py # PluginRegistry for tracking agent plugins +│ │ ├── multiagent_registry.py # Registry for tracking orchestrator plugins +│ │ └── _discovery.py # Shared hook/tool discovery utilities +│ │ +│ ├── handlers/ # Event handlers +│ │ └── callback_handler.py # Callback handling +│ │ +│ ├── vended_plugins/ # Production plugin implementations +│ │ ├── steering/ # Agent steering system +│ │ │ ├── context_providers/ # Context data providers (e.g., ledger) +│ │ │ ├── core/ # Base classes, actions, context +│ │ │ └── handlers/ # Handler implementations (e.g., LLM) +│ │ ├── skills/ # AgentSkills.io integration (Skill, AgentSkills) +│ │ └── context_offloader/ # Large tool result offloading plugin +│ │ +│ ├── experimental/ # Experimental features (API may change) +│ │ ├── agent_config.py # Experimental agent config +│ │ ├── bidi/ # Bidirectional streaming +│ │ │ ├── agent/ # Bidi agent implementation +│ │ │ ├── io/ # Input/output handling +│ │ │ ├── models/ # Bidi model providers +│ │ │ ├── tools/ # Bidi tools +│ │ │ ├── types/ # Bidi types +│ │ │ └── _async/ # Async utilities +│ │ ├── checkpoint/ # Durable agent execution checkpoints +│ │ │ └── checkpoint.py # Checkpoint dataclass and serialization +│ │ ├── hooks/ # Experimental hooks +│ │ │ ├── events.py +│ │ │ └── multiagent/ +│ │ ├── steering/ # Deprecated aliases for vended_plugins/steering +│ │ └── tools/ # Deprecated aliases for strands.tools +│ │ +│ ├── __init__.py # Public API exports +│ ├── interrupt.py # Interrupt handling +│ ├── _async.py # Async utilities +│ ├── _exception_notes.py # Exception helpers +│ ├── _identifier.py # ID generation +│ └── py.typed # PEP 561 marker +│ +├── tests/ # Unit tests (mirrors src/) +│ ├── conftest.py # Pytest fixtures +│ ├── fixtures/ # Test fixtures +│ │ ├── mocked_model_provider.py # Mock model for testing +│ │ ├── mock_agent_tool.py +│ │ ├── mock_hook_provider.py +│ │ └── ... +│ └── strands/ # Tests mirror src/strands/ +│ ├── agent/ +│ ├── event_loop/ +│ ├── models/ +│ ├── tools/ +│ ├── multiagent/ +│ ├── types/ +│ ├── session/ +│ ├── telemetry/ +│ ├── hooks/ +│ ├── plugins/ +│ ├── handlers/ +│ ├── experimental/ +│ └── utils/ +│ +├── tests_integ/ # Integration tests +│ ├── conftest.py +│ ├── models/ # Model provider tests +│ │ ├── test_model_bedrock.py +│ │ ├── test_model_anthropic.py +│ │ ├── test_model_openai.py +│ │ ├── test_model_gemini.py +│ │ ├── test_model_ollama.py +│ │ └── ... +│ ├── mcp/ # MCP integration tests +│ │ ├── test_mcp_client.py +│ │ ├── echo_server.py +│ │ └── ... +│ ├── tools/ # Tool system tests +│ ├── hooks/ # Hook tests +│ ├── interrupts/ # Interrupt tests +│ ├── steering/ # Steering tests +│ ├── bidi/ # Bidirectional streaming tests +│ ├── a2a/ # A2A agent integration tests +│ ├── test_multiagent_graph.py +│ ├── test_multiagent_swarm.py +│ ├── test_stream_agent.py +│ ├── test_session.py +│ └── ... +│ +├── docs/ # Developer documentation +│ ├── README.md # Docs folder overview +│ ├── STYLE_GUIDE.md # Code style conventions +│ ├── HOOKS.md # Hooks system guide +│ ├── PR.md # PR description guidelines +│ └── MCP_CLIENT_ARCHITECTURE.md # MCP threading architecture +│ +├── pyproject.toml # Project config (build, deps, tools) +├── AGENTS.md # This file +└── CONTRIBUTING.md # Human contributor guidelines +``` + +### Directory Purposes + +- **`src/strands/`**: All production code +- **`tests/`**: Unit tests mirroring src/ structure +- **`tests_integ/`**: Integration tests with real model providers +- **`docs/`**: Developer documentation for contributors + +**IMPORTANT**: After making changes that affect the directory structure (adding new directories, moving files, or adding significant new files), you MUST update this directory structure section to reflect the current state of the repository. + +## Development Workflow + +### 1. Environment Setup + +```bash +hatch shell # Enter dev environment +pre-commit install -t pre-commit -t commit-msg # Install hooks +``` + +### 2. Making Changes + +1. Create feature branch +2. Implement changes following the patterns below +3. Run quality checks before committing +4. Commit with conventional commits (`feat:`, `fix:`, `docs:`, `refactor:`, `test:`, `chore:`) +5. Push and open PR + +### 3. Pull Request Guidelines + +When creating pull requests, you MUST follow the guidelines in PR.md. Key principles: + +Focus on WHY: Explain motivation and user impact, not implementation details +Document public API changes: Show before/after code examples +Be concise: Use prose over bullet lists; avoid exhaustive checklists +Target senior engineers: Assume familiarity with the SDK +Exclude implementation details: Leave these to code comments and diffs +See PR.md for the complete guidance and template. + +### 4. Quality Gates + +Pre-commit hooks run automatically on commit: +- Formatting (ruff) +- Linting (ruff + mypy) +- Tests (pytest) +- Commit message validation (commitizen) + +All checks must pass before commit is allowed. + +## Coding Patterns and Best Practices + +### Logging Style + +Use structured logging with field-value pairs followed by human-readable messages: + +```python +logger.debug("field1=<%s>, field2=<%s> | human readable message", field1, field2) +``` + +**Guidelines:** +- Add context as `FIELD=` pairs at the beginning +- Separate pairs with commas +- Enclose values in `<>` for readability (especially for empty values) +- Use `%s` string interpolation (not f-strings) for performance +- Use lowercase messages, no punctuation +- Separate multiple statements with pipe `|` + +**Good:** +```python +logger.debug("user_id=<%s>, action=<%s> | user performed action", user_id, action) +logger.info("request_id=<%s>, duration_ms=<%d> | request completed", request_id, duration) +logger.warning("attempt=<%d>, max_attempts=<%d> | retry limit approaching", attempt, max_attempts) +``` + +**Bad:** +```python +logger.debug(f"User {user_id} performed action {action}") # Don't use f-strings +logger.info("Request completed in %d ms.", duration) # Don't add punctuation +``` + +### Type Annotations + +All code must include type annotations: +- Function parameters and return types required +- No implicit optional types +- Use `typing` or `typing_extensions` for complex types +- Mypy strict mode enforced + +```python +def process_message(content: str, max_tokens: int | None = None) -> AgentResult: + ... +``` + +### Docstrings + +Use Google-style docstrings for all public functions, classes, and modules: + +```python +def example_function(param1: str, param2: int) -> bool: + """Brief description of function. + + Longer description if needed. This docstring is used by LLMs + to understand the function's purpose when used as a tool. + + Args: + param1: Description of param1 + param2: Description of param2 + + Returns: + Description of return value + + Raises: + ValueError: When invalid input is provided + """ + pass +``` + +### Import Organization + +Imports must be at the top of the file. + +Imports are automatically organized by ruff/isort: +1. Standard library imports +2. Third-party imports +3. Local application imports + +Use absolute imports for cross-package references, relative imports within packages. + +```python +# Standard library +import logging +from typing import Any + +# Third-party +import boto3 +from pydantic import BaseModel + +# Local +from strands.agent import Agent +from .tools import Tool +``` + +### File Organization + +- Each major feature in its own directory +- Base classes and interfaces defined first +- Implementation-specific code in separate files +- Private modules prefixed with `_` +- Test files prefixed with `test_` + +### Naming Conventions + +- **Variables/Functions**: `snake_case` +- **Classes**: `PascalCase` +- **Constants**: `UPPER_SNAKE_CASE` +- **Private members**: Prefix with `_` + +### Error Handling + +- Use custom exceptions from `strands.types.exceptions` +- Provide clear error messages with context +- Don't swallow exceptions silently + +## Testing Patterns + +### Unit Tests (`tests/`) + +- Mirror the `src/strands/` structure exactly +- Focus on isolated component testing +- Use mocking for external dependencies (models, AWS services) +- Use fixtures from `tests/fixtures/` (e.g., `mocked_model_provider.py`) + +```python +# tests/strands/agent/test_agent.py mirrors src/strands/agent/agent.py +``` + +### Integration Tests (`tests_integ/`) + +- End-to-end testing with real model providers +- Require credentials/API keys (set via environment variables) +- Organized by feature area + +### Test File Naming + +- Unit tests: `test_{module}.py` in `tests/strands/{path}/` +- Integration tests: `test_{feature}.py` in `tests_integ/` + +### Running Tests + +```bash +hatch test # Run unit tests +hatch test -c # Run with coverage +hatch run test-integ # Run integration tests +hatch test tests/strands/agent/ # Run specific directory +hatch test --all # Test all Python versions (3.10-3.13) +``` + +### Writing Tests + +- Use pytest fixtures for setup/teardown +- Use `moto` for mocking AWS services +- Use `pytest.mark.asyncio` for async tests +- Keep tests focused and independent +- Import packages at the top of the test files + +## MCP Tasks (Experimental) + +The SDK supports MCP task-augmented execution for long-running tools. This feature is experimental and aligns with the MCP specification 2025-11-25. + +### Overview + +Task-augmented execution allows tools to run asynchronously with a workflow: +1. Create task via `call_tool_as_task` +2. Poll for completion via `poll_task` +3. Get result via `get_task_result` + +### Configuration + +Enable tasks by passing a `TasksConfig` to `MCPClient`: + +```python +from datetime import timedelta +from strands.tools.mcp import MCPClient, TasksConfig + +# Enable with defaults (ttl=1min, poll_timeout=5min) +client = MCPClient(transport, tasks_config={}) + +# Or configure explicitly +client = MCPClient( + transport, + tasks_config=TasksConfig( + ttl=timedelta(minutes=2), # Task time-to-live + poll_timeout=timedelta(minutes=10), # Polling timeout + ), +) +``` + +### Tool Support Levels + +MCP tools declare their task support via `execution.taskSupport`: +- `TASK_REQUIRED`: Tool must use task-augmented execution +- `TASK_OPTIONAL`: Tool can use tasks if client opts in +- `TASK_FORBIDDEN`: Tool does not support tasks (default) + +### Decision Logic + +Task-augmented execution is used when ALL conditions are met: +1. Client opts in via `tasks_config` (not None) +2. Server advertises task capability (`tasks.requests.tools.call`) +3. Tool's `taskSupport` is `required` or `optional` + +### Key Files + +- `src/strands/tools/mcp/mcp_tasks.py` - `TasksConfig` and defaults +- `src/strands/tools/mcp/mcp_client.py` - Task execution logic (`_call_tool_as_task_and_poll_async`) +- `tests/strands/tools/mcp/test_mcp_client_tasks.py` - Unit tests +- `tests_integ/mcp/test_mcp_client_tasks.py` - Integration tests +- `tests_integ/mcp/task_echo_server.py` - Test server with task support + +## Things to Do + +- Use explicit return types for all functions +- Write Google-style docstrings for public APIs +- Use structured logging format +- Add type annotations everywhere +- Use relative imports within packages +- Mirror src/ structure in tests/ +- Run `hatch fmt --formatter` and `hatch fmt --linter` before committing +- Follow conventional commits (`feat:`, `fix:`, `docs:`, etc.) + +## Things NOT to Do + +- Don't use f-strings in logging calls +- Don't use `Any` type without good reason +- Don't skip type annotations +- Don't put unit tests outside `tests/strands/` structure +- Don't commit without running pre-commit hooks +- Don't add punctuation to log messages +- Don't use implicit optional types + +## Development Commands + +```bash +# Environment +hatch shell # Enter dev environment + +# Formatting & Linting +hatch fmt --formatter # Format code +hatch fmt --linter # Run linters (ruff + mypy) + +# Testing +hatch test # Run unit tests +hatch test -c # Run with coverage +hatch run test-integ # Run integration tests +hatch test --all # Test all Python versions + +# Pre-commit +pre-commit run --all-files # Run all hooks manually + +# Readiness Check +hatch run prepare # Run all checks (format, lint, test) + +# Build +hatch build # Build package +``` + +## Agent-Specific Notes + +### Writing Code + +- Make the SMALLEST reasonable changes to achieve the desired outcome +- Prefer simple, clean, maintainable solutions over clever ones +- Reduce code duplication, even if refactoring takes extra effort +- Match the style and formatting of surrounding code +- Fix broken things immediately when you find them + +### Code Comments + +- Comments should explain WHAT the code does or WHY it exists +- NEVER add comments about what used to be there or how something changed +- NEVER refer to temporal context ("recently refactored", "moved") +- Keep comments concise and evergreen + +### Code Review Considerations + +- Address all review comments +- Test changes thoroughly +- Update documentation if behavior changes +- Maintain test coverage +- Follow conventional commit format for fix commits + +## Additional Resources + +- [Strands Agents Documentation](https://strandsagents.com/) +- [CONTRIBUTING.md](./CONTRIBUTING.md) - Human contributor guidelines +- [docs/](./docs/) - Developer documentation + - [STYLE_GUIDE.md](./docs/STYLE_GUIDE.md) - Code style conventions + - [HOOKS.md](./docs/HOOKS.md) - Hooks system guide + - [PR.md](./docs/PR.md) - PR description guidelines + - [MCP_CLIENT_ARCHITECTURE.md](./docs/MCP_CLIENT_ARCHITECTURE.md) - MCP threading design diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index be83ff85b..86691a2d7 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -122,7 +122,7 @@ hatch fmt --linter If you're using an IDE like VS Code or PyCharm, consider configuring it to use these tools automatically. -For additional details on styling, please see our dedicated [Style Guide](./STYLE_GUIDE.md). +For additional details on styling, please see our dedicated [Style Guide](./docs/STYLE_GUIDE.md). ## Contributing via Pull Requests @@ -132,6 +132,8 @@ Contributions via pull requests are much appreciated. Before sending us a pull r 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. +For guidance on writing effective PR descriptions, see our [PR Description Guidelines](./docs/PR.md). + To send us a pull request, please: 1. Create a branch. diff --git a/README.md b/README.md index e7d1b2a7e..7e1612858 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,7 @@ License PyPI version Python versions + Strands Discord

@@ -169,20 +170,21 @@ response = agent("Tell me about Agentic AI") ``` Built-in providers: - - [Amazon Bedrock](https://strandsagents.com/latest/user-guide/concepts/model-providers/amazon-bedrock/) - - [Anthropic](https://strandsagents.com/latest/user-guide/concepts/model-providers/anthropic/) - - [Gemini](https://strandsagents.com/latest/user-guide/concepts/model-providers/gemini/) - - [Cohere](https://strandsagents.com/latest/user-guide/concepts/model-providers/cohere/) - - [LiteLLM](https://strandsagents.com/latest/user-guide/concepts/model-providers/litellm/) - - [llama.cpp](https://strandsagents.com/latest/user-guide/concepts/model-providers/llamacpp/) - - [LlamaAPI](https://strandsagents.com/latest/user-guide/concepts/model-providers/llamaapi/) - - [MistralAI](https://strandsagents.com/latest/user-guide/concepts/model-providers/mistral/) - - [Ollama](https://strandsagents.com/latest/user-guide/concepts/model-providers/ollama/) - - [OpenAI](https://strandsagents.com/latest/user-guide/concepts/model-providers/openai/) - - [SageMaker](https://strandsagents.com/latest/user-guide/concepts/model-providers/sagemaker/) - - [Writer](https://strandsagents.com/latest/user-guide/concepts/model-providers/writer/) - -Custom providers can be implemented using [Custom Providers](https://strandsagents.com/latest/user-guide/concepts/model-providers/custom_model_provider/) + - [Amazon Bedrock](https://strandsagents.com/docs/user-guide/concepts/model-providers/amazon-bedrock/) + - [Anthropic](https://strandsagents.com/docs/user-guide/concepts/model-providers/anthropic/) + - [Gemini](https://strandsagents.com/docs/user-guide/concepts/model-providers/gemini/) + - [Cohere](https://strandsagents.com/docs/user-guide/concepts/model-providers/cohere/) + - [LiteLLM](https://strandsagents.com/docs/user-guide/concepts/model-providers/litellm/) + - [llama.cpp](https://strandsagents.com/docs/user-guide/concepts/model-providers/llamacpp/) + - [LlamaAPI](https://strandsagents.com/docs/user-guide/concepts/model-providers/llamaapi/) + - [MistralAI](https://strandsagents.com/docs/user-guide/concepts/model-providers/mistral/) + - [Ollama](https://strandsagents.com/docs/user-guide/concepts/model-providers/ollama/) + - [OpenAI](https://strandsagents.com/docs/user-guide/concepts/model-providers/openai/) + - [OpenAI Responses API](https://strandsagents.com/docs/user-guide/concepts/model-providers/openai/) + - [SageMaker](https://strandsagents.com/docs/user-guide/concepts/model-providers/sagemaker/) + - [Writer](https://strandsagents.com/docs/user-guide/concepts/model-providers/writer/) + +Custom providers can be implemented using [Custom Providers](https://strandsagents.com/docs/user-guide/concepts/model-providers/custom_model_provider/) ### Example tools @@ -201,12 +203,22 @@ It's also available on GitHub via [strands-agents/tools](https://github.com/stra > **⚠️ Experimental Feature**: Bidirectional streaming is currently in experimental status. APIs may change in future releases as we refine the feature based on user feedback and evolving model capabilities. -Build real-time voice and audio conversations with persistent streaming connections. Unlike traditional request-response patterns, bidirectional streaming maintains long-running conversations where users can interrupt, provide continuous input, and receive real-time audio responses. Get started with your first BidiAgent by following the [Quickstart](https://strandsagents.com/latest/documentation/docs/user-guide/concepts/experimental/bidirectional-streaming/quickstart) guide. +Build real-time voice and audio conversations with persistent streaming connections. Unlike traditional request-response patterns, bidirectional streaming maintains long-running conversations where users can interrupt, provide continuous input, and receive real-time audio responses. Get started with your first BidiAgent by following the [Quickstart](https://strandsagents.com/docs/user-guide/concepts/bidirectional-streaming/quickstart/) guide. **Supported Model Providers:** -- Amazon Nova Sonic (`amazon.nova-sonic-v1:0`) -- Google Gemini Live (`gemini-2.5-flash-native-audio-preview-09-2025`) -- OpenAI Realtime API (`gpt-realtime`) +- Amazon Nova Sonic (v1, v2) +- Google Gemini Live +- OpenAI Realtime API + +**Installation:** + +```bash +# Server-side only (no audio I/O dependencies) +pip install strands-agents[bidi] + +# With audio I/O support (includes PyAudio dependency) +pip install strands-agents[bidi,bidi-io] +``` **Quick Example:** @@ -219,11 +231,11 @@ from strands.experimental.bidi.tools import stop_conversation from strands_tools import calculator async def main(): - # Create bidirectional agent with audio model + # Create bidirectional agent with Nova Sonic v2 model = BidiNovaSonicModel() agent = BidiAgent(model=model, tools=[calculator, stop_conversation]) - # Setup audio and text I/O + # Setup audio and text I/O (requires bidi-io extra) audio_io = BidiAudioIO() text_io = BidiTextIO() @@ -238,10 +250,14 @@ if __name__ == "__main__": asyncio.run(main()) ``` +> **Note**: `BidiAudioIO` and `BidiTextIO` require the `bidi-io` extra. For server-side deployments where audio I/O is handled by clients (browsers, mobile apps), install only `strands-agents[bidi]` and implement custom input/output handlers using the `BidiInput` and `BidiOutput` protocols. + **Configuration Options:** ```python -# Configure audio settings +from strands.experimental.bidi.models import BidiNovaSonicModel + +# Configure audio settings and turn detection (v2 only) model = BidiNovaSonicModel( provider_config={ "audio": { @@ -249,6 +265,9 @@ model = BidiNovaSonicModel( "output_rate": 16000, "voice": "matthew" }, + "turn_detection": { + "endpointingSensitivity": "MEDIUM" # HIGH, MEDIUM, or LOW + }, "inference": { "max_tokens": 2048, "temperature": 0.7 @@ -263,6 +282,19 @@ audio_io = BidiAudioIO( input_buffer_size=10, output_buffer_size=10 ) + +# Text input mode (type messages instead of speaking) +text_io = BidiTextIO() +await agent.run( + inputs=[text_io.input()], # Use text input + outputs=[audio_io.output(), text_io.output()] +) + +# Multi-modal: Both audio and text input +await agent.run( + inputs=[audio_io.input(), text_io.input()], # Speak OR type + outputs=[audio_io.output(), text_io.output()] +) ``` ## Documentation @@ -270,11 +302,11 @@ audio_io = BidiAudioIO( For detailed guidance & examples, explore our documentation: - [User Guide](https://strandsagents.com/) -- [Quick Start Guide](https://strandsagents.com/latest/user-guide/quickstart/) -- [Agent Loop](https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/) -- [Examples](https://strandsagents.com/latest/examples/) -- [API Reference](https://strandsagents.com/latest/api-reference/agent/) -- [Production & Deployment Guide](https://strandsagents.com/latest/user-guide/deploy/operating-agents-in-production/) +- [Quick Start Guide](https://strandsagents.com/docs/user-guide/quickstart/) +- [Agent Loop](https://strandsagents.com/docs/user-guide/concepts/agents/agent-loop/) +- [Examples](https://strandsagents.com/docs/examples/) +- [API Reference](https://strandsagents.com/docs/api/python/strands.agent.agent/) +- [Production & Deployment Guide](https://strandsagents.com/docs/user-guide/deploy/operating-agents-in-production/) ## Contributing ❤️ @@ -285,6 +317,9 @@ We welcome contributions! See our [Contributing Guide](CONTRIBUTING.md) for deta - Code of Conduct - Reporting of security issues +## Stay in touch with the team +Come meet the Strands team and other users on [**Discord**](https://discord.com/invite/strands) + ## License This project is licensed under the Apache License 2.0 - see the [LICENSE](LICENSE) file for details. diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 000000000..b520ee1fb --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,20 @@ +# Security Policy + +## Supported Versions + +| Version | Supported | +| ------- | ------------------ | +| 1.x.x | :white_check_mark: | +| < 1.0 | :x: | + +## Reporting Security Issues + +Amazon Web Services (AWS) is dedicated to the responsible disclosure of security vulnerabilities. + +We kindly ask that you **do not** open a public GitHub issue to report security concerns. + +Instead, please submit the issue to the AWS Vulnerability Disclosure Program via [HackerOne](https://hackerone.com/aws_vdp) or send your report via [email](mailto:aws-security@amazon.com). + +For more details, visit the [AWS Vulnerability Reporting Page](http://aws.amazon.com/security/vulnerability-reporting/). + +Thank you in advance for collaborating with us to help protect our customers. diff --git a/docs/HOOKS.md b/docs/HOOKS.md new file mode 100644 index 000000000..b447c6400 --- /dev/null +++ b/docs/HOOKS.md @@ -0,0 +1,24 @@ +# Hooks System + +The hooks system enables extensible agent functionality through strongly-typed event callbacks. + +## Terminology + +- **Paired events**: Events that denote the beginning and end of an operation +- **Hook callback**: A function that receives a strongly-typed event argument +- **Hook provider**: An object implementing `HookProvider` that registers callbacks via `register_hooks()` + +## Naming Conventions + +- All hook events have a suffix of `Event` +- Paired events follow `Before{Action}Event` and `After{Action}Event` +- Action words come after the lifecycle indicator (e.g., `BeforeToolCallEvent` not `BeforeToolEvent`) + +## Paired Events + +- For every `Before` event there is a corresponding `After` event, even if an exception occurs +- `After` events invoke callbacks in reverse registration order (for proper cleanup) + +## Writable Properties + +Some events have writable properties that modify agent behavior. Values are re-read after callbacks complete. For example, `BeforeToolCallEvent.selected_tool` is writable - after invoking the callback, the modified `selected_tool` takes effect for the tool call. diff --git a/_MCP_CLIENT_ARCHITECTURE.md b/docs/MCP_CLIENT_ARCHITECTURE.md similarity index 100% rename from _MCP_CLIENT_ARCHITECTURE.md rename to docs/MCP_CLIENT_ARCHITECTURE.md diff --git a/docs/PR.md b/docs/PR.md new file mode 100644 index 000000000..b4778f2b1 --- /dev/null +++ b/docs/PR.md @@ -0,0 +1,201 @@ +# Pull Request Description Guidelines + +Good PR descriptions help reviewers understand the context and impact of your changes. They enable faster reviews, better decision-making, and serve as valuable historical documentation. + +When creating a PR, follow the [GitHub PR template](../.github/PULL_REQUEST_TEMPLATE.md) and use these guidelines to fill it out effectively. + +## Who's Reading Your PR? + +Write for senior engineers familiar with the SDK. Assume your reader: + +- Understands the SDK's architecture and patterns +- Has context about the broader system +- Can read code diffs to understand implementation details +- Values concise, focused communication + +## What to Include + +Every PR description should have: + +1. **Motivation** — Why is this change needed? +2. **Public API Changes** — What changes to the public API (with code snippets)? +3. **Use Cases** (optional) — When would developers use this feature? Only include for non-obvious functionality; skip for trivial changes or obvious fixes. +4. **Breaking Changes** (if applicable) — What breaks and how to migrate? + +## Writing Principles + +**Focus on WHY, not HOW:** + +- ✅ "Hook providers need access to the agent's result to perform post-invocation actions like logging or analytics" +- ❌ "Added result field to AfterInvocationEvent dataclass" + +**Document public API changes with example code snippets:** + +- ✅ Show before/after code snippets for API changes +- ❌ List every file or line changed + +**Be concise:** + +- ✅ Use prose over bullet lists when possible +- ❌ Create exhaustive implementation checklists + +**Emphasize user impact:** + +- ✅ "Enables hooks to log conversation outcomes or trigger follow-up actions based on the result" +- ❌ "Updated AfterInvocationEvent to include optional AgentResult field" + +## What to Skip + +Leave these out of your PR description: + +- **Implementation details** — Code comments and commit messages cover this +- **Test coverage notes** — CI will catch issues; assume tests are comprehensive +- **Line-by-line change lists** — The diff provides this +- **Build/lint/coverage status** — CI handles verification +- **Commit hashes** — GitHub links commits automatically + +## Anti-patterns + +❌ **Over-detailed checklists:** + +```markdown +### Type Definition Updates + +- Added result field to AfterInvocationEvent dataclass +- Updated Agent._run_loop to capture and pass AgentResult +``` + +❌ **Implementation notes reviewers don't need:** + +```markdown +## Implementation Notes + +- Result field defaults to None +- AgentResult is captured from EventLoopStopEvent before invoking hooks +``` + +❌ **Test coverage bullets:** + +```markdown +### Test Coverage + +- Added test: AfterInvocationEvent includes AgentResult +- Added test: result is None when structured_output is used +``` + +## Good Examples + +✅ **Motivation section:** + +```markdown +## Motivation + +Hook providers often need to perform actions based on the outcome of an agent's +invocation, such as logging results, updating metrics, or triggering follow-up +workflows. Currently, the `AfterInvocationEvent` doesn't provide access to the +`AgentResult`, forcing hook implementations to track state externally or miss +this information entirely. +``` + +✅ **Public API Changes section:** + +````markdown +## Public API Changes + +`AfterInvocationEvent` now includes an optional `result` attribute containing +the `AgentResult`: + +```python +# Before: no access to result +class MyHook(HookProvider): + def on_after_invocation(self, event: AfterInvocationEvent) -> None: + # Could only access event.agent, no result available + logger.info("Invocation completed") + +# After: result available for inspection +class MyHook(HookProvider): + def on_after_invocation(self, event: AfterInvocationEvent) -> None: + if event.result: + logger.info(f"Completed with stop_reason: {event.result.stop_reason}") +``` + +The `result` field is `None` when invoked from `structured_output` methods. + +```` + +✅ **Use Cases section:** + +```markdown +## Use Cases + +- **Result logging**: Log conversation outcomes including stop reasons and token usage +- **Analytics**: Track agent performance metrics based on invocation results +- **Conditional workflows**: Trigger follow-up actions based on how the agent completed +```` + +## Template + +````markdown +## Motivation + +[Explain WHY this change is needed. What problem does it solve? What limitation +does it address? What user need does it fulfill?] + +Resolves: #[issue-number] + +## Public API Changes + +[Document changes to public APIs with before/after code snippets. If no public +API changes, state "No public API changes."] + +```python +# Before +[existing API usage] + +# After +[new API usage] +``` + +[Explain behavior, parameters, return values, and backward compatibility.] + +## Use Cases (optional) + +[Only include for non-obvious functionality. Provide 1-3 concrete use cases +showing when developers would use this feature. Skip for trivial changes obvious fixes..] + +## Breaking Changes (if applicable) + +[If this is a breaking change, explain what breaks and provide migration guidance.] + +### Migration + +```python +# Before +[old code] + +# After +[new code] +``` + +```` + +## Why These Guidelines? + +**Focus on WHY over HOW** because code diffs show implementation details, commit messages document granular changes, and PR descriptions provide the broader context reviewers need. + +**Skip test/lint/coverage details** because CI pipelines verify these automatically. Including them adds noise without value. + +**Write for senior engineers** to enable concise, technical communication without redundant explanations. + +## References + +- [Conventional Commits](https://www.conventionalcommits.org/) +- [Google's Code Review Guidelines](https://google.github.io/eng-practices/review/) + +## Checklist Items + + - [ ] Does the PR description target a Senior Engineer familiar with the project? + - [ ] Does the PR description give an overview of the feature being implemented, including any notes on key implementation decisions + - [ ] Does the PR include a "Resolves #" in the body and is not bolded? + - [ ] Does the PR contain the motivation or use-cases behind the change? + - [ ] Does the PR omit irrelevant details not needed for historical reference? \ No newline at end of file diff --git a/docs/README.md b/docs/README.md new file mode 100644 index 000000000..857edc4c4 --- /dev/null +++ b/docs/README.md @@ -0,0 +1,16 @@ +# Developer Documentation + +This folder contains documentation for contributors and developers working on the Strands Agents SDK. + +## Contents + +- [STYLE_GUIDE.md](./STYLE_GUIDE.md) - Code style conventions and formatting guidelines +- [HOOKS.md](./HOOKS.md) - Hooks system rules and usage guide +- [PR.md](./PR.md) - Pull request description guidelines +- [MCP_CLIENT_ARCHITECTURE.md](./MCP_CLIENT_ARCHITECTURE.md) - MCP client threading architecture and design decisions + +## Related Documentation + +- [AGENTS.md](../AGENTS.md) - Guidance for AI agents and LLMs working with this codebase +- [CONTRIBUTING.md](../CONTRIBUTING.md) - Contribution guidelines for human contributors +- [strandsagents.com](https://strandsagents.com/) - User-facing documentation diff --git a/STYLE_GUIDE.md b/docs/STYLE_GUIDE.md similarity index 65% rename from STYLE_GUIDE.md rename to docs/STYLE_GUIDE.md index 51dc0a73a..82ee51847 100644 --- a/STYLE_GUIDE.md +++ b/docs/STYLE_GUIDE.md @@ -57,3 +57,33 @@ logger.warning("Retry limit approaching! attempt=%d max_attempts=%d", attempt, m ``` By following these log formatting guidelines, we ensure that logs are both human-readable and machine-parseable, making debugging and monitoring more efficient. + +## Type Annotations + +### Avoid `Callable` for Extensible Interfaces + +Do not use `Callable` for function type annotations that may need additional parameters in the future. `Callable` signatures are fixed and cannot be expanded without breaking existing implementations. + +```python +# Bad: Cannot add parameters later without breaking all existing implementations +EdgeCondition = Callable[[GraphState], bool] + +# Good: Protocol allows adding optional keyword arguments in the future +class EdgeCondition(Protocol): + def __call__(self, state: GraphState, **kwargs: Any) -> bool: ... +``` + +Using `Protocol` with `**kwargs` allows the interface to evolve by adding new keyword arguments without breaking existing implementations that don't use them. + +### Tool Name References + +When comparing against tool names in hooks or plugins, use the tool instance's `tool_name` property instead of hardcoding strings. Tool specs can be modified at runtime via the `AgentTool.tool_spec` setter, so hardcoded names may not match the actual registered name. + +```python +# Good +if event.tool_use.get("name") == self.my_tool.tool_name: + ... + +# Bad — fragile if tool name is changed at runtime +if event.tool_use.get("name") == "my_tool": + ... \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 2c2a6b260..cdc09fe45 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ classifiers = [ "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", "Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Software Development :: Libraries :: Python Modules", ] @@ -31,9 +32,10 @@ dependencies = [ "botocore>=1.29.0,<2.0.0", "docstring_parser>=0.15,<1.0", "jsonschema>=4.0.0,<5.0.0", - "mcp>=1.11.0,<2.0.0", + "mcp>=1.23.0,<2.0.0", "pydantic>=2.4.0,<3.0.0", "typing-extensions>=4.13.2,<5.0.0", + "pyyaml>=6.0.0,<7.0.0", "watchdog>=6.0.0,<7.0.0", "opentelemetry-api>=1.30.0,<2.0.0", "opentelemetry-sdk>=1.30.0,<2.0.0", @@ -44,20 +46,20 @@ dependencies = [ [project.optional-dependencies] anthropic = ["anthropic>=0.21.0,<1.0.0"] gemini = ["google-genai>=1.32.0,<2.0.0"] -litellm = ["litellm>=1.75.9,<2.0.0", "openai>=1.68.0,<1.110.0"] +litellm = ["litellm>=1.75.9,<=1.83.13", "openai>=1.68.0,<3.0.0"] llamaapi = ["llama-api-client>=0.1.0,<1.0.0"] -mistral = ["mistralai>=1.8.2"] +mistral = ["mistralai>=1.8.2,<2.0.0"] ollama = ["ollama>=0.4.8,<1.0.0"] -openai = ["openai>=1.68.0,<2.0.0"] +openai = ["openai>=1.68.0,<3.0.0", "aws-bedrock-token-generator>=1.1.0,<2.0.0"] writer = ["writer-sdk>=2.2.0,<3.0.0"] sagemaker = [ "boto3-stubs[sagemaker-runtime]>=1.26.0,<2.0.0", - "openai>=1.68.0,<2.0.0", # SageMaker uses OpenAI-compatible interface + "openai>=1.68.0,<3.0.0", # SageMaker uses OpenAI-compatible interface ] otel = ["opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0"] docs = [ - "sphinx>=5.0.0,<9.0.0", - "sphinx-rtd-theme>=1.0.0,<2.0.0", + "sphinx>=5.0.0,<10.0.0", + "sphinx-rtd-theme>=1.0.0,<4.0.0", "sphinx-autodoc-typehints>=1.12.0,<4.0.0", ] @@ -66,33 +68,37 @@ a2a = [ "a2a-sdk[sql]>=0.3.0,<0.4.0", "uvicorn>=0.34.2,<1.0.0", "httpx>=0.28.1,<1.0.0", - "fastapi>=0.115.12,<1.0.0", - "starlette>=0.46.2,<1.0.0", + "fastapi>=0.133.0,<1.0.0", + "starlette>=1.0.0,<2.0.0", ] bidi = [ - "aws_sdk_bedrock_runtime; python_version>='3.12'", + "aws_sdk_bedrock_runtime>=0.4.0,<1.0.0; python_version>='3.12'", + "smithy-aws-core>=0.4.0,<1.0.0; python_version>='3.12'", +] +bidi-io = [ "prompt_toolkit>=3.0.0,<4.0.0", "pyaudio>=0.2.13,<1.0.0", - "smithy-aws-core>=0.0.1; python_version>='3.12'", ] bidi-gemini = ["google-genai>=1.32.0,<2.0.0"] -bidi-openai = ["websockets>=15.0.0,<16.0.0"] +bidi-openai = ["websockets>=15.0.0,<17.0.0"] all = ["strands-agents[a2a,anthropic,docs,gemini,litellm,llamaapi,mistral,ollama,openai,writer,sagemaker,otel]"] -bidi-all = ["strands-agents[a2a,bidi,bidi-gemini,bidi-openai,docs,otel]"] +bidi-all = ["strands-agents[a2a,bidi,bidi-io,bidi-gemini,bidi-openai,docs,otel]"] dev = [ "commitizen>=4.4.0,<5.0.0", "hatch>=1.0.0,<2.0.0", "moto>=5.1.0,<6.0.0", "mypy>=1.15.0,<2.0.0", - "pre-commit>=3.2.0,<4.4.0", - "pytest>=8.0.0,<9.0.0", + "pre-commit>=3.2.0,<4.7.0", + "pytest>=9.0.0,<10.0.0", "pytest-cov>=7.0.0,<8.0.0", - "pytest-asyncio>=1.0.0,<1.3.0", + "pytest-asyncio>=1.0.0,<1.4.0", + "pytest-timeout>=2.0.0,<3.0.0", "pytest-xdist>=3.0.0,<4.0.0", - "ruff>=0.13.0,<0.14.0", + "ruff>=0.13.0,<0.15.0", + "tenacity>=9.0.0,<10.0.0", ] [project.urls] @@ -114,7 +120,7 @@ installer = "uv" features = ["all"] dependencies = [ "mypy>=1.15.0,<2.0.0", - "ruff>=0.13.0,<0.14.0", + "ruff>=0.13.0,<0.15.0", # Include required package dependencies for mypy "strands-agents @ {root:uri}", ] @@ -142,15 +148,17 @@ installer = "uv" features = ["all"] extra-args = ["-n", "auto", "-vv"] dependencies = [ - "pytest>=8.0.0,<9.0.0", + "pytest>=9.0.0,<10.0.0", "pytest-cov>=7.0.0,<8.0.0", - "pytest-asyncio>=1.0.0,<1.3.0", + "pytest-asyncio>=1.0.0,<1.4.0", + "pytest-timeout>=2.0.0,<3.0.0", "pytest-xdist>=3.0.0,<4.0.0", + "pytest-timeout>=2.0.0,<3.0.0", "moto>=5.1.0,<6.0.0", ] [[tool.hatch.envs.hatch-test.matrix]] -python = ["3.13", "3.12", "3.11", "3.10"] +python = ["3.14", "3.13", "3.12", "3.11", "3.10"] [tool.hatch.envs.hatch-test.scripts] run = "pytest{env:HATCH_TEST_ARGS:} {args}" # Run with: hatch test @@ -166,7 +174,7 @@ features = ["all"] dependencies = [ "commitizen>=4.4.0,<5.0.0", "hatch>=1.0.0,<2.0.0", - "pre-commit>=3.2.0,<4.4.0", + "pre-commit>=3.2.0,<4.7.0", ] @@ -224,6 +232,7 @@ select = [ "G", # logging format "I", # isort "LOG", # logging + "UP" # pyupgrade ] [tool.ruff.lint.per-file-ignores] @@ -236,7 +245,8 @@ convention = "google" [tool.pytest.ini_options] testpaths = ["tests"] asyncio_default_fixture_loop_scope = "function" -addopts = "--ignore=tests/strands/experimental/bidi --ignore=tests_integ/bidi" +addopts = "--ignore=tests/strands/experimental/bidi --ignore=tests_integ/bidi --junit-xml=build/test-results.xml" +timeout = 90 [tool.coverage.run] @@ -293,7 +303,7 @@ prepare = [ "hatch run bidi-test:test-cov", ] -[tools.hatch.envs.bidi-lint] +[tool.hatch.envs.bidi-lint] template = "bidi" [tool.hatch.envs.bidi-lint.scripts] diff --git a/src/strands/__init__.py b/src/strands/__init__.py index 3718a29c5..00e32ead3 100644 --- a/src/strands/__init__.py +++ b/src/strands/__init__.py @@ -2,13 +2,25 @@ from . import agent, models, telemetry, types from .agent.agent import Agent +from .agent.base import AgentBase +from .event_loop._retry import ModelRetryStrategy +from .plugins import MultiAgentPlugin, Plugin from .tools.decorator import tool +from .types._snapshot import Snapshot from .types.tools import ToolContext +from .vended_plugins.skills import AgentSkills, Skill __all__ = [ "Agent", + "AgentBase", + "AgentSkills", "agent", "models", + "ModelRetryStrategy", + "MultiAgentPlugin", + "Plugin", + "Skill", + "Snapshot", "tool", "ToolContext", "types", diff --git a/src/strands/_async.py b/src/strands/_async.py index 141ca71b7..0ceb038f3 100644 --- a/src/strands/_async.py +++ b/src/strands/_async.py @@ -2,8 +2,9 @@ import asyncio import contextvars +from collections.abc import Awaitable, Callable from concurrent.futures import ThreadPoolExecutor -from typing import Awaitable, Callable, TypeVar +from typing import TypeVar T = TypeVar("T") diff --git a/src/strands/agent/__init__.py b/src/strands/agent/__init__.py index 6618d3328..c901e800f 100644 --- a/src/strands/agent/__init__.py +++ b/src/strands/agent/__init__.py @@ -4,10 +4,15 @@ - Agent: The main interface for interacting with AI models and tools - ConversationManager: Classes for managing conversation history and context windows +- Retry Strategies: Configurable retry behavior for model calls """ +from typing import Any + +from ..event_loop._retry import ModelRetryStrategy from .agent import Agent from .agent_result import AgentResult +from .base import AgentBase from .conversation_manager import ( ConversationManager, NullConversationManager, @@ -17,9 +22,20 @@ __all__ = [ "Agent", + "AgentBase", "AgentResult", "ConversationManager", "NullConversationManager", "SlidingWindowConversationManager", "SummarizingConversationManager", + "ModelRetryStrategy", ] + + +def __getattr__(name: str) -> Any: + """Lazy load A2AAgent to defer import of optional a2a dependency.""" + if name == "A2AAgent": + from .a2a_agent import A2AAgent + + return A2AAgent + raise AttributeError(f"cannot import name '{name}' from '{__name__}' ({__file__})") diff --git a/src/strands/agent/_agent_as_tool.py b/src/strands/agent/_agent_as_tool.py new file mode 100644 index 000000000..11b536789 --- /dev/null +++ b/src/strands/agent/_agent_as_tool.py @@ -0,0 +1,296 @@ +"""Agent-as-tool adapter. + +This module provides the _AgentAsTool class that wraps an Agent as a tool +so it can be passed to another agent's tool list. +""" + +from __future__ import annotations + +import copy +import logging +import threading +from typing import TYPE_CHECKING, Any + +from typing_extensions import override + +from ..agent.state import AgentState +from ..types._events import AgentAsToolStreamEvent, ToolInterruptEvent, ToolResultEvent +from ..types.content import Messages +from ..types.interrupt import InterruptResponseContent +from ..types.tools import AgentTool, ToolGenerator, ToolSpec, ToolUse + +if TYPE_CHECKING: + from .agent import Agent + +logger = logging.getLogger(__name__) + + +class _AgentAsTool(AgentTool): + """Adapter that exposes an Agent as a tool for use by other agents. + + The tool accepts a single ``input`` string parameter, invokes the wrapped + agent, and returns the text response. + + Example: + ```python + from strands import Agent + + researcher = Agent(name="researcher", description="Finds information") + + # Use via convenience method (default: fresh conversation each call) + tool = researcher.as_tool() + + # Preserve context across invocations + tool = researcher.as_tool(preserve_context=True) + + writer = Agent(name="writer", tools=[tool]) + writer("Write about AI agents") + ``` + """ + + def __init__( + self, + agent: Agent, + *, + name: str, + description: str | None = None, + preserve_context: bool = False, + ) -> None: + r"""Initialize the agent-as-tool adapter. + + Args: + agent: The agent to wrap as a tool. + name: Tool name. Must match the pattern ``[a-zA-Z0-9_\\-]{1,64}``. + description: Tool description. Defaults to the agent's description, or a + generic description if the agent has no description set. + preserve_context: Whether to preserve the agent's conversation history across + invocations. When False, the agent's messages and state are reset to the + values they had at construction time before each call, ensuring every + invocation starts from the same baseline regardless of any external + interactions with the agent. Defaults to False. + """ + super().__init__() + self._agent = agent + self._tool_name = name + self._description = ( + description or agent.description or f"Use the {name} agent as a tool by providing a natural language input" + ) + self._preserve_context = preserve_context + + # When preserve_context=False, we snapshot the agent's initial state so we can + # restore it before each invocation. This mirrors GraphNode.reset_executor_state(). + self._initial_messages: Messages = [] + self._initial_state: AgentState = AgentState() + # Serialize access so _reset_agent_state + stream_async are atomic. + # threading.Lock (not asyncio.Lock) because run_async() may create + # separate event loops in different threads. + self._lock = threading.Lock() + + if not preserve_context: + if getattr(agent, "_session_manager", None) is not None: + raise ValueError( + "preserve_context=False cannot be used with an agent that has a session manager. " + "The session manager persists conversation history externally, which conflicts with " + "resetting the agent's state between invocations." + ) + self._initial_messages = copy.deepcopy(agent.messages) + self._initial_state = AgentState(agent.state.get()) + + @property + def agent(self) -> Agent: + """The wrapped agent instance.""" + return self._agent + + @property + def tool_name(self) -> str: + """Get the tool name.""" + return self._tool_name + + @property + def tool_spec(self) -> ToolSpec: + """Get the tool specification.""" + return { + "name": self._tool_name, + "description": self._description, + "inputSchema": { + "json": { + "type": "object", + "properties": { + "input": { + "type": "string", + "description": "The input to send to the agent tool.", + }, + }, + "required": ["input"], + } + }, + } + + @property + def tool_type(self) -> str: + """Get the tool type.""" + return "agent" + + @override + async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kwargs: Any) -> ToolGenerator: + """Invoke the wrapped agent via streaming and yield events. + + Intermediate agent events are wrapped in AgentAsToolStreamEvent so the caller + can distinguish sub-agent progress from regular tool events. The final + AgentResult is yielded as a ToolResultEvent. + + When the sub-agent encounters a hook interrupt (e.g. from BeforeToolCallEvent), + the interrupts are propagated to the parent agent via ToolInterruptEvent. On + resume, interrupt responses are forwarded to the sub-agent automatically. + + Args: + tool_use: The tool use request containing the input parameter. + invocation_state: Context for the tool invocation. + **kwargs: Additional keyword arguments. + + Yields: + AgentAsToolStreamEvent for intermediate events, ToolInterruptEvent if the + sub-agent is interrupted, or ToolResultEvent with the final response. + """ + tool_input = tool_use["input"] + if isinstance(tool_input, dict): + prompt = tool_input.get("input", "") + elif isinstance(tool_input, str): + prompt = tool_input + else: + logger.warning("tool_name=<%s> | unexpected input type: %s", self._tool_name, type(tool_input)) + prompt = str(tool_input) + + tool_use_id = tool_use["toolUseId"] + + # Serialize access to the underlying agent. _reset_agent_state() mutates + # the agent before stream_async acquires its own lock, so a concurrent + # call would corrupt an in-flight invocation. + if not self._lock.acquire(blocking=False): + logger.warning( + "tool_name=<%s>, tool_use_id=<%s> | agent is already processing a request", + self._tool_name, + tool_use_id, + ) + yield ToolResultEvent( + { + "toolUseId": tool_use_id, + "status": "error", + "content": [{"text": f"Agent '{self._tool_name}' is already processing a request"}], + } + ) + return + + try: + # Determine if we are resuming the sub-agent from an interrupt. + if self._is_sub_agent_interrupted(): + prompt = self._build_interrupt_responses() + logger.debug( + "tool_name=<%s>, tool_use_id=<%s> | resuming sub-agent from interrupt", + self._tool_name, + tool_use_id, + ) + elif not self._preserve_context: + self._reset_agent_state(tool_use_id) + + logger.debug("tool_name=<%s>, tool_use_id=<%s> | invoking agent", self._tool_name, tool_use_id) + + result = None + async for event in self._agent.stream_async(prompt): + if "result" in event: + result = event["result"] + else: + yield AgentAsToolStreamEvent(tool_use, event, self) + + if result is None: + yield ToolResultEvent( + { + "toolUseId": tool_use_id, + "status": "error", + "content": [{"text": "Agent did not produce a result"}], + } + ) + return + + # Propagate sub-agent interrupts to the parent agent. + if result.stop_reason == "interrupt" and result.interrupts: + yield ToolInterruptEvent(tool_use, list(result.interrupts)) + return + + if result.structured_output: + yield ToolResultEvent( + { + "toolUseId": tool_use_id, + "status": "success", + "content": [{"json": result.structured_output.model_dump()}], + } + ) + else: + yield ToolResultEvent( + { + "toolUseId": tool_use_id, + "status": "success", + "content": [{"text": str(result)}], + } + ) + + except Exception as e: + logger.warning( + "tool_name=<%s>, tool_use_id=<%s> | agent invocation failed: %s", + self._tool_name, + tool_use_id, + e, + ) + yield ToolResultEvent( + { + "toolUseId": tool_use_id, + "status": "error", + "content": [{"text": f"Agent error: {e}"}], + } + ) + finally: + self._lock.release() + + def _reset_agent_state(self, tool_use_id: str) -> None: + """Reset the wrapped agent to its initial state. + + Restores messages and state to the values captured at construction time. + This mirrors the pattern used by ``GraphNode.reset_executor_state()``. + + Args: + tool_use_id: Tool use ID for logging context. + """ + logger.debug( + "tool_name=<%s>, tool_use_id=<%s> | resetting agent to initial state", + self._tool_name, + tool_use_id, + ) + self._agent.messages = copy.deepcopy(self._initial_messages) + self._agent.state = AgentState(self._initial_state.get()) + + def _is_sub_agent_interrupted(self) -> bool: + """Check whether the wrapped agent is in an activated interrupt state.""" + return self._agent._interrupt_state.activated + + def _build_interrupt_responses(self) -> list[InterruptResponseContent]: + """Build interrupt response payloads from the sub-agent's interrupt state. + + The parent agent's ``_interrupt_state.resume()`` sets ``.response`` on the shared + ``Interrupt`` objects (registered by the executor), so we re-package them in the + format expected by ``Agent.stream_async``. + + Returns: + List of interrupt response content blocks for resuming the sub-agent. + """ + return [ + {"interruptResponse": {"interruptId": interrupt.id, "response": interrupt.response}} + for interrupt in self._agent._interrupt_state.interrupts.values() + if interrupt.response is not None + ] + + @override + def get_display_properties(self) -> dict[str, str]: + """Get properties for UI display.""" + properties = super().get_display_properties() + properties["Agent"] = getattr(self._agent, "name", "unknown") + return properties diff --git a/src/strands/agent/a2a_agent.py b/src/strands/agent/a2a_agent.py new file mode 100644 index 000000000..eeb96f7a2 --- /dev/null +++ b/src/strands/agent/a2a_agent.py @@ -0,0 +1,312 @@ +"""A2A Agent client for Strands Agents. + +This module provides the A2AAgent class, which acts as a client wrapper for remote A2A agents, +allowing them to be used standalone or as part of multi-agent patterns. + +A2AAgent can be used to get the Agent Card and interact with the agent. +""" + +import dataclasses +import logging +import warnings +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from typing import Any + +import httpx +from a2a.client import A2ACardResolver, ClientConfig, ClientFactory +from a2a.types import AgentCard, Message, TaskArtifactUpdateEvent, TaskStatusUpdateEvent + +from .._async import run_async +from ..multiagent.a2a._converters import ( + _STATE_TO_STOP_REASON, + convert_input_to_message, + convert_response_to_agent_result, +) +from ..types._events import AgentResultEvent +from ..types.a2a import A2AResponse, A2AStreamEvent +from ..types.agent import AgentInput +from .agent_result import AgentResult +from .base import AgentBase + +logger = logging.getLogger(__name__) + +_DEFAULT_TIMEOUT = 300 + +# A2A task states that indicate the response stream is complete. +# Derived from the canonical _STATE_TO_STOP_REASON mapping in _converters. +# Terminal states (end_turn) mean no more events; input states (interrupt) mean execution is paused. +_TERMINAL_STATES = {state for state, reason in _STATE_TO_STOP_REASON.items() if reason == "end_turn"} +_INPUT_STATES = {state for state, reason in _STATE_TO_STOP_REASON.items() if reason == "interrupt"} +_COMPLETE_STATES = _TERMINAL_STATES | _INPUT_STATES + + +class A2AAgent(AgentBase): + """Client wrapper for remote A2A agents.""" + + def __init__( + self, + endpoint: str, + *, + name: str | None = None, + description: str | None = None, + timeout: int = _DEFAULT_TIMEOUT, + client_config: ClientConfig | None = None, + a2a_client_factory: ClientFactory | None = None, + ): + """Initialize A2A agent. + + Args: + endpoint: The base URL of the remote A2A agent. + name: Agent name. If not provided, will be populated from agent card. + description: Agent description. If not provided, will be populated from agent card. + timeout: Timeout for HTTP operations in seconds (defaults to 300). + client_config: A2A ``ClientConfig`` for authentication and transport settings. + The ``httpx_client`` configured here is used for both card discovery and + message sending, enabling authenticated endpoints (SigV4, OAuth, bearer tokens). + When providing an ``httpx_client``, you are responsible for configuring its timeout. + a2a_client_factory: Deprecated. Use ``client_config`` instead. + + Raises: + ValueError: If both ``client_config`` and ``a2a_client_factory`` are provided. + """ + if client_config is not None and a2a_client_factory is not None: + raise ValueError( + "Cannot provide both client_config and a2a_client_factory. " + "Use client_config (recommended) or a2a_client_factory (deprecated), not both." + ) + + if a2a_client_factory is not None: + warnings.warn( + "a2a_client_factory is deprecated. Use client_config instead. " + "a2a_client_factory will be removed in a future version.", + DeprecationWarning, + stacklevel=2, + ) + + self.endpoint = endpoint + self.name = name + self.description = description + self.timeout = timeout + self._client_config: ClientConfig | None = client_config + self._agent_card: AgentCard | None = None + self._a2a_client_factory: ClientFactory | None = a2a_client_factory + + def __call__( + self, + prompt: AgentInput = None, + **kwargs: Any, + ) -> AgentResult: + """Synchronously invoke the remote A2A agent. + + Args: + prompt: Input to the agent (string, message list, or content blocks). + **kwargs: Additional arguments (ignored). + + Returns: + AgentResult containing the agent's response. + + Raises: + ValueError: If prompt is None. + RuntimeError: If no response received from agent. + """ + return run_async(lambda: self.invoke_async(prompt, **kwargs)) + + async def invoke_async( + self, + prompt: AgentInput = None, + **kwargs: Any, + ) -> AgentResult: + """Asynchronously invoke the remote A2A agent. + + Args: + prompt: Input to the agent (string, message list, or content blocks). + **kwargs: Additional arguments (ignored). + + Returns: + AgentResult containing the agent's response. + + Raises: + ValueError: If prompt is None. + RuntimeError: If no response received from agent. + """ + result: AgentResult | None = None + async for event in self.stream_async(prompt, **kwargs): + if "result" in event: + result = event["result"] + + if result is None: + raise RuntimeError("No response received from A2A agent") + + return result + + async def stream_async( + self, + prompt: AgentInput = None, + **kwargs: Any, + ) -> AsyncIterator[Any]: + """Stream remote agent execution asynchronously. + + This method provides an asynchronous interface for streaming A2A protocol events. + Unlike Agent.stream_async() which yields text deltas and tool events, this method + yields raw A2A protocol events wrapped in A2AStreamEvent dictionaries. + + Args: + prompt: Input to the agent (string, message list, or content blocks). + **kwargs: Additional arguments (ignored). + + Yields: + An async iterator that yields events. Each event is a dictionary: + - A2AStreamEvent: {"type": "a2a_stream", "event": } + where the A2A object can be a Message, or a tuple of + (Task, TaskStatusUpdateEvent) or (Task, TaskArtifactUpdateEvent). + - AgentResultEvent: {"result": AgentResult} - always emitted last. + + Raises: + ValueError: If prompt is None. + + Example: + ```python + async for event in a2a_agent.stream_async("Hello"): + if event.get("type") == "a2a_stream": + print(f"A2A event: {event['event']}") + elif "result" in event: + print(f"Final result: {event['result'].message}") + ``` + """ + last_event = None + last_complete_event = None + + async for event in self._send_message(prompt): + last_event = event + if self._is_complete_event(event): + last_complete_event = event + yield A2AStreamEvent(event) + + # Use the last complete event if available, otherwise fall back to last event + final_event = last_complete_event or last_event + + if final_event is not None: + result = convert_response_to_agent_result(final_event) + yield AgentResultEvent(result) + + async def get_agent_card(self) -> AgentCard: + """Fetch and return the remote agent's card. + + Eagerly fetches the agent card from the remote endpoint, populating name and description + if not already set. The card is cached after the first fetch. + + When ``client_config`` is provided with an ``httpx_client``, that client is used for + card resolution, enabling authenticated card discovery (e.g., SigV4, OAuth, bearer tokens). + + Returns: + The remote agent's AgentCard containing name, description, capabilities, skills, etc. + """ + if self._agent_card is not None: + return self._agent_card + + if self._client_config is not None and self._client_config.httpx_client is not None: + resolver = A2ACardResolver(httpx_client=self._client_config.httpx_client, base_url=self.endpoint) + self._agent_card = await resolver.get_agent_card() + else: + async with httpx.AsyncClient(timeout=self.timeout) as client: + resolver = A2ACardResolver(httpx_client=client, base_url=self.endpoint) + self._agent_card = await resolver.get_agent_card() + + # Populate name from card if not set + if self.name is None and self._agent_card.name is not None: + self.name = self._agent_card.name + + # Populate description from card if not set + if self.description is None and self._agent_card.description is not None: + self.description = self._agent_card.description + + logger.debug("agent=<%s>, endpoint=<%s> | discovered agent card", self.name, self.endpoint) + return self._agent_card + + @asynccontextmanager + async def _get_a2a_client(self) -> AsyncIterator[Any]: + """Get A2A client for sending messages. + + If a deprecated factory was provided, delegates to it for client creation. + If client_config was provided, uses it directly — ClientFactory handles defaults. + Otherwise creates a managed httpx client with the agent's timeout. + + Yields: + Configured A2A client instance. + """ + agent_card = await self.get_agent_card() + + if self._a2a_client_factory is not None: + yield self._a2a_client_factory.create(agent_card) + return + + if self._client_config is not None: + config = dataclasses.replace(self._client_config, streaming=True) + yield ClientFactory(config).create(agent_card) + return + + # No client_config — create a managed httpx client, consistent with get_agent_card() path + async with httpx.AsyncClient(timeout=self.timeout) as httpx_client: + config = ClientConfig(httpx_client=httpx_client, streaming=True) + yield ClientFactory(config).create(agent_card) + + async def _send_message(self, prompt: AgentInput) -> AsyncIterator[A2AResponse]: + """Send message to A2A agent. + + Args: + prompt: Input to send to the agent. + + Yields: + A2A response events. + + Raises: + ValueError: If prompt is None. + """ + if prompt is None: + raise ValueError("prompt is required for A2AAgent") + + message = convert_input_to_message(prompt) + logger.debug("agent=<%s>, endpoint=<%s> | sending message", self.name, self.endpoint) + + async with self._get_a2a_client() as client: + async for event in client.send_message(message): + yield event + + def _is_complete_event(self, event: A2AResponse) -> bool: + """Check if an A2A event represents a complete response. + + Recognizes all terminal states (completed, failed, canceled, rejected) + and pausing states (input_required, auth_required) as complete events. + + Args: + event: A2A event. + + Returns: + True if the event represents a complete response. + """ + # Direct Message is always complete + if isinstance(event, Message): + return True + + # Handle tuple responses (Task, UpdateEvent | None) + if isinstance(event, tuple) and len(event) == 2: + task, update_event = event + + # Initial task response (no update event) + if update_event is None: + return True + + # Artifact update with last_chunk flag + if isinstance(update_event, TaskArtifactUpdateEvent): + if hasattr(update_event, "last_chunk") and update_event.last_chunk is not None: + return update_event.last_chunk + return False + + # Status update - check for terminal or pausing states + if isinstance(update_event, TaskStatusUpdateEvent): + if update_event.status and hasattr(update_event.status, "state"): + state = update_event.status.state + return state in _COMPLETE_STATES + + return False diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index d6b08eff0..965969961 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -9,17 +9,14 @@ 2. Method-style for direct tool access: `agent.tool.tool_name(param1="value")` """ +import copy import logging +import threading import warnings +from collections.abc import AsyncGenerator, AsyncIterator, Callable, Mapping from typing import ( TYPE_CHECKING, Any, - AsyncGenerator, - AsyncIterator, - Callable, - Mapping, - Optional, - Type, TypeVar, Union, cast, @@ -30,23 +27,35 @@ from .. import _identifier from .._async import run_async -from ..event_loop.event_loop import event_loop_cycle +from ..event_loop._retry import ModelRetryStrategy +from ..event_loop.event_loop import INITIAL_DELAY, MAX_ATTEMPTS, MAX_DELAY, event_loop_cycle from ..tools._tool_helpers import generate_missing_tool_result_content +from ..types._snapshot import ( + SNAPSHOT_SCHEMA_VERSION, + Snapshot, + SnapshotField, + SnapshotPreset, + resolve_snapshot_fields, +) if TYPE_CHECKING: - from ..experimental.tools import ToolProvider + from ..tools import ToolProvider from ..handlers.callback_handler import PrintingCallbackHandler, null_callback_handler from ..hooks import ( AfterInvocationEvent, AgentInitializedEvent, BeforeInvocationEvent, + HookCallback, HookProvider, HookRegistry, MessageAddedEvent, ) +from ..hooks.registry import TEvent from ..interrupt import _InterruptState from ..models.bedrock import BedrockModel -from ..models.model import Model +from ..models.model import Model, _ModelPlugin +from ..plugins import Plugin +from ..plugins.registry import _PluginRegistry from ..session.session_manager import SessionManager from ..telemetry.metrics import EventLoopMetrics from ..telemetry.tracer import get_tracer, serialize @@ -56,14 +65,18 @@ from ..tools.registry import ToolRegistry from ..tools.structured_output._structured_output_context import StructuredOutputContext from ..tools.watcher import ToolWatcher -from ..types._events import AgentResultEvent, InitEventLoopEvent, ModelStreamChunkEvent, TypedEvent -from ..types.agent import AgentInput +from ..types._events import AgentResultEvent, EventLoopStopEvent, InitEventLoopEvent, ModelStreamChunkEvent, TypedEvent +from ..types.agent import AgentInput, ConcurrentInvocationMode from ..types.content import ContentBlock, Message, Messages, SystemContentBlock -from ..types.exceptions import ContextWindowOverflowException +from ..types.exceptions import ConcurrencyException, ContextWindowOverflowException +from ..types.tools import AgentTool from ..types.traces import AttributeValue +from ._agent_as_tool import _AgentAsTool from .agent_result import AgentResult +from .base import AgentBase from .conversation_manager import ( ConversationManager, + NullConversationManager, SlidingWindowConversationManager, ) from .state import AgentState @@ -81,13 +94,20 @@ class _DefaultCallbackHandlerSentinel: pass +class _DefaultRetryStrategySentinel: + """Sentinel class to distinguish between explicit None and default parameter value for retry_strategy.""" + + pass + + _DEFAULT_CALLBACK_HANDLER = _DefaultCallbackHandlerSentinel() +_DEFAULT_RETRY_STRATEGY = _DefaultRetryStrategySentinel() _DEFAULT_AGENT_NAME = "Strands Agents" _DEFAULT_AGENT_ID = "default" -class Agent: - """Core Agent interface. +class Agent(AgentBase): + """Core Agent implementation. An agent orchestrates the following workflow: @@ -104,26 +124,28 @@ class Agent: def __init__( self, - model: Union[Model, str, None] = None, - messages: Optional[Messages] = None, - tools: Optional[list[Union[str, dict[str, str], "ToolProvider", Any]]] = None, - system_prompt: Optional[str | list[SystemContentBlock]] = None, - structured_output_model: Optional[Type[BaseModel]] = None, - callback_handler: Optional[ - Union[Callable[..., Any], _DefaultCallbackHandlerSentinel] - ] = _DEFAULT_CALLBACK_HANDLER, - conversation_manager: Optional[ConversationManager] = None, + model: Model | str | None = None, + messages: Messages | None = None, + tools: list[Union[str, dict[str, str], "ToolProvider", Any]] | None = None, + system_prompt: str | list[SystemContentBlock] | None = None, + structured_output_model: type[BaseModel] | None = None, + callback_handler: Callable[..., Any] | _DefaultCallbackHandlerSentinel | None = _DEFAULT_CALLBACK_HANDLER, + conversation_manager: ConversationManager | None = None, record_direct_tool_call: bool = True, load_tools_from_directory: bool = False, - trace_attributes: Optional[Mapping[str, AttributeValue]] = None, + trace_attributes: Mapping[str, AttributeValue] | None = None, *, - agent_id: Optional[str] = None, - name: Optional[str] = None, - description: Optional[str] = None, - state: Optional[Union[AgentState, dict]] = None, - hooks: Optional[list[HookProvider]] = None, - session_manager: Optional[SessionManager] = None, - tool_executor: Optional[ToolExecutor] = None, + agent_id: str | None = None, + name: str | None = None, + description: str | None = None, + state: AgentState | dict | None = None, + plugins: list[Plugin] | None = None, + hooks: list[HookProvider | HookCallback] | None = None, + session_manager: SessionManager | None = None, + structured_output_prompt: str | None = None, + tool_executor: ToolExecutor | None = None, + retry_strategy: ModelRetryStrategy | _DefaultRetryStrategySentinel | None = _DEFAULT_RETRY_STRATEGY, + concurrent_invocation_mode: ConcurrentInvocationMode = ConcurrentInvocationMode.THROW, ): """Initialize the Agent with the specified configuration. @@ -140,7 +162,8 @@ def __init__( - Imported Python modules (e.g., from strands_tools import current_time) - Dictionaries with name/path keys (e.g., {"name": "tool_name", "path": "/path/to/tool.py"}) - ToolProvider instances for managed tool collections - - Functions decorated with `@strands.tool` decorator. + - Functions decorated with `@strands.tool` decorator + - Agent instances (auto-wrapped via `agent.as_tool()` with defaults) If provided, only these tools will be available. If None, all tools will be available. system_prompt: System prompt to guide model behavior. @@ -168,11 +191,29 @@ def __init__( Defaults to None. state: stateful information for the agent. Can be either an AgentState object, or a json serializable dict. Defaults to an empty AgentState object. - hooks: hooks to be added to the agent hook registry + plugins: List of Plugin instances to extend agent functionality. + Plugins are initialized with the agent instance after construction and can register hooks, + modify agent attributes, or perform other setup tasks. + Defaults to None. + hooks: Hooks to be added to the agent hook registry. Accepts HookProvider instances + or plain callable hook callbacks (functions with typed event parameters). Defaults to None. session_manager: Manager for handling agent sessions including conversation history and state. If provided, enables session-based persistence and state management. + structured_output_prompt: Custom prompt message used when forcing structured output. + When using structured output, if the model doesn't automatically use the output tool, + the agent sends a follow-up message to request structured formatting. This parameter + allows customizing that message. + Defaults to "You must format the previous response as structured output." tool_executor: Definition of tool execution strategy (e.g., sequential, concurrent, etc.). + retry_strategy: Strategy for retrying model calls on throttling or other transient errors. + Defaults to ModelRetryStrategy with max_attempts=6, initial_delay=4s, max_delay=240s. + Implement a custom HookProvider for custom retry logic, or pass None to disable retries. + concurrent_invocation_mode: Mode controlling concurrent invocation behavior. + Defaults to "throw" which raises ConcurrencyException if concurrent invocation is attempted. + Set to "unsafe_reentrant" to skip lock acquisition entirely, allowing concurrent invocations. + Warning: "unsafe_reentrant" makes no guarantees about resulting behavior and is provided + only for advanced use cases where the caller understands the risks. Raises: ValueError: If agent id contains path separators. @@ -182,6 +223,7 @@ def __init__( # initializing self._system_prompt for backwards compatibility self._system_prompt, self._system_prompt_content = self._initialize_system_prompt(system_prompt) self._default_structured_output_model = structured_output_model + self._structured_output_prompt = structured_output_prompt self.agent_id = _identifier.validate(agent_id or _DEFAULT_AGENT_ID, _identifier.Identifier.AGENT) self.name = name or _DEFAULT_AGENT_NAME self.description = description @@ -189,7 +231,7 @@ def __init__( # If not provided, create a new PrintingCallbackHandler instance # If explicitly set to None, use null_callback_handler # Otherwise use the passed callback_handler - self.callback_handler: Union[Callable[..., Any], PrintingCallbackHandler] + self.callback_handler: Callable[..., Any] | PrintingCallbackHandler if isinstance(callback_handler, _DefaultCallbackHandlerSentinel): self.callback_handler = PrintingCallbackHandler() elif callback_handler is None: @@ -197,7 +239,19 @@ def __init__( else: self.callback_handler = callback_handler - self.conversation_manager = conversation_manager if conversation_manager else SlidingWindowConversationManager() + if self.model.stateful and conversation_manager is not None: + raise ValueError( + "conversation_manager cannot be used with a stateful model. " + "The model manages conversation state server-side." + ) + + self.conversation_manager: ConversationManager + if self.model.stateful: + self.conversation_manager = NullConversationManager() + elif conversation_manager: + self.conversation_manager = conversation_manager + else: + self.conversation_manager = SlidingWindowConversationManager() # Process trace attributes to ensure they're of compatible types self.trace_attributes: dict[str, AttributeValue] = {} @@ -211,6 +265,9 @@ def __init__( self.record_direct_tool_call = record_direct_tool_call self.load_tools_from_directory = load_tools_from_directory + # Create internal cancel signal for graceful cancellation using threading.Event + self._cancel_signal = threading.Event() + self.tool_registry = ToolRegistry() # Process tool list if provided @@ -226,7 +283,7 @@ def __init__( # Initialize tracer instance (no-op if not configured) self.tracer = get_tracer() - self.trace_span: Optional[trace_api.Span] = None + self.trace_span: trace_api.Span | None = None # Initialize agent state management if state is not None: @@ -243,20 +300,105 @@ def __init__( self.hooks = HookRegistry() + self._plugin_registry = _PluginRegistry(self) + self._interrupt_state = _InterruptState() + # Runtime state for model providers (e.g., server-side response ids) + self._model_state: dict[str, Any] = {} + + # Initialize lock for guarding concurrent invocations + # Using threading.Lock instead of asyncio.Lock because run_async() creates + # separate event loops in different threads, so asyncio.Lock wouldn't work + self._invocation_lock = threading.Lock() + self._concurrent_invocation_mode = concurrent_invocation_mode + + # In the future, we'll have a RetryStrategy base class but until + # that API is determined we only allow ModelRetryStrategy + if ( + retry_strategy is not None + and not isinstance(retry_strategy, _DefaultRetryStrategySentinel) + and type(retry_strategy) is not ModelRetryStrategy + ): + raise ValueError("retry_strategy must be an instance of ModelRetryStrategy") + + # If not provided (using the default), create a new ModelRetryStrategy instance + # If explicitly set to None, disable retries (max_attempts=1 means no retries) + # Otherwise use the passed retry_strategy + if isinstance(retry_strategy, _DefaultRetryStrategySentinel): + self._retry_strategy = ModelRetryStrategy( + max_attempts=MAX_ATTEMPTS, max_delay=MAX_DELAY, initial_delay=INITIAL_DELAY + ) + elif retry_strategy is None: + # If no retry strategy is passed in, then we turn retries off + self._retry_strategy = ModelRetryStrategy(max_attempts=1) + else: + self._retry_strategy = retry_strategy + # Initialize session management functionality self._session_manager = session_manager if self._session_manager: self.hooks.add_hook(self._session_manager) + # Allow conversation_managers to subscribe to hooks + self.hooks.add_hook(self.conversation_manager) + + # Register retry strategy as a hook + self.hooks.add_hook(self._retry_strategy) + self.tool_executor = tool_executor or ConcurrentToolExecutor() if hooks: for hook in hooks: - self.hooks.add_hook(hook) + if isinstance(hook, HookProvider): + self.hooks.add_hook(hook) + elif callable(hook): + self.hooks.add_callback(None, hook) + else: + raise ValueError( + f"Invalid hook: {hook!r}. Must be a HookProvider instance or a callable hook callback." + ) + + # Register built-in plugins + self._plugin_registry.add_and_init(_ModelPlugin()) + + if plugins: + for plugin in plugins: + self._plugin_registry.add_and_init(plugin) + self.hooks.invoke_callbacks(AgentInitializedEvent(agent=self)) + def cancel(self) -> None: + """Cancel the currently running agent invocation. + + This method is thread-safe and can be called from any context + (e.g., another thread, web request handler, background task). + + The agent will stop gracefully at the next checkpoint: + - During model response streaming + - Before tool execution + + The agent will return a result with stop_reason="cancelled". + + Example: + ```python + agent = Agent(model=model) + + # Start agent in background + task = asyncio.create_task(agent.invoke_async("Hello")) + + # Cancel from another context + agent.cancel() + + result = await task + assert result.stop_reason == "cancelled" + ``` + + Note: + Multiple calls to cancel() are safe and idempotent. + """ + self._cancel_signal.set() + @property def system_prompt(self) -> str | None: """Get the system prompt as a string for backwards compatibility. @@ -286,6 +428,18 @@ def system_prompt(self, value: str | list[SystemContentBlock] | None) -> None: """ self._system_prompt, self._system_prompt_content = self._initialize_system_prompt(value) + @property + def system_prompt_content(self) -> list[SystemContentBlock] | None: + """Get the system prompt as a list of content blocks. + + Returns the structured content block representation, preserving cache points + and other non-text blocks. Returns None if no system prompt is set. + + Returns: + The system prompt as a list of content blocks, or None if no system prompt is set. + """ + return list(self._system_prompt_content) if self._system_prompt_content is not None else None + @property def tool(self) -> _ToolCaller: """Call tool as a function. @@ -316,7 +470,8 @@ def __call__( prompt: AgentInput = None, *, invocation_state: dict[str, Any] | None = None, - structured_output_model: Type[BaseModel] | None = None, + structured_output_model: type[BaseModel] | None = None, + structured_output_prompt: str | None = None, **kwargs: Any, ) -> AgentResult: """Process a natural language prompt through the agent's event loop. @@ -335,6 +490,7 @@ def __call__( - None: Use existing conversation history invocation_state: Additional parameters to pass through the event loop. structured_output_model: Pydantic model type(s) for structured output (overrides agent default). + structured_output_prompt: Custom prompt for forcing structured output (overrides agent default). **kwargs: Additional parameters to pass through the event loop.[Deprecating] Returns: @@ -348,7 +504,11 @@ def __call__( """ return run_async( lambda: self.invoke_async( - prompt, invocation_state=invocation_state, structured_output_model=structured_output_model, **kwargs + prompt, + invocation_state=invocation_state, + structured_output_model=structured_output_model, + structured_output_prompt=structured_output_prompt, + **kwargs, ) ) @@ -357,7 +517,8 @@ async def invoke_async( prompt: AgentInput = None, *, invocation_state: dict[str, Any] | None = None, - structured_output_model: Type[BaseModel] | None = None, + structured_output_model: type[BaseModel] | None = None, + structured_output_prompt: str | None = None, **kwargs: Any, ) -> AgentResult: """Process a natural language prompt through the agent's event loop. @@ -376,6 +537,7 @@ async def invoke_async( - None: Use existing conversation history invocation_state: Additional parameters to pass through the event loop. structured_output_model: Pydantic model type(s) for structured output (overrides agent default). + structured_output_prompt: Custom prompt for forcing structured output (overrides agent default). **kwargs: Additional parameters to pass through the event loop.[Deprecating] Returns: @@ -387,14 +549,18 @@ async def invoke_async( - state: The final state of the event loop """ events = self.stream_async( - prompt, invocation_state=invocation_state, structured_output_model=structured_output_model, **kwargs + prompt, + invocation_state=invocation_state, + structured_output_model=structured_output_model, + structured_output_prompt=structured_output_prompt, + **kwargs, ) async for event in events: _ = event return cast(AgentResult, event["result"]) - def structured_output(self, output_model: Type[T], prompt: AgentInput = None) -> T: + def structured_output(self, output_model: type[T], prompt: AgentInput = None) -> T: """This method allows you to get structured output from the agent. If you pass in a prompt, it will be used temporarily without adding it to the conversation history. @@ -425,7 +591,7 @@ def structured_output(self, output_model: Type[T], prompt: AgentInput = None) -> return run_async(lambda: self.structured_output_async(output_model, prompt)) - async def structured_output_async(self, output_model: Type[T], prompt: AgentInput = None) -> T: + async def structured_output_async(self, output_model: type[T], prompt: AgentInput = None) -> T: """This method allows you to get structured output from the agent. If you pass in a prompt, it will be used temporarily without adding it to the conversation history. @@ -453,7 +619,7 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu category=DeprecationWarning, stacklevel=2, ) - await self.hooks.invoke_callbacks_async(BeforeInvocationEvent(agent=self)) + await self.hooks.invoke_callbacks_async(BeforeInvocationEvent(agent=self, invocation_state={})) with self.tracer.tracer.start_as_current_span( "execute_structured_output", kind=trace_api.SpanKind.CLIENT ) as structured_output_span: @@ -494,7 +660,41 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu return event["output"] finally: - await self.hooks.invoke_callbacks_async(AfterInvocationEvent(agent=self)) + await self.hooks.invoke_callbacks_async(AfterInvocationEvent(agent=self, invocation_state={})) + + def as_tool( + self, + *, + name: str | None = None, + description: str | None = None, + preserve_context: bool = False, + ) -> AgentTool: + r"""Convert this agent into a tool for use by another agent. + + Args: + name: Tool name. Must match the pattern ``[a-zA-Z0-9_\\-]{1,64}``. + Defaults to the agent's name. + description: Tool description. Defaults to the agent's description, or a + generic description if the agent has no description set. + preserve_context: Whether to preserve the agent's conversation history across + invocations. When False, the agent's messages and state are reset to the + values they had at construction time before each call, ensuring every + invocation starts from the same baseline regardless of any external + interactions with the agent. Defaults to False. + + Returns: + A tool wrapping this agent. + + Example: + ```python + researcher = Agent(name="researcher", description="Finds information") + writer = Agent(name="writer", tools=[researcher.as_tool()]) + writer("Write about AI agents") + ``` + """ + if not name: + name = self.name + return _AgentAsTool(self, name=name, description=description, preserve_context=preserve_context) def cleanup(self) -> None: """Clean up resources used by the agent. @@ -508,6 +708,60 @@ def cleanup(self) -> None: """ self.tool_registry.cleanup() + def add_hook( + self, callback: HookCallback[TEvent], event_type: type[TEvent] | list[type[TEvent]] | None = None + ) -> None: + """Register a callback function for a specific event type. + + This method supports multiple call patterns: + 1. ``add_hook(callback)`` - Event type inferred from callback's type hint + 2. ``add_hook(callback, event_type)`` - Event type specified explicitly + 3. ``add_hook(callback, [TypeA, TypeB])`` - Register for multiple event types + + When the callback's type hint is a union type (``A | B`` or ``Union[A, B]``), + the callback is automatically registered for each event type in the union. + + Callbacks can be either synchronous or asynchronous functions. + + Args: + callback: The callback function to invoke when events of this type occur. + event_type: The class type(s) of events this callback should handle. + Can be a single type, a list of types, or None to infer from + the callback's first parameter type hint. If a list is provided, + the callback is registered for each type in the list. + + Raises: + ValueError: If event_type is not provided and cannot be inferred from + the callback's type hints, or if the event_type list is empty. + + Example: + ```python + def log_model_call(event: BeforeModelCallEvent) -> None: + print(f"Calling model for agent: {event.agent.name}") + + agent = Agent() + + # With event type inferred from type hint + agent.add_hook(log_model_call) + + # With explicit event type + agent.add_hook(log_model_call, BeforeModelCallEvent) + + # With union type hint (registers for all types) + def log_event(event: BeforeModelCallEvent | AfterModelCallEvent) -> None: + print(f"Event: {type(event).__name__}") + agent.add_hook(log_event) + + # With list of event types + def multi_handler(event) -> None: + print(f"Event: {type(event).__name__}") + agent.add_hook(multi_handler, [BeforeModelCallEvent, AfterModelCallEvent]) + ``` + Docs: + https://strandsagents.com/latest/documentation/docs/user-guide/concepts/agents/hooks/ + """ + self.hooks.add_callback(event_type, callback) + def __del__(self) -> None: """Clean up resources when agent is garbage collected.""" # __del__ is called even when an exception is thrown in the constructor, @@ -520,7 +774,8 @@ async def stream_async( prompt: AgentInput = None, *, invocation_state: dict[str, Any] | None = None, - structured_output_model: Type[BaseModel] | None = None, + structured_output_model: type[BaseModel] | None = None, + structured_output_prompt: str | None = None, **kwargs: Any, ) -> AsyncIterator[Any]: """Process a natural language prompt and yield events as an async iterator. @@ -539,6 +794,7 @@ async def stream_async( - None: Use existing conversation history invocation_state: Additional parameters to pass through the event loop. structured_output_model: Pydantic model type(s) for structured output (overrides agent default). + structured_output_prompt: Custom prompt for forcing structured output (overrides agent default). **kwargs: Additional parameters to pass to the event loop.[Deprecating] Yields: @@ -551,6 +807,7 @@ async def stream_async( - And other event data provided by the callback handler Raises: + ConcurrencyException: If another invocation is already in progress on this agent instance. Exception: Any exceptions from the agent invocation will be propagated to the caller. Example: @@ -560,54 +817,75 @@ async def stream_async( yield event["data"] ``` """ - self._interrupt_state.resume(prompt) - - merged_state = {} - if kwargs: - warnings.warn("`**kwargs` parameter is deprecating, use `invocation_state` instead.", stacklevel=2) - merged_state.update(kwargs) - if invocation_state is not None: - merged_state["invocation_state"] = invocation_state - else: - if invocation_state is not None: - merged_state = invocation_state + # Conditionally acquire lock based on concurrent_invocation_mode + # Using threading.Lock instead of asyncio.Lock because run_async() creates + # separate event loops in different threads + if self._concurrent_invocation_mode == ConcurrentInvocationMode.THROW: + lock_acquired = self._invocation_lock.acquire(blocking=False) + if not lock_acquired: + raise ConcurrencyException( + "Agent is already processing a request. Concurrent invocations are not supported." + ) + + try: + self._interrupt_state.resume(prompt) - callback_handler = self.callback_handler - if kwargs: - callback_handler = kwargs.get("callback_handler", self.callback_handler) + self.event_loop_metrics.reset_usage_metrics() - # Process input and get message to add (if any) - messages = await self._convert_prompt_to_messages(prompt) + merged_state = {} + if kwargs: + warnings.warn("`**kwargs` parameter is deprecating, use `invocation_state` instead.", stacklevel=2) + merged_state.update(kwargs) + if invocation_state is not None: + merged_state["invocation_state"] = invocation_state + else: + if invocation_state is not None: + merged_state = invocation_state - self.trace_span = self._start_agent_trace_span(messages) + callback_handler = self.callback_handler + if kwargs: + callback_handler = kwargs.get("callback_handler", self.callback_handler) - with trace_api.use_span(self.trace_span): - try: - events = self._run_loop(messages, merged_state, structured_output_model) + # Process input and get message to add (if any) + messages = await self._convert_prompt_to_messages(prompt) - async for event in events: - event.prepare(invocation_state=merged_state) + self.trace_span = self._start_agent_trace_span(messages) + + with trace_api.use_span(self.trace_span): + try: + events = self._run_loop(messages, merged_state, structured_output_model, structured_output_prompt) - if event.is_callback_event: - as_dict = event.as_dict() - callback_handler(**as_dict) - yield as_dict + async for event in events: + event.prepare(invocation_state=merged_state) - result = AgentResult(*event["stop"]) - callback_handler(result=result) - yield AgentResultEvent(result=result).as_dict() + if event.is_callback_event: + as_dict = event.as_dict() + callback_handler(**as_dict) + yield as_dict + + result = AgentResult(*event["stop"]) + callback_handler(result=result) + yield AgentResultEvent(result=result).as_dict() - self._end_agent_trace_span(response=result) + self._end_agent_trace_span(response=result) - except Exception as e: - self._end_agent_trace_span(error=e) - raise + except Exception as e: + self._end_agent_trace_span(error=e) + raise + + finally: + # Clear cancel signal to allow agent reuse after cancellation + self._cancel_signal.clear() + + if self._invocation_lock.locked(): + self._invocation_lock.release() async def _run_loop( self, messages: Messages, invocation_state: dict[str, Any], - structured_output_model: Type[BaseModel] | None = None, + structured_output_model: type[BaseModel] | None = None, + structured_output_prompt: str | None = None, ) -> AsyncGenerator[TypedEvent, None]: """Execute the agent's event loop with the given message and parameters. @@ -615,42 +893,71 @@ async def _run_loop( messages: The input messages to add to the conversation. invocation_state: Additional parameters to pass to the event loop. structured_output_model: Optional Pydantic model type for structured output. + structured_output_prompt: Optional custom prompt for forcing structured output. Yields: Events from the event loop cycle. """ - await self.hooks.invoke_callbacks_async(BeforeInvocationEvent(agent=self)) + current_messages: Messages | None = messages - try: - yield InitEventLoopEvent() + while current_messages is not None: + before_invocation_event, _interrupts = await self.hooks.invoke_callbacks_async( + BeforeInvocationEvent(agent=self, invocation_state=invocation_state, messages=current_messages) + ) + current_messages = ( + before_invocation_event.messages if before_invocation_event.messages is not None else current_messages + ) - await self._append_messages(*messages) + agent_result: AgentResult | None = None + try: + yield InitEventLoopEvent() - structured_output_context = StructuredOutputContext( - structured_output_model or self._default_structured_output_model - ) + await self._append_messages(*current_messages) - # Execute the event loop cycle with retry logic for context limits - events = self._execute_event_loop_cycle(invocation_state, structured_output_context) - async for event in events: - # Signal from the model provider that the message sent by the user should be redacted, - # likely due to a guardrail. - if ( - isinstance(event, ModelStreamChunkEvent) - and event.chunk - and event.chunk.get("redactContent") - and event.chunk["redactContent"].get("redactUserContentMessage") - ): - self.messages[-1]["content"] = self._redact_user_content( - self.messages[-1]["content"], str(event.chunk["redactContent"]["redactUserContentMessage"]) - ) - if self._session_manager: - self._session_manager.redact_latest_message(self.messages[-1], self) - yield event + structured_output_context = StructuredOutputContext( + structured_output_model or self._default_structured_output_model, + structured_output_prompt=structured_output_prompt or self._structured_output_prompt, + ) - finally: - self.conversation_manager.apply_management(self) - await self.hooks.invoke_callbacks_async(AfterInvocationEvent(agent=self)) + # Execute the event loop cycle with retry logic for context limits + events = self._execute_event_loop_cycle(invocation_state, structured_output_context) + async for event in events: + # Signal from the model provider that the message sent by the user should be redacted, + # likely due to a guardrail. + if ( + isinstance(event, ModelStreamChunkEvent) + and event.chunk + and event.chunk.get("redactContent") + and event.chunk["redactContent"].get("redactUserContentMessage") + ): + self.messages[-1]["content"] = self._redact_user_content( + self.messages[-1]["content"], + str(event.chunk["redactContent"]["redactUserContentMessage"]), + ) + if self._session_manager: + self._session_manager.redact_latest_message(self.messages[-1], self) + yield event + + # Capture the result from the final event if available + if isinstance(event, EventLoopStopEvent): + agent_result = AgentResult(*event["stop"]) + + finally: + self.conversation_manager.apply_management(self) + after_invocation_event, _interrupts = await self.hooks.invoke_callbacks_async( + AfterInvocationEvent(agent=self, invocation_state=invocation_state, result=agent_result) + ) + + # Convert resume input to messages for next iteration, or None to stop + if after_invocation_event.resume is not None: + logger.debug("resume= | hook requested agent resume with new input") + # If in interrupt state, process interrupt responses before continuing. + # This mirrors the _interrupt_state.resume() call in stream_async and will + # raise TypeError if the resume input is not valid interrupt responses. + self._interrupt_state.resume(after_invocation_event.resume) + current_messages = await self._convert_prompt_to_messages(after_invocation_event.resume) + else: + current_messages = None async def _execute_event_loop_cycle( self, invocation_state: dict[str, Any], structured_output_context: StructuredOutputContext | None = None @@ -730,7 +1037,7 @@ async def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages: # Check if all item in input list are dictionaries elif all(isinstance(item, dict) for item in prompt): # Check if all items are messages - if all(all(key in item for key in Message.__annotations__.keys()) for item in prompt): + if all(all(key in item for key in Message.__required_keys__) for item in prompt): # Messages input - add all messages to conversation messages = cast(Messages, prompt) @@ -764,8 +1071,8 @@ def _start_agent_trace_span(self, messages: Messages) -> trace_api.Span: def _end_agent_trace_span( self, - response: Optional[AgentResult] = None, - error: Optional[Exception] = None, + response: AgentResult | None = None, + error: Exception | None = None, ) -> None: """Ends a trace span for the agent. @@ -816,6 +1123,78 @@ async def _append_messages(self, *messages: Message) -> None: self.messages.append(message) await self.hooks.invoke_callbacks_async(MessageAddedEvent(agent=self, message=message)) + def take_snapshot( + self, + *, + preset: SnapshotPreset | None = None, + include: list[SnapshotField] | None = None, + exclude: list[SnapshotField] | None = None, + app_data: dict[str, Any] | None = None, + ) -> Snapshot: + """Capture current agent state as an in-memory snapshot. + + Args: + preset: Named preset of fields to capture. Currently only "session" is supported, + which captures messages, state, conversation_manager_state, and interrupt_state. + include: Additional fields to capture on top of the preset. + exclude: Fields to remove after applying preset and include. + app_data: Application-owned arbitrary JSON stored verbatim in the snapshot. + + Returns: + A Snapshot containing the captured agent state. + + Raises: + SnapshotException: If no fields are resolved or an invalid field name is provided. + """ + fields = resolve_snapshot_fields(preset=preset, include=include, exclude=exclude) + + data: dict[str, Any] = {} + if "messages" in fields: + data["messages"] = copy.deepcopy(self.messages) + if "state" in fields: + data["state"] = self.state.get() + if "conversation_manager_state" in fields: + data["conversation_manager_state"] = self.conversation_manager.get_state() + if "interrupt_state" in fields: + data["interrupt_state"] = self._interrupt_state.to_dict() + if "system_prompt" in fields: + # Store the content-block representation so round-trips preserve caching hints and + # other block-level metadata. + data["system_prompt"] = copy.deepcopy(self._system_prompt_content) + + return Snapshot( + scope="agent", + schema_version=SNAPSHOT_SCHEMA_VERSION, + data=data, + app_data=copy.deepcopy(app_data) if app_data else {}, + ) + + def load_snapshot(self, snapshot: Snapshot) -> None: + """Restore agent state from a previously captured snapshot. + + Only fields present in snapshot.data are restored; absent fields are left unchanged. + + Args: + snapshot: The snapshot to restore from. + + Raises: + SnapshotException: If snapshot.schema_version is not "1.0". + """ + snapshot.validate() + + data = snapshot.data + + if "messages" in data: + self.messages = copy.deepcopy(data["messages"]) + if "state" in data: + self.state = AgentState(data["state"]) + if "conversation_manager_state" in data: + self.conversation_manager.restore_from_session(data["conversation_manager_state"]) + if "interrupt_state" in data: + self._interrupt_state = _InterruptState.from_dict(data["interrupt_state"]) + if "system_prompt" in data: + self.system_prompt = copy.deepcopy(data["system_prompt"]) + def _redact_user_content(self, content: list[ContentBlock], redact_message: str) -> list[ContentBlock]: """Redact user content preserving toolResult blocks. diff --git a/src/strands/agent/agent_result.py b/src/strands/agent/agent_result.py index ef8a11029..80e483088 100644 --- a/src/strands/agent/agent_result.py +++ b/src/strands/agent/agent_result.py @@ -3,8 +3,9 @@ This module defines the AgentResult class which encapsulates the complete response from an agent's processing cycle. """ +from collections.abc import Sequence from dataclasses import dataclass -from typing import Any, Sequence, cast +from typing import Any, cast from pydantic import BaseModel @@ -34,25 +35,53 @@ class AgentResult: interrupts: Sequence[Interrupt] | None = None structured_output: BaseModel | None = None + @property + def context_size(self) -> int | None: + """Most recent context size in tokens from the last LLM call. + + Returns: + The input token count from the most recent cycle, or None if no data is available. + """ + return self.metrics.latest_context_size + + @property + def projected_context_size(self) -> int | None: + """Projected context size for the next model call. + + Returns: + The projected token count (inputTokens + outputTokens), or None if no data is available. + """ + return self.metrics.projected_context_size + def __str__(self) -> str: - """Get the agent's last message as a string. + """Return a string representation of the agent result. - This method extracts and concatenates all text content from the final message, ignoring any non-text content - like images or structured data. If there's no text content but structured output is present, it serializes - the structured output instead. + Priority order: + 1. Interrupts (if present) → stringified list of interrupt dicts + 2. Structured output (if present) → JSON string + 3. Text content from message → concatenated text blocks Returns: - The agent's last message as a string. + String representation based on the priority order above. """ - content_array = self.message.get("content", []) + if self.interrupts: + return str([interrupt.to_dict() for interrupt in self.interrupts]) + if self.structured_output: + return self.structured_output.model_dump_json() + + content_array = self.message.get("content", []) result = "" for item in content_array: - if isinstance(item, dict) and "text" in item: - result += item.get("text", "") + "\n" - - if not result and self.structured_output: - result = self.structured_output.model_dump_json() + if isinstance(item, dict): + if "text" in item: + result += item.get("text", "") + "\n" + elif "citationsContent" in item: + citations_block = item["citationsContent"] + if "content" in citations_block: + for content in citations_block["content"]: + if isinstance(content, dict) and "text" in content: + result += content.get("text", "") + "\n" return result diff --git a/src/strands/agent/base.py b/src/strands/agent/base.py new file mode 100644 index 000000000..ae8a14e75 --- /dev/null +++ b/src/strands/agent/base.py @@ -0,0 +1,67 @@ +"""Agent Interface. + +Defines the minimal interface that all agent types must implement. +""" + +from collections.abc import AsyncIterator +from typing import Any, Protocol, runtime_checkable + +from ..types.agent import AgentInput +from .agent_result import AgentResult + + +@runtime_checkable +class AgentBase(Protocol): + """Protocol defining the interface for all agent types in Strands. + + This protocol defines the minimal contract that all agent implementations + must satisfy. + """ + + async def invoke_async( + self, + prompt: AgentInput = None, + **kwargs: Any, + ) -> AgentResult: + """Asynchronously invoke the agent with the given prompt. + + Args: + prompt: Input to the agent. + **kwargs: Additional arguments. + + Returns: + AgentResult containing the agent's response. + """ + ... + + def __call__( + self, + prompt: AgentInput = None, + **kwargs: Any, + ) -> AgentResult: + """Synchronously invoke the agent with the given prompt. + + Args: + prompt: Input to the agent. + **kwargs: Additional arguments. + + Returns: + AgentResult containing the agent's response. + """ + ... + + def stream_async( + self, + prompt: AgentInput = None, + **kwargs: Any, + ) -> AsyncIterator[Any]: + """Stream agent execution asynchronously. + + Args: + prompt: Input to the agent. + **kwargs: Additional arguments. + + Yields: + Events representing the streaming execution. + """ + ... diff --git a/src/strands/agent/conversation_manager/__init__.py b/src/strands/agent/conversation_manager/__init__.py index c59623215..9f6d54ff9 100644 --- a/src/strands/agent/conversation_manager/__init__.py +++ b/src/strands/agent/conversation_manager/__init__.py @@ -3,6 +3,7 @@ It includes: - ConversationManager: Abstract base class defining the conversation management interface +- ProactiveCompressionConfig: Configuration type for proactive compression settings - NullConversationManager: A no-op implementation that does not modify conversation history - SlidingWindowConversationManager: An implementation that maintains a sliding window of messages to control context size while preserving conversation coherence @@ -13,7 +14,7 @@ is critical for effective agent interactions. """ -from .conversation_manager import ConversationManager +from .conversation_manager import ConversationManager, ProactiveCompressionConfig from .null_conversation_manager import NullConversationManager from .sliding_window_conversation_manager import SlidingWindowConversationManager from .summarizing_conversation_manager import SummarizingConversationManager @@ -21,6 +22,7 @@ __all__ = [ "ConversationManager", "NullConversationManager", + "ProactiveCompressionConfig", "SlidingWindowConversationManager", "SummarizingConversationManager", ] diff --git a/src/strands/agent/conversation_manager/conversation_manager.py b/src/strands/agent/conversation_manager/conversation_manager.py index 2c1ee7847..7e2283883 100644 --- a/src/strands/agent/conversation_manager/conversation_manager.py +++ b/src/strands/agent/conversation_manager/conversation_manager.py @@ -1,15 +1,35 @@ """Abstract interface for conversation history management.""" +import logging from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, TypedDict, Union +from ...hooks.events import BeforeModelCallEvent +from ...hooks.registry import HookProvider, HookRegistry from ...types.content import Message if TYPE_CHECKING: from ...agent.agent import Agent +logger = logging.getLogger(__name__) -class ConversationManager(ABC): +DEFAULT_COMPRESSION_THRESHOLD = 0.7 +DEFAULT_CONTEXT_WINDOW_LIMIT = 200_000 + + +class ProactiveCompressionConfig(TypedDict, total=False): + """Configuration for proactive compression when passed as an object. + + Attributes: + compression_threshold: Ratio of context window usage that triggers proactive compression. + Value between 0 (exclusive) and 1 (inclusive). + Defaults to 0.7 (compress when 70% of the context window is used). + """ + + compression_threshold: float + + +class ConversationManager(ABC, HookProvider): """Abstract base class for managing conversation history. This class provides an interface for implementing conversation management strategies to control the size of message @@ -18,19 +38,127 @@ class ConversationManager(ABC): - Manage memory usage - Control context length - Maintain relevant conversation state + + ConversationManager implements the HookProvider protocol, allowing derived classes to register hooks for agent + lifecycle events. Derived classes that override register_hooks must call the base implementation to ensure proper + hook registration chain. + + The primary responsibility of a ConversationManager is overflow recovery: when the model encounters a context + window overflow, :meth:`reduce_context` is called with ``e`` set and MUST reduce the history enough for the next + model call to succeed. + + Subclasses can enable proactive compression by passing ``proactive_compression`` in the constructor. + When enabled, the base class registers a ``BeforeModelCallEvent`` hook that checks projected input tokens + against the model's context window limit and calls :meth:`reduce_context` (without ``e``) when the + threshold is exceeded. This is a best-effort operation — errors are swallowed so the model call can + still proceed. + + Example: + ```python + # Enable proactive compression with default threshold (0.7) + SlidingWindowConversationManager(window_size=50, proactive_compression=True) + + # Enable proactive compression with custom threshold + SummarizingConversationManager(proactive_compression={"compression_threshold": 0.8}) + ``` """ - def __init__(self) -> None: + def __init__(self, *, proactive_compression: Union[bool, "ProactiveCompressionConfig", None] = None) -> None: """Initialize the ConversationManager. + Args: + proactive_compression: Enable proactive context compression before the model call. + - ``True``: compress when 70% of the context window is used (default threshold). + - ``{"compression_threshold": float}``: compress at the specified ratio (0, 1]. + - ``False`` or ``None``: disabled, only reactive overflow recovery is used. + + Raises: + ValueError: If compression_threshold is not in the valid range (0, 1]. + Attributes: removed_message_count: The messages that have been removed from the agents messages array. These represent messages provided by the user or LLM that have been removed, not messages included by the conversation manager through something like summarization. """ + # Resolve the threshold from proactive_compression parameter + if proactive_compression is True: + threshold: float | None = DEFAULT_COMPRESSION_THRESHOLD + elif isinstance(proactive_compression, dict): + threshold = proactive_compression.get("compression_threshold", DEFAULT_COMPRESSION_THRESHOLD) + else: + threshold = None + + if threshold is not None and (threshold <= 0 or threshold > 1): + raise ValueError( + f"compression_threshold must be between 0 (exclusive) and 1 (inclusive), got {threshold}" + ) + self.removed_message_count = 0 + self._compression_threshold = threshold + self._context_window_limit_warned = False + + def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: + """Register hooks for agent lifecycle events. + + Always registers a ``BeforeModelCallEvent`` hook for proactive compression. + When ``proactive_compression`` is not configured, the handler is a no-op (early return). + + Derived classes that override this method must call the base implementation to ensure proper hook + registration chain. + + Args: + registry: The hook registry to register callbacks with. + **kwargs: Additional keyword arguments for future extensibility. + """ + # Always subscribe — the threshold check happens inside the handler + registry.add_callback(BeforeModelCallEvent, self._on_before_model_call_threshold) + + def _on_before_model_call_threshold(self, event: BeforeModelCallEvent) -> None: + """Handle BeforeModelCallEvent for proactive compression. + + When proactive compression is not configured, this is a no-op. + When configured, checks projected input tokens against the context window limit + and calls reduce_context() without error (best-effort) when threshold is exceeded. - def restore_from_session(self, state: dict[str, Any]) -> Optional[list[Message]]: + Args: + event: The before model call event. + """ + # Early return if proactive compression is not enabled + if self._compression_threshold is None: + return + + context_window_limit = event.agent.model.context_window_limit + if context_window_limit is None: + context_window_limit = DEFAULT_CONTEXT_WINDOW_LIMIT + if not self._context_window_limit_warned: + self._context_window_limit_warned = True + logger.warning( + "context_window_limit=<%s> | context_window_limit not set on model, using default." + " Set context_window_limit in your model config for accurate proactive compression", + DEFAULT_CONTEXT_WINDOW_LIMIT, + ) + + if event.projected_input_tokens is None: + logger.debug("projected_input_tokens= | skipping proactive compression") + return + + ratio = event.projected_input_tokens / context_window_limit + if ratio >= self._compression_threshold: + logger.debug( + "projected_tokens=<%s>, limit=<%s>, ratio=<%.2f>, compression_threshold=<%s>" + " | compression threshold exceeded, reducing context", + event.projected_input_tokens, + context_window_limit, + ratio, + self._compression_threshold, + ) + # Proactive compression is best-effort: swallow errors so the model call can still proceed. + try: + self.reduce_context(agent=event.agent) + except Exception: + logger.debug("proactive compression failed, will proceed with model call", exc_info=True) + + def restore_from_session(self, state: dict[str, Any]) -> list[Message] | None: """Restore the Conversation Manager's state from a session. Args: @@ -66,23 +194,25 @@ def apply_management(self, agent: "Agent", **kwargs: Any) -> None: pass @abstractmethod - def reduce_context(self, agent: "Agent", e: Optional[Exception] = None, **kwargs: Any) -> None: - """Called when the model's context window is exceeded. - - This method should implement the specific strategy for reducing the window size when a context overflow occurs. - It is typically called after a ContextWindowOverflowException is caught. + def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: Any) -> None: + """Reduce the conversation history. - Implementations might use strategies such as: + Called in two scenarios: + 1. **Reactive** (e is set): A context window overflow occurred. The implementation + MUST remove enough history for the next model call to succeed, or re-raise the error. + 2. **Proactive** (e is None): The compression threshold was exceeded. This is best-effort — + returning without reduction or raising is acceptable; the model call proceeds regardless. - - Removing the N oldest messages - - Summarizing older context - - Applying importance-based filtering - - Maintaining critical conversation markers + Implementations should modify ``agent.messages`` in-place. Args: agent: The agent whose conversation history will be reduced. This list is modified in-place. e: The exception that triggered the context reduction, if any. + When set, this is a reactive overflow recovery call — the implementation MUST + reduce enough history for the next model call to succeed. + When None, this is a proactive compression call — best-effort reduction to avoid + hitting the context window limit. **kwargs: Additional keyword arguments for future extensibility. """ pass diff --git a/src/strands/agent/conversation_manager/null_conversation_manager.py b/src/strands/agent/conversation_manager/null_conversation_manager.py index 5ff6874e5..4077cb08b 100644 --- a/src/strands/agent/conversation_manager/null_conversation_manager.py +++ b/src/strands/agent/conversation_manager/null_conversation_manager.py @@ -1,11 +1,10 @@ """Null implementation of conversation management.""" -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from ...agent.agent import Agent -from ...types.exceptions import ContextWindowOverflowException from .conversation_manager import ConversationManager @@ -28,8 +27,11 @@ def apply_management(self, agent: "Agent", **kwargs: Any) -> None: """ pass - def reduce_context(self, agent: "Agent", e: Optional[Exception] = None, **kwargs: Any) -> None: - """Does not reduce context and raises an exception. + def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: Any) -> None: + """Does not reduce context. + + When called reactively (e is not None), re-raises the overflow exception since this + manager cannot reduce context. When called proactively (e is None), returns silently. Args: agent: The agent whose conversation history will remain unmodified. @@ -37,10 +39,7 @@ def reduce_context(self, agent: "Agent", e: Optional[Exception] = None, **kwargs **kwargs: Additional keyword arguments for future extensibility. Raises: - e: If provided. - ContextWindowOverflowException: If e is None. + e: If provided (reactive overflow). """ if e: raise e - else: - raise ContextWindowOverflowException("Context window overflowed!") diff --git a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py index e082abe8e..1ad8edc24 100644 --- a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py +++ b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py @@ -1,36 +1,144 @@ """Sliding window conversation history management.""" import logging -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from ...agent.agent import Agent -from ...types.content import Messages +from ...hooks import BeforeModelCallEvent, HookRegistry +from ...types.content import ContentBlock, Messages from ...types.exceptions import ContextWindowOverflowException -from .conversation_manager import ConversationManager +from ...types.tools import ToolResultContent +from .conversation_manager import ConversationManager, ProactiveCompressionConfig logger = logging.getLogger(__name__) +_PRESERVE_CHARS = 200 + class SlidingWindowConversationManager(ConversationManager): """Implements a sliding window strategy for managing conversation history. This class handles the logic of maintaining a conversation window that preserves tool usage pairs and avoids invalid window states. + + When truncation is enabled (the default), large tool results are partially truncated, preserving the first + and last 200 characters, and image blocks inside tool results are replaced with descriptive text placeholders. + Truncation targets the oldest tool results first so the most relevant recent context is preserved as long + as possible. + + Supports proactive management during agent loop execution via the per_turn parameter. """ - def __init__(self, window_size: int = 40, should_truncate_results: bool = True): + def __init__( + self, + window_size: int = 40, + should_truncate_results: bool = True, + *, + per_turn: bool | int = False, + proactive_compression: bool | ProactiveCompressionConfig | None = None, + ): """Initialize the sliding window conversation manager. Args: window_size: Maximum number of messages to keep in the agent's history. - Defaults to 40 messages. + Use 0 to clear all messages on every reduction. Defaults to 40 messages. should_truncate_results: Truncate tool results when a message is too large for the model's context window + per_turn: Controls when to apply message management during agent execution. + - False (default): Only apply management at the end (default behavior) + - True: Apply management before every model call + - int (e.g., 3): Apply management before every N model calls + + When to use per_turn: If your agent performs many tool operations in loops + (e.g., web browsing with frequent screenshots), enable per_turn to proactively + manage message history and prevent the agent loop from slowing down. Start with + per_turn=True and adjust to a specific frequency (e.g., per_turn=5) if needed + for performance tuning. + proactive_compression: Enable proactive context compression before the model call. + - ``True``: compress when 70% of the context window is used (default threshold). + - ``{"compression_threshold": float}``: compress at the specified ratio (0, 1]. + - ``False`` or ``None``: disabled, only reactive overflow recovery is used. + + Raises: + ValueError: If window_size is negative, or if per_turn is 0 or a negative integer. """ - super().__init__() + if not isinstance(window_size, bool) and window_size < 0: + raise ValueError(f"window_size must be a non-negative integer, got {window_size}") + if isinstance(per_turn, int) and not isinstance(per_turn, bool) and per_turn <= 0: + raise ValueError(f"per_turn must be a positive integer, True, or False, got {per_turn}") + + super().__init__(proactive_compression=proactive_compression) + self.window_size = window_size self.should_truncate_results = should_truncate_results + self.per_turn = per_turn + self._model_call_count = 0 + + def register_hooks(self, registry: "HookRegistry", **kwargs: Any) -> None: + """Register hook callbacks for per-turn conversation management. + + Args: + registry: The hook registry to register callbacks with. + **kwargs: Additional keyword arguments for future extensibility. + """ + super().register_hooks(registry, **kwargs) + + # Always register the callback - per_turn check happens in the callback + registry.add_callback(BeforeModelCallEvent, self._on_before_model_call) + + def _on_before_model_call(self, event: BeforeModelCallEvent) -> None: + """Handle before model call event for per-turn management. + + This callback is invoked before each model call. It tracks the model call count and applies message management + based on the per_turn configuration. + + Args: + event: The before model call event containing the agent and model execution details. + """ + # Check if per_turn is enabled + if self.per_turn is False: + return + + self._model_call_count += 1 + + # Determine if we should apply management + should_apply = False + if self.per_turn is True: + should_apply = True + elif isinstance(self.per_turn, int) and self.per_turn > 0: + should_apply = self._model_call_count % self.per_turn == 0 + + if should_apply: + logger.debug( + "model_call_count=<%d>, per_turn=<%s> | applying per-turn conversation management", + self._model_call_count, + self.per_turn, + ) + self.apply_management(event.agent) + + def get_state(self) -> dict[str, Any]: + """Get the current state of the conversation manager. + + Returns: + Dictionary containing the manager's state, including model call count for per-turn tracking. + """ + state = super().get_state() + state["model_call_count"] = self._model_call_count + return state + + def restore_from_session(self, state: dict[str, Any]) -> list | None: + """Restore the conversation manager's state from a session. + + Args: + state: Previous state of the conversation manager + + Returns: + Optional list of messages to prepend to the agent's messages. + """ + result = super().restore_from_session(state) + self._model_call_count = state.get("model_call_count", 0) + return result def apply_management(self, agent: "Agent", **kwargs: Any) -> None: """Apply the sliding window to the agent's messages array to maintain a manageable history size. @@ -52,9 +160,15 @@ def apply_management(self, agent: "Agent", **kwargs: Any) -> None: return self.reduce_context(agent) - def reduce_context(self, agent: "Agent", e: Optional[Exception] = None, **kwargs: Any) -> None: + def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: Any) -> None: """Trim the oldest messages to reduce the conversation context size. + When ``e`` is set (reactive overflow recovery), attempts to truncate large tool results + first before falling back to message trimming. + + When ``e`` is None (proactive compression or routine management), only trims messages + without attempting tool result truncation. + The method handles special cases where trimming the messages leads to: - toolResult with no corresponding toolUse - toolUse with no corresponding toolResult @@ -63,48 +177,101 @@ def reduce_context(self, agent: "Agent", e: Optional[Exception] = None, **kwargs agent: The agent whose messages will be reduce. This list is modified in-place. e: The exception that triggered the context reduction, if any. + When set, this is a reactive overflow recovery call. + When None, this is a proactive or routine management call. **kwargs: Additional keyword arguments for future extensibility. Raises: - ContextWindowOverflowException: If the context cannot be reduced further. - Such as when the conversation is already minimal or when tool result messages cannot be properly - converted. + ContextWindowOverflowException: If the context cannot be reduced further and a context overflow + error was provided (e is not None). When called during routine window management or + proactive compression (e is None), logs a warning and returns without modification. """ messages = agent.messages - # Try to truncate the tool result first - last_message_idx_with_tool_results = self._find_last_message_with_tool_results(messages) - if last_message_idx_with_tool_results is not None and self.should_truncate_results: - logger.debug( - "message_index=<%s> | found message with tool results at index", last_message_idx_with_tool_results - ) - results_truncated = self._truncate_tool_results(messages, last_message_idx_with_tool_results) - if results_truncated: - logger.debug("message_index=<%s> | tool results truncated", last_message_idx_with_tool_results) - return + # window_size=0 means "remove all messages" (matches TypeScript SDK behaviour) + if self.window_size == 0: + self.removed_message_count += len(messages) + messages[:] = [] + return + + # Try to truncate the tool result first (only for reactive overflow, not proactive compression) + if e is not None: + oldest_message_idx_with_tool_results = self._find_oldest_message_with_tool_results(messages) + if oldest_message_idx_with_tool_results is not None and self.should_truncate_results: + logger.debug( + "message_index=<%s> | found message with tool results at index", + oldest_message_idx_with_tool_results, + ) + results_truncated = self._truncate_tool_results(messages, oldest_message_idx_with_tool_results) + if results_truncated: + logger.debug("message_index=<%s> | tool results truncated", oldest_message_idx_with_tool_results) + return # Try to trim index id when tool result cannot be truncated anymore # If the number of messages is less than the window_size, then we default to 2, otherwise, trim to window size trim_index = 2 if len(messages) <= self.window_size else len(messages) - self.window_size - # Find the next valid trim_index + # Find the next valid trim point that: + # 1. Starts with a user message (required by most model providers) + # 2. Does not start with an orphaned toolResult + # 3. Does not start with a toolUse unless its toolResult immediately follows + # Falls back to an assistant(toolUse) + user(toolResult) boundary if no plain user message exists. + # This is acceptable because providers treat a complete toolUse/toolResult pair as a valid + # conversation continuation, and without this fallback tool-heavy conversations cannot be trimmed. + fallback_trim_index = None + while trim_index < len(messages): + # Prefer starting with a user message + if messages[trim_index]["role"] != "user": + # Track first valid assistant(toolUse) + user(toolResult) pair as fallback + if ( + fallback_trim_index is None + and any("toolUse" in content for content in messages[trim_index]["content"]) + and trim_index + 1 < len(messages) + and messages[trim_index + 1]["role"] == "user" + and any("toolResult" in content for content in messages[trim_index + 1]["content"]) + ): + fallback_trim_index = trim_index + + trim_index += 1 + continue + if ( # Oldest message cannot be a toolResult because it needs a toolUse preceding it any("toolResult" in content for content in messages[trim_index]["content"]) or ( # Oldest message can be a toolUse only if a toolResult immediately follows it. + # Note: toolUse content normally appears only in assistant messages, but this + # check is kept as a defensive safeguard for non-standard message formats. any("toolUse" in content for content in messages[trim_index]["content"]) - and trim_index + 1 < len(messages) - and not any("toolResult" in content for content in messages[trim_index + 1]["content"]) + and not ( + trim_index + 1 < len(messages) + and any("toolResult" in content for content in messages[trim_index + 1]["content"]) + ) ) ): trim_index += 1 else: break else: - # If we didn't find a valid trim_index, then we throw - raise ContextWindowOverflowException("Unable to trim conversation context!") from e + # No plain user message found — use assistant+toolResult fallback if available + if fallback_trim_index is not None: + logger.debug( + "trim_index=<%s> | no plain user message trim point found, " + "falling back to assistant(toolUse) + user(toolResult) boundary", + fallback_trim_index, + ) + trim_index = fallback_trim_index + elif e is not None: + raise ContextWindowOverflowException("Unable to trim conversation context!") from e + else: + logger.warning( + "window_size=<%s>, message_count=<%s> | unable to trim conversation context, " + "no valid trim point found", + self.window_size, + len(messages), + ) + return # trim_index represents the number of messages being removed from the agents messages array self.removed_message_count += trim_index @@ -113,10 +280,14 @@ def reduce_context(self, agent: "Agent", e: Optional[Exception] = None, **kwargs messages[:] = messages[trim_index:] def _truncate_tool_results(self, messages: Messages, msg_idx: int) -> bool: - """Truncate tool results in a message to reduce context size. + """Truncate tool results and replace image blocks in a message to reduce context size. - When a message contains tool results that are too large for the model's context window, this function - replaces the content of those tool results with a simple error message. + For text blocks within tool results, all blocks are partially truncated unless they + have already been truncated. The first and last _PRESERVE_CHARS characters are kept, + and the removed middle is replaced with a notice indicating how many characters were + removed. The tool result status is not changed. + + Image blocks nested inside tool result content are replaced with a short descriptive placeholder. Args: messages: The conversation message history. @@ -128,52 +299,82 @@ def _truncate_tool_results(self, messages: Messages, msg_idx: int) -> bool: if msg_idx >= len(messages) or msg_idx < 0: return False + def _image_placeholder(image_block: Any) -> str: + source: Any = image_block.get("source", {}) + media_type = image_block.get("format", "unknown") + data = source.get("bytes", b"") + return f"[image: {media_type}, {len(data) if data else 0} bytes]" + message = messages[msg_idx] changes_made = False - tool_result_too_large_message = "The tool result was too large!" - for i, content in enumerate(message.get("content", [])): - if isinstance(content, dict) and "toolResult" in content: - tool_result_content_text = next( - (item["text"] for item in content["toolResult"]["content"] if "text" in item), - "", - ) - # make the overwriting logic togglable - if ( - message["content"][i]["toolResult"]["status"] == "error" - and tool_result_content_text == tool_result_too_large_message - ): - logger.info("ToolResult has already been updated, skipping overwrite") - return False - # Update status to error with informative message - message["content"][i]["toolResult"]["status"] = "error" - message["content"][i]["toolResult"]["content"] = [{"text": tool_result_too_large_message}] - changes_made = True + new_content: list[ContentBlock] = [] + + for content in message.get("content", []): + if "toolResult" in content: + tool_result: Any = content["toolResult"] + tool_result_items = tool_result.get("content", []) + new_items: list[ToolResultContent] = [] + item_changed = False + + for item in tool_result_items: + # Replace image items nested inside toolResult content + if "image" in item: + new_items.append({"text": _image_placeholder(item["image"])}) + item_changed = True + continue + + # Partially truncate text items that have not already been truncated + if "text" in item: + text = item["text"] + truncation_marker = "... [truncated:" + if truncation_marker not in text and len(text) > 2 * _PRESERVE_CHARS: + prefix = text[:_PRESERVE_CHARS] + suffix = text[-_PRESERVE_CHARS:] + removed = len(text) - 2 * _PRESERVE_CHARS + truncated_text = ( + f"{prefix}...\n\n... [truncated: {removed} chars removed] ...\n\n...{suffix}" + ) + new_items.append({"text": truncated_text}) + item_changed = True + continue + + new_items.append(item) + + if item_changed: + updated_tool_result: Any = { + **{k: v for k, v in tool_result.items() if k != "content"}, + "content": new_items, + } + new_content.append({"toolResult": updated_tool_result}) + changes_made = True + else: + new_content.append(content) + continue + + new_content.append(content) + + if changes_made: + message["content"] = new_content return changes_made - def _find_last_message_with_tool_results(self, messages: Messages) -> Optional[int]: - """Find the index of the last message containing tool results. + def _find_oldest_message_with_tool_results(self, messages: Messages) -> int | None: + """Find the index of the oldest message containing tool results. - This is useful for identifying messages that might need to be truncated to reduce context size. + Iterates from oldest to newest so that truncation targets the least-recent + (and therefore least relevant) tool results first. Args: messages: The conversation message history. Returns: - Index of the last message with tool results, or None if no such message exists. + Index of the oldest message with tool results, or None if no such message exists. """ - # Iterate backwards through all messages (from newest to oldest) - for idx in range(len(messages) - 1, -1, -1): - # Check if this message has any content with toolResult + # Iterate from oldest to newest + for idx in range(len(messages)): current_message = messages[idx] - has_tool_result = False - for content in current_message.get("content", []): if isinstance(content, dict) and "toolResult" in content: - has_tool_result = True - break - - if has_tool_result: - return idx + return idx return None diff --git a/src/strands/agent/conversation_manager/summarizing_conversation_manager.py b/src/strands/agent/conversation_manager/summarizing_conversation_manager.py index 12185c286..2030e1d3b 100644 --- a/src/strands/agent/conversation_manager/summarizing_conversation_manager.py +++ b/src/strands/agent/conversation_manager/summarizing_conversation_manager.py @@ -1,16 +1,18 @@ """Summarizing conversation history management with configurable options.""" import logging -from typing import TYPE_CHECKING, Any, List, Optional, cast +from typing import TYPE_CHECKING, Any, Optional, cast from typing_extensions import override +from ..._async import run_async +from ...event_loop.streaming import process_stream from ...tools._tool_helpers import noop_tool from ...tools.registry import ToolRegistry from ...types.content import Message from ...types.exceptions import ContextWindowOverflowException from ...types.tools import AgentTool -from .conversation_manager import ConversationManager +from .conversation_manager import ConversationManager, ProactiveCompressionConfig if TYPE_CHECKING: from ..agent import Agent @@ -62,7 +64,9 @@ def __init__( summary_ratio: float = 0.3, preserve_recent_messages: int = 10, summarization_agent: Optional["Agent"] = None, - summarization_system_prompt: Optional[str] = None, + summarization_system_prompt: str | None = None, + *, + proactive_compression: bool | ProactiveCompressionConfig | None = None, ): """Initialize the summarizing conversation manager. @@ -75,8 +79,12 @@ def __init__( If provided, this agent can use tools as part of the summarization process. summarization_system_prompt: Optional system prompt override for summarization. If None, uses the default summarization prompt. + proactive_compression: Enable proactive context compression before the model call. + - ``True``: compress when 70% of the context window is used (default threshold). + - ``{"compression_threshold": float}``: compress at the specified ratio (0, 1]. + - ``False`` or ``None``: disabled, only reactive overflow recovery is used. """ - super().__init__() + super().__init__(proactive_compression=proactive_compression) if summarization_agent is not None and summarization_system_prompt is not None: raise ValueError( "Cannot provide both summarization_agent and summarization_system_prompt. " @@ -87,10 +95,10 @@ def __init__( self.preserve_recent_messages = preserve_recent_messages self.summarization_agent = summarization_agent self.summarization_system_prompt = summarization_system_prompt - self._summary_message: Optional[Message] = None + self._summary_message: Message | None = None @override - def restore_from_session(self, state: dict[str, Any]) -> Optional[list[Message]]: + def restore_from_session(self, state: dict[str, Any]) -> list[Message] | None: """Restores the Summarizing Conversation manager from its previous state in a session. Args: @@ -121,64 +129,94 @@ def apply_management(self, agent: "Agent", **kwargs: Any) -> None: # No proactive management - summarization only happens on context overflow pass - def reduce_context(self, agent: "Agent", e: Optional[Exception] = None, **kwargs: Any) -> None: + def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: Any) -> None: """Reduce context using summarization. + When ``e`` is set (reactive overflow recovery), summarization failure is re-raised — + the agent loop must not proceed with an overflow. + + When ``e`` is None (proactive compression), summarization failure is logged and + returns silently — the model call proceeds regardless. + Args: agent: The agent whose conversation history will be reduced. The agent's messages list is modified in-place. e: The exception that triggered the context reduction, if any. + When set, this is a reactive overflow recovery call. + When None, this is a proactive compression call (best-effort). **kwargs: Additional keyword arguments for future extensibility. Raises: - ContextWindowOverflowException: If the context cannot be summarized. + Exception: If summarization fails during reactive overflow recovery (e is set). """ try: - # Calculate how many messages to summarize - messages_to_summarize_count = max(1, int(len(agent.messages) * self.summary_ratio)) + self._summarize_oldest(agent) + except Exception as summarization_error: + if e is not None: + # Reactive: rethrow so the ContextWindowOverflowException propagates + logger.error("Summarization failed: %s", summarization_error) + raise summarization_error from e + # Proactive: best-effort, swallow errors so the model call can still proceed. + logger.warning("Proactive summarization failed, continuing: %s", summarization_error) - # Ensure we don't summarize recent messages - messages_to_summarize_count = min( - messages_to_summarize_count, len(agent.messages) - self.preserve_recent_messages - ) + def _summarize_oldest(self, agent: "Agent") -> None: + """Summarize the oldest messages and replace them with a summary. + + Args: + agent: The agent instance. - if messages_to_summarize_count <= 0: - raise ContextWindowOverflowException("Cannot summarize: insufficient messages for summarization") + Raises: + ContextWindowOverflowException: If there are insufficient messages for summarization. + """ + # Calculate how many messages to summarize + messages_to_summarize_count = max(1, int(len(agent.messages) * self.summary_ratio)) - # Adjust split point to avoid breaking ToolUse/ToolResult pairs - messages_to_summarize_count = self._adjust_split_point_for_tool_pairs( - agent.messages, messages_to_summarize_count - ) + # Ensure we don't summarize recent messages + messages_to_summarize_count = min( + messages_to_summarize_count, len(agent.messages) - self.preserve_recent_messages + ) - if messages_to_summarize_count <= 0: - raise ContextWindowOverflowException("Cannot summarize: insufficient messages for summarization") + if messages_to_summarize_count <= 0: + raise ContextWindowOverflowException("Cannot summarize: insufficient messages for summarization") - # Extract messages to summarize - messages_to_summarize = agent.messages[:messages_to_summarize_count] - remaining_messages = agent.messages[messages_to_summarize_count:] + # Adjust split point to avoid breaking ToolUse/ToolResult pairs + messages_to_summarize_count = self._adjust_split_point_for_tool_pairs( + agent.messages, messages_to_summarize_count + ) - # Keep track of the number of messages that have been summarized thus far. - self.removed_message_count += len(messages_to_summarize) - # If there is a summary message, don't count it in the removed_message_count. - if self._summary_message: - self.removed_message_count -= 1 + if messages_to_summarize_count <= 0: + raise ContextWindowOverflowException("Cannot summarize: insufficient messages for summarization") - # Generate summary - self._summary_message = self._generate_summary(messages_to_summarize, agent) + # Extract messages to summarize + messages_to_summarize = agent.messages[:messages_to_summarize_count] + remaining_messages = agent.messages[messages_to_summarize_count:] - # Replace the summarized messages with the summary - agent.messages[:] = [self._summary_message] + remaining_messages + # Keep track of the number of messages that have been summarized thus far. + self.removed_message_count += len(messages_to_summarize) + # If there is a summary message, don't count it in the removed_message_count. + if self._summary_message: + self.removed_message_count -= 1 - except Exception as summarization_error: - logger.error("Summarization failed: %s", summarization_error) - raise summarization_error from e + # Generate summary + self._summary_message = self._generate_summary(messages_to_summarize, agent) - def _generate_summary(self, messages: List[Message], agent: "Agent") -> Message: + # Replace the summarized messages with the summary + agent.messages[:] = [self._summary_message] + remaining_messages + + def _generate_summary(self, messages: list[Message], agent: "Agent") -> Message: """Generate a summary of the provided messages. + When a dedicated summarization_agent was provided at init time, it is invoked as before + (full agent pipeline, tool execution, etc.). + + In the default case (no summarization_agent), the parent agent's *model* is called + directly via ``model.stream()``. This avoids re-entering the agent pipeline which + would deadlock on ``_invocation_lock`` and corrupt metrics / traces / interrupt state. + Args: messages: The messages to summarize. - agent: The agent instance to use for summarization. + agent: The agent instance whose model will be used for summarization when no + dedicated summarization_agent was configured. Returns: A message containing the conversation summary. @@ -186,25 +224,37 @@ def _generate_summary(self, messages: List[Message], agent: "Agent") -> Message: Raises: Exception: If summary generation fails. """ - # Choose which agent to use for summarization - summarization_agent = self.summarization_agent if self.summarization_agent is not None else agent + if self.summarization_agent is not None: + return self._generate_summary_with_agent(messages) + + return self._generate_summary_with_model(messages, agent) + + # ------------------------------------------------------------------ + # Path 1 – dedicated summarization agent (backward-compatible) + # ------------------------------------------------------------------ + + def _generate_summary_with_agent(self, messages: list[Message]) -> Message: + """Generate a summary using the dedicated summarization agent. + + Args: + messages: The messages to summarize. + + Returns: + A message containing the conversation summary. + """ + summarization_agent = self.summarization_agent + assert summarization_agent is not None # guaranteed by caller - # Save original system prompt, messages, and tool registry to restore later original_system_prompt = summarization_agent.system_prompt original_messages = summarization_agent.messages.copy() original_tool_registry = summarization_agent.tool_registry + original_structured_output_model = getattr(summarization_agent, "_default_structured_output_model", None) try: - # Only override system prompt if no agent was provided during initialization - if self.summarization_agent is None: - # Use custom system prompt if provided, otherwise use default - system_prompt = ( - self.summarization_system_prompt - if self.summarization_system_prompt is not None - else DEFAULT_SUMMARIZATION_PROMPT - ) - # Temporarily set the system prompt for summarization - summarization_agent.system_prompt = system_prompt + # Disable structured output for summarization. Summaries are plain text and + # structured output adds toolUse blocks that are invalid in user messages. + if hasattr(summarization_agent, "_default_structured_output_model"): + summarization_agent._default_structured_output_model = None # Add no-op tool if agent has no tools to satisfy tool spec requirement if not summarization_agent.tool_names: @@ -214,17 +264,64 @@ def _generate_summary(self, messages: List[Message], agent: "Agent") -> Message: summarization_agent.messages = messages - # Use the agent to generate summary with rich content (can use tools if needed) result = summarization_agent("Please summarize this conversation.") return cast(Message, {**result.message, "role": "user"}) finally: - # Restore original agent state summarization_agent.system_prompt = original_system_prompt summarization_agent.messages = original_messages summarization_agent.tool_registry = original_tool_registry + if hasattr(summarization_agent, "_default_structured_output_model"): + summarization_agent._default_structured_output_model = original_structured_output_model + + # ------------------------------------------------------------------ + # Path 2 – default case: call model.stream() directly + # ------------------------------------------------------------------ + + def _generate_summary_with_model(self, messages: list[Message], agent: "Agent") -> Message: + """Generate a summary by calling the agent's model directly. + + This bypasses the full agent pipeline (lock, metrics, traces, tool loop) and + simply asks the underlying model to summarize the conversation. + + Args: + messages: The messages to summarize. + agent: The parent agent whose model is used. + + Returns: + A message containing the conversation summary. + """ + system_prompt = ( + self.summarization_system_prompt + if self.summarization_system_prompt is not None + else DEFAULT_SUMMARIZATION_PROMPT + ) + + # Build the message list: conversation history + summarization request + summarization_messages = list(messages) + [ + {"role": "user", "content": [{"text": "Please summarize this conversation."}]} + ] + + async def _call_model() -> Message: + chunks = agent.model.stream( + summarization_messages, + tool_specs=None, + system_prompt=system_prompt, + ) + + result_message: Message | None = None + async for event in process_stream(chunks): + if "stop" in event: + _, result_message, _, _ = event["stop"] + + if result_message is None: + raise RuntimeError("Failed to generate summary: no response from model") + return result_message + + message = run_async(_call_model) + return cast(Message, {**message, "role": "user"}) - def _adjust_split_point_for_tool_pairs(self, messages: List[Message], split_point: int) -> int: + def _adjust_split_point_for_tool_pairs(self, messages: list[Message], split_point: int) -> int: """Adjust the split point to avoid breaking ToolUse/ToolResult pairs. Uses the same logic as SlidingWindowConversationManager for consistency. diff --git a/src/strands/event_loop/_recover_message_on_max_tokens_reached.py b/src/strands/event_loop/_recover_message_on_max_tokens_reached.py index ab6fb4abe..dc073ba07 100644 --- a/src/strands/event_loop/_recover_message_on_max_tokens_reached.py +++ b/src/strands/event_loop/_recover_message_on_max_tokens_reached.py @@ -68,4 +68,7 @@ def recover_message_on_max_tokens_reached(message: Message) -> Message: } ) - return {"content": valid_content, "role": message["role"]} + recovered: Message = {"content": valid_content, "role": message["role"]} + if "metadata" in message: + recovered["metadata"] = message["metadata"] + return recovered diff --git a/src/strands/event_loop/_retry.py b/src/strands/event_loop/_retry.py new file mode 100644 index 000000000..04a6101b8 --- /dev/null +++ b/src/strands/event_loop/_retry.py @@ -0,0 +1,157 @@ +"""Retry strategy implementations for handling model throttling and other retry scenarios. + +This module provides hook-based retry strategies that can be configured on the Agent +to control retry behavior for model invocations. Retry strategies implement the +HookProvider protocol and register callbacks for AfterModelCallEvent to determine +when and how to retry failed model calls. +""" + +import asyncio +import logging +from typing import Any + +from ..hooks.events import AfterInvocationEvent, AfterModelCallEvent +from ..hooks.registry import HookProvider, HookRegistry +from ..types._events import EventLoopThrottleEvent, TypedEvent +from ..types.exceptions import ModelThrottledException + +logger = logging.getLogger(__name__) + + +class ModelRetryStrategy(HookProvider): + """Default retry strategy for model throttling with exponential backoff. + + Retries model calls on ModelThrottledException using exponential backoff. + Delay doubles after each attempt: initial_delay, initial_delay*2, initial_delay*4, + etc., capped at max_delay. State resets after successful calls. + + With defaults (initial_delay=4, max_delay=240, max_attempts=6), delays are: + 4s → 8s → 16s → 32s → 64s (5 retries before giving up on the 6th attempt). + + Args: + max_attempts: Total model attempts before re-raising the exception. + initial_delay: Base delay in seconds; used for first two retries, then doubles. + max_delay: Upper bound in seconds for the exponential backoff. + """ + + def __init__( + self, + *, + max_attempts: int = 6, + initial_delay: int = 4, + max_delay: int = 240, + ): + """Initialize the retry strategy. + + Args: + max_attempts: Total model attempts before re-raising the exception. Defaults to 6. + initial_delay: Base delay in seconds; used for first two retries, then doubles. + Defaults to 4. + max_delay: Upper bound in seconds for the exponential backoff. Defaults to 240. + """ + self._max_attempts = max_attempts + self._initial_delay = initial_delay + self._max_delay = max_delay + self._current_attempt = 0 + self._backwards_compatible_event_to_yield: TypedEvent | None = None + + def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: + """Register callbacks for AfterModelCallEvent and AfterInvocationEvent. + + Args: + registry: The hook registry to register callbacks with. + **kwargs: Additional keyword arguments for future extensibility. + """ + registry.add_callback(AfterModelCallEvent, self._handle_after_model_call) + registry.add_callback(AfterInvocationEvent, self._handle_after_invocation) + + def _calculate_delay(self, attempt: int) -> int: + """Calculate retry delay using exponential backoff. + + Args: + attempt: The attempt number (0-indexed) to calculate delay for. + + Returns: + Delay in seconds for the given attempt. + """ + delay: int = self._initial_delay * (2**attempt) + return min(delay, self._max_delay) + + def _reset_retry_state(self) -> None: + """Reset retry state to initial values.""" + self._current_attempt = 0 + + async def _handle_after_invocation(self, event: AfterInvocationEvent) -> None: + """Reset retry state after invocation completes. + + Args: + event: The AfterInvocationEvent signaling invocation completion. + """ + self._reset_retry_state() + + async def _handle_after_model_call(self, event: AfterModelCallEvent) -> None: + """Handle model call completion and determine if retry is needed. + + This callback is invoked after each model call. If the call failed with + a ModelThrottledException and we haven't exceeded max_attempts, it sets + event.retry to True and sleeps for the current delay before returning. + + On successful calls, it resets the retry state to prepare for future calls. + + Args: + event: The AfterModelCallEvent containing call results or exception. + """ + delay = self._calculate_delay(self._current_attempt) + + self._backwards_compatible_event_to_yield = None + + # If already retrying, skip processing (another hook may have triggered retry) + if event.retry: + return + + # If model call succeeded, reset retry state + if event.stop_response is not None: + logger.debug( + "stop_reason=<%s> | model call succeeded, resetting retry state", + event.stop_response.stop_reason, + ) + self._reset_retry_state() + return + + # Check if we have an exception and reset state if no exception + if event.exception is None: + self._reset_retry_state() + return + + # Only retry on ModelThrottledException + if not isinstance(event.exception, ModelThrottledException): + return + + # Increment attempt counter first + self._current_attempt += 1 + + # Check if we've exceeded max attempts + if self._current_attempt >= self._max_attempts: + logger.debug( + "current_attempt=<%d>, max_attempts=<%d> | max retry attempts reached, not retrying", + self._current_attempt, + self._max_attempts, + ) + return + + self._backwards_compatible_event_to_yield = EventLoopThrottleEvent(delay=delay) + + # Retry the model call + logger.debug( + "retry_delay_seconds=<%s>, max_attempts=<%s>, current_attempt=<%s> " + "| throttling exception encountered | delaying before next retry", + delay, + self._max_attempts, + self._current_attempt, + ) + + # Sleep for current delay + await asyncio.sleep(delay) + + # Set retry flag and track that this strategy triggered it + event.retry = True diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index f25057e4d..128ef9ca3 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -8,10 +8,10 @@ 4. Manage recursive execution cycles """ -import asyncio import logging import uuid -from typing import TYPE_CHECKING, Any, AsyncGenerator +from collections.abc import AsyncGenerator +from typing import TYPE_CHECKING, Any from opentelemetry import trace as trace_api @@ -22,7 +22,6 @@ from ..tools.structured_output._structured_output_context import StructuredOutputContext from ..types._events import ( EventLoopStopEvent, - EventLoopThrottleEvent, ForceStopEvent, ModelMessageEvent, ModelStopReason, @@ -38,12 +37,12 @@ ContextWindowOverflowException, EventLoopException, MaxTokensReachedException, - ModelThrottledException, StructuredOutputException, ) from ..types.streaming import StopReason from ..types.tools import ToolResult, ToolUse from ._recover_message_on_max_tokens_reached import recover_message_on_max_tokens_reached +from ._retry import ModelRetryStrategy from .streaming import stream_messages if TYPE_CHECKING: @@ -76,6 +75,48 @@ def _has_tool_use_in_latest_message(messages: "Messages") -> bool: return False +async def _estimate_input_tokens(agent: "Agent") -> int: + """Estimate the input token count for the next model call. + + Reads inputTokens + outputTokens from the last assistant message's metadata as a known + baseline, then estimates only new messages added after it. Falls back to full estimation + when no metadata is available (cold start or first call). On cold start, tool specs are + resolved lazily so that the caller does not need to resolve them before BeforeModelCallEvent. + + Args: + agent: The agent instance with messages and model. + + Returns: + Estimated input token count. + """ + messages = agent.messages + + # Find the last assistant message with usage metadata + last_assistant_idx = -1 + for i, msg in reversed(list(enumerate(messages))): + if msg.get("role") == "assistant" and msg.get("metadata", {}).get("usage"): + last_assistant_idx = i + break + + if last_assistant_idx >= 0: + usage = messages[last_assistant_idx]["metadata"]["usage"] + known_baseline = usage["inputTokens"] + usage["outputTokens"] + new_messages = messages[last_assistant_idx + 1 :] + if not new_messages: + return known_baseline + # System prompt and tool spec tokens are already included in the baseline + return known_baseline + await agent.model.count_tokens(new_messages) + + # Cold start: resolve tool specs lazily for estimation only + tool_specs = agent.tool_registry.get_all_tool_specs() + return await agent.model.count_tokens( + messages, + tool_specs=tool_specs, + system_prompt=agent.system_prompt, + system_prompt_content=agent._system_prompt_content, + ) + + async def event_loop_cycle( agent: "Agent", invocation_state: dict[str, Any], @@ -140,108 +181,105 @@ async def event_loop_cycle( ) invocation_state["event_loop_cycle_span"] = cycle_span - # Skipping model invocation if in interrupt state as interrupts are currently only supported for tool calls. - if agent._interrupt_state.activated: - stop_reason: StopReason = "tool_use" - message = agent._interrupt_state.context["tool_use_message"] - # Skip model invocation if the latest message contains ToolUse - elif _has_tool_use_in_latest_message(agent.messages): - stop_reason = "tool_use" - message = agent.messages[-1] - else: - model_events = _handle_model_execution( - agent, cycle_span, cycle_trace, invocation_state, tracer, structured_output_context - ) - async for model_event in model_events: - if not isinstance(model_event, ModelStopReason): - yield model_event + with trace_api.use_span(cycle_span, end_on_exit=False): + try: + # Skipping model invocation if in interrupt state as interrupts are currently only supported for tool calls. + if agent._interrupt_state.activated: + stop_reason: StopReason = "tool_use" + message = agent._interrupt_state.context["tool_use_message"] + # Skip model invocation if the latest message contains ToolUse + elif _has_tool_use_in_latest_message(agent.messages): + stop_reason = "tool_use" + message = agent.messages[-1] + else: + model_events = _handle_model_execution( + agent, cycle_span, cycle_trace, invocation_state, tracer, structured_output_context + ) + async for model_event in model_events: + if not isinstance(model_event, ModelStopReason): + yield model_event - stop_reason, message, *_ = model_event["stop"] - yield ModelMessageEvent(message=message) + stop_reason, message, *_ = model_event["stop"] + yield ModelMessageEvent(message=message) + except Exception as e: + tracer.end_span_with_error(cycle_span, str(e), e) + raise + + try: + if stop_reason == "max_tokens": + """ + Handle max_tokens limit reached by the model. + + When the model reaches its maximum token limit, this represents a potentially unrecoverable + state where the model's response was truncated. By default, Strands fails hard with an + MaxTokensReachedException to maintain consistency with other failure types. + """ + raise MaxTokensReachedException( + message=( + "Agent has reached an unrecoverable state due to max_tokens limit. " + "For more information see: " + "https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#maxtokensreachedexception" + ) + ) - try: - if stop_reason == "max_tokens": - """ - Handle max_tokens limit reached by the model. - - When the model reaches its maximum token limit, this represents a potentially unrecoverable - state where the model's response was truncated. By default, Strands fails hard with an - MaxTokensReachedException to maintain consistency with other failure types. - """ - raise MaxTokensReachedException( - message=( - "Agent has reached an unrecoverable state due to max_tokens limit. " - "For more information see: " - "https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#maxtokensreachedexception" + if stop_reason == "tool_use": + # Handle tool execution + tool_events = _handle_tool_execution( + stop_reason, + message, + agent=agent, + cycle_trace=cycle_trace, + cycle_span=cycle_span, + cycle_start_time=cycle_start_time, + invocation_state=invocation_state, + tracer=tracer, + structured_output_context=structured_output_context, ) - ) + async for tool_event in tool_events: + yield tool_event - if stop_reason == "tool_use": - # Handle tool execution - tool_events = _handle_tool_execution( - stop_reason, - message, - agent=agent, - cycle_trace=cycle_trace, - cycle_span=cycle_span, - cycle_start_time=cycle_start_time, - invocation_state=invocation_state, - tracer=tracer, - structured_output_context=structured_output_context, - ) - async for tool_event in tool_events: - yield tool_event + return - return + # End the cycle and return results + agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace, attributes) - # End the cycle and return results - agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace, attributes) - if cycle_span: - tracer.end_event_loop_cycle_span( - span=cycle_span, - message=message, - ) - except EventLoopException as e: - if cycle_span: - tracer.end_span_with_error(cycle_span, str(e), e) + # Force structured output tool call if LLM didn't use it automatically + if structured_output_context.is_enabled and stop_reason == "end_turn": + if structured_output_context.force_attempted: + raise StructuredOutputException( + "The model failed to invoke the structured output tool even after it was forced." + ) + structured_output_context.set_forced_mode() + logger.debug("Forcing structured output tool") + await agent._append_messages( + {"role": "user", "content": [{"text": structured_output_context.structured_output_prompt}]} + ) - # Don't yield or log the exception - we already did it when we - # raised the exception and we don't need that duplication. - raise - except (ContextWindowOverflowException, MaxTokensReachedException) as e: - # Special cased exceptions which we want to bubble up rather than get wrapped in an EventLoopException - if cycle_span: + tracer.end_event_loop_cycle_span(cycle_span, message) + events = recurse_event_loop( + agent=agent, invocation_state=invocation_state, structured_output_context=structured_output_context + ) + async for typed_event in events: + yield typed_event + return + + tracer.end_event_loop_cycle_span(cycle_span, message) + yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"]) + except ( + StructuredOutputException, + EventLoopException, + ContextWindowOverflowException, + MaxTokensReachedException, + ) as e: + # These exceptions should bubble up directly rather than get wrapped in an EventLoopException tracer.end_span_with_error(cycle_span, str(e), e) - raise e - except Exception as e: - if cycle_span: + raise + except Exception as e: tracer.end_span_with_error(cycle_span, str(e), e) - - # Handle any other exceptions - yield ForceStopEvent(reason=e) - logger.exception("cycle failed") - raise EventLoopException(e, invocation_state["request_state"]) from e - - # Force structured output tool call if LLM didn't use it automatically - if structured_output_context.is_enabled and stop_reason == "end_turn": - if structured_output_context.force_attempted: - raise StructuredOutputException( - "The model failed to invoke the structured output tool even after it was forced." - ) - structured_output_context.set_forced_mode() - logger.debug("Forcing structured output tool") - await agent._append_messages( - {"role": "user", "content": [{"text": "You must format the previous response as structured output."}]} - ) - - events = recurse_event_loop( - agent=agent, invocation_state=invocation_state, structured_output_context=structured_output_context - ) - async for typed_event in events: - yield typed_event - return - - yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"]) + # Handle any other exceptions + yield ForceStopEvent(reason=e) + logger.exception("cycle failed") + raise EventLoopException(e, invocation_state["request_state"]) from e async def recurse_event_loop( @@ -315,29 +353,41 @@ async def _handle_model_execution( stream_trace = Trace("stream_messages", parent_id=cycle_trace.id) cycle_trace.add_child(stream_trace) - # Retry loop for handling throttling exceptions - current_delay = INITIAL_DELAY - for attempt in range(MAX_ATTEMPTS): + # Retry loop - actual retry logic is handled by retry_strategy hook + # Hooks control when to stop retrying via the event.retry flag + while True: model_id = agent.model.config.get("model_id") if hasattr(agent.model, "config") else None model_invoke_span = tracer.start_model_invoke_span( messages=agent.messages, parent_span=cycle_span, model_id=model_id, custom_trace_attributes=agent.trace_attributes, + system_prompt=agent.system_prompt, + system_prompt_content=agent._system_prompt_content, ) - with trace_api.use_span(model_invoke_span): - await agent.hooks.invoke_callbacks_async( - BeforeModelCallEvent( - agent=agent, + with trace_api.use_span(model_invoke_span, end_on_exit=False): + try: + # Estimate input tokens for the upcoming model call (non-fatal) + projected_input_tokens: int | None = None + try: + projected_input_tokens = await _estimate_input_tokens(agent) + except Exception as e: + logger.debug("error=<%s> | token estimation failed, proceeding without estimate", e) + + await agent.hooks.invoke_callbacks_async( + BeforeModelCallEvent( + agent=agent, + invocation_state=invocation_state, + projected_input_tokens=projected_input_tokens, + ) ) - ) - if structured_output_context.forced_mode: - tool_spec = structured_output_context.get_tool_spec() - tool_specs = [tool_spec] if tool_spec else [] - else: - tool_specs = agent.tool_registry.get_all_tool_specs() - try: + if structured_output_context.forced_mode: + tool_spec = structured_output_context.get_tool_spec() + tool_specs = [tool_spec] if tool_spec else [] + else: + tool_specs = agent.tool_registry.get_all_tool_specs() + async for event in stream_messages( agent.model, agent.system_prompt, @@ -345,59 +395,78 @@ async def _handle_model_execution( tool_specs, system_prompt_content=agent._system_prompt_content, tool_choice=structured_output_context.tool_choice, + invocation_state=invocation_state, + model_state=agent._model_state, + cancel_signal=agent._cancel_signal, ): yield event stop_reason, message, usage, metrics = event["stop"] invocation_state.setdefault("request_state", {}) - await agent.hooks.invoke_callbacks_async( - AfterModelCallEvent( - agent=agent, - stop_response=AfterModelCallEvent.ModelStopResponse( - stop_reason=stop_reason, - message=message, - ), - ) + # Attach metadata to the assistant message immediately so it's + # available to all downstream consumers (hooks, events, state). + message["metadata"] = { + "usage": usage, + "metrics": metrics, + } + + after_model_call_event = AfterModelCallEvent( + agent=agent, + invocation_state=invocation_state, + stop_response=AfterModelCallEvent.ModelStopResponse( + stop_reason=stop_reason, + message=message, + ), ) + await agent.hooks.invoke_callbacks_async(after_model_call_event) + + # Check if hooks want to retry the model call + if after_model_call_event.retry: + logger.debug( + "stop_reason=<%s>, retry_requested= | hook requested model retry", + stop_reason, + ) + tracer.end_model_invoke_span(model_invoke_span, message, usage, metrics, stop_reason) + continue # Retry the model call + if stop_reason == "max_tokens": message = recover_message_on_max_tokens_reached(message) - if model_invoke_span: - tracer.end_model_invoke_span(model_invoke_span, message, usage, metrics, stop_reason) + tracer.end_model_invoke_span(model_invoke_span, message, usage, metrics, stop_reason) break # Success! Break out of retry loop except Exception as e: - if model_invoke_span: - tracer.end_span_with_error(model_invoke_span, str(e), e) - - await agent.hooks.invoke_callbacks_async( - AfterModelCallEvent( - agent=agent, - exception=e, - ) + tracer.end_span_with_error(model_invoke_span, str(e), e) + after_model_call_event = AfterModelCallEvent( + agent=agent, + invocation_state=invocation_state, + exception=e, ) + await agent.hooks.invoke_callbacks_async(after_model_call_event) - if isinstance(e, ModelThrottledException): - if attempt + 1 == MAX_ATTEMPTS: - yield ForceStopEvent(reason=e) - raise e + # Emit backwards-compatible events if retry strategy supports it + # (prior to making the retry strategy configurable, this is what we emitted) + if ( + isinstance(agent._retry_strategy, ModelRetryStrategy) + and agent._retry_strategy._backwards_compatible_event_to_yield + ): + yield agent._retry_strategy._backwards_compatible_event_to_yield + + # Check if hooks want to retry the model call + if after_model_call_event.retry: logger.debug( - "retry_delay_seconds=<%s>, max_attempts=<%s>, current_attempt=<%s> " - "| throttling exception encountered " - "| delaying before next retry", - current_delay, - MAX_ATTEMPTS, - attempt + 1, + "exception=<%s>, retry_requested= | hook requested model retry", + type(e).__name__, ) - await asyncio.sleep(current_delay) - current_delay = min(current_delay * 2, MAX_DELAY) - yield EventLoopThrottleEvent(delay=current_delay) - else: - raise e + continue # Retry the model call + + # No retry requested, raise the exception + yield ForceStopEvent(reason=e) + raise e try: # Add message in trace and mark the end of the stream messages trace @@ -413,9 +482,6 @@ async def _handle_model_execution( agent.event_loop_metrics.update_metrics(metrics) except Exception as e: - if cycle_span: - tracer.end_span_with_error(cycle_span, str(e), e) - yield ForceStopEvent(reason=e) logger.exception("cycle failed") raise EventLoopException(e, invocation_state["request_state"]) from e @@ -468,6 +534,47 @@ async def _handle_tool_execution( tool_uses = [tool_use for tool_use in tool_uses if tool_use["toolUseId"] not in tool_use_ids] interrupts = [] + + # Check for cancellation before tool execution + # Add tool_result for each tool_use to maintain valid conversation state + if agent._cancel_signal.is_set(): + logger.debug("tool_count=<%d> | cancellation detected before tool execution", len(tool_uses)) + + # Create cancellation tool_result for each tool_use to avoid invalid message state + # (tool_use without tool_result would be rejected on next invocation) + for tool_use in tool_uses: + cancel_result: ToolResult = { + "toolUseId": str(tool_use.get("toolUseId")), + "status": "error", + "content": [{"text": "Tool execution cancelled"}], + } + tool_results.append(cancel_result) + + # Add tool results message to conversation if any tools were cancelled + cancelled_tool_result_message: Message | None = None + if tool_results: + _cancelled_msg: Message = { + "role": "user", + "content": [{"toolResult": result} for result in tool_results], + } + cancelled_tool_result_message = _cancelled_msg + agent.messages.append(_cancelled_msg) + await agent.hooks.invoke_callbacks_async(MessageAddedEvent(agent=agent, message=_cancelled_msg)) + yield ToolResultMessageEvent(message=_cancelled_msg) + + agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace) + yield EventLoopStopEvent( + "cancelled", + message, + agent.event_loop_metrics, + invocation_state["request_state"], + ) + if cycle_span: + tracer.end_event_loop_cycle_span( + span=cycle_span, message=message, tool_result_message=cancelled_tool_result_message + ) + return + tool_events = agent.tool_executor._execute( agent, tool_uses, tool_results, cycle_trace, cycle_span, invocation_state, structured_output_context ) @@ -499,6 +606,7 @@ async def _handle_tool_execution( interrupts, structured_output=structured_output_result, ) + # End the cycle span before yielding the recursive cycle. if cycle_span: tracer.end_event_loop_cycle_span(span=cycle_span, message=message) @@ -516,11 +624,13 @@ async def _handle_tool_execution( yield ToolResultMessageEvent(message=tool_result_message) + # End the cycle span before yielding the recursive cycle. if cycle_span: tracer.end_event_loop_cycle_span(span=cycle_span, message=message, tool_result_message=tool_result_message) + agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace) + if invocation_state["request_state"].get("stop_event_loop", False) or structured_output_context.stop_loop: - agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace) yield EventLoopStopEvent( stop_reason, message, diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py index 43836fe34..76eda48bf 100644 --- a/src/strands/event_loop/streaming.py +++ b/src/strands/event_loop/streaming.py @@ -2,9 +2,11 @@ import json import logging +import threading import time import warnings -from typing import Any, AsyncGenerator, AsyncIterable, Optional +from collections.abc import AsyncGenerator, AsyncIterable +from typing import Any from ..models.model import Model from ..tools import InvalidToolUseNameException @@ -185,6 +187,8 @@ def handle_content_block_start(event: ContentBlockStartEvent) -> dict[str, Any]: current_tool_use["toolUseId"] = tool_use_data["toolUseId"] current_tool_use["name"] = tool_use_data["name"] current_tool_use["input"] = "" + if "reasoningSignature" in tool_use_data: + current_tool_use["reasoningSignature"] = tool_use_data["reasoningSignature"] return current_tool_use @@ -285,16 +289,19 @@ def handle_content_block_stop(state: dict[str, Any]) -> dict[str, Any]: name=tool_use_name, input=current_tool_use["input"], ) + if "reasoningSignature" in current_tool_use: + tool_use["reasoningSignature"] = current_tool_use["reasoningSignature"] content.append({"toolUse": tool_use}) state["current_tool_use"] = {} elif text: - content.append({"text": text}) - state["text"] = "" if citations_content: - citations_block: CitationsContentBlock = {"citations": citations_content} + citations_block: CitationsContentBlock = {"citations": citations_content, "content": [{"text": text}]} content.append({"citationsContent": citations_block}) state["citationsContent"] = [] + else: + content.append({"text": text}) + state["text"] = "" elif reasoning_text: content_block: ContentBlock = { @@ -317,16 +324,31 @@ def handle_content_block_stop(state: dict[str, Any]) -> dict[str, Any]: return state -def handle_message_stop(event: MessageStopEvent) -> StopReason: +def handle_message_stop(event: MessageStopEvent, content: list[dict[str, Any]]) -> StopReason: """Handles the end of a message by returning the stop reason. + Some models return "end_turn" even when tool calls are present, which prevents the event loop from processing + those tool calls. This function overrides to "tool_use" so tool execution proceeds correctly. + Args: event: Stop event. + content: The message content blocks accumulated during streaming. Returns: The reason for stopping the stream. """ - return event["stopReason"] + stop_reason = event["stopReason"] + + if stop_reason == "end_turn" and any("toolUse" in item for item in content): + logger.warning( + "original_stop_reason=<%s>, new_stop_reason=<%s> | " + "overriding stop reason due to toolUse blocks in response", + "end_turn", + "tool_use", + ) + stop_reason = "tool_use" + + return stop_reason def handle_redact_content(event: RedactContentEvent, state: dict[str, Any]) -> None: @@ -362,13 +384,16 @@ def extract_usage_metrics(event: MetadataEvent, time_to_first_byte_ms: int | Non async def process_stream( - chunks: AsyncIterable[StreamEvent], start_time: float | None = None + chunks: AsyncIterable[StreamEvent], + start_time: float | None = None, + cancel_signal: threading.Event | None = None, ) -> AsyncGenerator[TypedEvent, None]: """Processes the response stream from the API, constructing the final message and extracting usage metrics. Args: chunks: The chunks of the response stream from the model. start_time: Time when the model request is initiated + cancel_signal: Optional threading.Event to check for cancellation during streaming. Yields: The reason for stopping, the constructed message, and the usage metrics. @@ -389,6 +414,19 @@ async def process_stream( metrics: Metrics = Metrics(latencyMs=0, timeToFirstByteMs=0) async for chunk in chunks: + # Check for cancellation during stream processing + if cancel_signal and cancel_signal.is_set(): + logger.debug("cancellation detected during stream processing") + # Return cancelled stop reason with cancellation message + # The incomplete message in state["message"] is discarded and never added to agent.messages + yield ModelStopReason( + stop_reason="cancelled", + message={"role": "assistant", "content": [{"text": "Cancelled by user"}]}, + usage=usage, + metrics=metrics, + ) + return + # Track first byte time when we get first content if first_byte_time is None and ("contentBlockDelta" in chunk or "contentBlockStart" in chunk): first_byte_time = time.time() @@ -404,7 +442,7 @@ async def process_stream( elif "contentBlockStop" in chunk: state = handle_content_block_stop(state) elif "messageStop" in chunk: - stop_reason = handle_message_stop(chunk["messageStop"]) + stop_reason = handle_message_stop(chunk["messageStop"], state["message"].get("content", [])) elif "metadata" in chunk: time_to_first_byte_ms = ( int(1000 * (first_byte_time - start_time)) if (start_time and first_byte_time) else None @@ -418,12 +456,15 @@ async def process_stream( async def stream_messages( model: Model, - system_prompt: Optional[str], + system_prompt: str | None, messages: Messages, tool_specs: list[ToolSpec], *, - tool_choice: Optional[Any] = None, - system_prompt_content: Optional[list[SystemContentBlock]] = None, + tool_choice: Any | None = None, + system_prompt_content: list[SystemContentBlock] | None = None, + invocation_state: dict[str, Any] | None = None, + model_state: dict[str, Any] | None = None, + cancel_signal: threading.Event | None = None, **kwargs: Any, ) -> AsyncGenerator[TypedEvent, None]: """Streams messages to the model and processes the response. @@ -436,6 +477,9 @@ async def stream_messages( tool_choice: Optional tool choice constraint for forcing specific tool usage. system_prompt_content: The authoritative system prompt content blocks that always contains the system prompt data. + invocation_state: Caller-provided state/context that was passed to the agent when it was invoked. + model_state: Runtime state for model providers (e.g., server-side response ids). + cancel_signal: Optional threading.Event to check for cancellation during streaming. **kwargs: Additional keyword arguments for future extensibility. Yields: @@ -444,6 +488,9 @@ async def stream_messages( logger.debug("model=<%s> | streaming messages", model) messages = _normalize_messages(messages) + # Whitelist only role and content before sending to the model provider. + # This ensures metadata (and any future non-model fields) never leak to providers. + messages = [Message(role=msg["role"], content=msg["content"]) for msg in messages] start_time = time.time() chunks = model.stream( @@ -452,7 +499,9 @@ async def stream_messages( system_prompt, tool_choice=tool_choice, system_prompt_content=system_prompt_content, + invocation_state=invocation_state, + model_state=model_state, ) - async for event in process_stream(chunks, start_time): + async for event in process_stream(chunks, start_time, cancel_signal): yield event diff --git a/src/strands/experimental/__init__.py b/src/strands/experimental/__init__.py index 3c1d0ee46..cbd9a713e 100644 --- a/src/strands/experimental/__init__.py +++ b/src/strands/experimental/__init__.py @@ -3,7 +3,7 @@ This module implements experimental features that are subject to change in future revisions without notice. """ -from . import steering, tools +from . import checkpoint, steering, tools from .agent_config import config_to_agent -__all__ = ["config_to_agent", "tools", "steering"] +__all__ = ["checkpoint", "config_to_agent", "tools", "steering"] diff --git a/src/strands/experimental/agent_config.py b/src/strands/experimental/agent_config.py index f65afb57d..e6fb94118 100644 --- a/src/strands/experimental/agent_config.py +++ b/src/strands/experimental/agent_config.py @@ -98,7 +98,7 @@ def config_to_agent(config: str | dict[str, Any], **kwargs: dict[str, Any]) -> A if not config_path.exists(): raise FileNotFoundError(f"Configuration file not found: {file_path}") - with open(config_path, "r") as f: + with open(config_path) as f: config_dict = json.load(f) elif isinstance(config, dict): config_dict = config.copy() diff --git a/src/strands/experimental/bidi/__init__.py b/src/strands/experimental/bidi/__init__.py index 57986062e..8ce3ebc68 100644 --- a/src/strands/experimental/bidi/__init__.py +++ b/src/strands/experimental/bidi/__init__.py @@ -1,9 +1,6 @@ """Bidirectional streaming package.""" -import sys - -if sys.version_info < (3, 12): - raise ImportError("bidi only supported for >= Python 3.12") +from typing import Any # Main components - Primary user interface # Re-export standard agent events for tool handling @@ -14,14 +11,10 @@ ) from .agent.agent import BidiAgent -# IO channels - Hardware abstraction -from .io.audio import BidiAudioIO - # Model interface (for custom implementations) from .models.model import BidiModel -from .models.nova_sonic import BidiNovaSonicModel -# Built-in tools +# Built-in tools (deprecated - use strands_tools.stop instead) from .tools import stop_conversation # Event types - For type hints and event handling @@ -46,12 +39,6 @@ __all__ = [ # Main interface "BidiAgent", - # IO channels - "BidiAudioIO", - # Model providers - "BidiNovaSonicModel", - # Built-in tools - "stop_conversation", # Input Event types "BidiTextInputEvent", "BidiAudioInputEvent", @@ -75,4 +62,22 @@ "ToolStreamEvent", # Model interface "BidiModel", + # Built-in tools (deprecated) + "stop_conversation", ] + + +def __getattr__(name: str) -> Any: + """Lazy load IO implementations only when accessed. + + This defers the import of optional dependencies until actually needed. + """ + if name == "BidiAudioIO": + from .io.audio import BidiAudioIO + + return BidiAudioIO + if name == "BidiTextIO": + from .io.text import BidiTextIO + + return BidiTextIO + raise AttributeError(f"cannot import name '{name}' from '{__name__}' ({__file__})") diff --git a/src/strands/experimental/bidi/_async/__init__.py b/src/strands/experimental/bidi/_async/__init__.py index 6cee3264d..47473115c 100644 --- a/src/strands/experimental/bidi/_async/__init__.py +++ b/src/strands/experimental/bidi/_async/__init__.py @@ -2,9 +2,10 @@ from typing import Awaitable, Callable +from ._task_group import _TaskGroup from ._task_pool import _TaskPool -__all__ = ["_TaskPool"] +__all__ = ["_TaskGroup", "_TaskPool"] async def stop_all(*funcs: Callable[..., Awaitable[None]]) -> None: @@ -16,14 +17,14 @@ async def stop_all(*funcs: Callable[..., Awaitable[None]]) -> None: funcs: Stop functions to call in sequence. Raises: - ExceptionGroup: If any stop function raises an exception. + RuntimeError: If any stop function raises an exception. """ exceptions = [] for func in funcs: try: await func() except Exception as exception: - exceptions.append(exception) + exceptions.append({"func_name": func.__name__, "exception": repr(exception)}) if exceptions: - raise ExceptionGroup("failed stop sequence", exceptions) + raise RuntimeError(f"exceptions={exceptions} | failed stop sequence") diff --git a/src/strands/experimental/bidi/_async/_task_group.py b/src/strands/experimental/bidi/_async/_task_group.py new file mode 100644 index 000000000..33cf63dca --- /dev/null +++ b/src/strands/experimental/bidi/_async/_task_group.py @@ -0,0 +1,69 @@ +"""Manage a group of async tasks. + +This is intended to mimic the behaviors of asyncio.TaskGroup released in Python 3.11. + +- Docs: https://docs.python.org/3/library/asyncio-task.html#task-groups +""" + +import asyncio +from typing import Any, Coroutine, cast + + +class _TaskGroup: + """Shim of asyncio.TaskGroup for use in Python 3.10. + + Attributes: + _tasks: Set of tasks in group. + """ + + _tasks: set[asyncio.Task] + + def create_task(self, coro: Coroutine[Any, Any, Any]) -> asyncio.Task: + """Create an async task and add to group. + + Returns: + The created task. + """ + task = asyncio.create_task(coro) + self._tasks.add(task) + return task + + async def __aenter__(self) -> "_TaskGroup": + """Setup self managed task group context.""" + self._tasks = set() + return self + + async def __aexit__(self, *_: Any) -> None: + """Execute tasks in group. + + The following execution rules are enforced: + - The context stops executing all tasks if at least one task raises an Exception or the context is cancelled. + - The context re-raises Exceptions to the caller. + - The context re-raises CancelledErrors to the caller only if the context itself was cancelled. + """ + try: + pending_tasks = self._tasks + while pending_tasks: + done_tasks, pending_tasks = await asyncio.wait(pending_tasks, return_when=asyncio.FIRST_EXCEPTION) + + if any(exception := done_task.exception() for done_task in done_tasks if not done_task.cancelled()): + break + + else: # all tasks completed/cancelled successfully + return + + for pending_task in pending_tasks: + pending_task.cancel() + + await asyncio.gather(*pending_tasks, return_exceptions=True) + raise cast(BaseException, exception) + + except asyncio.CancelledError: # context itself was cancelled + for task in self._tasks: + task.cancel() + + await asyncio.gather(*self._tasks, return_exceptions=True) + raise + + finally: + self._tasks = set() diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index 4012d5e2d..8c68e780e 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -25,14 +25,13 @@ from ....tools.executors import ConcurrentToolExecutor from ....tools.executors._executor import ToolExecutor from ....tools.registry import ToolRegistry +from ....tools.tool_provider import ToolProvider from ....tools.watcher import ToolWatcher from ....types.content import Message, Messages from ....types.tools import AgentTool from ...hooks.events import BidiAgentInitializedEvent, BidiMessageAddedEvent -from ...tools import ToolProvider -from .._async import stop_all +from .._async import _TaskGroup, stop_all from ..models.model import BidiModel -from ..models.nova_sonic import BidiNovaSonicModel from ..types.agent import BidiAgentInput from ..types.events import ( BidiAudioInputEvent, @@ -100,13 +99,13 @@ def __init__( ValueError: If model configuration is invalid or state is invalid type. TypeError: If model type is unsupported. """ - self.model = ( - BidiNovaSonicModel() - if not model - else BidiNovaSonicModel(model_id=model) - if isinstance(model, str) - else model - ) + if isinstance(model, BidiModel): + self.model = model + else: + from ..models.nova_sonic import BidiNovaSonicModel + + self.model = BidiNovaSonicModel(model_id=model) if isinstance(model, str) else BidiNovaSonicModel() + self.system_prompt = system_prompt self.messages = messages or [] @@ -390,7 +389,7 @@ async def run_outputs(inputs_task: asyncio.Task) -> None: for start in [*input_starts, *output_starts]: await start(self) - async with asyncio.TaskGroup() as task_group: + async with _TaskGroup() as task_group: inputs_task = task_group.create_task(run_inputs()) task_group.create_task(run_outputs(inputs_task)) diff --git a/src/strands/experimental/bidi/agent/loop.py b/src/strands/experimental/bidi/agent/loop.py index 2b883cf73..79818ae7c 100644 --- a/src/strands/experimental/bidi/agent/loop.py +++ b/src/strands/experimental/bidi/agent/loop.py @@ -5,6 +5,7 @@ import asyncio import logging +import warnings from typing import TYPE_CHECKING, Any, AsyncGenerator, cast from ....types._events import ToolInterruptEvent, ToolResultEvent, ToolResultMessageEvent, ToolUseStreamEvent @@ -248,6 +249,10 @@ async def _run_tool(self, tool_use: ToolUse) -> None: tool_results: list[ToolResult] = [] + # Ensure request_state exists for tools like strands_tools.stop + if "request_state" not in self._invocation_state: + self._invocation_state["request_state"] = {} + invocation_state: dict[str, Any] = { **self._invocation_state, "agent": self._agent, @@ -282,16 +287,29 @@ async def _run_tool(self, tool_use: ToolUse) -> None: await self._event_queue.put(ToolResultMessageEvent(tool_result_message)) - # Check for stop_conversation before sending to model - if tool_use["name"] == "stop_conversation": - logger.info("tool_name=<%s> | conversation stop requested, skipping model send", tool_use["name"]) + # Check for stop_event_loop flag (set by strands_tools.stop, stop_conversation, or any custom tool) + request_state = invocation_state.get("request_state", {}) + should_stop = request_state.get("stop_event_loop", False) + + # Backward compatibility: also check for stop_conversation by name (deprecated) + if not should_stop and tool_use["name"] == "stop_conversation": + warnings.warn( + "Stopping the event loop by tool name 'stop_conversation' is deprecated. " + "Use request_state['stop_event_loop'] = True instead.", + DeprecationWarning, + stacklevel=2, + ) + should_stop = True + + if should_stop: + logger.info("stop_event_loop= | stopping conversation") connection_id = getattr(self._agent.model, "_connection_id", "unknown") await self._event_queue.put( BidiConnectionCloseEvent(connection_id=connection_id, reason="user_request") ) - return # Skip the model send + return # Skip sending result to model - # Send result to model (all tools except stop_conversation) + # Send result to model await self.send(tool_result_event) except Exception as error: diff --git a/src/strands/experimental/bidi/io/text.py b/src/strands/experimental/bidi/io/text.py index f575c5606..00d999818 100644 --- a/src/strands/experimental/bidi/io/text.py +++ b/src/strands/experimental/bidi/io/text.py @@ -42,7 +42,7 @@ async def __call__(self, event: BidiOutputEvent) -> None: elif isinstance(event, BidiConnectionCloseEvent): if event.reason == "user_request": - print("user requested connection close using the stop_conversation tool.") + print("user requested connection close using the stop tool.") logger.debug("connection_id=<%s> | user requested connection close", event.connection_id) elif isinstance(event, BidiTranscriptStreamEvent): text = event["text"] diff --git a/src/strands/experimental/bidi/models/__init__.py b/src/strands/experimental/bidi/models/__init__.py index cc62c9987..7b87e09fe 100644 --- a/src/strands/experimental/bidi/models/__init__.py +++ b/src/strands/experimental/bidi/models/__init__.py @@ -1,10 +1,30 @@ """Bidirectional model interfaces and implementations.""" +from typing import Any + from .model import BidiModel, BidiModelTimeoutError -from .nova_sonic import BidiNovaSonicModel __all__ = [ "BidiModel", "BidiModelTimeoutError", - "BidiNovaSonicModel", ] + + +def __getattr__(name: str) -> Any: + """Lazy load bidi model implementations only when accessed. + + This defers the import of optional dependencies until actually needed. + """ + if name == "BidiGeminiLiveModel": + from .gemini_live import BidiGeminiLiveModel + + return BidiGeminiLiveModel + if name == "BidiNovaSonicModel": + from .nova_sonic import BidiNovaSonicModel + + return BidiNovaSonicModel + if name == "BidiOpenAIRealtimeModel": + from .openai_realtime import BidiOpenAIRealtimeModel + + return BidiOpenAIRealtimeModel + raise AttributeError(f"cannot import name '{name}' from '{__name__}' ({__file__})") diff --git a/src/strands/experimental/bidi/models/model.py b/src/strands/experimental/bidi/models/model.py index f5e34aa50..5941d7e41 100644 --- a/src/strands/experimental/bidi/models/model.py +++ b/src/strands/experimental/bidi/models/model.py @@ -14,7 +14,7 @@ """ import logging -from typing import Any, AsyncIterable, Protocol +from typing import Any, AsyncIterable, Protocol, runtime_checkable from ....types._events import ToolResultEvent from ....types.content import Messages @@ -27,6 +27,7 @@ logger = logging.getLogger(__name__) +@runtime_checkable class BidiModel(Protocol): """Protocol for bidirectional streaming models. diff --git a/src/strands/experimental/bidi/models/nova_sonic.py b/src/strands/experimental/bidi/models/nova_sonic.py index 6a2477e22..8ad5b2a83 100644 --- a/src/strands/experimental/bidi/models/nova_sonic.py +++ b/src/strands/experimental/bidi/models/nova_sonic.py @@ -11,8 +11,15 @@ - Tool execution with content containers and identifier tracking - 8-minute connection limits with proper cleanup sequences - Interruption detection through stopReason events + +Note, BidiNovaSonicModel is only supported for Python 3.12+ """ +import sys + +if sys.version_info < (3, 12): + raise ImportError("BidiNovaSonicModel is only supported for Python 3.12+") + import asyncio import base64 import json @@ -57,6 +64,10 @@ logger = logging.getLogger(__name__) +# Nova Sonic model identifiers +NOVA_SONIC_V1_MODEL_ID = "amazon.nova-sonic-v1:0" +NOVA_SONIC_V2_MODEL_ID = "amazon.nova-2-sonic-v1:0" + _NOVA_INFERENCE_CONFIG_KEYS = { "max_tokens": "maxTokens", "temperature": "temperature", @@ -85,6 +96,9 @@ NOVA_TEXT_CONFIG = {"mediaType": "text/plain"} NOVA_TOOL_CONFIG = {"mediaType": "application/json"} +_MAX_HISTORY_MESSAGE_BYTES = 50 * 1024 # 50KB per message +_MAX_HISTORY_TOTAL_BYTES = 200 * 1024 # 200KB total history + class BidiNovaSonicModel(BidiModel): """Nova Sonic implementation for bidirectional streaming. @@ -93,6 +107,8 @@ class BidiNovaSonicModel(BidiModel): Manages Nova Sonic's complex event sequencing, audio format conversion, and tool execution patterns while providing the standard BidiModel interface. + Note, BidiNovaSonicModel is only supported for Python 3.12+. + Attributes: _stream: open bedrock stream to nova sonic. """ @@ -101,7 +117,7 @@ class BidiNovaSonicModel(BidiModel): def __init__( self, - model_id: str = "amazon.nova-sonic-v1:0", + model_id: str = NOVA_SONIC_V2_MODEL_ID, provider_config: dict[str, Any] | None = None, client_config: dict[str, Any] | None = None, **kwargs: Any, @@ -109,19 +125,41 @@ def __init__( """Initialize Nova Sonic bidirectional model. Args: - model_id: Model identifier (default: amazon.nova-sonic-v1:0) - provider_config: Model behavior (audio, inference settings) + model_id: Model identifier (default: amazon.nova-2-sonic-v1:0) + provider_config: Model behavior configuration including: + - audio: Audio input/output settings (sample rate, voice, etc.) + - inference: Model inference settings (max_tokens, temperature, top_p) + - turn_detection: Turn detection configuration (v2 only feature) + - endpointingSensitivity: "HIGH" | "MEDIUM" | "LOW" (optional) client_config: AWS authentication (boto_session OR region, not both) **kwargs: Reserved for future parameters. + + Raises: + ValueError: If turn_detection is used with v1 model. + ValueError: If endpointingSensitivity is not HIGH, MEDIUM, or LOW. """ # Store model ID self.model_id = model_id + # Validate turn_detection configuration + provider_config = provider_config or {} + if "turn_detection" in provider_config and provider_config["turn_detection"]: + if model_id == NOVA_SONIC_V1_MODEL_ID: + raise ValueError( + f"turn_detection is only supported in Nova Sonic v2. " + f"Current model_id: {model_id}. Use {NOVA_SONIC_V2_MODEL_ID} instead." + ) + + # Validate endpointingSensitivity value if provided + sensitivity = provider_config["turn_detection"].get("endpointingSensitivity") + if sensitivity and sensitivity not in ["HIGH", "MEDIUM", "LOW"]: + raise ValueError(f"Invalid endpointingSensitivity: {sensitivity}. Must be HIGH, MEDIUM, or LOW") + # Resolve client config with defaults self._client_config = self._resolve_client_config(client_config or {}) # Resolve provider config with defaults - self.config = self._resolve_provider_config(provider_config or {}) + self.config = self._resolve_provider_config(provider_config) # Store session and region for later use self._session = self._client_config["boto_session"] @@ -173,6 +211,7 @@ def _resolve_provider_config(self, config: dict[str, Any]) -> dict[str, Any]: **config.get("audio", {}), }, "inference": config.get("inference", {}), + "turn_detection": config.get("turn_detection", {}), } return resolved @@ -260,21 +299,57 @@ def _build_initialization_events( def _log_event_type(self, nova_event: dict[str, Any]) -> None: """Log specific Nova Sonic event types for debugging.""" + # Log the full event structure for detailed debugging + event_keys = list(nova_event.keys()) + logger.debug("event_keys=<%s> | nova sonic event received", event_keys) + if "usageEvent" in nova_event: - logger.debug("usage=<%s> | nova usage event received", nova_event["usageEvent"]) + usage = nova_event["usageEvent"] + logger.debug( + "input_tokens=<%s>, output_tokens=<%s>, usage_details=<%s> | nova usage event", + usage.get("totalInputTokens", 0), + usage.get("totalOutputTokens", 0), + json.dumps(usage, indent=2), + ) elif "textOutput" in nova_event: - logger.debug("nova text output received") + text_content = nova_event["textOutput"].get("content", "") + logger.debug( + "text_length=<%d>, text_preview=<%s>, text_output_details=<%s> | nova text output", + len(text_content), + text_content[:100], + json.dumps(nova_event["textOutput"], indent=2)[:500], + ) elif "toolUse" in nova_event: tool_use = nova_event["toolUse"] logger.debug( - "tool_name=<%s>, tool_use_id=<%s> | nova tool use received", + "tool_name=<%s>, tool_use_id=<%s>, tool_use_details=<%s> | nova tool use received", tool_use["toolName"], tool_use["toolUseId"], + json.dumps(tool_use, indent=2)[:500], ) elif "audioOutput" in nova_event: audio_content = nova_event["audioOutput"]["content"] audio_bytes = base64.b64decode(audio_content) logger.debug("audio_bytes=<%d> | nova audio output received", len(audio_bytes)) + elif "completionStart" in nova_event: + completion_id = nova_event["completionStart"].get("completionId", "unknown") + logger.debug("completion_id=<%s> | nova completion started", completion_id) + elif "completionEnd" in nova_event: + completion_data = nova_event["completionEnd"] + logger.debug( + "completion_id=<%s>, stop_reason=<%s> | nova completion ended", + completion_data.get("completionId", "unknown"), + completion_data.get("stopReason", "unknown"), + ) + elif "stopReason" in nova_event: + logger.debug("stop_reason=<%s> | nova stop reason event", nova_event["stopReason"]) + else: + # Log any other event types + audio_metadata = self._get_audio_metadata_for_logging({"event": nova_event}) + if audio_metadata: + logger.debug("audio_byte_count=<%d> | nova sonic event with audio", audio_metadata["audio_byte_count"]) + else: + logger.debug("event_payload=<%s> | nova sonic event details", json.dumps(nova_event, indent=2)[:500]) async def receive(self) -> AsyncGenerator[BidiOutputEvent, None]: """Receive Nova Sonic events and convert to provider-agnostic format. @@ -303,14 +378,25 @@ async def receive(self) -> AsyncGenerator[BidiOutputEvent, None]: raise BidiModelTimeoutError(error.message) from error if not event_data: + logger.debug("received empty event data, continuing") continue - nova_event = json.loads(event_data.value.bytes_.decode("utf-8"))["event"] + # Decode and parse the event + raw_bytes = event_data.value.bytes_.decode("utf-8") + logger.debug("raw_event_size=<%d> | received nova sonic event", len(raw_bytes)) + + nova_event = json.loads(raw_bytes)["event"] self._log_event_type(nova_event) model_event = self._convert_nova_event(nova_event) if model_event: + event_type = ( + model_event.get("type", "unknown") if isinstance(model_event, dict) else type(model_event).__name__ + ) + logger.debug("converted_event_type=<%s> | yielding converted event", event_type) yield model_event + else: + logger.debug("event_not_converted | nova event did not produce output event") async def send(self, content: BidiInputEvent | ToolResultEvent) -> None: """Unified send method for all content types. Sends the given content to Nova Sonic. @@ -327,14 +413,24 @@ async def send(self, content: BidiInputEvent | ToolResultEvent) -> None: raise RuntimeError("model not started | call start before sending") if isinstance(content, BidiTextInputEvent): + text_preview = content.text[:100] if len(content.text) > 100 else content.text + logger.debug("text_length=<%d>, text_preview=<%s> | sending text content", len(content.text), text_preview) await self._send_text_content(content.text) elif isinstance(content, BidiAudioInputEvent): + audio_size = len(base64.b64decode(content.audio)) if content.audio else 0 + logger.debug("audio_bytes=<%d>, format=<%s> | sending audio content", audio_size, content.format) await self._send_audio_content(content) elif isinstance(content, ToolResultEvent): tool_result = content.get("tool_result") if tool_result: + logger.debug( + "tool_use_id=<%s>, content_blocks=<%d> | sending tool result", + tool_result.get("toolUseId", "unknown"), + len(tool_result.get("content", [])), + ) await self._send_tool_result(tool_result) else: + logger.error("content_type=<%s> | unsupported content type", type(content)) raise ValueError(f"content_type={type(content)} | content not supported") async def _start_audio_connection(self) -> None: @@ -574,7 +670,15 @@ def _convert_nova_event(self, nova_event: dict[str, Any]) -> BidiOutputEvent | N def _get_connection_start_event(self) -> str: """Generate Nova Sonic connection start event.""" inference_config = {_NOVA_INFERENCE_CONFIG_KEYS[key]: value for key, value in self.config["inference"].items()} - return json.dumps({"event": {"sessionStart": {"inferenceConfiguration": inference_config}}}) + + session_start_event: dict[str, Any] = {"event": {"sessionStart": {"inferenceConfiguration": inference_config}}} + + # Add turn detection configuration if provided (v2 feature) + turn_detection_config = self.config.get("turn_detection", {}) + if turn_detection_config: + session_start_event["event"]["sessionStart"]["turnDetectionConfiguration"] = turn_detection_config + + return json.dumps(session_start_event) def _get_prompt_start_event(self, tools: list[ToolSpec]) -> str: """Generate Nova Sonic prompt start event with tool configuration.""" @@ -625,7 +729,7 @@ def _get_system_prompt_events(self, system_prompt: str | None) -> list[str]: """Generate system prompt events.""" content_name = str(uuid.uuid4()) return [ - self._get_text_content_start_event(content_name, "SYSTEM"), + self._get_text_content_start_event(content_name, "SYSTEM", interactive=False), self._get_text_input_event(content_name, system_prompt or ""), self._get_content_end_event(content_name), ] @@ -636,42 +740,98 @@ def _get_message_history_events(self, messages: Messages) -> list[str]: Converts agent message history to Nova Sonic format following the contentStart/textInput/contentEnd pattern for each message. + History messages are sent as non-interactive (interactive=False) so Nova Sonic + treats them as prior context rather than new inputs requiring a response. + + Individual messages are truncated to 50KB and total history is capped + at 200KB. When the limit is reached, the oldest messages are dropped + to prioritize recent conversation context. + Args: messages: List of conversation messages with role and content. Returns: List of JSON event strings for Nova Sonic. """ - events = [] + max_message_bytes = _MAX_HISTORY_MESSAGE_BYTES + max_total_bytes = _MAX_HISTORY_TOTAL_BYTES - for message in messages: - role = message["role"].upper() # Convert to ASSISTANT or USER + # First pass: extract and truncate text from each message, walking backwards + # to prioritize recent messages when the total size limit is hit + prepared: list[tuple[str, str]] = [] # (role, text) + total_bytes = 0 + + for message in reversed(messages): + role = message["role"].upper() content_blocks = message.get("content", []) - # Extract text content from content blocks text_parts = [] for block in content_blocks: if "text" in block: text_parts.append(block["text"]) - # Combine all text parts - if text_parts: - combined_text = "\n".join(text_parts) - content_name = str(uuid.uuid4()) - - # Add contentStart, textInput, and contentEnd events - events.extend( - [ - self._get_text_content_start_event(content_name, role), - self._get_text_input_event(content_name, combined_text), - self._get_content_end_event(content_name), - ] + if not text_parts: + continue + + combined_text = "\n".join(text_parts) + + # Truncate individual message + encoded = combined_text.encode("utf-8") + if len(encoded) > max_message_bytes: + encoded = encoded[:max_message_bytes] + combined_text = encoded.decode("utf-8", errors="ignore") + encoded = combined_text.encode("utf-8") + + msg_bytes = len(encoded) + + if total_bytes + msg_bytes > max_total_bytes: + logger.debug( + "total_bytes=<%d>, msg_bytes=<%d>, max_total_bytes=<%d> | dropping older messages to fit limit", + total_bytes, + msg_bytes, + max_total_bytes, ) + break + + total_bytes += msg_bytes + prepared.append((role, combined_text)) + + # Reverse back to chronological order + prepared.reverse() + + # Ensure the first message is from the user role — drop leading assistant messages + while prepared and prepared[0][0] != "USER": + dropped_role, dropped_text = prepared.pop(0) + logger.debug( + "role=<%s>, text_preview=<%s> | dropping leading non-user message from history", + dropped_role, + dropped_text[:100], + ) + + logger.debug("prepared_count=<%d>, total_bytes=<%d> | final history after trimming", len(prepared), total_bytes) + + # Second pass: build events + events: list[str] = [] + for role, text in prepared: + content_name = str(uuid.uuid4()) + events.extend( + [ + self._get_text_content_start_event(content_name, role, interactive=False), + self._get_text_input_event(content_name, text), + self._get_content_end_event(content_name), + ] + ) return events - def _get_text_content_start_event(self, content_name: str, role: str = "USER") -> str: - """Generate text content start event.""" + def _get_text_content_start_event(self, content_name: str, role: str = "USER", interactive: bool = True) -> str: + """Generate text content start event. + + Args: + content_name: Unique identifier for this content block. + role: Message role (USER, ASSISTANT, SYSTEM). + interactive: Whether this is a live input (True) or history context (False). + """ return json.dumps( { "event": { @@ -680,7 +840,7 @@ def _get_text_content_start_event(self, content_name: str, role: str = "USER") - "contentName": content_name, "type": "TEXT", "role": role, - "interactive": True, + "interactive": interactive, "textInputConfiguration": NOVA_TEXT_CONFIG, } } @@ -740,6 +900,37 @@ def _get_connection_end_event(self) -> str: """Generate connection end event.""" return json.dumps({"event": {"connectionEnd": {}}}) + def _get_audio_metadata_for_logging(self, event_dict: dict[str, Any]) -> dict[str, Any]: + """Extract audio metadata from event dict for logging. + + Instead of logging large base64-encoded audio data, this extracts metadata + like byte count to verify audio presence without bloating logs. + + Args: + event_dict: The event dictionary to process. + + Returns: + A dict with audio metadata (byte_count) if audio is present, empty dict otherwise. + """ + metadata: dict[str, Any] = {} + + if "event" in event_dict: + event_data = event_dict["event"] + + # Handle contentStart events with audio + if "contentStart" in event_data and "content" in event_data["contentStart"]: + content = event_data["contentStart"]["content"] + if "audio" in content and "bytes" in content["audio"]: + metadata["audio_byte_count"] = len(content["audio"]["bytes"]) + + # Handle content events with audio + if "content" in event_data and "content" in event_data["content"]: + content = event_data["content"]["content"] + if "audio" in content and "bytes" in content["audio"]: + metadata["audio_byte_count"] = len(content["audio"]["bytes"]) + + return metadata + async def _send_nova_events(self, events: list[str]) -> None: """Send event JSON string to Nova Sonic stream. @@ -755,4 +946,3 @@ async def _send_nova_events(self, events: list[str]) -> None: value=BidirectionalInputPayloadPart(bytes_=bytes_data) ) await self._stream.input_stream.send(chunk) - logger.debug("nova sonic event sent successfully") diff --git a/src/strands/experimental/bidi/tools/__init__.py b/src/strands/experimental/bidi/tools/__init__.py index c665dc65a..de67040de 100644 --- a/src/strands/experimental/bidi/tools/__init__.py +++ b/src/strands/experimental/bidi/tools/__init__.py @@ -1,4 +1,17 @@ -"""Built-in tools for bidirectional agents.""" +"""Built-in tools for bidirectional agents. + +.. deprecated:: + The built-in ``stop_conversation`` tool is deprecated. Use ``strands_tools.stop`` or set + ``request_state["stop_event_loop"] = True`` in any custom tool instead. + +To stop a bidirectional conversation, use the standard ``stop`` tool from strands_tools:: + + from strands_tools import stop + agent = BidiAgent(tools=[stop, ...]) + +The stop tool sets ``request_state["stop_event_loop"] = True``, which signals the +BidiAgent to gracefully close the connection. +""" from .stop_conversation import stop_conversation diff --git a/src/strands/experimental/bidi/tools/stop_conversation.py b/src/strands/experimental/bidi/tools/stop_conversation.py index 9c7e1c6cd..21b530552 100644 --- a/src/strands/experimental/bidi/tools/stop_conversation.py +++ b/src/strands/experimental/bidi/tools/stop_conversation.py @@ -1,4 +1,11 @@ -"""Tool to gracefully stop a bidirectional connection.""" +"""Tool to gracefully stop a bidirectional connection. + +.. deprecated:: + The ``stop_conversation`` tool is deprecated and will be removed in a future version. + Use ``strands_tools.stop`` or set ``request_state["stop_event_loop"] = True`` in any custom tool instead. +""" + +import warnings from ....tools.decorator import tool @@ -7,10 +14,19 @@ def stop_conversation() -> str: """Stop the bidirectional conversation gracefully. + .. deprecated:: + Use ``strands_tools.stop`` or set ``request_state["stop_event_loop"] = True`` in a custom tool instead. + Use ONLY when user says "stop conversation" exactly. Do NOT use for: "stop", "goodbye", "bye", "exit", "quit", "end" or other farewells or phrases. Returns: - Success message confirming the conversation will end + Success message confirming the conversation will end. """ + warnings.warn( + "stop_conversation is deprecated and will be removed in a future version. " + "Use strands_tools.stop or set request_state['stop_event_loop'] = True in any custom tool instead.", + DeprecationWarning, + stacklevel=2, + ) return "Ending conversation" diff --git a/src/strands/experimental/checkpoint/__init__.py b/src/strands/experimental/checkpoint/__init__.py new file mode 100644 index 000000000..848cda6d6 --- /dev/null +++ b/src/strands/experimental/checkpoint/__init__.py @@ -0,0 +1,12 @@ +"""Experimental checkpoint types for durable agent execution. + +This module is experimental and subject to change in future revisions without notice. + +Checkpoints enable crash-resilient agent workflows by capturing agent state at +cycle boundaries in the agent loop. A durability provider (e.g. Temporal) can +persist checkpoints and resume from them after failures. +""" + +from .checkpoint import CHECKPOINT_SCHEMA_VERSION, Checkpoint, CheckpointPosition + +__all__ = ["CHECKPOINT_SCHEMA_VERSION", "Checkpoint", "CheckpointPosition"] diff --git a/src/strands/experimental/checkpoint/checkpoint.py b/src/strands/experimental/checkpoint/checkpoint.py new file mode 100644 index 000000000..f37e403c9 --- /dev/null +++ b/src/strands/experimental/checkpoint/checkpoint.py @@ -0,0 +1,94 @@ +"""Checkpoint system for durable agent execution. + +Checkpoints enable crash-resilient agent workflows by capturing agent state at +cycle boundaries in the agent loop. A durability provider (e.g. Temporal) can +persist checkpoints and resume from them after failures. + +Two checkpoint positions per ReAct cycle: +- after_model: model call completed, tools not yet executed. +- after_tools: all tools executed, next model call pending. + +Per-tool granularity is handled by the ToolExecutor abstraction (e.g. +TemporalToolExecutor routes each tool to a separate Temporal activity). +The SDK checkpoint operates at cycle boundaries. + +User-facing pattern (same as interrupts): +- Pause via stop_reason="checkpoint" on AgentResult +- State via AgentResult.checkpoint field +- Resume via checkpointResume content block in next agent() call + +V0 Known Limitations: +- Metrics reset on each resume call. The caller is responsible for aggregating + metrics across a durable run. EventLoopMetrics reflects only the current call. +- OpenAIResponsesModel(stateful=True) is not supported. The server-side + response_id (_model_state) is not captured in the snapshot. +- When position is "after_tools", AgentResult.message is the assistant message + that requested the tools; tool results are in the snapshot messages. +- BeforeInvocationEvent and AfterInvocationEvent fire on every resume call, + same as interrupts. Hooks counting invocations will see each resume as a + separate invocation. +- Per-tool granularity within a cycle requires a custom ToolExecutor + (e.g. TemporalToolExecutor). +""" + +import logging +from dataclasses import asdict, dataclass, field +from typing import Any, Literal + +logger = logging.getLogger(__name__) + +CHECKPOINT_SCHEMA_VERSION = "1.0" + +CheckpointPosition = Literal["after_model", "after_tools"] + + +@dataclass +class Checkpoint: + """Pause point in the agent loop. Treat as opaque — pass back to resume. + + Attributes: + position: What just completed (after_model or after_tools). + cycle_index: Which ReAct loop cycle (0-based). + snapshot: Serialized agent state as a dict, produced by ``Snapshot.to_dict()``. + Stored as ``dict[str, Any]`` (not a ``Snapshot`` object) because checkpoints + must be JSON-serializable for cross-process persistence. The consumer + reconstructs via ``Snapshot.from_dict()`` on resume. + app_data: Application-level internal state data. The SDK does not read + or modify this. Applications can store arbitrary data needed across + checkpoint boundaries (e.g. session context, workflow metadata). + Separate from ``Snapshot.app_data`` which captures agent-state-level + data managed by the SDK. + schema_version: Rejects mismatches on resume across schema versions. + """ + + position: CheckpointPosition + cycle_index: int = 0 + snapshot: dict[str, Any] = field(default_factory=dict) + app_data: dict[str, Any] = field(default_factory=dict) + schema_version: str = field(init=False, default=CHECKPOINT_SCHEMA_VERSION) + + def to_dict(self) -> dict[str, Any]: + """Serialize for persistence.""" + return asdict(self) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "Checkpoint": + """Reconstruct from a dict produced by to_dict(). + + Args: + data: Serialized checkpoint data. + + Raises: + ValueError: If schema_version doesn't match the current version. + """ + version = data.get("schema_version", "") + if version != CHECKPOINT_SCHEMA_VERSION: + raise ValueError( + f"Checkpoints with schema version {version!r} are not compatible " + f"with current version {CHECKPOINT_SCHEMA_VERSION}." + ) + known_keys = {k for k in cls.__dataclass_fields__ if k != "schema_version"} + unknown_keys = set(data.keys()) - known_keys - {"schema_version"} + if unknown_keys: + logger.warning("unknown_keys=<%s> | ignoring unknown fields in checkpoint data", unknown_keys) + return cls(**{k: v for k, v in data.items() if k in known_keys}) diff --git a/src/strands/experimental/hooks/__init__.py b/src/strands/experimental/hooks/__init__.py index c76b57ea4..f2219bf7b 100644 --- a/src/strands/experimental/hooks/__init__.py +++ b/src/strands/experimental/hooks/__init__.py @@ -1,19 +1,28 @@ """Experimental hook functionality that has not yet reached stability.""" +from typing import Any + from .events import ( - AfterModelInvocationEvent, - AfterToolInvocationEvent, - BeforeModelInvocationEvent, - BeforeToolInvocationEvent, + BidiAfterConnectionRestartEvent, BidiAfterInvocationEvent, BidiAfterToolCallEvent, BidiAgentInitializedEvent, + BidiBeforeConnectionRestartEvent, BidiBeforeInvocationEvent, BidiBeforeToolCallEvent, BidiInterruptionEvent, BidiMessageAddedEvent, ) +# Deprecated aliases are accessed via __getattr__ to emit warnings only on use + + +def __getattr__(name: str) -> Any: + from . import events + + return getattr(events, name) + + __all__ = [ "BeforeToolInvocationEvent", "AfterToolInvocationEvent", @@ -27,4 +36,6 @@ "BidiBeforeToolCallEvent", "BidiAfterToolCallEvent", "BidiInterruptionEvent", + "BidiBeforeConnectionRestartEvent", + "BidiAfterConnectionRestartEvent", ] diff --git a/src/strands/experimental/hooks/events.py b/src/strands/experimental/hooks/events.py index 8a8d80629..081190af3 100644 --- a/src/strands/experimental/hooks/events.py +++ b/src/strands/experimental/hooks/events.py @@ -5,7 +5,7 @@ import warnings from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Literal, TypeAlias +from typing import TYPE_CHECKING, Any, Literal from ...hooks.events import AfterModelCallEvent, AfterToolCallEvent, BeforeModelCallEvent, BeforeToolCallEvent from ...hooks.registry import BaseHookEvent @@ -16,17 +16,25 @@ from ..bidi.agent.agent import BidiAgent from ..bidi.models import BidiModelTimeoutError -warnings.warn( - "BeforeModelCallEvent, AfterModelCallEvent, BeforeToolCallEvent, and AfterToolCallEvent are no longer experimental." - "Import from strands.hooks instead.", - DeprecationWarning, - stacklevel=2, -) - -BeforeToolInvocationEvent: TypeAlias = BeforeToolCallEvent -AfterToolInvocationEvent: TypeAlias = AfterToolCallEvent -BeforeModelInvocationEvent: TypeAlias = BeforeModelCallEvent -AfterModelInvocationEvent: TypeAlias = AfterModelCallEvent +# Deprecated aliases - warning emitted on access via __getattr__ +_DEPRECATED_ALIASES = { + "BeforeToolInvocationEvent": BeforeToolCallEvent, + "AfterToolInvocationEvent": AfterToolCallEvent, + "BeforeModelInvocationEvent": BeforeModelCallEvent, + "AfterModelInvocationEvent": AfterModelCallEvent, +} + + +def __getattr__(name: str) -> Any: + if name in _DEPRECATED_ALIASES: + warnings.warn( + f"{name} has been moved to production with an updated name. " + f"Use {_DEPRECATED_ALIASES[name].__name__} from strands.hooks instead.", + DeprecationWarning, + stacklevel=2, + ) + return _DEPRECATED_ALIASES[name] + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") # BidiAgent Hook Events diff --git a/src/strands/experimental/hooks/multiagent/__init__.py b/src/strands/experimental/hooks/multiagent/__init__.py index d059d0da5..6755db7e4 100644 --- a/src/strands/experimental/hooks/multiagent/__init__.py +++ b/src/strands/experimental/hooks/multiagent/__init__.py @@ -1,6 +1,6 @@ -"""Multi-agent hook events and utilities. +"""Multi-agent hook events. -Provides event classes for hooking into multi-agent orchestrator lifecycle. +Deprecated: Use strands.hooks.multiagent instead. """ from .events import ( diff --git a/src/strands/experimental/hooks/multiagent/events.py b/src/strands/experimental/hooks/multiagent/events.py index fa881bf32..2c65c53e3 100644 --- a/src/strands/experimental/hooks/multiagent/events.py +++ b/src/strands/experimental/hooks/multiagent/events.py @@ -1,118 +1,28 @@ """Multi-agent execution lifecycle events for hook system integration. -These events are fired by orchestrators (Graph/Swarm) at key points so -hooks can persist, monitor, or debug execution. No intermediate state model -is used—hooks read from the orchestrator directly. +Deprecated: Use strands.hooks.multiagent instead. """ -import uuid -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any - -from typing_extensions import override - -from ....hooks import BaseHookEvent -from ....types.interrupt import _Interruptible - -if TYPE_CHECKING: - from ....multiagent.base import MultiAgentBase - - -@dataclass -class MultiAgentInitializedEvent(BaseHookEvent): - """Event triggered when multi-agent orchestrator initialized. - - Attributes: - source: The multi-agent orchestrator instance - invocation_state: Configuration that user passes in - """ - - source: "MultiAgentBase" - invocation_state: dict[str, Any] | None = None - - -@dataclass -class BeforeNodeCallEvent(BaseHookEvent, _Interruptible): - """Event triggered before individual node execution starts. - - Attributes: - source: The multi-agent orchestrator instance - node_id: ID of the node about to execute - invocation_state: Configuration that user passes in - cancel_node: A user defined message that when set, will cancel the node execution with status FAILED. - The message will be emitted under a MultiAgentNodeCancel event. If set to `True`, Strands will cancel the - node using a default cancel message. - """ - - source: "MultiAgentBase" - node_id: str - invocation_state: dict[str, Any] | None = None - cancel_node: bool | str = False - - def _can_write(self, name: str) -> bool: - return name in ["cancel_node"] - - @override - def _interrupt_id(self, name: str) -> str: - """Unique id for the interrupt. - - Args: - name: User defined name for the interrupt. - - Returns: - Interrupt id. - """ - node_id = uuid.uuid5(uuid.NAMESPACE_OID, self.node_id) - call_id = uuid.uuid5(uuid.NAMESPACE_OID, name) - return f"v1:before_node_call:{node_id}:{call_id}" - - -@dataclass -class AfterNodeCallEvent(BaseHookEvent): - """Event triggered after individual node execution completes. - - Attributes: - source: The multi-agent orchestrator instance - node_id: ID of the node that just completed execution - invocation_state: Configuration that user passes in - """ - - source: "MultiAgentBase" - node_id: str - invocation_state: dict[str, Any] | None = None - - @property - def should_reverse_callbacks(self) -> bool: - """True to invoke callbacks in reverse order.""" - return True - - -@dataclass -class BeforeMultiAgentInvocationEvent(BaseHookEvent): - """Event triggered before orchestrator execution starts. - - Attributes: - source: The multi-agent orchestrator instance - invocation_state: Configuration that user passes in - """ - - source: "MultiAgentBase" - invocation_state: dict[str, Any] | None = None - - -@dataclass -class AfterMultiAgentInvocationEvent(BaseHookEvent): - """Event triggered after orchestrator execution completes. - - Attributes: - source: The multi-agent orchestrator instance - invocation_state: Configuration that user passes in - """ - - source: "MultiAgentBase" - invocation_state: dict[str, Any] | None = None - - @property - def should_reverse_callbacks(self) -> bool: - """True to invoke callbacks in reverse order.""" - return True +import warnings + +from ....hooks import ( + AfterMultiAgentInvocationEvent, + AfterNodeCallEvent, + BeforeMultiAgentInvocationEvent, + BeforeNodeCallEvent, + MultiAgentInitializedEvent, +) + +warnings.warn( + "strands.experimental.hooks.multiagent is deprecated. Use strands.hooks instead.", + DeprecationWarning, + stacklevel=2, +) + +__all__ = [ + "AfterMultiAgentInvocationEvent", + "AfterNodeCallEvent", + "BeforeMultiAgentInvocationEvent", + "BeforeNodeCallEvent", + "MultiAgentInitializedEvent", +] diff --git a/src/strands/experimental/steering/__init__.py b/src/strands/experimental/steering/__init__.py index 4d0775873..1db07c90f 100644 --- a/src/strands/experimental/steering/__init__.py +++ b/src/strands/experimental/steering/__init__.py @@ -1,37 +1,14 @@ -"""Steering system for Strands agents. +"""Deprecated: Steering has moved to strands.vended_plugins.steering. -Provides contextual guidance for agents through modular prompting with progressive disclosure. -Instead of front-loading all instructions, steering handlers provide just-in-time feedback -based on local context data populated by context callbacks. - -Core components: - -- SteeringHandler: Base class for guidance logic with local context -- SteeringContextCallback: Protocol for context update functions -- SteeringContextProvider: Protocol for multi-event context providers -- SteeringAction: Proceed/Guide/Interrupt decisions - -Usage: - handler = LLMSteeringHandler(system_prompt="...") - agent = Agent(tools=[...], hooks=[handler]) +This module provides backwards-compatible aliases that emit deprecation warnings. """ -# Core primitives -# Context providers -from .context_providers.ledger_provider import ( - LedgerAfterToolCall, - LedgerBeforeToolCall, - LedgerProvider, -) -from .core.action import Guide, Interrupt, Proceed, SteeringAction -from .core.context import SteeringContextCallback, SteeringContextProvider -from .core.handler import SteeringHandler - -# Handler implementations -from .handlers.llm import LLMPromptMapper, LLMSteeringHandler - -__all__ = [ - "SteeringAction", +import warnings +from typing import Any + +_DEPRECATED_NAMES = { + "ToolSteeringAction", + "ModelSteeringAction", "Proceed", "Guide", "Interrupt", @@ -43,4 +20,20 @@ "LedgerProvider", "LLMSteeringHandler", "LLMPromptMapper", -] +} + + +def __getattr__(name: str) -> Any: + if name in _DEPRECATED_NAMES: + from strands.vended_plugins import steering + + warnings.warn( + f"{name} has been moved to production. Use {name} from strands.vended_plugins.steering instead.", + DeprecationWarning, + stacklevel=2, + ) + return getattr(steering, name) + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +__all__: list[str] = [] diff --git a/src/strands/experimental/steering/context_providers/__init__.py b/src/strands/experimental/steering/context_providers/__init__.py index 242ed9cf1..81a0fa709 100644 --- a/src/strands/experimental/steering/context_providers/__init__.py +++ b/src/strands/experimental/steering/context_providers/__init__.py @@ -1,13 +1,23 @@ -"""Context providers for steering evaluation.""" - -from .ledger_provider import ( - LedgerAfterToolCall, - LedgerBeforeToolCall, - LedgerProvider, -) - -__all__ = [ - "LedgerAfterToolCall", - "LedgerBeforeToolCall", - "LedgerProvider", -] +"""Deprecated: Use strands.vended_plugins.steering.context_providers instead.""" + +import warnings +from typing import Any + +_TARGET_MODULE = "strands.vended_plugins.steering.context_providers" + + +def __getattr__(name: str) -> Any: + from strands.vended_plugins.steering import context_providers + + obj = getattr(context_providers, name, None) + if obj is not None: + warnings.warn( + f"{name} has been moved to production. Use {name} from {_TARGET_MODULE} instead.", + DeprecationWarning, + stacklevel=2, + ) + return obj + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +__all__: list[str] = [] diff --git a/src/strands/experimental/steering/context_providers/ledger_provider.py b/src/strands/experimental/steering/context_providers/ledger_provider.py index da8504bd0..3cc21774e 100644 --- a/src/strands/experimental/steering/context_providers/ledger_provider.py +++ b/src/strands/experimental/steering/context_providers/ledger_provider.py @@ -1,85 +1,23 @@ -"""Ledger context provider for comprehensive agent activity tracking. +"""Deprecated: Use strands.vended_plugins.steering.context_providers.ledger_provider instead.""" -Tracks complete agent activity ledger including tool calls, conversation history, -and timing information. This comprehensive audit trail enables steering handlers -to make informed guidance decisions based on agent behavior patterns and history. - -Data captured: - - - Tool call history with inputs, outputs, timing, success/failure - - Conversation messages and agent responses - - Session metadata and timing information - - Error patterns and recovery attempts - -Usage: - Use as context provider functions or mix into steering handlers. -""" - -import logging -from datetime import datetime +import warnings from typing import Any -from ....hooks.events import AfterToolCallEvent, BeforeToolCallEvent -from ..core.context import SteeringContext, SteeringContextCallback, SteeringContextProvider - -logger = logging.getLogger(__name__) - - -class LedgerBeforeToolCall(SteeringContextCallback[BeforeToolCallEvent]): - """Context provider for ledger tracking before tool calls.""" - - def __init__(self) -> None: - """Initialize the ledger provider.""" - self.session_start = datetime.now().isoformat() - - def __call__(self, event: BeforeToolCallEvent, steering_context: SteeringContext, **kwargs: Any) -> None: - """Update ledger before tool call.""" - ledger = steering_context.data.get("ledger") or {} - - if not ledger: - ledger = { - "session_start": self.session_start, - "tool_calls": [], - "conversation_history": [], - "session_metadata": {}, - } - - tool_call_entry = { - "timestamp": datetime.now().isoformat(), - "tool_name": event.tool_use.get("name"), - "tool_args": event.tool_use.get("arguments", {}), - "status": "pending", - } - ledger["tool_calls"].append(tool_call_entry) - steering_context.data.set("ledger", ledger) - - -class LedgerAfterToolCall(SteeringContextCallback[AfterToolCallEvent]): - """Context provider for ledger tracking after tool calls.""" +_TARGET_MODULE = "strands.vended_plugins.steering.context_providers.ledger_provider" - def __call__(self, event: AfterToolCallEvent, steering_context: SteeringContext, **kwargs: Any) -> None: - """Update ledger after tool call.""" - ledger = steering_context.data.get("ledger") or {} - if ledger.get("tool_calls"): - last_call = ledger["tool_calls"][-1] - last_call.update( - { - "completion_timestamp": datetime.now().isoformat(), - "status": event.result["status"], - "result": event.result["content"], - "error": str(event.exception) if event.exception else None, - } - ) - steering_context.data.set("ledger", ledger) +def __getattr__(name: str) -> Any: + from strands.vended_plugins.steering.context_providers import ledger_provider + obj = getattr(ledger_provider, name, None) + if obj is not None: + warnings.warn( + f"{name} has been moved to production. Use {name} from {_TARGET_MODULE} instead.", + DeprecationWarning, + stacklevel=2, + ) + return obj + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") -class LedgerProvider(SteeringContextProvider): - """Combined ledger context provider for both before and after tool calls.""" - def context_providers(self, **kwargs: Any) -> list[SteeringContextCallback]: - """Return ledger context providers with shared state.""" - return [ - LedgerBeforeToolCall(), - LedgerAfterToolCall(), - ] +__all__: list[str] = [] diff --git a/src/strands/experimental/steering/core/__init__.py b/src/strands/experimental/steering/core/__init__.py index a3efe0dbc..e7c79f66d 100644 --- a/src/strands/experimental/steering/core/__init__.py +++ b/src/strands/experimental/steering/core/__init__.py @@ -1,6 +1,23 @@ -"""Core steering system interfaces and base classes.""" +"""Deprecated: Use strands.vended_plugins.steering.core instead.""" -from .action import Guide, Interrupt, Proceed, SteeringAction -from .handler import SteeringHandler +import warnings +from typing import Any -__all__ = ["SteeringAction", "Proceed", "Guide", "Interrupt", "SteeringHandler"] +_TARGET_MODULE = "strands.vended_plugins.steering.core" + + +def __getattr__(name: str) -> Any: + from strands.vended_plugins.steering import core + + obj = getattr(core, name, None) + if obj is not None: + warnings.warn( + f"{name} has been moved to production. Use {name} from {_TARGET_MODULE} instead.", + DeprecationWarning, + stacklevel=2, + ) + return obj + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +__all__: list[str] = [] diff --git a/src/strands/experimental/steering/core/action.py b/src/strands/experimental/steering/core/action.py index 8b4ec141d..9e60aa704 100644 --- a/src/strands/experimental/steering/core/action.py +++ b/src/strands/experimental/steering/core/action.py @@ -1,65 +1,23 @@ -"""SteeringAction types for steering evaluation results. +"""Deprecated: Use strands.vended_plugins.steering.core.action instead.""" -Defines structured outcomes from steering handlers that determine how tool calls -should be handled. SteeringActions enable modular prompting by providing just-in-time -feedback rather than front-loading all instructions in monolithic prompts. +import warnings +from typing import Any -Flow: - SteeringHandler.steer() → SteeringAction → BeforeToolCallEvent handling - ↓ ↓ ↓ - Evaluate context Action type Tool execution modified +_TARGET_MODULE = "strands.vended_plugins.steering.core.action" -SteeringAction types: - Proceed: Tool executes immediately (no intervention needed) - Guide: Tool cancelled, agent receives contextual feedback to explore alternatives - Interrupt: Tool execution paused for human input via interrupt system -Extensibility: - New action types can be added to the union. Always handle the default - case in pattern matching to maintain backward compatibility. -""" +def __getattr__(name: str) -> Any: + from strands.vended_plugins.steering.core import action -from typing import Annotated, Literal + obj = getattr(action, name, None) + if obj is not None: + warnings.warn( + f"{name} has been moved to production. Use {name} from {_TARGET_MODULE} instead.", + DeprecationWarning, + stacklevel=2, + ) + return obj + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") -from pydantic import BaseModel, Field - -class Proceed(BaseModel): - """Allow tool to execute immediately without intervention. - - The tool call proceeds as planned. The reason provides context - for logging and debugging purposes. - """ - - type: Literal["proceed"] = "proceed" - reason: str - - -class Guide(BaseModel): - """Cancel tool and provide contextual feedback for agent to explore alternatives. - - The tool call is cancelled and the agent receives the reason as contextual - feedback to help them consider alternative approaches while maintaining - adaptive reasoning capabilities. - """ - - type: Literal["guide"] = "guide" - reason: str - - -class Interrupt(BaseModel): - """Pause tool execution for human input via interrupt system. - - The tool call is paused and human input is requested through Strands' - interrupt system. The human can approve or deny the operation, and their - decision determines whether the tool executes or is cancelled. - """ - - type: Literal["interrupt"] = "interrupt" - reason: str - - -# SteeringAction union - extensible for future action types -# IMPORTANT: Always handle the default case when pattern matching -# to maintain backward compatibility as new action types are added -SteeringAction = Annotated[Proceed | Guide | Interrupt, Field(discriminator="type")] +__all__: list[str] = [] diff --git a/src/strands/experimental/steering/core/context.py b/src/strands/experimental/steering/core/context.py index 446c4c9f9..15014118f 100644 --- a/src/strands/experimental/steering/core/context.py +++ b/src/strands/experimental/steering/core/context.py @@ -1,77 +1,23 @@ -"""Steering context protocols for contextual guidance. +"""Deprecated: Use strands.vended_plugins.steering.core.context instead.""" -Defines protocols for context callbacks and providers that populate -steering context data used by handlers to make guidance decisions. +import warnings +from typing import Any -Architecture: - SteeringContextCallback → Handler.steering_context → SteeringHandler.steer() - ↓ ↓ ↓ - Update local context Store in handler Access via self.steering_context +_TARGET_MODULE = "strands.vended_plugins.steering.core.context" -Context lifecycle: - 1. Handler registers context callbacks for hook events - 2. Callbacks update handler's local steering_context on events - 3. Handler accesses self.steering_context in steer() method - 4. Context persists across calls within handler instance -Implementation: - Each handler maintains its own JSONSerializableDict context. - Callbacks are registered per handler instance for isolation. - Providers can supply multiple callbacks for different events. -""" +def __getattr__(name: str) -> Any: + from strands.vended_plugins.steering.core import context -import logging -from abc import ABC, abstractmethod -from dataclasses import dataclass, field -from typing import Any, Generic, TypeVar, cast, get_args, get_origin + obj = getattr(context, name, None) + if obj is not None: + warnings.warn( + f"{name} has been moved to production. Use {name} from {_TARGET_MODULE} instead.", + DeprecationWarning, + stacklevel=2, + ) + return obj + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") -from ....hooks.registry import HookEvent -from ....types.json_dict import JSONSerializableDict -logger = logging.getLogger(__name__) - - -@dataclass -class SteeringContext: - """Container for steering context data.""" - - """Container for steering context data. - - This class should not be instantiated directly - it is intended for internal use only. - """ - - data: JSONSerializableDict = field(default_factory=JSONSerializableDict) - - -EventType = TypeVar("EventType", bound=HookEvent, contravariant=True) - - -class SteeringContextCallback(ABC, Generic[EventType]): - """Abstract base class for steering context update callbacks.""" - - @property - def event_type(self) -> type[HookEvent]: - """Return the event type this callback handles.""" - for base in getattr(self.__class__, "__orig_bases__", ()): - if get_origin(base) is SteeringContextCallback: - return cast(type[HookEvent], get_args(base)[0]) - raise ValueError("Could not determine event type from generic parameter") - - def __call__(self, event: EventType, steering_context: "SteeringContext", **kwargs: Any) -> None: - """Update steering context based on hook event. - - Args: - event: The hook event that triggered the callback - steering_context: The steering context to update - **kwargs: Additional keyword arguments for context updates - """ - ... - - -class SteeringContextProvider(ABC): - """Abstract base class for context providers that handle multiple event types.""" - - @abstractmethod - def context_providers(self, **kwargs: Any) -> list[SteeringContextCallback]: - """Return list of context callbacks with event types extracted from generics.""" - ... +__all__: list[str] = [] diff --git a/src/strands/experimental/steering/core/handler.py b/src/strands/experimental/steering/core/handler.py index 4a0bcaa6a..5892fb026 100644 --- a/src/strands/experimental/steering/core/handler.py +++ b/src/strands/experimental/steering/core/handler.py @@ -1,134 +1,23 @@ -"""Steering handler base class for providing contextual guidance to agents. +"""Deprecated: Use strands.vended_plugins.steering.core.handler instead.""" -Provides modular prompting through contextual guidance that appears when relevant, -rather than front-loading all instructions. Handlers integrate with the Strands hook -system to intercept tool calls and provide just-in-time feedback based on local context. +import warnings +from typing import Any -Architecture: - BeforeToolCallEvent → Context Callbacks → Update steering_context → steer() → SteeringAction - ↓ ↓ ↓ ↓ ↓ - Hook triggered Populate context Handler evaluates Handler decides Action taken +_TARGET_MODULE = "strands.vended_plugins.steering.core.handler" -Lifecycle: - 1. Context callbacks update handler's steering_context on hook events - 2. BeforeToolCallEvent triggers steering evaluation via steer() method - 3. Handler accesses self.steering_context for guidance decisions - 4. SteeringAction determines tool execution: Proceed/Guide/Interrupt -Implementation: - Subclass SteeringHandler and implement steer() method. - Pass context_callbacks in constructor to register context update functions. - Each handler maintains isolated steering_context that persists across calls. +def __getattr__(name: str) -> Any: + from strands.vended_plugins.steering.core import handler -SteeringAction handling: - Proceed: Tool executes immediately - Guide: Tool cancelled, agent receives contextual feedback to explore alternatives - Interrupt: Tool execution paused for human input via interrupt system -""" + obj = getattr(handler, name, None) + if obj is not None: + warnings.warn( + f"{name} has been moved to production. Use {name} from {_TARGET_MODULE} instead.", + DeprecationWarning, + stacklevel=2, + ) + return obj + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") -import logging -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any -from ....hooks.events import BeforeToolCallEvent -from ....hooks.registry import HookProvider, HookRegistry -from ....types.tools import ToolUse -from .action import Guide, Interrupt, Proceed, SteeringAction -from .context import SteeringContext, SteeringContextProvider - -if TYPE_CHECKING: - from ....agent import Agent - -logger = logging.getLogger(__name__) - - -class SteeringHandler(HookProvider, ABC): - """Base class for steering handlers that provide contextual guidance to agents. - - Steering handlers maintain local context and register hook callbacks - to populate context data as needed for guidance decisions. - """ - - def __init__(self, context_providers: list[SteeringContextProvider] | None = None): - """Initialize the steering handler. - - Args: - context_providers: List of context providers for context updates - """ - super().__init__() - self.steering_context = SteeringContext() - self._context_callbacks = [] - - # Collect callbacks from all providers - for provider in context_providers or []: - self._context_callbacks.extend(provider.context_providers()) - - logger.debug("handler_class=<%s> | initialized", self.__class__.__name__) - - def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: - """Register hooks for steering guidance and context updates.""" - # Register context update callbacks - for callback in self._context_callbacks: - registry.add_callback( - callback.event_type, lambda event, callback=callback: callback(event, self.steering_context) - ) - - # Register steering guidance - registry.add_callback(BeforeToolCallEvent, self._provide_steering_guidance) - - async def _provide_steering_guidance(self, event: BeforeToolCallEvent) -> None: - """Provide steering guidance for tool call.""" - tool_name = event.tool_use["name"] - logger.debug("tool_name=<%s> | providing steering guidance", tool_name) - - try: - action = await self.steer(event.agent, event.tool_use) - except Exception as e: - logger.debug("tool_name=<%s>, error=<%s> | steering handler guidance failed", tool_name, e) - return - - self._handle_steering_action(action, event, tool_name) - - def _handle_steering_action(self, action: SteeringAction, event: BeforeToolCallEvent, tool_name: str) -> None: - """Handle the steering action by modifying tool execution flow. - - Proceed: Tool executes normally - Guide: Tool cancelled with contextual feedback for agent to consider alternatives - Interrupt: Tool execution paused for human input via interrupt system - """ - if isinstance(action, Proceed): - logger.debug("tool_name=<%s> | tool call proceeding", tool_name) - elif isinstance(action, Guide): - logger.debug("tool_name=<%s> | tool call guided: %s", tool_name, action.reason) - event.cancel_tool = ( - f"Tool call cancelled given new guidance. {action.reason}. Consider this approach and continue" - ) - elif isinstance(action, Interrupt): - logger.debug("tool_name=<%s> | tool call requires human input: %s", tool_name, action.reason) - can_proceed: bool = event.interrupt(name=f"steering_input_{tool_name}", reason={"message": action.reason}) - logger.debug("tool_name=<%s> | received human input for tool call", tool_name) - - if not can_proceed: - event.cancel_tool = f"Manual approval denied: {action.reason}" - logger.debug("tool_name=<%s> | tool call denied by manual approval", tool_name) - else: - logger.debug("tool_name=<%s> | tool call approved manually", tool_name) - else: - raise ValueError(f"Unknown steering action type: {action}") - - @abstractmethod - async def steer(self, agent: "Agent", tool_use: ToolUse, **kwargs: Any) -> SteeringAction: - """Provide contextual guidance to help agent navigate complex workflows. - - Args: - agent: The agent instance - tool_use: The tool use object with name and arguments - **kwargs: Additional keyword arguments for guidance evaluation - - Returns: - SteeringAction indicating how to guide the agent's next action - - Note: - Access steering context via self.steering_context - """ - ... +__all__: list[str] = [] diff --git a/src/strands/experimental/steering/handlers/__init__.py b/src/strands/experimental/steering/handlers/__init__.py index ca529530f..128fc946c 100644 --- a/src/strands/experimental/steering/handlers/__init__.py +++ b/src/strands/experimental/steering/handlers/__init__.py @@ -1,3 +1,23 @@ -"""Steering handler implementations.""" +"""Deprecated: Use strands.vended_plugins.steering.handlers instead.""" -__all__ = [] +import warnings +from typing import Any + +_TARGET_MODULE = "strands.vended_plugins.steering.handlers" + + +def __getattr__(name: str) -> Any: + from strands.vended_plugins.steering import handlers + + obj = getattr(handlers, name, None) + if obj is not None: + warnings.warn( + f"{name} has been moved to production. Use {name} from {_TARGET_MODULE} instead.", + DeprecationWarning, + stacklevel=2, + ) + return obj + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +__all__: list[str] = [] diff --git a/src/strands/experimental/steering/handlers/llm/__init__.py b/src/strands/experimental/steering/handlers/llm/__init__.py index 4dcccbe80..aef580729 100644 --- a/src/strands/experimental/steering/handlers/llm/__init__.py +++ b/src/strands/experimental/steering/handlers/llm/__init__.py @@ -1,6 +1,23 @@ -"""LLM steering handler with prompt mapping.""" +"""Deprecated: Use strands.vended_plugins.steering.handlers.llm instead.""" -from .llm_handler import LLMSteeringHandler -from .mappers import DefaultPromptMapper, LLMPromptMapper, ToolUse +import warnings +from typing import Any -__all__ = ["LLMSteeringHandler", "LLMPromptMapper", "DefaultPromptMapper", "ToolUse"] +_TARGET_MODULE = "strands.vended_plugins.steering.handlers.llm" + + +def __getattr__(name: str) -> Any: + from strands.vended_plugins.steering.handlers import llm + + obj = getattr(llm, name, None) + if obj is not None: + warnings.warn( + f"{name} has been moved to production. Use {name} from {_TARGET_MODULE} instead.", + DeprecationWarning, + stacklevel=2, + ) + return obj + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +__all__: list[str] = [] diff --git a/src/strands/experimental/steering/handlers/llm/llm_handler.py b/src/strands/experimental/steering/handlers/llm/llm_handler.py index 9d9b34911..8c1b6d200 100644 --- a/src/strands/experimental/steering/handlers/llm/llm_handler.py +++ b/src/strands/experimental/steering/handlers/llm/llm_handler.py @@ -1,95 +1,23 @@ -"""LLM-based steering handler that uses an LLM to provide contextual guidance.""" +"""Deprecated: Use strands.vended_plugins.steering.handlers.llm.llm_handler instead.""" -from __future__ import annotations +import warnings +from typing import Any -import logging -from typing import TYPE_CHECKING, Any, Literal, cast +_TARGET_MODULE = "strands.vended_plugins.steering.handlers.llm.llm_handler" -from pydantic import BaseModel, Field -from .....models import Model -from .....types.tools import ToolUse -from ...context_providers.ledger_provider import LedgerProvider -from ...core.action import Guide, Interrupt, Proceed, SteeringAction -from ...core.context import SteeringContextProvider -from ...core.handler import SteeringHandler -from .mappers import DefaultPromptMapper, LLMPromptMapper +def __getattr__(name: str) -> Any: + from strands.vended_plugins.steering.handlers.llm import llm_handler -if TYPE_CHECKING: - from .....agent import Agent - -logger = logging.getLogger(__name__) - - -class _LLMSteering(BaseModel): - """Structured output model for LLM steering decisions.""" - - decision: Literal["proceed", "guide", "interrupt"] = Field( - description="Steering decision: 'proceed' to continue, 'guide' to provide feedback, 'interrupt' for human input" - ) - reason: str = Field(description="Clear explanation of the steering decision and any guidance provided") - - -class LLMSteeringHandler(SteeringHandler): - """Steering handler that uses an LLM to provide contextual guidance. - - Uses natural language prompts to evaluate tool calls and provide - contextual steering guidance to help agents navigate complex workflows. - """ - - def __init__( - self, - system_prompt: str, - prompt_mapper: LLMPromptMapper | None = None, - model: Model | None = None, - context_providers: list[SteeringContextProvider] | None = None, - ): - """Initialize the LLMSteeringHandler. - - Args: - system_prompt: System prompt defining steering guidance rules - prompt_mapper: Custom prompt mapper for evaluation prompts - model: Optional model override for steering evaluation - context_providers: List of context providers for populating steering context - """ - providers = context_providers or [LedgerProvider()] - super().__init__(context_providers=providers) - self.system_prompt = system_prompt - self.prompt_mapper = prompt_mapper or DefaultPromptMapper() - self.model = model - - async def steer(self, agent: "Agent", tool_use: ToolUse, **kwargs: Any) -> SteeringAction: - """Provide contextual guidance for tool usage. - - Args: - agent: The agent instance - tool_use: The tool use object with name and arguments - **kwargs: Additional keyword arguments for steering evaluation - - Returns: - SteeringAction indicating how to guide the agent's next action - """ - # Generate steering prompt - prompt = self.prompt_mapper.create_steering_prompt(self.steering_context, tool_use=tool_use) - - # Create isolated agent for steering evaluation (no shared conversation state) - from .....agent import Agent - - steering_agent = Agent(system_prompt=self.system_prompt, model=self.model or agent.model, callback_handler=None) - - # Get LLM decision - llm_result: _LLMSteering = cast( - _LLMSteering, steering_agent(prompt, structured_output_model=_LLMSteering).structured_output + obj = getattr(llm_handler, name, None) + if obj is not None: + warnings.warn( + f"{name} has been moved to production. Use {name} from {_TARGET_MODULE} instead.", + DeprecationWarning, + stacklevel=2, ) + return obj + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + - # Convert LLM decision to steering action - match llm_result.decision: - case "proceed": - return Proceed(reason=llm_result.reason) - case "guide": - return Guide(reason=llm_result.reason) - case "interrupt": - return Interrupt(reason=llm_result.reason) - case _: - logger.warning("decision=<%s> | uŹknown llm decision, defaulting to proceed", llm_result.decision) # type: ignore[unreachable] - return Proceed(reason="Unknown LLM decision, defaulting to proceed") +__all__: list[str] = [] diff --git a/src/strands/experimental/steering/handlers/llm/mappers.py b/src/strands/experimental/steering/handlers/llm/mappers.py index 9901da7d4..56ea3125f 100644 --- a/src/strands/experimental/steering/handlers/llm/mappers.py +++ b/src/strands/experimental/steering/handlers/llm/mappers.py @@ -1,116 +1,23 @@ -"""LLM steering prompt mappers for generating evaluation prompts.""" +"""Deprecated: Use strands.vended_plugins.steering.handlers.llm.mappers instead.""" -import json -from typing import Any, Protocol +import warnings +from typing import Any -from .....types.tools import ToolUse -from ...core.context import SteeringContext +_TARGET_MODULE = "strands.vended_plugins.steering.handlers.llm.mappers" -# Agent SOP format - see https://github.com/strands-agents/agent-sop -_STEERING_PROMPT_TEMPLATE = """# Steering Evaluation -## Overview +def __getattr__(name: str) -> Any: + from strands.vended_plugins.steering.handlers.llm import mappers -You are a STEERING AGENT that evaluates a {action_type} that ANOTHER AGENT is attempting to make. -Your job is to provide contextual guidance to help the other agent navigate workflows effectively. -You act as a safety net that can intervene when patterns in the context data suggest the agent -should try a different approach or get human input. - -**YOUR ROLE:** -- Analyze context data for concerning patterns (repeated failures, inappropriate timing, etc.) -- Provide just-in-time guidance when the agent is going down an ineffective path -- Allow normal operations to proceed when context shows no issues - -**CRITICAL CONSTRAINTS:** -- Base decisions ONLY on the context data provided below -- Do NOT use external knowledge about domains, URLs, or tool purposes -- Do NOT make assumptions about what tools "should" or "shouldn't" do -- Focus ONLY on patterns in the context data - -## Context - -{context_str} - -## Event to Evaluate - -{event_description} - -## Steps - -### 1. Analyze the {action_type_title} - -Review ONLY the context data above. Look for patterns in the data that indicate: - -- Previous failures or successes with this tool -- Frequency of attempts -- Any relevant tracking information - -**Constraints:** -- You MUST base analysis ONLY on the provided context data -- You MUST NOT use external knowledge about tool purposes or domains -- You SHOULD identify patterns in the context data -- You MAY reference relevant context data to inform your decision - -### 2. Make Steering Decision - -**Constraints:** -- You MUST respond with exactly one of: "proceed", "guide", or "interrupt" -- You MUST base the decision ONLY on context data patterns -- Your reason will be shown to the AGENT as guidance - -**Decision Options:** -- "proceed" if context data shows no concerning patterns -- "guide" if context data shows patterns requiring intervention -- "interrupt" if context data shows patterns requiring human input -""" - - -class LLMPromptMapper(Protocol): - """Protocol for mapping context and events to LLM evaluation prompts.""" - - def create_steering_prompt( - self, steering_context: SteeringContext, tool_use: ToolUse | None = None, **kwargs: Any - ) -> str: - """Create steering prompt for LLM evaluation. - - Args: - steering_context: Steering context with populated data - tool_use: Tool use object for tool call events (None for other events) - **kwargs: Additional event data for other steering events - - Returns: - Formatted prompt string for LLM evaluation - """ - ... - - -class DefaultPromptMapper(LLMPromptMapper): - """Default prompt mapper for steering evaluation.""" - - def create_steering_prompt( - self, steering_context: SteeringContext, tool_use: ToolUse | None = None, **kwargs: Any - ) -> str: - """Create default steering prompt using Agent SOP structure. - - Uses Agent SOP format for structured, constraint-based prompts. - See: https://github.com/strands-agents/agent-sop - """ - context_str = ( - json.dumps(steering_context.data.get(), indent=2) if steering_context.data.get() else "No context available" + obj = getattr(mappers, name, None) + if obj is not None: + warnings.warn( + f"{name} has been moved to production. Use {name} from {_TARGET_MODULE} instead.", + DeprecationWarning, + stacklevel=2, ) + return obj + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") - if tool_use: - event_description = ( - f"Tool: {tool_use['name']}\nArguments: {json.dumps(tool_use.get('input', {}), indent=2)}" - ) - action_type = "tool call" - else: - event_description = "General evaluation" - action_type = "action" - return _STEERING_PROMPT_TEMPLATE.format( - action_type=action_type, - action_type_title=action_type.title(), - context_str=context_str, - event_description=event_description, - ) +__all__: list[str] = [] diff --git a/src/strands/experimental/tools/__init__.py b/src/strands/experimental/tools/__init__.py index ad693f8ac..a23b7a10c 100644 --- a/src/strands/experimental/tools/__init__.py +++ b/src/strands/experimental/tools/__init__.py @@ -1,5 +1,22 @@ """Experimental tools package.""" -from .tool_provider import ToolProvider +import warnings +from typing import Any -__all__ = ["ToolProvider"] +_DEPRECATED_NAMES = {"ToolProvider"} + + +def __getattr__(name: str) -> Any: + if name in _DEPRECATED_NAMES: + from ...tools import ToolProvider + + warnings.warn( + f"{name} has been moved to production. Use {name} from strands.tools instead.", + DeprecationWarning, + stacklevel=2, + ) + return ToolProvider + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +__all__: list[str] = [] diff --git a/src/strands/handlers/callback_handler.py b/src/strands/handlers/callback_handler.py index d449f76da..45b7efda8 100644 --- a/src/strands/handlers/callback_handler.py +++ b/src/strands/handlers/callback_handler.py @@ -14,7 +14,6 @@ def __init__(self, verbose_tool_use: bool = True) -> None: verbose_tool_use: Print out verbose information about tool calls. """ self.tool_count = 0 - self.previous_tool_use = None self._verbose_tool_use = verbose_tool_use def __call__(self, **kwargs: Any) -> None: @@ -25,12 +24,12 @@ def __call__(self, **kwargs: Any) -> None: - reasoningText (Optional[str]): Reasoning text to print if provided. - data (str): Text content to stream. - complete (bool): Whether this is the final chunk of a response. - - current_tool_use (dict): Information about the current tool being used. + - event (dict): ModelStreamChunkEvent. """ reasoningText = kwargs.get("reasoningText", False) data = kwargs.get("data", "") complete = kwargs.get("complete", False) - current_tool_use = kwargs.get("current_tool_use", {}) + tool_use = kwargs.get("event", {}).get("contentBlockStart", {}).get("start", {}).get("toolUse") if reasoningText: print(reasoningText, end="") @@ -38,13 +37,11 @@ def __call__(self, **kwargs: Any) -> None: if data: print(data, end="" if not complete else "\n") - if current_tool_use and current_tool_use.get("name"): - if self.previous_tool_use != current_tool_use: - self.previous_tool_use = current_tool_use - self.tool_count += 1 - if self._verbose_tool_use: - tool_name = current_tool_use.get("name", "Unknown tool") - print(f"\nTool #{self.tool_count}: {tool_name}") + if tool_use: + self.tool_count += 1 + if self._verbose_tool_use: + tool_name = tool_use["name"] + print(f"\nTool #{self.tool_count}: {tool_name}") if complete and data: print("\n") diff --git a/src/strands/hooks/__init__.py b/src/strands/hooks/__init__.py index 30163f207..96c7f577b 100644 --- a/src/strands/hooks/__init__.py +++ b/src/strands/hooks/__init__.py @@ -32,12 +32,18 @@ def log_end(self, event: AfterInvocationEvent) -> None: from .events import ( AfterInvocationEvent, AfterModelCallEvent, + # Multiagent hook events + AfterMultiAgentInvocationEvent, + AfterNodeCallEvent, AfterToolCallEvent, AgentInitializedEvent, BeforeInvocationEvent, BeforeModelCallEvent, + BeforeMultiAgentInvocationEvent, + BeforeNodeCallEvent, BeforeToolCallEvent, MessageAddedEvent, + MultiAgentInitializedEvent, ) from .registry import BaseHookEvent, HookCallback, HookEvent, HookProvider, HookRegistry @@ -56,4 +62,9 @@ def log_end(self, event: AfterInvocationEvent) -> None: "HookRegistry", "HookEvent", "BaseHookEvent", + "AfterMultiAgentInvocationEvent", + "AfterNodeCallEvent", + "BeforeMultiAgentInvocationEvent", + "BeforeNodeCallEvent", + "MultiAgentInitializedEvent", ] diff --git a/src/strands/hooks/_type_inference.py b/src/strands/hooks/_type_inference.py new file mode 100644 index 000000000..fbfb34c04 --- /dev/null +++ b/src/strands/hooks/_type_inference.py @@ -0,0 +1,78 @@ +"""Utility for inferring event types from callback type hints.""" + +import inspect +import logging +import types +from typing import TYPE_CHECKING, Union, cast, get_args, get_origin, get_type_hints + +if TYPE_CHECKING: + from .registry import HookCallback, TEvent + +logger = logging.getLogger(__name__) + + +def infer_event_types(callback: "HookCallback[TEvent]") -> "list[type[TEvent]]": + """Infer the event type(s) from a callback's type hints. + + Supports both single types and union types (A | B or Union[A, B]). + + Args: + callback: The callback function to inspect. + + Returns: + A list of event types inferred from the callback's first parameter type hint. + + Raises: + ValueError: If the event type cannot be inferred from the callback's type hints, + or if a union contains None or non-BaseHookEvent types. + """ + # Import here to avoid circular dependency + from .registry import BaseHookEvent + + try: + hints = get_type_hints(callback) + except Exception as e: + logger.debug("callback=<%s>, error=<%s> | failed to get type hints", callback, e) + raise ValueError( + "failed to get type hints for callback | cannot infer event type, please provide event_type explicitly" + ) from e + + # Get the first parameter's type hint + sig = inspect.signature(callback) + params = list(sig.parameters.values()) + + if not params: + raise ValueError("callback has no parameters | cannot infer event type, please provide event_type explicitly") + + # Skip 'self' and 'cls' parameters for methods + first_param = params[0] + if first_param.name in ("self", "cls") and len(params) > 1: + first_param = params[1] + + type_hint = hints.get(first_param.name) + + if type_hint is None: + raise ValueError( + f"parameter=<{first_param.name}> has no type hint | " + "cannot infer event type, please provide event_type explicitly" + ) + + # Check if it's a Union type (Union[A, B] or A | B) + origin = get_origin(type_hint) + if origin is Union or origin is types.UnionType: + event_types: list[type[TEvent]] = [] + for arg in get_args(type_hint): + if arg is type(None): + raise ValueError("None is not a valid event type in union") + if not (isinstance(arg, type) and issubclass(arg, BaseHookEvent)): + raise ValueError(f"Invalid type in union: {arg} | must be a subclass of BaseHookEvent") + event_types.append(cast("type[TEvent]", arg)) + return event_types + + # Handle single type + if isinstance(type_hint, type) and issubclass(type_hint, BaseHookEvent): + return [cast("type[TEvent]", type_hint)] + + raise ValueError( + f"parameter=<{first_param.name}>, type=<{type_hint}> | type hint must be a subclass of BaseHookEvent" + ) diff --git a/src/strands/hooks/events.py b/src/strands/hooks/events.py index 05be255f6..80b50770a 100644 --- a/src/strands/hooks/events.py +++ b/src/strands/hooks/events.py @@ -4,16 +4,23 @@ """ import uuid -from dataclasses import dataclass -from typing import Any, Optional +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any from typing_extensions import override -from ..types.content import Message +if TYPE_CHECKING: + from ..agent.agent_result import AgentResult + +from ..types.agent import AgentInput +from ..types.content import Message, Messages from ..types.interrupt import _Interruptible from ..types.streaming import StopReason from ..types.tools import AgentTool, ToolResult, ToolUse -from .registry import HookEvent +from .registry import BaseHookEvent, HookEvent + +if TYPE_CHECKING: + from ..multiagent.base import MultiAgentBase @dataclass @@ -40,9 +47,20 @@ class BeforeInvocationEvent(HookEvent): - Agent.__call__ - Agent.stream_async - Agent.structured_output + + Attributes: + invocation_state: State and configuration passed through the agent invocation. + This can include shared context for multi-agent coordination, request tracking, + and dynamic configuration. + messages: The input messages for this invocation. Can be modified by hooks + to redact or transform content before processing. """ - pass + invocation_state: dict[str, Any] = field(default_factory=dict) + messages: Messages | None = None + + def _can_write(self, name: str) -> bool: + return name == "messages" @dataclass @@ -60,8 +78,33 @@ class AfterInvocationEvent(HookEvent): - Agent.__call__ - Agent.stream_async - Agent.structured_output + + Resume: + When ``resume`` is set to a non-None value by a hook callback, the agent will + automatically re-invoke itself with the provided input. This enables hooks to + implement autonomous looping patterns where the agent continues processing + based on its previous result. The resume triggers a full new invocation cycle + including ``BeforeInvocationEvent``. + + Attributes: + invocation_state: State and configuration passed through the agent invocation. + This can include shared context for multi-agent coordination, request tracking, + and dynamic configuration. + result: The result of the agent invocation, if available. + This will be None when invoked from structured_output methods, as those return typed output directly rather + than AgentResult. + resume: When set to a non-None agent input by a hook callback, the agent will + re-invoke itself with this input. The value can be any valid AgentInput + (str, content blocks, messages, etc.). Defaults to None (no resume). """ + invocation_state: dict[str, Any] = field(default_factory=dict) + result: "AgentResult | None" = None + resume: AgentInput = None + + def _can_write(self, name: str) -> bool: + return name == "resume" + @property def should_reverse_callbacks(self) -> bool: """True to invoke callbacks in reverse order.""" @@ -106,7 +149,7 @@ class BeforeToolCallEvent(HookEvent, _Interruptible): the tool call and use a default cancel message. """ - selected_tool: Optional[AgentTool] + selected_tool: AgentTool | None tool_use: ToolUse invocation_state: dict[str, Any] cancel_tool: bool | str = False @@ -138,6 +181,18 @@ class AfterToolCallEvent(HookEvent): Note: This event uses reverse callback ordering, meaning callbacks registered later will be invoked first during cleanup. + Tool Retrying: + When ``retry`` is set to True by a hook callback, the tool executor will + discard the current tool result and invoke the tool again. This has important + implications for streaming consumers: + + - ToolStreamEvents (intermediate streaming events) from the discarded tool execution + will have already been emitted to callers before the retry occurs. Agent invokers + consuming streamed events should be prepared to handle this scenario, potentially + by tracking retry state or implementing idempotent event processing + - ToolResultEvent is NOT emitted for discarded attempts - only the final attempt's + result is emitted and added to the conversation history + Attributes: selected_tool: The tool that was invoked. It may be None if tool lookup failed. tool_use: The tool parameters that were passed to the tool invoked. @@ -145,17 +200,21 @@ class AfterToolCallEvent(HookEvent): result: The result of the tool invocation. Either a ToolResult on success or an Exception if the tool execution failed. cancel_message: The cancellation message if the user cancelled the tool call. + retry: Whether to retry the tool invocation. Can be set by hook callbacks + to trigger a retry. When True, the current result is discarded and the + tool is called again. Defaults to False. """ - selected_tool: Optional[AgentTool] + selected_tool: AgentTool | None tool_use: ToolUse invocation_state: dict[str, Any] result: ToolResult - exception: Optional[Exception] = None + exception: Exception | None = None cancel_message: str | None = None + retry: bool = False def _can_write(self, name: str) -> bool: - return name == "result" + return name in ["result", "retry"] @property def should_reverse_callbacks(self) -> bool: @@ -172,9 +231,19 @@ class BeforeModelCallEvent(HookEvent): that will be sent to the model. Note: This event is not fired for invocations to structured_output. + + Attributes: + invocation_state: State and configuration passed through the agent invocation. + This can include shared context for multi-agent coordination, request tracking, + and dynamic configuration. + projected_input_tokens: Projected input token count for the upcoming model call. + Computed by the agent loop from message metadata and token estimation. + Available for hooks and plugins (e.g. conversation managers) to make + proactive decisions about context management. None if estimation failed. """ - pass + invocation_state: dict[str, Any] = field(default_factory=dict) + projected_input_tokens: int | None = None @dataclass @@ -190,9 +259,27 @@ class AfterModelCallEvent(HookEvent): Note: This event is not fired for invocations to structured_output. + Model Retrying: + When ``retry_model`` is set to True by a hook callback, the agent will discard + the current model response and invoke the model again. This has important + implications for streaming consumers: + + - Streaming events from the discarded response will have already been emitted + to callers before the retry occurs. Agent invokers consuming streamed events + should be prepared to handle this scenario, potentially by tracking retry state + or implementing idempotent event processing + - The original model message is thrown away internally and not added to the + conversation history + Attributes: + invocation_state: State and configuration passed through the agent invocation. + This can include shared context for multi-agent coordination, request tracking, + and dynamic configuration. stop_response: The model response data if invocation was successful, None if failed. exception: Exception if the model invocation failed, None if successful. + retry: Whether to retry the model invocation. Can be set by hook callbacks + to trigger a retry. When True, the current response is discarded and the + model is called again. Defaults to False. """ @dataclass @@ -207,8 +294,114 @@ class ModelStopResponse: message: Message stop_reason: StopReason - stop_response: Optional[ModelStopResponse] = None - exception: Optional[Exception] = None + invocation_state: dict[str, Any] = field(default_factory=dict) + stop_response: ModelStopResponse | None = None + exception: Exception | None = None + retry: bool = False + + def _can_write(self, name: str) -> bool: + return name == "retry" + + @property + def should_reverse_callbacks(self) -> bool: + """True to invoke callbacks in reverse order.""" + return True + + +# Multiagent hook events start here +@dataclass +class MultiAgentInitializedEvent(BaseHookEvent): + """Event triggered when multi-agent orchestrator initialized. + + Attributes: + source: The multi-agent orchestrator instance + invocation_state: Configuration that user passes in + """ + + source: "MultiAgentBase" + invocation_state: dict[str, Any] | None = None + + +@dataclass +class BeforeNodeCallEvent(BaseHookEvent, _Interruptible): + """Event triggered before individual node execution starts. + + Attributes: + source: The multi-agent orchestrator instance + node_id: ID of the node about to execute + invocation_state: Configuration that user passes in + cancel_node: A user defined message that when set, will cancel the node execution with status FAILED. + The message will be emitted under a MultiAgentNodeCancel event. If set to `True`, Strands will cancel the + node using a default cancel message. + """ + + source: "MultiAgentBase" + node_id: str + invocation_state: dict[str, Any] | None = None + cancel_node: bool | str = False + + def _can_write(self, name: str) -> bool: + return name in ["cancel_node"] + + @override + def _interrupt_id(self, name: str) -> str: + """Unique id for the interrupt. + + Args: + name: User defined name for the interrupt. + + Returns: + Interrupt id. + """ + node_id = uuid.uuid5(uuid.NAMESPACE_OID, self.node_id) + call_id = uuid.uuid5(uuid.NAMESPACE_OID, name) + return f"v1:before_node_call:{node_id}:{call_id}" + + +@dataclass +class AfterNodeCallEvent(BaseHookEvent): + """Event triggered after individual node execution completes. + + Attributes: + source: The multi-agent orchestrator instance + node_id: ID of the node that just completed execution + invocation_state: Configuration that user passes in + """ + + source: "MultiAgentBase" + node_id: str + invocation_state: dict[str, Any] | None = None + + @property + def should_reverse_callbacks(self) -> bool: + """True to invoke callbacks in reverse order.""" + return True + + +@dataclass +class BeforeMultiAgentInvocationEvent(BaseHookEvent): + """Event triggered before orchestrator execution starts. + + Attributes: + source: The multi-agent orchestrator instance + invocation_state: Configuration that user passes in + """ + + source: "MultiAgentBase" + invocation_state: dict[str, Any] | None = None + + +@dataclass +class AfterMultiAgentInvocationEvent(BaseHookEvent): + """Event triggered after orchestrator execution completes. + + Attributes: + source: The multi-agent orchestrator instance + invocation_state: Configuration that user passes in + """ + + source: "MultiAgentBase" + invocation_state: dict[str, Any] | None = None @property def should_reverse_callbacks(self) -> bool: diff --git a/src/strands/hooks/registry.py b/src/strands/hooks/registry.py index 1efc0bf5b..8b284b0c2 100644 --- a/src/strands/hooks/registry.py +++ b/src/strands/hooks/registry.py @@ -9,10 +9,12 @@ import inspect import logging +from collections.abc import Awaitable, Generator from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Awaitable, Generator, Generic, Protocol, Type, TypeVar +from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar, runtime_checkable from ..interrupt import Interrupt, InterruptException +from ._type_inference import infer_event_types if TYPE_CHECKING: from ..agent import Agent @@ -84,6 +86,7 @@ class HookEvent(BaseHookEvent): """Generic for invoking events - non-contravariant to enable returning events.""" +@runtime_checkable class HookProvider(Protocol): """Protocol for objects that provide hook callbacks to an agent. @@ -153,29 +156,98 @@ class HookRegistry: def __init__(self) -> None: """Initialize an empty hook registry.""" - self._registered_callbacks: dict[Type, list[HookCallback]] = {} + self._registered_callbacks: dict[type, list[HookCallback]] = {} - def add_callback(self, event_type: Type[TEvent], callback: HookCallback[TEvent]) -> None: + def add_callback( + self, + event_type: type[TEvent] | list[type[TEvent]] | None, + callback: HookCallback[TEvent], + ) -> None: """Register a callback function for a specific event type. + If ``event_type`` is None, then this will check the callback handler type hint + for the lifecycle event type. Union types (``A | B`` or ``Union[A, B]``) in + type hints will register the callback for each event type in the union. + + If ``event_type`` is a list, the callback will be registered for each event + type in the list (duplicates are ignored). + Args: - event_type: The class type of events this callback should handle. + event_type: The lifecycle event type(s) this callback should handle. + Can be a single type, a list of types, or None to infer from type hints. callback: The callback function to invoke when events of this type occur. + Raises: + ValueError: If event_type is not provided and cannot be inferred from + the callback's type hints, or if AgentInitializedEvent is registered + with an async callback, or if the event_type list is empty. + Example: ```python def my_handler(event: StartRequestEvent): print("Request started") + # With explicit event type registry.add_callback(StartRequestEvent, my_handler) + + # With event type inferred from type hint + registry.add_callback(None, my_handler) + + # With union type hint (registers for both types) + def union_handler(event: BeforeModelCallEvent | AfterModelCallEvent): + print(f"Event: {type(event).__name__}") + registry.add_callback(None, union_handler) + + # With list of event types + def multi_handler(event): + print(f"Event: {type(event).__name__}") + registry.add_callback([BeforeModelCallEvent, AfterModelCallEvent], multi_handler) ``` """ - # Related issue: https://github.com/strands-agents/sdk-python/issues/330 - if event_type.__name__ == "AgentInitializedEvent" and inspect.iscoroutinefunction(callback): - raise ValueError("AgentInitializedEvent can only be registered with a synchronous callback") + resolved_event_types: list[type[TEvent]] + + # Handle list of event types + if isinstance(event_type, list): + if not event_type: + raise ValueError("event_type list cannot be empty") + resolved_event_types = self._validate_event_type_list(event_type) + elif event_type is None: + # Infer event type(s) from callback type hints + resolved_event_types = infer_event_types(callback) + else: + # Single event type provided explicitly + resolved_event_types = [event_type] + + # Deduplicate event types while preserving order + unique_event_types: set[type[TEvent]] = set(resolved_event_types) - callbacks = self._registered_callbacks.setdefault(event_type, []) - callbacks.append(callback) + # Register callback for each event type + for resolved_event_type in unique_event_types: + # Related issue: https://github.com/strands-agents/sdk-python/issues/330 + if resolved_event_type.__name__ == "AgentInitializedEvent" and inspect.iscoroutinefunction(callback): + raise ValueError("AgentInitializedEvent can only be registered with a synchronous callback") + + callbacks = self._registered_callbacks.setdefault(resolved_event_type, []) + callbacks.append(callback) + + def _validate_event_type_list(self, event_types: list[type[TEvent]]) -> list[type[TEvent]]: + """Validate that all types in a list are valid BaseHookEvent subclasses. + + Args: + event_types: List of event types to validate. + + Returns: + The validated list of event types. + + Raises: + ValueError: If any type is not a valid BaseHookEvent subclass. + """ + validated: list[type[TEvent]] = [] + for et in event_types: + if not (isinstance(et, type) and issubclass(et, BaseHookEvent)): + raise ValueError(f"Invalid event type: {et} | must be a subclass of BaseHookEvent") + validated.append(et) + return validated def add_hook(self, hook: HookProvider) -> None: """Register all callbacks from a hook provider. diff --git a/src/strands/hooks/rules.md b/src/strands/hooks/rules.md deleted file mode 100644 index 4d0f571c6..000000000 --- a/src/strands/hooks/rules.md +++ /dev/null @@ -1,21 +0,0 @@ -# Hook System Rules - -## Terminology - -- **Paired events**: Events that denote the beginning and end of an operation -- **Hook callback**: A function that receives a strongly-typed event argument and performs some action in response - -## Naming Conventions - -- All hook events have a suffix of `Event` -- Paired events follow the naming convention of `Before{Item}Event` and `After{Item}Event` -- Pre actions in the name. i.e. prefer `BeforeToolCallEvent` over `BeforeToolEvent`. - -## Paired Events - -- The final event in a pair returns `True` for `should_reverse_callbacks` -- For every `Before` event there is a corresponding `After` event, even if an exception occurs - -## Writable Properties - -For events with writable properties, those values are re-read after invoking the hook callbacks and used in subsequent processing. For example, `BeforeToolEvent.selected_tool` is writable - after invoking the callback for `BeforeToolEvent`, the `selected_tool` takes effect for the tool call. \ No newline at end of file diff --git a/src/strands/interrupt.py b/src/strands/interrupt.py index 85997c9be..7d02b50ff 100644 --- a/src/strands/interrupt.py +++ b/src/strands/interrupt.py @@ -52,10 +52,12 @@ class _InterruptState: interrupts: dict[str, Interrupt] = field(default_factory=dict) context: dict[str, Any] = field(default_factory=dict) activated: bool = False + _version: int = field(default=0, compare=False, repr=False) def activate(self) -> None: """Activate the interrupt state.""" self.activated = True + self._version += 1 def deactivate(self) -> None: """Deacitvate the interrupt state. @@ -65,6 +67,7 @@ def deactivate(self) -> None: self.interrupts = {} self.context = {} self.activated = False + self._version += 1 def resume(self, prompt: "AgentInput") -> None: """Configure the interrupt state if resuming from an interrupt event. @@ -100,10 +103,27 @@ def resume(self, prompt: "AgentInput") -> None: self.interrupts[interrupt_id].response = interrupt_response self.context["responses"] = contents + self._version += 1 + + def _get_version(self) -> int: + """Get the current version number of the interrupt state. + + The version is incremented each time activate(), deactivate(), or resume() is called. + Consumers can compare versions to detect changes without requiring + explicit dirty flag clearing. + + Returns: + The current version number. + """ + return self._version def to_dict(self) -> dict[str, Any]: """Serialize to dict for session management.""" - return asdict(self) + return { + "interrupts": {k: v.to_dict() for k, v in self.interrupts.items()}, + "context": self.context, + "activated": self.activated, + } @classmethod def from_dict(cls, data: dict[str, Any]) -> "_InterruptState": diff --git a/src/strands/models/__init__.py b/src/strands/models/__init__.py index ead290a35..8ae660da0 100644 --- a/src/strands/models/__init__.py +++ b/src/strands/models/__init__.py @@ -3,8 +3,70 @@ This package includes an abstract base Model class along with concrete implementations for specific providers. """ +from typing import Any + from . import bedrock, model from .bedrock import BedrockModel -from .model import Model +from .model import BaseModelConfig, CacheConfig, CacheToolsConfig, Model + +__all__ = [ + "bedrock", + "model", + "BaseModelConfig", + "BedrockModel", + "CacheConfig", + "CacheToolsConfig", + "Model", +] + + +def __getattr__(name: str) -> Any: + """Lazy load model implementations only when accessed. + + This defers the import of optional dependencies until actually needed. + """ + if name == "AnthropicModel": + from .anthropic import AnthropicModel + + return AnthropicModel + if name == "GeminiModel": + from .gemini import GeminiModel + + return GeminiModel + if name == "LiteLLMModel": + from .litellm import LiteLLMModel + + return LiteLLMModel + if name == "LlamaAPIModel": + from .llamaapi import LlamaAPIModel + + return LlamaAPIModel + if name == "LlamaCppModel": + from .llamacpp import LlamaCppModel + + return LlamaCppModel + if name == "MistralModel": + from .mistral import MistralModel + + return MistralModel + if name == "OllamaModel": + from .ollama import OllamaModel + + return OllamaModel + if name == "OpenAIModel": + from .openai import OpenAIModel + + return OpenAIModel + if name == "OpenAIResponsesModel": + from .openai_responses import OpenAIResponsesModel + + return OpenAIResponsesModel + if name == "SageMakerAIModel": + from .sagemaker import SageMakerAIModel + + return SageMakerAIModel + if name == "WriterModel": + from .writer import WriterModel -__all__ = ["bedrock", "model", "BedrockModel", "Model"] + return WriterModel + raise AttributeError(f"cannot import name '{name}' from '{__name__}' ({__file__})") diff --git a/src/strands/models/_defaults.py b/src/strands/models/_defaults.py new file mode 100644 index 000000000..e463b8ef6 --- /dev/null +++ b/src/strands/models/_defaults.py @@ -0,0 +1,177 @@ +"""Default model metadata lookup tables. + +Provides context window limits for known model IDs across all providers. +Values sourced from provider documentation and +https://github.com/BerriAI/litellm/blob/litellm_internal_staging/model_prices_and_context_window.json + +Applied to providers with well-known, fixed model IDs: Bedrock, Anthropic, OpenAI, +OpenAI Responses, Gemini, and Mistral. Providers that use local/custom model IDs +(Ollama, LlamaCpp, SageMaker) or proxy to other providers with their own prefixed +ID format (LiteLLM) are excluded — their context windows depend on deployment config, +not a static table. +""" + +import logging +from collections.abc import Mapping +from typing import TypeVar + +logger = logging.getLogger(__name__) + +_C = TypeVar("_C", bound=Mapping[str, object]) + +# Context window limits (in tokens) for known model IDs. +# +# Best-effort lookup table — unknown models return None and callers +# fall back gracefully (e.g. proactive compression is disabled). +# Users can always override with an explicit context_window_limit in their model config. +# +# For Bedrock models with cross-region prefixes (e.g. us., eu., global.), +# get_context_window_limit strips the prefix before lookup so only the base model ID is needed here. +_CONTEXT_WINDOW_LIMITS: dict[str, int] = { + # Anthropic (direct API) + "claude-sonnet-4-6": 1_000_000, + "claude-sonnet-4-20250514": 1_000_000, + "claude-sonnet-4-5": 200_000, + "claude-sonnet-4-5-20250929": 200_000, + "claude-opus-4-6": 1_000_000, + "claude-opus-4-6-20260205": 1_000_000, + "claude-opus-4-7": 1_000_000, + "claude-opus-4-7-20260416": 1_000_000, + "claude-opus-4-5": 200_000, + "claude-opus-4-5-20251101": 200_000, + "claude-opus-4-20250514": 200_000, + "claude-opus-4-1": 200_000, + "claude-opus-4-1-20250805": 200_000, + "claude-haiku-4-5": 200_000, + "claude-haiku-4-5-20251001": 200_000, + "claude-3-7-sonnet-20250219": 200_000, + "claude-3-5-sonnet-20241022": 200_000, + "claude-3-5-sonnet-20240620": 200_000, + "claude-3-5-haiku-20241022": 200_000, + "claude-3-opus-20240229": 200_000, + "claude-3-haiku-20240307": 200_000, + # Bedrock Anthropic (base model IDs — cross-region prefixes stripped by get_context_window_limit) + "anthropic.claude-sonnet-4-6": 1_000_000, + "anthropic.claude-sonnet-4-20250514-v1:0": 1_000_000, + "anthropic.claude-sonnet-4-5-20250929-v1:0": 200_000, + "anthropic.claude-opus-4-6-v1": 1_000_000, + "anthropic.claude-opus-4-7": 1_000_000, + "anthropic.claude-opus-4-5-20251101-v1:0": 200_000, + "anthropic.claude-opus-4-20250514-v1:0": 200_000, + "anthropic.claude-opus-4-1-20250805-v1:0": 200_000, + "anthropic.claude-haiku-4-5-20251001-v1:0": 200_000, + "anthropic.claude-haiku-4-5@20251001": 200_000, + "anthropic.claude-3-7-sonnet-20250219-v1:0": 200_000, + "anthropic.claude-3-7-sonnet-20240620-v1:0": 200_000, + "anthropic.claude-3-5-sonnet-20241022-v2:0": 200_000, + "anthropic.claude-3-5-sonnet-20240620-v1:0": 200_000, + "anthropic.claude-3-5-haiku-20241022-v1:0": 200_000, + "anthropic.claude-3-opus-20240229-v1:0": 200_000, + "anthropic.claude-3-haiku-20240307-v1:0": 200_000, + "anthropic.claude-3-sonnet-20240229-v1:0": 200_000, + "anthropic.claude-mythos-preview": 1_000_000, + # Bedrock Amazon Nova + "amazon.nova-pro-v1:0": 300_000, + "amazon.nova-lite-v1:0": 300_000, + "amazon.nova-micro-v1:0": 128_000, + "amazon.nova-premier-v1:0": 1_000_000, + "amazon.nova-2-lite-v1:0": 1_000_000, + "amazon.nova-2-pro-preview-20251202-v1:0": 1_000_000, + # OpenAI + "gpt-5.5": 1_050_000, + "gpt-5.5-pro": 1_050_000, + "gpt-5.4": 1_050_000, + "gpt-5.4-pro": 1_050_000, + "gpt-5.4-mini": 272_000, + "gpt-5.4-nano": 272_000, + "gpt-5.2": 272_000, + "gpt-5.2-pro": 272_000, + "gpt-5.1": 272_000, + "gpt-5": 272_000, + "gpt-5-mini": 272_000, + "gpt-5-nano": 272_000, + "gpt-5-pro": 128_000, + "gpt-4.1": 1_047_576, + "gpt-4.1-mini": 1_047_576, + "gpt-4.1-nano": 1_047_576, + "gpt-4o": 128_000, + "gpt-4o-mini": 128_000, + "gpt-4-turbo": 128_000, + "o3": 200_000, + "o3-mini": 200_000, + "o3-pro": 200_000, + "o4-mini": 200_000, + "o1": 200_000, + # Google Gemini + "gemini-2.5-flash": 1_048_576, + "gemini-2.5-flash-lite": 1_048_576, + "gemini-2.5-pro": 1_048_576, + "gemini-2.0-flash": 1_048_576, + "gemini-2.0-flash-lite": 1_048_576, + "gemini-3-pro-preview": 1_048_576, + "gemini-3-flash-preview": 1_048_576, + "gemini-3.1-pro-preview": 1_048_576, + "gemini-3.1-flash-lite-preview": 1_048_576, + # Mistral + "mistral-large-latest": 262_144, + "mistral-large-2512": 262_144, + "mistral-large-3": 262_144, + "mistral-medium-latest": 131_072, + "mistral-medium-2505": 131_072, + "mistral-small-latest": 131_072, + "mistral-small-3-2-2506": 131_072, +} + + +def get_context_window_limit(model_id: str) -> int | None: + """Look up the context window limit for a model ID. + + For Bedrock cross-region model IDs (e.g. ``us.anthropic.claude-sonnet-4-6``), + the region prefix is stripped as a fallback if the direct lookup fails. + + Args: + model_id: The model ID to look up. + + Returns: + The context window limit in tokens, or None if not found. + """ + direct = _CONTEXT_WINDOW_LIMITS.get(model_id) + if direct is not None: + return direct + + # Fallback: strip prefix before first dot and retry (handles cross-region prefixes) + dot_index = model_id.find(".") + if dot_index != -1: + stripped = model_id[dot_index + 1 :] + result = _CONTEXT_WINDOW_LIMITS.get(stripped) + if result is not None: + logger.debug( + "model_id=<%s>, stripped_id=<%s> | resolved context window limit via prefix strip", model_id, stripped + ) + return result + + return None + + +def resolve_config_metadata(config: _C, model_id: str) -> _C: + """Resolve model metadata fields on a config dict from built-in lookup tables. + + When ``context_window_limit`` is not explicitly set, looks it up from the built-in table. + Explicit values pass through unchanged. Returns a new dict only when resolution adds a field; + otherwise returns the original config to avoid unnecessary allocation. + + Args: + config: The stored model config dict. + model_id: The model ID to look up. + + Returns: + The config with resolved metadata, or the original config if nothing to resolve. + """ + if "context_window_limit" in config: + return config + + limit = get_context_window_limit(model_id) + if limit is None: + return config + + return {**config, "context_window_limit": limit} # type: ignore[return-value] diff --git a/src/strands/models/_openai_bedrock.py b/src/strands/models/_openai_bedrock.py new file mode 100644 index 000000000..149a47ec5 --- /dev/null +++ b/src/strands/models/_openai_bedrock.py @@ -0,0 +1,126 @@ +"""Internal helpers for routing OpenAI-compatible clients to Bedrock Mantle. + +Converts a ``bedrock_mantle_config`` dict into the ``base_url`` and ``api_key`` that the +OpenAI Python SDK consumes. Tokens are minted on demand via +``aws_bedrock_token_generator.provide_token`` so long-running agents survive the +bearer token's maximum lifetime. + +``aws_bedrock_token_generator`` is part of the ``openai`` extras group +(``pip install strands-agents[openai]``) but is *not* included in the ``litellm`` +or ``sagemaker`` extras, which also pull in the ``openai`` package. The import is +therefore lazy — it happens inside :func:`resolve_bedrock_client_args` so that +those other extras never trigger an ``ImportError`` at module load. +""" + +from __future__ import annotations + +from datetime import timedelta +from typing import Any, TypedDict + +import boto3 +from botocore.credentials import CredentialProvider + +_MANTLE_BASE_URL_TEMPLATE = "https://bedrock-mantle.{region}.api.aws/v1" +_MANTLE_DOCS_URL = "https://docs.aws.amazon.com/bedrock/latest/userguide/inference-openai.html" + + +class BedrockMantleConfig(TypedDict, total=False): + """Config for routing an OpenAI-compatible client through Bedrock Mantle. + + Attributes: + region: AWS region hosting the Bedrock Mantle endpoint. If omitted, resolved + from ``boto_session`` (if provided) or the standard boto3 chain + (``AWS_REGION`` / ``AWS_DEFAULT_REGION`` / active profile / EC2 metadata). + A :class:`ValueError` is raised if none resolve. + boto_session: Optional :class:`boto3.Session` used to resolve the region when + ``region`` is not provided. Useful for picking up a non-default profile + without exporting env vars. + credentials_provider: Optional botocore :class:`~botocore.credentials.CredentialProvider` + forwarded to ``provide_token``. Omit to let the token generator use the + standard AWS credential chain. + expiry: Optional ``timedelta`` for the bearer token's lifetime, forwarded to + ``provide_token``. Defaults to the generator's built-in lifetime when + omitted. + """ + + region: str + boto_session: boto3.Session + credentials_provider: CredentialProvider + expiry: timedelta + + +def _resolve_region(config: BedrockMantleConfig) -> str: + """Resolve the AWS region, preferring explicit config then falling back to boto3. + + Raises: + ValueError: If no region can be resolved from the config, an attached session, + or the standard boto3 credential chain. + """ + region = config.get("region") + if region: + return region + + session = config.get("boto_session") + if session is not None and session.region_name: + return str(session.region_name) + + # ``boto3.Session()`` with no args reads ``AWS_REGION`` / ``AWS_DEFAULT_REGION``, + # the active profile, and falls back to EC2 instance metadata — the same chain + # :class:`BedrockModel` uses. + default_region = boto3.Session().region_name + if default_region: + return str(default_region) + + raise ValueError( + "Could not resolve an AWS region for Bedrock Mantle. Pass 'region' in " + "bedrock_mantle_config, attach a boto_session with a configured region, or set " + f"AWS_REGION in the environment. See {_MANTLE_DOCS_URL} for supported regions." + ) + + +def resolve_bedrock_client_args( + config: BedrockMantleConfig, client_args: dict[str, Any] | None = None +) -> dict[str, Any]: + """Resolve a ``BedrockMantleConfig`` (plus optional ``client_args``) into OpenAI client kwargs. + + Mints a fresh bearer token on every call. Callers are expected to validate that + ``client_args`` does not contain ``base_url`` or ``api_key`` before calling this + function (typically at ``__init__`` time for fail-fast behavior). + + Raises: + ValueError: If no region can be resolved. + ImportError: If ``aws-bedrock-token-generator`` is not installed. + RuntimeError: If token minting fails (e.g. missing AWS credentials). + """ + region = _resolve_region(config) + + # ``aws-bedrock-token-generator`` is included in the ``openai`` extras group but not in + # ``litellm`` or ``sagemaker`` (which also depend on the ``openai`` package). The lazy + # import keeps those extras from hitting an ImportError at module load. + try: + from aws_bedrock_token_generator import provide_token + except ImportError as e: + raise ImportError( + "bedrock_mantle_config requires the 'aws-bedrock-token-generator' package. " + "Install it with: pip install strands-agents[openai]" + ) from e + + # Only forward kwargs the user set; provide_token rejects expiry=None. + token_kwargs: dict[str, Any] = {"region": region} + if "credentials_provider" in config: + token_kwargs["aws_credentials_provider"] = config["credentials_provider"] + if "expiry" in config: + token_kwargs["expiry"] = config["expiry"] + + try: + token = provide_token(**token_kwargs) + except Exception as e: + raise RuntimeError( + f"Failed to mint Bedrock Mantle bearer token for region '{region}'. " + "Verify your AWS credentials and network connectivity." + ) from e + + resolved: dict[str, Any] = dict(client_args or {}) + resolved["base_url"] = _MANTLE_BASE_URL_TEMPLATE.format(region=region) + resolved["api_key"] = token + return resolved diff --git a/src/strands/models/_strict_schema.py b/src/strands/models/_strict_schema.py new file mode 100644 index 000000000..e7f13e244 --- /dev/null +++ b/src/strands/models/_strict_schema.py @@ -0,0 +1,144 @@ +"""Strict JSON schema transformation for tool definitions. + +When model providers require `strict: true` on tool definitions, they also require +`"additionalProperties": false` on every `object` type in the input schema. This module +provides a utility to recursively apply that constraint. + +Modeled after OpenAI's `_ensure_strict_json_schema`: +https://github.com/openai/openai-python/blob/main/src/openai/lib/_pydantic.py +""" + +import copy +import logging +from typing import Any + +logger = logging.getLogger(__name__) + + +def ensure_strict_json_schema( + schema: dict[str, Any], + *, + require_all_properties: bool = False, +) -> dict[str, Any]: + """Ensure a JSON schema conforms to strict tool use requirements. + + Creates a deep copy of the schema and recursively: + 1. Adds ``"additionalProperties": false`` to all ``object`` types that do not already define it + 2. Optionally adds all properties to the ``required`` array (needed for OpenAI) + 3. Handles ``$defs``, ``definitions``, ``anyOf``, ``allOf``, ``items``, and ``$ref`` + + Args: + schema: The JSON schema to process. A deep copy is made internally so the original is not mutated. + require_all_properties: If True, set ``required`` to include all property keys. OpenAI strict mode + requires this; Bedrock and Anthropic do not. + + Returns: + A new schema dict with strict-mode constraints applied. + """ + schema_copy = copy.deepcopy(schema) + _apply_strict(schema_copy, root=schema_copy, require_all_properties=require_all_properties) + return schema_copy + + +def _apply_strict( + schema: dict[str, Any], + *, + root: dict[str, Any], + require_all_properties: bool, +) -> None: + """Recursively apply strict-mode constraints to a JSON schema in place. + + Args: + schema: The schema node to process (modified in place). + root: The root schema, used for resolving ``$ref`` pointers. + require_all_properties: If True, add all properties to ``required``. + """ + # Process $defs / definitions blocks + for defs_key in ("$defs", "definitions"): + defs = schema.get(defs_key) + if isinstance(defs, dict): + for def_schema in defs.values(): + if isinstance(def_schema, dict): + _apply_strict(def_schema, root=root, require_all_properties=require_all_properties) + + # Add additionalProperties: false to object types that lack it + if schema.get("type") == "object" and "additionalProperties" not in schema: + schema["additionalProperties"] = False + + # Process properties and optionally enforce required + properties = schema.get("properties") + if isinstance(properties, dict): + if require_all_properties: + schema["required"] = list(properties.keys()) + + for prop_schema in properties.values(): + if isinstance(prop_schema, dict): + _apply_strict(prop_schema, root=root, require_all_properties=require_all_properties) + + # Process array items + items = schema.get("items") + if isinstance(items, dict): + _apply_strict(items, root=root, require_all_properties=require_all_properties) + + # Process anyOf variants + any_of = schema.get("anyOf") + if isinstance(any_of, list): + for variant in any_of: + if isinstance(variant, dict): + _apply_strict(variant, root=root, require_all_properties=require_all_properties) + + # Process allOf variants + all_of = schema.get("allOf") + if isinstance(all_of, list): + for entry in all_of: + if isinstance(entry, dict): + _apply_strict(entry, root=root, require_all_properties=require_all_properties) + + # Process oneOf variants + one_of = schema.get("oneOf") + if isinstance(one_of, list): + for variant in one_of: + if isinstance(variant, dict): + _apply_strict(variant, root=root, require_all_properties=require_all_properties) + + # Resolve $ref combined with other keys by inlining the referenced schema + ref = schema.get("$ref") + if isinstance(ref, str) and len(schema) > 1: + resolved = _resolve_ref(root, ref) + if isinstance(resolved, dict): + # Inline the resolved schema, giving priority to existing keys + merged = {**copy.deepcopy(resolved), **schema} + merged.pop("$ref", None) + schema.clear() + schema.update(merged) + # Re-apply strict to the inlined schema + _apply_strict(schema, root=root, require_all_properties=require_all_properties) + + +def _resolve_ref(root: dict[str, Any], ref: str) -> dict[str, Any] | None: + """Resolve a JSON Schema ``$ref`` pointer against the root schema. + + Args: + root: The root schema containing definitions. + ref: A JSON pointer string (e.g., ``#/$defs/MyModel``). + + Returns: + The resolved schema dict, or None if resolution fails. + """ + if not ref.startswith("#/"): + logger.warning("ref=<%s> | unexpected $ref format, skipping resolution", ref) + return None + + path = ref[2:].split("/") + current: Any = root + for key in path: + if not isinstance(current, dict) or key not in current: + logger.warning("ref=<%s> | failed to resolve $ref path", ref) + return None + current = current[key] + + if not isinstance(current, dict): + logger.warning("ref=<%s> | resolved to non-dict value", ref) + return None + + return current diff --git a/src/strands/models/_validation.py b/src/strands/models/_validation.py index 9eabe28a1..9d4d8b178 100644 --- a/src/strands/models/_validation.py +++ b/src/strands/models/_validation.py @@ -1,14 +1,16 @@ """Configuration validation utilities for model providers.""" import warnings -from typing import Any, Mapping, Type +from collections.abc import Mapping +from typing import Any from typing_extensions import get_type_hints +from ..types.content import ContentBlock from ..types.tools import ToolChoice -def validate_config_keys(config_dict: Mapping[str, Any], config_class: Type) -> None: +def validate_config_keys(config_dict: Mapping[str, Any], config_class: type) -> None: """Validate that config keys match the TypedDict fields. Args: @@ -40,3 +42,23 @@ def warn_on_tool_choice_not_supported(tool_choice: ToolChoice | None) -> None: "A ToolChoice was provided to this provider but is not supported and will be ignored", stacklevel=4, ) + + +def _has_location_source(content: ContentBlock) -> bool: + """Check if a content block contains a location source. + + Providers need to explicitly define an implementation to support content locations. + + Args: + content: Content block to check. + + Returns: + True if the content block contains an location source, False otherwise. + """ + if "image" in content: + return "location" in content["image"].get("source", {}) + if "document" in content: + return "location" in content["document"].get("source", {}) + if "video" in content: + return "location" in content["video"].get("source", {}) + return False diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index 68b234729..812171a0c 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -7,7 +7,8 @@ import json import logging import mimetypes -from typing import Any, AsyncGenerator, Optional, Type, TypedDict, TypeVar, Union, cast +from collections.abc import AsyncGenerator +from typing import Any, TypeVar, cast import anthropic from pydantic import BaseModel @@ -15,12 +16,13 @@ from ..event_loop.streaming import process_stream from ..tools.structured_output.structured_output_utils import convert_pydantic_to_tool_spec -from ..types.content import ContentBlock, Messages +from ..types.content import ContentBlock, Messages, SystemContentBlock from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.streaming import StreamEvent from ..types.tools import ToolChoice, ToolChoiceToolDict, ToolSpec -from ._validation import validate_config_keys -from .model import Model +from ._defaults import resolve_config_metadata +from ._validation import _has_location_source, validate_config_keys +from .model import BaseModelConfig, Model logger = logging.getLogger(__name__) @@ -45,7 +47,7 @@ class AnthropicModel(Model): "input and output tokens exceed your context limit", } - class AnthropicConfig(TypedDict, total=False): + class AnthropicConfig(BaseModelConfig, total=False): """Configuration options for Anthropic models. Attributes: @@ -55,13 +57,17 @@ class AnthropicConfig(TypedDict, total=False): https://docs.anthropic.com/en/docs/about-claude/models/all-models. params: Additional model parameters (e.g., temperature). For a complete list of supported parameters, see https://docs.anthropic.com/en/api/messages. + use_native_token_count: Whether to use the native Anthropic count_tokens API. + When True, count_tokens() calls the Anthropic API for accurate counts. + When False (default), skips the API call and uses the local estimator. """ max_tokens: Required[int] model_id: Required[str] - params: Optional[dict[str, Any]] + params: dict[str, Any] | None + use_native_token_count: bool - def __init__(self, *, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[AnthropicConfig]): + def __init__(self, *, client_args: dict[str, Any] | None = None, **model_config: Unpack[AnthropicConfig]): """Initialize provider instance. Args: @@ -94,7 +100,7 @@ def get_config(self) -> AnthropicConfig: Returns: The Anthropic model configuration. """ - return self.config + return resolve_config_metadata(self.config, self.config["model_id"]) def _format_request_message_content(self, content: ContentBlock) -> dict[str, Any]: """Format an Anthropic content block. @@ -188,6 +194,11 @@ def _format_request_messages(self, messages: Messages) -> list[dict[str, Any]]: formatted_contents[-1]["cache_control"] = {"type": "ephemeral"} continue + # Check for location sources in image, document, or video content + if _has_location_source(content): + logger.warning("Location sources are not supported by Anthropic | skipping content block") + continue + formatted_contents.append(self._format_request_message_content(content)) if formatted_contents: @@ -198,8 +209,8 @@ def _format_request_messages(self, messages: Messages) -> list[dict[str, Any]]: def format_request( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, tool_choice: ToolChoice | None = None, ) -> dict[str, Any]: """Format an Anthropic streaming request. @@ -365,12 +376,63 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent: case _: raise RuntimeError(f"event_type=<{event['type']} | unknown type") + @override + async def count_tokens( + self, + messages: Messages, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, + system_prompt_content: list[SystemContentBlock] | None = None, + ) -> int: + """Count tokens using Anthropic's native count_tokens API. + + Uses the same message format as the Messages API to get accurate token counts + directly from the Anthropic service. + + Args: + messages: List of message objects to count tokens for. + tool_specs: List of tool specifications to include in the count. + system_prompt: Plain string system prompt. Ignored if system_prompt_content is provided. + system_prompt_content: Structured system prompt content blocks. + + Returns: + Total input token count. + """ + if self.config.get("use_native_token_count") is not True: + return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content) + + try: + # system_prompt_content is not used; this provider only accepts system_prompt as a plain string, + # matching the behavior of stream(). The caller always provides system_prompt alongside + # system_prompt_content, so the plain string is always available. + request = self.format_request(messages, tool_specs, system_prompt) + # Keep only fields accepted by count_tokens; strip inference params (max_tokens, temperature, etc.) + count_tokens_fields = {"model", "messages", "tools", "tool_choice", "system"} + request = {k: request[k] for k in request.keys() & count_tokens_fields} + + response = await self.client.messages.count_tokens(**request) + total_tokens: int = response.input_tokens + + logger.debug( + "model_id=<%s>, total_tokens=<%d> | native token count", + self.config["model_id"], + total_tokens, + ) + return total_tokens + except Exception as e: + logger.debug( + "model_id=<%s>, error=<%s> | native token counting failed, falling back to estimation", + self.config["model_id"], + e, + ) + return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content) + @override async def stream( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, *, tool_choice: ToolChoice | None = None, **kwargs: Any, @@ -401,10 +463,24 @@ async def stream( logger.debug("got response from model") async for event in stream: if event.type in AnthropicModel.EVENT_TYPES: - yield self.format_chunk(event.model_dump()) + if event.type == "message_stop": + # Build dict directly to avoid Pydantic serialization warnings + # when the message contains ParsedTextBlock objects (issue #1746) + yield self.format_chunk( + { + "type": "message_stop", + "message": {"stop_reason": event.message.stop_reason}, + } + ) + else: + yield self.format_chunk(event.model_dump()) - usage = event.message.usage # type: ignore - yield self.format_chunk({"type": "metadata", "usage": usage.model_dump()}) + try: + message_snapshot = await stream.get_final_message() + except AssertionError as e: + logger.warning("error=<%s> | failed to retrieve message snapshot, usage metadata unavailable", e) + else: + yield self.format_chunk({"type": "metadata", "usage": message_snapshot.usage.model_dump()}) except anthropic.RateLimitError as error: raise ModelThrottledException(str(error)) from error @@ -419,8 +495,8 @@ async def stream( @override async def structured_output( - self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any - ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + self, output_model: type[T], prompt: Messages, system_prompt: str | None = None, **kwargs: Any + ) -> AsyncGenerator[dict[str, T | Any], None]: """Get structured output from the model. Args: diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 4a7c81672..4cd6f7fbc 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -8,13 +8,16 @@ import logging import os import warnings -from typing import Any, AsyncGenerator, Callable, Iterable, Literal, Optional, Type, TypeVar, Union, ValuesView, cast +from collections.abc import AsyncGenerator, Callable, Iterable, ValuesView +from typing import Any, Literal, TypeVar, cast import boto3 from botocore.config import Config as BotocoreConfig from botocore.exceptions import ClientError from pydantic import BaseModel -from typing_extensions import TypedDict, Unpack, override +from typing_extensions import Unpack, override + +from strands.types.media import S3Location, SourceLocation from .._exception_notes import add_exception_note from ..event_loop import streaming @@ -24,23 +27,27 @@ from ..types.exceptions import ( ContextWindowOverflowException, ModelThrottledException, + ProviderTokenCountError, ) from ..types.streaming import CitationsDelta, StreamEvent from ..types.tools import ToolChoice, ToolSpec +from ._defaults import resolve_config_metadata +from ._strict_schema import ensure_strict_json_schema from ._validation import validate_config_keys -from .model import Model +from .model import BaseModelConfig, CacheConfig, CacheToolsConfig, Model logger = logging.getLogger(__name__) # See: `BedrockModel._get_default_model_with_warning` for why we need both -DEFAULT_BEDROCK_MODEL_ID = "us.anthropic.claude-sonnet-4-20250514-v1:0" -_DEFAULT_BEDROCK_MODEL_ID = "{}.anthropic.claude-sonnet-4-20250514-v1:0" +DEFAULT_BEDROCK_MODEL_ID = "global.anthropic.claude-sonnet-4-6" +_DEFAULT_BEDROCK_MODEL_ID = "{}.anthropic.claude-sonnet-4-6" DEFAULT_BEDROCK_REGION = "us-west-2" BEDROCK_CONTEXT_WINDOW_OVERFLOW_MESSAGES = [ "Input is too long for requested model", "input length and `max_tokens` exceed context limit", "too many total text bytes", + "prompt is too long", ] # Models that should include tool result status (include_tool_result_status = True) @@ -48,6 +55,15 @@ "anthropic.claude", ] +# Cache of model IDs for which CountTokens API calls should be skipped. +_SKIP_COUNT_TOKENS_MODELS: set[str] = set() + + +def _clear_skip_count_tokens_cache() -> None: + """Clear the cache of model IDs for which CountTokens API calls should be skipped.""" + _SKIP_COUNT_TOKENS_MODELS.clear() + + T = TypeVar("T", bound=BaseModel) DEFAULT_READ_TIMEOUT = 120 @@ -65,15 +81,17 @@ class BedrockModel(Model): - Context window overflow detection """ - class BedrockConfig(TypedDict, total=False): + class BedrockConfig(BaseModelConfig, total=False): """Configuration options for Bedrock models. Attributes: additional_args: Any additional arguments to include in the request additional_request_fields: Additional fields to include in the Bedrock request additional_response_field_paths: Additional response field paths to extract - cache_prompt: Cache point type for the system prompt - cache_tools: Cache point type for tools + cache_prompt: Cache point type for the system prompt (deprecated, use cache_config) + cache_config: Configuration for prompt caching. Use CacheConfig(strategy="auto") for automatic caching. + cache_tools: Cache point type for tools. Pass a string (e.g. "default") for the default 5m TTL, + or a CacheToolsConfig instance to set both type and TTL (e.g. "1h"). guardrail_id: ID of the guardrail to apply guardrail_trace: Guardrail trace mode. Defaults to enabled. guardrail_version: Version of the guardrail to apply @@ -82,44 +100,62 @@ class BedrockConfig(TypedDict, total=False): guardrail_redact_input_message: If a Bedrock Input guardrail triggers, replace the input with this message. guardrail_redact_output: Flag to redact output if guardrail is triggered. Defaults to False. guardrail_redact_output_message: If a Bedrock Output guardrail triggers, replace output with this message. + guardrail_latest_message: Flag to send only the lastest user message to guardrails. + Defaults to False. max_tokens: Maximum number of tokens to generate in the response - model_id: The Bedrock model ID (e.g., "us.anthropic.claude-sonnet-4-20250514-v1:0") + model_id: The Bedrock model ID (e.g., "global.anthropic.claude-sonnet-4-6") include_tool_result_status: Flag to include status field in tool results. True includes status, False removes status, "auto" determines based on model_id. Defaults to "auto". + service_tier: Service tier for the request, controlling the trade-off between latency and cost. + Valid values: "default" (standard), "priority" (faster, premium), "flex" (cheaper, slower). + Please check https://docs.aws.amazon.com/bedrock/latest/userguide/service-tiers-inference.html for + supported service tiers, models, and regions stop_sequences: List of sequences that will stop generation when encountered streaming: Flag to enable/disable streaming. Defaults to True. + strict_tools: Flag to enable structured output enforcement on tool definitions. + When True, adds strict: true to each tool spec and automatically injects + "additionalProperties": false into all object types in tool input schemas. + See https://docs.aws.amazon.com/bedrock/latest/userguide/structured-output.html temperature: Controls randomness in generation (higher = more random) top_p: Controls diversity via nucleus sampling (alternative to temperature) + use_native_token_count: Whether to use the native Bedrock CountTokens API. + When True, count_tokens() calls the Bedrock API for accurate counts. + When False (default), skips the API call and uses the local estimator. """ - additional_args: Optional[dict[str, Any]] - additional_request_fields: Optional[dict[str, Any]] - additional_response_field_paths: Optional[list[str]] - cache_prompt: Optional[str] - cache_tools: Optional[str] - guardrail_id: Optional[str] - guardrail_trace: Optional[Literal["enabled", "disabled", "enabled_full"]] - guardrail_stream_processing_mode: Optional[Literal["sync", "async"]] - guardrail_version: Optional[str] - guardrail_redact_input: Optional[bool] - guardrail_redact_input_message: Optional[str] - guardrail_redact_output: Optional[bool] - guardrail_redact_output_message: Optional[str] - max_tokens: Optional[int] + additional_args: dict[str, Any] | None + additional_request_fields: dict[str, Any] | None + additional_response_field_paths: list[str] | None + cache_prompt: str | None + cache_config: CacheConfig | None + cache_tools: str | CacheToolsConfig | None + guardrail_id: str | None + guardrail_trace: Literal["enabled", "disabled", "enabled_full"] | None + guardrail_stream_processing_mode: Literal["sync", "async"] | None + guardrail_version: str | None + guardrail_redact_input: bool | None + guardrail_redact_input_message: str | None + guardrail_redact_output: bool | None + guardrail_redact_output_message: str | None + guardrail_latest_message: bool | None + max_tokens: int | None model_id: str - include_tool_result_status: Optional[Literal["auto"] | bool] - stop_sequences: Optional[list[str]] - streaming: Optional[bool] - temperature: Optional[float] - top_p: Optional[float] + include_tool_result_status: Literal["auto"] | bool | None + service_tier: str | None + stop_sequences: list[str] | None + streaming: bool | None + strict_tools: bool | None + temperature: float | None + top_p: float | None + use_native_token_count: bool def __init__( self, *, - boto_session: Optional[boto3.Session] = None, - boto_client_config: Optional[BotocoreConfig] = None, - region_name: Optional[str] = None, - endpoint_url: Optional[str] = None, + boto_session: boto3.Session | None = None, + boto_client_config: BotocoreConfig | None = None, + region_name: str | None = None, + endpoint_url: str | None = None, **model_config: Unpack[BedrockConfig], ): """Initialize provider instance. @@ -168,6 +204,17 @@ def __init__( logger.debug("region=<%s> | bedrock client created", self.client.meta.region_name) + @property + def _cache_strategy(self) -> str | None: + """The cache strategy for this model based on its model ID. + + Returns the appropriate cache strategy name, or None if automatic caching is not supported for this model. + """ + model_id = self.config.get("model_id", "").lower() + if "claude" in model_id or "anthropic" in model_id: + return "anthropic" + return None + @override def update_config(self, **model_config: Unpack[BedrockConfig]) -> None: # type: ignore """Update the Bedrock Model configuration with the provided arguments. @@ -185,13 +232,13 @@ def get_config(self) -> BedrockConfig: Returns: The Bedrock model configuration. """ - return self.config + return resolve_config_metadata(self.config, self.config.get("model_id", "")) def _format_request( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt_content: Optional[list[SystemContentBlock]] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt_content: list[SystemContentBlock] | None = None, tool_choice: ToolChoice | None = None, ) -> dict[str, Any]: """Format a Bedrock converse stream request. @@ -199,7 +246,6 @@ def _format_request( Args: messages: List of message objects to be processed by the model. tool_specs: List of tool specifications to make available to the model. - system_prompt: System prompt to provide context to the model. tool_choice: Selection strategy for tool invocation. system_prompt_content: System prompt content blocks to provide context to the model. @@ -215,6 +261,7 @@ def _format_request( # Use system_prompt_content directly (copy for mutability) system_blocks: list[SystemContentBlock] = system_prompt_content.copy() if system_prompt_content else [] + # Add cache point if configured (backwards compatibility) if cache_prompt := self.config.get("cache_prompt"): warnings.warn( @@ -226,6 +273,7 @@ def _format_request( "modelId": self.config["model_id"], "messages": self._format_bedrock_messages(messages), "system": system_blocks, + **({"serviceTier": {"type": self.config["service_tier"]}} if self.config.get("service_tier") else {}), **( { "toolConfig": { @@ -235,16 +283,17 @@ def _format_request( "toolSpec": { "name": tool_spec["name"], "description": tool_spec["description"], - "inputSchema": tool_spec["inputSchema"], + "inputSchema": ( + {"json": ensure_strict_json_schema(tool_spec["inputSchema"]["json"])} + if self.config.get("strict_tools") + else tool_spec["inputSchema"] + ), + **({"strict": True} if self.config.get("strict_tools") else {}), } } for tool_spec in tool_specs ], - *( - [{"cachePoint": {"type": self.config["cache_tools"]}}] - if self.config.get("cache_tools") - else [] - ), + *self._build_tools_cache_point(), ], **({"toolChoice": tool_choice if tool_choice else {"auto": {}}}), } @@ -252,11 +301,7 @@ def _format_request( if tool_specs else {} ), - **( - {"additionalModelRequestFields": self.config["additional_request_fields"]} - if self.config.get("additional_request_fields") - else {} - ), + **(self._get_additional_request_fields(tool_choice)), **( {"additionalModelResponseFieldPaths": self.config["additional_response_field_paths"]} if self.config.get("additional_response_field_paths") @@ -295,6 +340,101 @@ def _format_request( ), } + def _get_additional_request_fields(self, tool_choice: ToolChoice | None) -> dict[str, Any]: + """Get additional request fields, removing thinking if tool_choice forces tool use. + + Bedrock's API does not allow thinking mode when tool_choice forces tool use. + When forcing a tool (e.g., for structured_output retry), we temporarily disable thinking. + + Args: + tool_choice: The tool choice configuration. + + Returns: + A dict containing additionalModelRequestFields if configured, or empty dict. + """ + additional_fields = self.config.get("additional_request_fields") + if not additional_fields: + return {} + + # Check if tool_choice is forcing tool use ("any" or specific "tool") + is_forcing_tool = tool_choice is not None and ("any" in tool_choice or "tool" in tool_choice) + + if is_forcing_tool and "thinking" in additional_fields: + # Create a copy without the thinking key + fields_without_thinking = {k: v for k, v in additional_fields.items() if k != "thinking"} + if fields_without_thinking: + return {"additionalModelRequestFields": fields_without_thinking} + return {} + + return {"additionalModelRequestFields": additional_fields} + + def _build_tools_cache_point(self) -> list[dict[str, Any]]: + """Build the cache point block appended to ``toolConfig.tools`` if ``cache_tools`` is configured. + + Returns: + A single-element list containing the cache point block, or an empty list if no cache_tools is set. + """ + cache_tools = self.config.get("cache_tools") + if not cache_tools: + return [] + + if isinstance(cache_tools, CacheToolsConfig): + cache_point: dict[str, Any] = {"type": cache_tools.type} + if cache_tools.ttl: + cache_point["ttl"] = cache_tools.ttl + else: + cache_point = {"type": cache_tools} + + return [{"cachePoint": cache_point}] + + def _inject_cache_point(self, messages: list[dict[str, Any]]) -> None: + """Inject a cache point at the end of the last user message. + + Args: + messages: List of messages to inject cache point into (modified in place). + """ + if not messages: + return + + last_user_idx: int | None = None + for msg_idx, msg in enumerate(messages): + content = msg.get("content", []) + for block_idx, block in reversed(list(enumerate(content))): + if "cachePoint" in block: + del content[block_idx] + logger.warning( + "msg_idx=<%s>, block_idx=<%s> | stripped existing cache point (auto mode manages cache points)", + msg_idx, + block_idx, + ) + if msg.get("role") == "user": + last_user_idx = msg_idx + + if last_user_idx is not None and messages[last_user_idx].get("content"): + cache_point: dict[str, Any] = {"type": "default"} + cache_config = self.config.get("cache_config") + if cache_config and cache_config.ttl: + cache_point["ttl"] = cache_config.ttl + messages[last_user_idx]["content"].append({"cachePoint": cache_point}) + logger.debug("msg_idx=<%s> | added cache point to last user message", last_user_idx) + + def _find_last_user_text_message_index(self, messages: Messages) -> int | None: + """Find the index of the last user message containing text or image content. + + This is used for guardrail_latest_message to ensure that guardContent wrapping + targets the correct message even when toolResult messages follow. + + Args: + messages: List of messages to search + + Returns: + Index of the last user message with text/image content, or None if not found + """ + for idx, msg in reversed(list(enumerate(messages))): + if msg["role"] == "user" and any("text" in cb or "image" in cb for cb in msg.get("content", [])): + return idx + return None + def _format_bedrock_messages(self, messages: Messages) -> list[dict[str, Any]]: """Format messages for Bedrock API compatibility. @@ -302,6 +442,8 @@ def _format_bedrock_messages(self, messages: Messages) -> list[dict[str, Any]]: - Filtering out SDK_UNKNOWN_MEMBER content blocks - Eagerly filtering content blocks to only include Bedrock-supported fields - Ensuring all message content blocks are properly formatted for the Bedrock API + - Optionally wrapping the last user message in guardrailConverseContent blocks + - Injecting cache points when cache_config is set with strategy="auto" Args: messages: List of messages to format @@ -321,7 +463,14 @@ def _format_bedrock_messages(self, messages: Messages) -> list[dict[str, Any]]: filtered_unknown_members = False dropped_deepseek_reasoning_content = False - for message in messages: + # Pre-compute the index of the last user message containing text or image content. + # This ensures guardContent wrapping is maintained across tool execution cycles, where + # the final message in the list is a toolResult (role=user) rather than text/image content. + last_user_text_idx = None + if self.config.get("guardrail_latest_message", False): + last_user_text_idx = self._find_last_user_text_message_index(messages) + + for idx, message in enumerate(messages): cleaned_content: list[dict[str, Any]] = [] for content_block in message["content"]: @@ -338,6 +487,16 @@ def _format_bedrock_messages(self, messages: Messages) -> list[dict[str, Any]]: # Format content blocks for Bedrock API compatibility formatted_content = self._format_request_message_content(content_block) + if formatted_content is None: + continue + + # Wrap text or image content in guardContent if this is the last user text/image message + if idx == last_user_text_idx and ("text" in formatted_content or "image" in formatted_content): + if "text" in formatted_content: + formatted_content = {"guardContent": {"text": {"text": formatted_content["text"]}}} + elif "image" in formatted_content: + formatted_content = {"guardContent": {"image": formatted_content["image"]}} + cleaned_content.append(formatted_content) # Create new message with cleaned content (skip if empty) @@ -353,6 +512,20 @@ def _format_bedrock_messages(self, messages: Messages) -> list[dict[str, Any]]: "Filtered DeepSeek reasoningContent content blocks from messages - https://api-docs.deepseek.com/guides/reasoning_model#multi-round-conversation" ) + # Inject cache point into cleaned_messages (not original messages) if cache_config is set + cache_config = self.config.get("cache_config") + if cache_config: + strategy: str | None = cache_config.strategy + if strategy == "auto": + strategy = self._cache_strategy + if not strategy: + logger.warning( + "model_id=<%s> | cache_config is enabled but this model does not support automatic caching", + self.config.get("model_id"), + ) + if strategy == "anthropic": + self._inject_cache_point(cleaned_messages) + return cleaned_messages def _should_include_tool_result_status(self) -> bool: @@ -366,7 +539,19 @@ def _should_include_tool_result_status(self) -> bool: else: # "auto" return any(model in self.config["model_id"] for model in _MODELS_INCLUDE_STATUS) - def _format_request_message_content(self, content: ContentBlock) -> dict[str, Any]: + def _handle_location(self, location: SourceLocation) -> dict[str, Any] | None: + """Convert location content block to Bedrock format if its an S3Location.""" + if location["type"] == "s3": + s3_location = cast(S3Location, location) + formatted_document_s3: dict[str, Any] = {"uri": s3_location["uri"]} + if "bucketOwner" in s3_location: + formatted_document_s3["bucketOwner"] = s3_location["bucketOwner"] + return {"s3Location": formatted_document_s3} + else: + logger.warning("Non s3 location sources are not supported by Bedrock | skipping content block") + return None + + def _format_request_message_content(self, content: ContentBlock) -> dict[str, Any] | None: """Format a Bedrock content block. Bedrock strictly validates content blocks and throws exceptions for unknown fields. @@ -383,12 +568,16 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An """ # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_CachePointBlock.html if "cachePoint" in content: - return {"cachePoint": {"type": content["cachePoint"]["type"]}} + cache_point = content["cachePoint"] + result: dict[str, Any] = {"type": cache_point["type"]} + if "ttl" in cache_point: + result["ttl"] = cache_point["ttl"] + return {"cachePoint": result} # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_DocumentBlock.html if "document" in content: document = content["document"] - result: dict[str, Any] = {} + result = {} # Handle required fields (all optional due to total=False) if "name" in document: @@ -396,9 +585,17 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An if "format" in document: result["format"] = document["format"] - # Handle source + # Handle source - supports bytes or location if "source" in document: - result["source"] = {"bytes": document["source"]["bytes"]} + source = document["source"] + formatted_document_source: dict[str, Any] | None + if "location" in source: + formatted_document_source = self._handle_location(source["location"]) + if formatted_document_source is None: + return None + elif "bytes" in source: + formatted_document_source = {"bytes": source["bytes"]} + result["source"] = formatted_document_source # Handle optional fields if "citations" in document and document["citations"] is not None: @@ -419,10 +616,14 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An if "image" in content: image = content["image"] source = image["source"] - formatted_source = {} - if "bytes" in source: - formatted_source = {"bytes": source["bytes"]} - result = {"format": image["format"], "source": formatted_source} + formatted_image_source: dict[str, Any] | None + if "location" in source: + formatted_image_source = self._handle_location(source["location"]) + if formatted_image_source is None: + return None + elif "bytes" in source: + formatted_image_source = {"bytes": source["bytes"]} + result = {"format": image["format"], "source": formatted_image_source} return {"image": result} # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ReasoningContentBlock.html @@ -451,15 +652,25 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolResultBlock.html if "toolResult" in content: tool_result = content["toolResult"] + # Normalize empty toolResult content arrays. + # Some model providers (e.g., Nemotron) reject toolResult blocks with + # content: [] via the Converse API, while others (e.g., Claude) accept + # them. Replace empty content with a minimal text block to ensure + # cross-model compatibility. This follows the same pattern as the + # TypeScript SDK's _formatMessages in bedrock.ts. + tool_result_content_list = tool_result.get("content") or [{"text": ""}] formatted_content: list[dict[str, Any]] = [] - for tool_result_content in tool_result["content"]: + for tool_result_content in tool_result_content_list: if "json" in tool_result_content: # Handle json field since not in ContentBlock but valid in ToolResultContent formatted_content.append({"json": tool_result_content["json"]}) else: - formatted_content.append( - self._format_request_message_content(cast(ContentBlock, tool_result_content)) + formatted_message_content = self._format_request_message_content( + cast(ContentBlock, tool_result_content) ) + if formatted_message_content is None: + continue + formatted_content.append(formatted_message_content) result = { "content": formatted_content, @@ -484,10 +695,14 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An if "video" in content: video = content["video"] source = video["source"] - formatted_source = {} - if "bytes" in source: - formatted_source = {"bytes": source["bytes"]} - result = {"format": video["format"], "source": formatted_source} + formatted_video_source: dict[str, Any] | None + if "location" in source: + formatted_video_source = self._handle_location(source["location"]) + if formatted_video_source is None: + return None + elif "bytes" in source: + formatted_video_source = {"bytes": source["bytes"]} + result = {"format": video["format"], "source": formatted_video_source} return {"video": result} # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_CitationsContentBlock.html @@ -500,16 +715,7 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An for citation in citations["citations"]: filtered_citation: dict[str, Any] = {} if "location" in citation: - location = citation["location"] - filtered_location = {} - # Filter location fields to only include Bedrock-supported ones - if "documentIndex" in location: - filtered_location["documentIndex"] = location["documentIndex"] - if "start" in location: - filtered_location["start"] = location["start"] - if "end" in location: - filtered_location["end"] = location["end"] - filtered_citation["location"] = filtered_location + filtered_citation["location"] = citation["location"] if "sourceContent" in citation: filtered_source_content: list[dict[str, Any]] = [] for source_content in citation["sourceContent"]: @@ -590,15 +796,101 @@ def _generate_redaction_events(self) -> list[StreamEvent]: return events + @override + async def count_tokens( + self, + messages: Messages, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, + system_prompt_content: list[SystemContentBlock] | None = None, + ) -> int: + """Count tokens using Bedrock's native CountTokens API. + + Uses the same message format as the Converse API to get accurate token counts + directly from the Bedrock service. + + Args: + messages: List of message objects to count tokens for. + tool_specs: List of tool specifications to include in the count. + system_prompt: Plain string system prompt. Ignored if system_prompt_content is provided. + system_prompt_content: Structured system prompt content blocks. + + Returns: + Total input token count. + """ + if self.config.get("use_native_token_count") is not True: + return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content) + + model_id: str = self.config["model_id"] + + if model_id in _SKIP_COUNT_TOKENS_MODELS: + return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content) + + try: + if system_prompt and system_prompt_content is None: + system_prompt_content = [{"text": system_prompt}] + + request = self._format_request(messages, tool_specs, system_prompt_content) + converse_input: dict[str, Any] = {} + if "messages" in request: + converse_input["messages"] = request["messages"] + if "system" in request: + converse_input["system"] = request["system"] + if "toolConfig" in request: + converse_input["toolConfig"] = request["toolConfig"] + + response = await asyncio.to_thread( + self.client.count_tokens, + modelId=self.config["model_id"], + input={"converse": converse_input}, + ) + input_tokens = response.get("inputTokens") + if input_tokens is None: + raise ProviderTokenCountError("Bedrock count_tokens returned None for inputTokens") + total_tokens: int = input_tokens + + logger.debug("model_id=<%s>, total_tokens=<%d> | native token count", self.config["model_id"], total_tokens) + return total_tokens + except Exception as e: + if ( + isinstance(e, ClientError) + and e.response.get("Error", {}).get("Code") == "AccessDeniedException" + ): + logger.warning( + "model_id=<%s> | bedrock:CountTokens permission denied," + " falling back to heuristic estimation: %s", + model_id, + e, + ) + _SKIP_COUNT_TOKENS_MODELS.add(model_id) + elif ( + isinstance(e, ClientError) + and e.response.get("Error", {}).get("Code") == "ValidationException" + and "doesn't support counting tokens" in str(e) + ): + logger.debug( + "model_id=<%s> | model does not support CountTokens, caching for future calls," + " falling back to estimation", + model_id, + ) + _SKIP_COUNT_TOKENS_MODELS.add(model_id) + else: + logger.debug( + "model_id=<%s>, error=<%s> | native token counting failed, falling back to estimation", + model_id, + e, + ) + return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content) + @override async def stream( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, *, tool_choice: ToolChoice | None = None, - system_prompt_content: Optional[list[SystemContentBlock]] = None, + system_prompt_content: list[SystemContentBlock] | None = None, **kwargs: Any, ) -> AsyncGenerator[StreamEvent, None]: """Stream conversation with the Bedrock model. @@ -622,13 +914,13 @@ async def stream( ModelThrottledException: If the model service is throttling requests. """ - def callback(event: Optional[StreamEvent] = None) -> None: + def callback(event: StreamEvent | None = None) -> None: loop.call_soon_threadsafe(queue.put_nowait, event) if event is None: return loop = asyncio.get_event_loop() - queue: asyncio.Queue[Optional[StreamEvent]] = asyncio.Queue() + queue: asyncio.Queue[StreamEvent | None] = asyncio.Queue() # Handle backward compatibility: if system_prompt is provided but system_prompt_content is None if system_prompt and system_prompt_content is None: @@ -650,8 +942,8 @@ def _stream( self, callback: Callable[..., None], messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt_content: Optional[list[SystemContentBlock]] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt_content: list[SystemContentBlock] | None = None, tool_choice: ToolChoice | None = None, ) -> None: """Stream conversation with the Bedrock model. @@ -681,8 +973,6 @@ def _stream( logger.debug("got response from model") if streaming: response = self.client.converse_stream(**request) - # Track tool use events to fix stopReason for streaming responses - has_tool_use = False for chunk in response["stream"]: if ( "metadata" in chunk @@ -694,24 +984,7 @@ def _stream( for event in self._generate_redaction_events(): callback(event) - # Track if we see tool use events - if "contentBlockStart" in chunk and chunk["contentBlockStart"].get("start", {}).get("toolUse"): - has_tool_use = True - - # Fix stopReason for streaming responses that contain tool use - if ( - has_tool_use - and "messageStop" in chunk - and (message_stop := chunk["messageStop"]).get("stopReason") == "end_turn" - ): - # Create corrected chunk with tool_use stopReason - modified_chunk = chunk.copy() - modified_chunk["messageStop"] = message_stop.copy() - modified_chunk["messageStop"]["stopReason"] = "tool_use" - logger.warning("Override stop reason from end_turn to tool_use") - callback(modified_chunk) - else: - callback(chunk) + callback(chunk) else: response = self.client.converse(**request) @@ -837,30 +1110,24 @@ def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Itera } for citation in content["citationsContent"]["citations"]: - # Then emit citation metadata (for structure) - - citation_metadata: CitationsDelta = { - "title": citation["title"], - "location": citation["location"], - "sourceContent": citation["sourceContent"], - } + # Emit citation metadata, only including fields that are present + # Nova grounding may omit title/sourceContent + citation_metadata: CitationsDelta = {} + if "title" in citation: + citation_metadata["title"] = citation["title"] + if "location" in citation: + citation_metadata["location"] = citation["location"] + if "sourceContent" in citation: + citation_metadata["sourceContent"] = citation["sourceContent"] yield {"contentBlockDelta": {"delta": {"citation": citation_metadata}}} # Yield contentBlockStop event yield {"contentBlockStop": {}} # Yield messageStop event - # Fix stopReason for models that return end_turn when they should return tool_use on non-streaming side - current_stop_reason = response["stopReason"] - if current_stop_reason == "end_turn": - message_content = response["output"]["message"]["content"] - if any("toolUse" in content for content in message_content): - current_stop_reason = "tool_use" - logger.warning("Override stop reason from end_turn to tool_use") - yield { "messageStop": { - "stopReason": current_stop_reason, + "stopReason": response["stopReason"], "additionalModelResponseFields": response.get("additionalModelResponseFields"), } } @@ -904,11 +1171,11 @@ def _find_detected_and_blocked_policy(self, input: Any) -> bool: @override async def structured_output( self, - output_model: Type[T], + output_model: type[T], prompt: Messages, - system_prompt: Optional[str] = None, + system_prompt: str | None = None, **kwargs: Any, - ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + ) -> AsyncGenerator[dict[str, T | Any], None]: """Get structured output from the model. Args: @@ -953,7 +1220,7 @@ async def structured_output( yield {"output": output_model(**output_response)} @staticmethod - def _get_default_model_with_warning(region_name: str, model_config: Optional[BedrockConfig] = None) -> str: + def _get_default_model_with_warning(region_name: str, model_config: BedrockConfig | None = None) -> str: """Get the default Bedrock modelId based on region. If the region is not **known** to support inference then we show a helpful warning @@ -965,13 +1232,13 @@ def _get_default_model_with_warning(region_name: str, model_config: Optional[Bed region_name (str): region for bedrock model model_config (Optional[dict[str, Any]]): Model Config that caller passes in on init """ - if DEFAULT_BEDROCK_MODEL_ID != _DEFAULT_BEDROCK_MODEL_ID.format("us"): - return DEFAULT_BEDROCK_MODEL_ID - model_config = model_config or {} if model_config.get("model_id"): return model_config["model_id"] + if DEFAULT_BEDROCK_MODEL_ID != _DEFAULT_BEDROCK_MODEL_ID.format("us"): + return DEFAULT_BEDROCK_MODEL_ID + prefix_inference_map = {"ap": "apac"} # some inference endpoints can be a bit different than the region prefix prefix = "-".join(region_name.split("-")[:-2]).lower() # handles `us-east-1` or `us-gov-east-1` @@ -993,4 +1260,10 @@ def _get_default_model_with_warning(region_name: str, model_config: Optional[Bed stacklevel=2, ) - return _DEFAULT_BEDROCK_MODEL_ID.format(prefix_inference_map.get(prefix, prefix)) + default_model_id = _DEFAULT_BEDROCK_MODEL_ID.format(prefix_inference_map.get(prefix, prefix)) + warnings.warn( + f"You're using default model '{default_model_id}', which is subject to change. " + "Specify a model explicitly to pin the model target.", + stacklevel=2, + ) + return default_model_id diff --git a/src/strands/models/gemini.py b/src/strands/models/gemini.py index c24d91a0d..43e4f0349 100644 --- a/src/strands/models/gemini.py +++ b/src/strands/models/gemini.py @@ -3,21 +3,25 @@ - Docs: https://ai.google.dev/api """ +import base64 import json import logging import mimetypes -from typing import Any, AsyncGenerator, Optional, Type, TypedDict, TypeVar, Union, cast +import secrets +from collections.abc import AsyncGenerator +from typing import Any, TypeVar, cast import pydantic from google import genai from typing_extensions import Required, Unpack, override -from ..types.content import ContentBlock, Messages -from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException +from ..types.content import ContentBlock, ContentBlockStartToolUse, Messages, SystemContentBlock +from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException, ProviderTokenCountError from ..types.streaming import StreamEvent from ..types.tools import ToolChoice, ToolSpec -from ._validation import validate_config_keys -from .model import Model +from ._defaults import resolve_config_metadata +from ._validation import _has_location_source, validate_config_keys +from .model import BaseModelConfig, Model logger = logging.getLogger(__name__) @@ -30,7 +34,7 @@ class GeminiModel(Model): - Docs: https://ai.google.dev/api """ - class GeminiConfig(TypedDict, total=False): + class GeminiConfig(BaseModelConfig, total=False): """Configuration options for Gemini models. Attributes: @@ -40,31 +44,62 @@ class GeminiConfig(TypedDict, total=False): params: Additional model parameters (e.g., temperature). For a complete list of supported parameters, see https://ai.google.dev/api/generate-content#generationconfig. + gemini_tools: Gemini-specific tools that are not FunctionDeclarations + (e.g., GoogleSearch, CodeExecution, ComputerUse, UrlContext, FileSearch). + Use the standard tools interface for function calling tools. + For a complete list of supported tools, see + https://ai.google.dev/api/caching#Tool + use_native_token_count: Whether to use the native Gemini count_tokens API. + When True, count_tokens() calls the Gemini API for accurate counts. + When False (default), skips the API call and uses the local estimator. """ model_id: Required[str] params: dict[str, Any] + gemini_tools: list[genai.types.Tool] + use_native_token_count: bool def __init__( self, *, - client_args: Optional[dict[str, Any]] = None, + client: genai.Client | None = None, + client_args: dict[str, Any] | None = None, **model_config: Unpack[GeminiConfig], ) -> None: """Initialize provider instance. Args: + client: Pre-configured Gemini client to reuse across requests. + When provided, this client will be reused for all requests and will NOT be closed + by the model. The caller is responsible for managing the client lifecycle. + This is useful for: + - Injecting custom client wrappers + - Reusing connection pools within a single event loop/worker + - Centralizing observability, retries, and networking policy + Note: The client should not be shared across different asyncio event loops. client_args: Arguments for the underlying Gemini client (e.g., api_key). For a complete list of supported arguments, see https://googleapis.github.io/python-genai/. **model_config: Configuration options for the Gemini model. + + Raises: + ValueError: If both `client` and `client_args` are provided. """ validate_config_keys(model_config, GeminiModel.GeminiConfig) self.config = GeminiModel.GeminiConfig(**model_config) - logger.debug("config=<%s> | initializing", self.config) + # Validate that only one client configuration method is provided + if client is not None and client_args is not None and len(client_args) > 0: + raise ValueError("Only one of 'client' or 'client_args' should be provided, not both.") + self._custom_client = client self.client_args = client_args or {} + # Validate gemini_tools if provided + if "gemini_tools" in self.config: + self._validate_gemini_tools(self.config["gemini_tools"]) + + logger.debug("config=<%s> | initializing", self.config) + @override def update_config(self, **model_config: Unpack[GeminiConfig]) -> None: # type: ignore[override] """Update the Gemini model configuration with the provided arguments. @@ -72,6 +107,10 @@ def update_config(self, **model_config: Unpack[GeminiConfig]) -> None: # type: Args: **model_config: Configuration overrides. """ + # Validate gemini_tools if provided + if "gemini_tools" in model_config: + self._validate_gemini_tools(model_config["gemini_tools"]) + self.config.update(model_config) @override @@ -81,15 +120,39 @@ def get_config(self) -> GeminiConfig: Returns: The Gemini model configuration. """ - return self.config + return resolve_config_metadata(self.config, self.config["model_id"]) + + def _get_client(self) -> genai.Client: + """Get a Gemini client for making requests. + + This method handles client lifecycle management: + - If an injected client was provided during initialization, it returns that client + without managing its lifecycle (caller is responsible for cleanup). + - Otherwise, creates a new genai.Client from client_args. - def _format_request_content_part(self, content: ContentBlock) -> genai.types.Part: + Returns: + genai.Client: A Gemini client instance. + """ + if self._custom_client is not None: + # Use the injected client (caller manages lifecycle) + return self._custom_client + else: + # Create a new client from client_args + return genai.Client(**self.client_args) + + def _format_request_content_part( + self, content: ContentBlock, tool_use_id_to_name: dict[str, str] + ) -> genai.types.Part: """Format content block into a Gemini part instance. - Docs: https://googleapis.github.io/python-genai/genai.html#genai.types.Part Args: content: Message content to format. + tool_use_id_to_name: Mapping of tool use id to tool name. + Store the mapping from toolUseId to name for later use in toolResult formatting. This mapping is built + as we format the request, ensuring that when we encounter toolResult blocks (which come after toolUse + blocks in the message history), we can look up the function name. Returns: Gemini part. @@ -116,23 +179,27 @@ def _format_request_content_part(self, content: ContentBlock) -> genai.types.Par return genai.types.Part( text=content["reasoningContent"]["reasoningText"]["text"], thought=True, - thought_signature=thought_signature.encode("utf-8") if thought_signature else None, + thought_signature=base64.b64decode(thought_signature) if thought_signature else None, ) if "text" in content: return genai.types.Part(text=content["text"]) if "toolResult" in content: + tool_use_id = content["toolResult"]["toolUseId"] + function_name = tool_use_id_to_name.get(tool_use_id, tool_use_id) + return genai.types.Part( function_response=genai.types.FunctionResponse( - id=content["toolResult"]["toolUseId"], - name=content["toolResult"]["toolUseId"], + id=tool_use_id, + name=function_name, response={ "output": [ tool_result_content if "json" in tool_result_content else self._format_request_content_part( - cast(ContentBlock, tool_result_content) + cast(ContentBlock, tool_result_content), + tool_use_id_to_name, ).to_json_dict() for tool_result_content in content["toolResult"]["content"] ], @@ -141,12 +208,18 @@ def _format_request_content_part(self, content: ContentBlock) -> genai.types.Par ) if "toolUse" in content: + tool_use_id = content["toolUse"]["toolUseId"] + tool_use_id_to_name[tool_use_id] = content["toolUse"]["name"] + + reasoning_signature = content["toolUse"].get("reasoningSignature") + return genai.types.Part( function_call=genai.types.FunctionCall( args=content["toolUse"]["input"], - id=content["toolUse"]["toolUseId"], + id=tool_use_id, name=content["toolUse"]["name"], ), + thought_signature=base64.b64decode(reasoning_signature) if reasoning_signature else None, ) raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") @@ -162,15 +235,30 @@ def _format_request_content(self, messages: Messages) -> list[genai.types.Conten Returns: Gemini content list. """ - return [ - genai.types.Content( - parts=[self._format_request_content_part(content) for content in message["content"]], - role="user" if message["role"] == "user" else "model", + # Gemini FunctionResponses are constructed from tool result blocks. Function name is required but is not + # available in tool result blocks, hence the mapping. + tool_use_id_to_name: dict[str, str] = {} + + contents = [] + for message in messages: + parts = [] + for content in message["content"]: + # Check for location sources and skip with warning + if _has_location_source(content): + logger.warning("Location sources are not supported by Gemini | skipping content block") + continue + parts.append(self._format_request_content_part(content, tool_use_id_to_name)) + + contents.append( + genai.types.Content( + parts=parts, + role="user" if message["role"] == "user" else "model", + ) ) - for message in messages - ] - def _format_request_tools(self, tool_specs: Optional[list[ToolSpec]]) -> list[genai.types.Tool | Any]: + return contents + + def _format_request_tools(self, tool_specs: list[ToolSpec] | None) -> list[genai.types.Tool | Any]: """Format tool specs into Gemini tools. - Docs: https://googleapis.github.io/python-genai/genai.html#genai.types.Tool @@ -181,7 +269,7 @@ def _format_request_tools(self, tool_specs: Optional[list[ToolSpec]]) -> list[ge Return: Gemini tool list. """ - return [ + tools = [ genai.types.Tool( function_declarations=[ genai.types.FunctionDeclaration( @@ -193,12 +281,15 @@ def _format_request_tools(self, tool_specs: Optional[list[ToolSpec]]) -> list[ge ], ), ] + if self.config.get("gemini_tools"): + tools.extend(self.config["gemini_tools"]) + return tools def _format_request_config( self, - tool_specs: Optional[list[ToolSpec]], - system_prompt: Optional[str], - params: Optional[dict[str, Any]], + tool_specs: list[ToolSpec] | None, + system_prompt: str | None, + params: dict[str, Any] | None, ) -> genai.types.GenerateContentConfig: """Format Gemini request config. @@ -221,9 +312,9 @@ def _format_request_config( def _format_request( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]], - system_prompt: Optional[str], - params: Optional[dict[str, Any]], + tool_specs: list[ToolSpec] | None, + system_prompt: str | None, + params: dict[str, Any] | None, ) -> dict[str, Any]: """Format a Gemini streaming request. @@ -264,17 +355,22 @@ def _format_chunk(self, event: dict[str, Any]) -> StreamEvent: case "content_start": match event["data_type"]: case "tool": - # Note: toolUseId is the only identifier available in a tool result. However, Gemini requires - # that name be set in the equivalent FunctionResponse type. Consequently, we assign - # function name to toolUseId in our tool use block. And another reason, function_call is - # not guaranteed to have id populated. + function_call = event["data"].function_call + # Use Gemini's provided ID or generate one if missing + tool_use_id = function_call.id or f"tooluse_{secrets.token_urlsafe(16)}" + + tool_use_start: ContentBlockStartToolUse = { + "name": function_call.name, + "toolUseId": tool_use_id, + } + if event["data"].thought_signature: + tool_use_start["reasoningSignature"] = base64.b64encode( + event["data"].thought_signature + ).decode("ascii") return { "contentBlockStart": { "start": { - "toolUse": { - "name": event["data"].function_call.name, - "toolUseId": event["data"].function_call.name, - }, + "toolUse": tool_use_start, }, }, } @@ -298,7 +394,11 @@ def _format_chunk(self, event: dict[str, Any]) -> StreamEvent: "reasoningContent": { "text": event["data"].text, **( - {"signature": event["data"].thought_signature.decode("utf-8")} + { + "signature": base64.b64encode(event["data"].thought_signature).decode( + "ascii" + ) + } if event["data"].thought_signature else {} ), @@ -339,11 +439,73 @@ def _format_chunk(self, event: dict[str, Any]) -> StreamEvent: case _: # pragma: no cover raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type") + @override + async def count_tokens( + self, + messages: Messages, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, + system_prompt_content: list[SystemContentBlock] | None = None, + ) -> int: + """Count tokens using Gemini's native count_tokens API. + + Uses the Gemini count_tokens API for message contents. The Gemini API does not support + counting system_instruction or tools, so those are estimated via the base class heuristic. + + Args: + messages: List of message objects to count tokens for. + tool_specs: List of tool specifications to include in the count. + system_prompt: Plain string system prompt. + system_prompt_content: Structured system prompt content blocks. + + Returns: + Total input token count. + """ + if self.config.get("use_native_token_count") is not True: + return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content) + + try: + contents = list(self._format_request_content(messages)) + + client = self._get_client().aio + response = await client.models.count_tokens( + model=self.config["model_id"], + contents=contents, + ) + if response.total_tokens is None: + raise ProviderTokenCountError("Gemini count_tokens returned None for total_tokens") + total_tokens: int = response.total_tokens + + # The google-genai SDK explicitly raises ValueError for system_instruction, tools, and + # generation_config in CountTokensConfig on the non-Vertex (mldev) backend. + # Use heuristic for these. + extra = await super().count_tokens( + messages=[], + tool_specs=tool_specs, + system_prompt=system_prompt, + system_prompt_content=system_prompt_content, + ) + total_tokens += extra + + logger.debug( + "model_id=<%s>, total_tokens=<%d> | native token count", + self.config["model_id"], + total_tokens, + ) + return total_tokens + except Exception as e: + logger.debug( + "model_id=<%s>, error=<%s> | native token counting failed, falling back to estimation", + self.config["model_id"], + e, + ) + return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content) + async def stream( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, tool_choice: ToolChoice | None = None, **kwargs: Any, ) -> AsyncGenerator[StreamEvent, None]: @@ -365,14 +527,17 @@ async def stream( """ request = self._format_request(messages, tool_specs, system_prompt, self.config.get("params")) - client = genai.Client(**self.client_args).aio + client = self._get_client().aio + try: response = await client.models.generate_content_stream(**request) yield self._format_chunk({"chunk_type": "message_start"}) - yield self._format_chunk({"chunk_type": "content_start", "data_type": "text"}) + data_type: str | None = None tool_used = False + candidate = None + event = None async for event in response: candidates = event.candidates candidate = candidates[0] if candidates else None @@ -387,22 +552,30 @@ async def stream( tool_used = True if part.text: + new_data_type = "reasoning_content" if part.thought else "text" + if new_data_type != data_type: + if data_type is not None: + yield self._format_chunk({"chunk_type": "content_stop", "data_type": data_type}) + yield self._format_chunk({"chunk_type": "content_start", "data_type": new_data_type}) + data_type = new_data_type yield self._format_chunk( { "chunk_type": "content_delta", - "data_type": "reasoning_content" if part.thought else "text", + "data_type": data_type, "data": part, }, ) - yield self._format_chunk({"chunk_type": "content_stop", "data_type": "text"}) + if data_type is not None: + yield self._format_chunk({"chunk_type": "content_stop", "data_type": data_type}) yield self._format_chunk( { "chunk_type": "message_stop", "data": "TOOL_USE" if tool_used else (candidate.finish_reason if candidate else "STOP"), } ) - yield self._format_chunk({"chunk_type": "metadata", "data": event.usage_metadata}) + if event: + yield self._format_chunk({"chunk_type": "metadata", "data": event.usage_metadata}) except genai.errors.ClientError as error: if not error.message: @@ -427,8 +600,8 @@ async def stream( @override async def structured_output( - self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any - ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + self, output_model: type[T], prompt: Messages, system_prompt: str | None = None, **kwargs: Any + ) -> AsyncGenerator[dict[str, T | Any], None]: """Get structured output from the model using Gemini's native structured output. - Docs: https://ai.google.dev/gemini-api/docs/structured-output @@ -448,6 +621,30 @@ async def structured_output( "response_schema": output_model.model_json_schema(), } request = self._format_request(prompt, None, system_prompt, params) - client = genai.Client(**self.client_args).aio + client = self._get_client().aio response = await client.models.generate_content(**request) yield {"output": output_model.model_validate(response.parsed)} + + @staticmethod + def _validate_gemini_tools(gemini_tools: list[genai.types.Tool]) -> None: + """Validate that gemini_tools does not contain FunctionDeclarations. + + Gemini-specific tools should only include tools that cannot be represented + as FunctionDeclarations (e.g., GoogleSearch, CodeExecution, ComputerUse). + Standard function calling tools should use the tools interface instead. + + Args: + gemini_tools: List of Gemini tools to validate + + Raises: + ValueError: If any tool contains function_declarations + """ + for tool in gemini_tools: + # Check if the tool has function_declarations attribute and it's not empty + if hasattr(tool, "function_declarations") and tool.function_declarations: + raise ValueError( + "gemini_tools should not contain FunctionDeclarations. " + "Use the standard tools interface for function calling tools. " + "gemini_tools is reserved for Gemini-specific tools like " + "GoogleSearch, CodeExecution, ComputerUse, UrlContext, and FileSearch." + ) diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index 1f1e999d2..9fbdff794 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -5,7 +5,8 @@ import json import logging -from typing import Any, AsyncGenerator, Optional, Type, TypedDict, TypeVar, Union, cast +from collections.abc import AsyncGenerator +from typing import Any, TypeVar, cast import litellm from litellm.exceptions import ContextWindowExceededError @@ -18,19 +19,24 @@ from ..types.event_loop import Usage from ..types.exceptions import ContextWindowOverflowException from ..types.streaming import MetadataEvent, StreamEvent -from ..types.tools import ToolChoice, ToolSpec +from ..types.tools import ToolChoice, ToolSpec, ToolUse from ._validation import validate_config_keys +from .model import BaseModelConfig from .openai import OpenAIModel logger = logging.getLogger(__name__) +# Separator used by LiteLLM to embed thought signatures inside tool call IDs. +# See: https://ai.google.dev/gemini-api/docs/thought-signatures +_THOUGHT_SIGNATURE_SEPARATOR = "__thought__" + T = TypeVar("T", bound=BaseModel) class LiteLLMModel(OpenAIModel): """LiteLLM model provider implementation.""" - class LiteLLMConfig(TypedDict, total=False): + class LiteLLMConfig(BaseModelConfig, total=False): """Configuration options for LiteLLM models. Attributes: @@ -42,9 +48,9 @@ class LiteLLMConfig(TypedDict, total=False): """ model_id: str - params: Optional[dict[str, Any]] + params: dict[str, Any] | None - def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[LiteLLMConfig]) -> None: + def __init__(self, client_args: dict[str, Any] | None = None, **model_config: Unpack[LiteLLMConfig]) -> None: """Initialize provider instance. Args: @@ -113,6 +119,61 @@ def format_request_message_content(cls, content: ContentBlock, **kwargs: Any) -> return super().format_request_message_content(content) + @override + @classmethod + def format_request_message_tool_call(cls, tool_use: ToolUse, **kwargs: Any) -> dict[str, Any]: + """Format a LiteLLM compatible tool call, encoding thought signatures into the tool call ID. + + Gemini thinking models attach a thought_signature to each function call. LiteLLM's OpenAI-compatible + interface embeds this signature inside the tool call ID using the ``__thought__`` separator. When + ``reasoningSignature`` is present and the tool call ID does not already contain the separator, this + method encodes it so LiteLLM can reconstruct the Gemini-native format on the next request. + + Args: + tool_use: Tool use requested by the model. + **kwargs: Additional keyword arguments for future extensibility. + + Returns: + LiteLLM compatible tool call dict with thought signature encoded in the ID when present. + """ + tool_call = super().format_request_message_tool_call(tool_use, **kwargs) + + reasoning_signature = tool_use.get("reasoningSignature") + if reasoning_signature and _THOUGHT_SIGNATURE_SEPARATOR not in tool_call["id"]: + tool_call["id"] = f"{tool_call['id']}{_THOUGHT_SIGNATURE_SEPARATOR}{reasoning_signature}" + + return tool_call + + @staticmethod + def _extract_thought_signature(data: Any) -> str | None: + """Extract thought signature from a tool call event data. + + LiteLLM surfaces Gemini thought signatures in two ways: + + 1. ``provider_specific_fields.thought_signature`` — a structured field set by LiteLLM's Gemini response + transformer. Checked first as it doesn't depend on matching an internal string constant. + 2. ``__thought__`` separator encoded in the tool call ID. Used as fallback since it relies on a copy of + LiteLLM's internal ``THOUGHT_SIGNATURE_SEPARATOR`` constant. + + Args: + data: Tool call event data object. + + Returns: + The extracted thought signature, or None if not present. + """ + # Preferred: structured field that doesn't depend on matching an internal separator string + psf = getattr(data, "provider_specific_fields", None) or {} + if isinstance(psf, dict) and psf.get("thought_signature"): + return str(psf["thought_signature"]) + + # Fallback: extract from encoded ID (relies on hardcoded copy of LiteLLM's separator) + tool_call_id = getattr(data, "id", None) or "" + if isinstance(tool_call_id, str) and _THOUGHT_SIGNATURE_SEPARATOR in tool_call_id: + _, signature = tool_call_id.split(_THOUGHT_SIGNATURE_SEPARATOR, 1) + return signature + + return None + def _stream_switch_content(self, data_type: str, prev_data_type: str | None) -> tuple[list[StreamEvent], str]: """Handle switching to a new content stream. @@ -137,9 +198,9 @@ def _stream_switch_content(self, data_type: str, prev_data_type: str | None) -> @classmethod def _format_system_messages( cls, - system_prompt: Optional[str] = None, + system_prompt: str | None = None, *, - system_prompt_content: Optional[list[SystemContentBlock]] = None, + system_prompt_content: list[SystemContentBlock] | None = None, **kwargs: Any, ) -> list[dict[str, Any]]: """Format system messages for LiteLLM with cache point support. @@ -160,11 +221,14 @@ def _format_system_messages( for block in system_prompt_content or []: if "text" in block: system_content.append({"type": "text", "text": block["text"]}) - elif "cachePoint" in block and block["cachePoint"].get("type") == "default": + elif "cachePoint" in block and block["cachePoint"]["type"] == "default": # Apply cache control to the immediately preceding content block # for LiteLLM/Anthropic compatibility if system_content: - system_content[-1]["cache_control"] = {"type": "ephemeral"} + cache_control: dict[str, Any] = {"type": "ephemeral"} + if ttl := block["cachePoint"].get("ttl"): + cache_control["ttl"] = ttl + system_content[-1]["cache_control"] = cache_control # Create single system message with content array rather than mulitple system messages return [{"role": "system", "content": system_content}] if system_content else [] @@ -174,9 +238,9 @@ def _format_system_messages( def format_request_messages( cls, messages: Messages, - system_prompt: Optional[str] = None, + system_prompt: str | None = None, *, - system_prompt_content: Optional[list[SystemContentBlock]] = None, + system_prompt_content: list[SystemContentBlock] | None = None, **kwargs: Any, ) -> list[dict[str, Any]]: """Format a LiteLLM compatible messages array with cache point support. @@ -193,14 +257,15 @@ def format_request_messages( formatted_messages = cls._format_system_messages(system_prompt, system_prompt_content=system_prompt_content) formatted_messages.extend(cls._format_regular_messages(messages)) - return [message for message in formatted_messages if message["content"] or "tool_calls" in message] + return [message for message in formatted_messages if "content" in message or "tool_calls" in message] @override def format_chunk(self, event: dict[str, Any], **kwargs: Any) -> StreamEvent: """Format a LiteLLM response event into a standardized message chunk. - This method overrides OpenAI's format_chunk to handle the metadata case - with prompt caching support. All other chunk types use the parent implementation. + Extends OpenAI's format_chunk to: + 1. Handle metadata with prompt caching support. + 2. Extract thought signatures that LiteLLM embeds in tool call IDs for Gemini thinking models. Args: event: A response event from the LiteLLM model. @@ -236,6 +301,17 @@ def format_chunk(self, event: dict[str, Any], **kwargs: Any) -> StreamEvent: usage=usage_data, ) ) + + # Extract thought signature from tool call content_start events. + # The full encoded ID is kept in toolUseId so that tool result messages continue to match. + if event["chunk_type"] == "content_start" and event.get("data_type") == "tool": + signature = self._extract_thought_signature(event.get("data")) + chunk = super().format_chunk(event) + if signature: + tool_use_dict = cast(dict, chunk["contentBlockStart"]["start"]["toolUse"]) + tool_use_dict["reasoningSignature"] = signature + return chunk + # For all other cases, use the parent implementation return super().format_chunk(event) @@ -243,11 +319,11 @@ def format_chunk(self, event: dict[str, Any], **kwargs: Any) -> StreamEvent: async def stream( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, *, tool_choice: ToolChoice | None = None, - system_prompt_content: Optional[list[SystemContentBlock]] = None, + system_prompt_content: list[SystemContentBlock] | None = None, **kwargs: Any, ) -> AsyncGenerator[StreamEvent, None]: """Stream conversation with the LiteLLM model. @@ -269,80 +345,34 @@ async def stream( ) logger.debug("request=<%s>", request) - logger.debug("invoking model") - try: - if kwargs.get("stream") is False: - raise ValueError("stream parameter cannot be explicitly set to False") - response = await litellm.acompletion(**self.client_args, **request) - except ContextWindowExceededError as e: - logger.warning("litellm client raised context window overflow") - raise ContextWindowOverflowException(e) from e + # Check if streaming is disabled in the params + config = self.get_config() + params = config.get("params") or {} + is_streaming = params.get("stream", True) - logger.debug("got response from model") - yield self.format_chunk({"chunk_type": "message_start"}) + litellm_request = {**request} - tool_calls: dict[int, list[Any]] = {} - data_type: str | None = None + litellm_request["stream"] = is_streaming - async for event in response: - # Defensive: skip events with empty or missing choices - if not getattr(event, "choices", None): - continue - choice = event.choices[0] + logger.debug("invoking model with stream=%s", litellm_request.get("stream")) - if hasattr(choice.delta, "reasoning_content") and choice.delta.reasoning_content: - chunks, data_type = self._stream_switch_content("reasoning_content", data_type) - for chunk in chunks: + try: + if is_streaming: + async for chunk in self._handle_streaming_response(litellm_request): yield chunk - - yield self.format_chunk( - { - "chunk_type": "content_delta", - "data_type": data_type, - "data": choice.delta.reasoning_content, - } - ) - - if choice.delta.content: - chunks, data_type = self._stream_switch_content("text", data_type) - for chunk in chunks: + else: + async for chunk in self._handle_non_streaming_response(litellm_request): yield chunk + except ContextWindowExceededError as e: + logger.warning("litellm client raised context window overflow") + raise ContextWindowOverflowException(e) from e - yield self.format_chunk( - {"chunk_type": "content_delta", "data_type": data_type, "data": choice.delta.content} - ) - - for tool_call in choice.delta.tool_calls or []: - tool_calls.setdefault(tool_call.index, []).append(tool_call) - - if choice.finish_reason: - if data_type: - yield self.format_chunk({"chunk_type": "content_stop", "data_type": data_type}) - break - - for tool_deltas in tool_calls.values(): - yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]}) - - for tool_delta in tool_deltas: - yield self.format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta}) - - yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"}) - - yield self.format_chunk({"chunk_type": "message_stop", "data": choice.finish_reason}) - - # Skip remaining events as we don't have use for anything except the final usage payload - async for event in response: - _ = event - - if event.usage: - yield self.format_chunk({"chunk_type": "metadata", "data": event.usage}) - - logger.debug("finished streaming response from model") + logger.debug("finished processing response from model") @override async def structured_output( - self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any - ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + self, output_model: type[T], prompt: Messages, system_prompt: str | None = None, **kwargs: Any + ) -> AsyncGenerator[dict[str, T | Any], None]: """Get structured output from the model. Some models do not support native structured output via response_format. @@ -368,7 +398,7 @@ async def structured_output( yield {"output": result} async def _structured_output_using_response_schema( - self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None + self, output_model: type[T], prompt: Messages, system_prompt: str | None = None ) -> T: """Get structured output using native response_format support.""" response = await litellm.acompletion( @@ -396,7 +426,7 @@ async def _structured_output_using_response_schema( raise ValueError(f"Failed to parse or load content into model: {e}") from e async def _structured_output_using_tool( - self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None + self, output_model: type[T], prompt: Messages, system_prompt: str | None = None ) -> T: """Get structured output using tool calling fallback.""" tool_spec = convert_pydantic_to_tool_spec(output_model) @@ -422,6 +452,181 @@ async def _structured_output_using_tool( except (json.JSONDecodeError, TypeError, ValueError) as e: raise ValueError(f"Failed to parse or load content into model: {e}") from e + async def _process_choice_content( + self, choice: Any, data_type: str | None, tool_calls: dict[int, list[Any]], is_streaming: bool = True + ) -> AsyncGenerator[tuple[str | None, StreamEvent], None]: + """Process content from a choice object (streaming or non-streaming). + + Args: + choice: The choice object from the response. + data_type: Current data type being processed. + tool_calls: Dictionary to collect tool calls. + is_streaming: Whether this is from a streaming response. + + Yields: + Tuples of (updated_data_type, stream_event). + """ + # Get the content source - this is the only difference between streaming/non-streaming + # We use duck typing here: both choice.delta and choice.message have the same interface + # (reasoning_content, content, tool_calls attributes) but different object structures + content_source = choice.delta if is_streaming else choice.message + + # Process reasoning content + if hasattr(content_source, "reasoning_content") and content_source.reasoning_content: + chunks, data_type = self._stream_switch_content("reasoning_content", data_type) + for chunk in chunks: + yield data_type, chunk + chunk = self.format_chunk( + { + "chunk_type": "content_delta", + "data_type": "reasoning_content", + "data": content_source.reasoning_content, + } + ) + yield data_type, chunk + + # Process text content + if hasattr(content_source, "content") and content_source.content: + chunks, data_type = self._stream_switch_content("text", data_type) + for chunk in chunks: + yield data_type, chunk + chunk = self.format_chunk( + { + "chunk_type": "content_delta", + "data_type": "text", + "data": content_source.content, + } + ) + yield data_type, chunk + + # Process tool calls + if hasattr(content_source, "tool_calls") and content_source.tool_calls: + if is_streaming: + # Streaming: tool calls have index attribute for out-of-order delivery + for tool_call in content_source.tool_calls: + tool_calls.setdefault(tool_call.index, []).append(tool_call) + else: + # Non-streaming: tool calls arrive in order, use enumerated index + for i, tool_call in enumerate(content_source.tool_calls): + tool_calls.setdefault(i, []).append(tool_call) + + async def _process_tool_calls(self, tool_calls: dict[int, list[Any]]) -> AsyncGenerator[StreamEvent, None]: + """Process and yield tool call events. + + Args: + tool_calls: Dictionary of tool calls indexed by their position. + + Yields: + Formatted tool call chunks. + """ + for tool_deltas in tool_calls.values(): + yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]}) + + for tool_delta in tool_deltas: + yield self.format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta}) + + yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"}) + + async def _handle_non_streaming_response( + self, litellm_request: dict[str, Any] + ) -> AsyncGenerator[StreamEvent, None]: + """Handle non-streaming response from LiteLLM. + + Args: + litellm_request: The formatted request for LiteLLM. + + Yields: + Formatted message chunks from the model. + """ + response = await litellm.acompletion(**self.client_args, **litellm_request) + + logger.debug("got non-streaming response from model") + yield self.format_chunk({"chunk_type": "message_start"}) + + tool_calls: dict[int, list[Any]] = {} + data_type: str | None = None + finish_reason: str | None = None + + if hasattr(response, "choices") and response.choices and len(response.choices) > 0: + choice = response.choices[0] + + if hasattr(choice, "message") and choice.message: + # Process content using shared logic + async for updated_data_type, chunk in self._process_choice_content( + choice, data_type, tool_calls, is_streaming=False + ): + data_type = updated_data_type + yield chunk + + if hasattr(choice, "finish_reason"): + finish_reason = choice.finish_reason + + # Stop the current content block if we have one + if data_type: + yield self.format_chunk({"chunk_type": "content_stop", "data_type": data_type}) + + # Process tool calls + async for chunk in self._process_tool_calls(tool_calls): + yield chunk + + yield self.format_chunk({"chunk_type": "message_stop", "data": finish_reason}) + + # Add usage information if available + if hasattr(response, "usage"): + yield self.format_chunk({"chunk_type": "metadata", "data": response.usage}) + + async def _handle_streaming_response(self, litellm_request: dict[str, Any]) -> AsyncGenerator[StreamEvent, None]: + """Handle streaming response from LiteLLM. + + Args: + litellm_request: The formatted request for LiteLLM. + + Yields: + Formatted message chunks from the model. + """ + # For streaming, use the streaming API + response = await litellm.acompletion(**self.client_args, **litellm_request) + + logger.debug("got response from model") + yield self.format_chunk({"chunk_type": "message_start"}) + + tool_calls: dict[int, list[Any]] = {} + data_type: str | None = None + finish_reason: str | None = None + + async for event in response: + # Defensive: skip events with empty or missing choices + if not getattr(event, "choices", None): + continue + choice = event.choices[0] + + # Process content using shared logic + async for updated_data_type, chunk in self._process_choice_content( + choice, data_type, tool_calls, is_streaming=True + ): + data_type = updated_data_type + yield chunk + + if choice.finish_reason: + finish_reason = choice.finish_reason + if data_type: + yield self.format_chunk({"chunk_type": "content_stop", "data_type": data_type}) + break + + # Process tool calls + async for chunk in self._process_tool_calls(tool_calls): + yield chunk + + yield self.format_chunk({"chunk_type": "message_stop", "data": finish_reason}) + + # Skip remaining events as we don't have use for anything except the final usage payload + async for event in response: + _ = event + if usage := getattr(event, "usage", None): + yield self.format_chunk({"chunk_type": "metadata", "data": usage}) + + logger.debug("finished streaming response from model") + def _apply_proxy_prefix(self) -> None: """Apply litellm_proxy/ prefix to model_id when use_litellm_proxy is True. diff --git a/src/strands/models/llamaapi.py b/src/strands/models/llamaapi.py index 013cd2c7d..71db9b78d 100644 --- a/src/strands/models/llamaapi.py +++ b/src/strands/models/llamaapi.py @@ -8,19 +8,20 @@ import json import logging import mimetypes -from typing import Any, AsyncGenerator, Optional, Type, TypeVar, Union, cast +from collections.abc import AsyncGenerator +from typing import Any, TypeVar, cast import llama_api_client from llama_api_client import LlamaAPIClient from pydantic import BaseModel -from typing_extensions import TypedDict, Unpack, override +from typing_extensions import Unpack, override from ..types.content import ContentBlock, Messages from ..types.exceptions import ModelThrottledException from ..types.streaming import StreamEvent, Usage from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse -from ._validation import validate_config_keys, warn_on_tool_choice_not_supported -from .model import Model +from ._validation import _has_location_source, validate_config_keys, warn_on_tool_choice_not_supported +from .model import BaseModelConfig, Model logger = logging.getLogger(__name__) @@ -30,7 +31,7 @@ class LlamaAPIModel(Model): """Llama API model provider implementation.""" - class LlamaConfig(TypedDict, total=False): + class LlamaConfig(BaseModelConfig, total=False): """Configuration options for Llama API models. Attributes: @@ -43,16 +44,16 @@ class LlamaConfig(TypedDict, total=False): """ model_id: str - repetition_penalty: Optional[float] - temperature: Optional[float] - top_p: Optional[float] - max_completion_tokens: Optional[int] - top_k: Optional[int] + repetition_penalty: float | None + temperature: float | None + top_p: float | None + max_completion_tokens: int | None + top_k: int | None def __init__( self, *, - client_args: Optional[dict[str, Any]] = None, + client_args: dict[str, Any] | None = None, **model_config: Unpack[LlamaConfig], ) -> None: """Initialize provider instance. @@ -159,7 +160,7 @@ def _format_request_tool_message(self, tool_result: ToolResult) -> dict[str, Any "content": [self._format_request_message_content(content) for content in contents], } - def _format_request_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: + def _format_request_messages(self, messages: Messages, system_prompt: str | None = None) -> list[dict[str, Any]]: """Format a LlamaAPI compatible messages array. Args: @@ -175,12 +176,18 @@ def _format_request_messages(self, messages: Messages, system_prompt: Optional[s for message in messages: contents = message["content"] + # Filter out location sources and unsupported block types + filtered_contents = [] + for content in contents: + if any(block_type in content for block_type in ["toolResult", "toolUse"]): + continue + if _has_location_source(content): + logger.warning("Location sources are not supported by LlamaAPI | skipping content block") + continue + filtered_contents.append(content) + formatted_contents: list[dict[str, Any]] | dict[str, Any] | str = "" - formatted_contents = [ - self._format_request_message_content(content) - for content in contents - if not any(block_type in content for block_type in ["toolResult", "toolUse"]) - ] + formatted_contents = [self._format_request_message_content(content) for content in filtered_contents] formatted_tool_calls = [ self._format_request_message_tool_call(content["toolUse"]) for content in contents @@ -206,7 +213,7 @@ def _format_request_messages(self, messages: Messages, system_prompt: Optional[s return [message for message in formatted_messages if message["content"] or "tool_calls" in message] def format_request( - self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None + self, messages: Messages, tool_specs: list[ToolSpec] | None = None, system_prompt: str | None = None ) -> dict[str, Any]: """Format a Llama API chat streaming request. @@ -328,8 +335,8 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent: async def stream( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, *, tool_choice: ToolChoice | None = None, **kwargs: Any, @@ -416,8 +423,8 @@ async def stream( @override def structured_output( - self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any - ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + self, output_model: type[T], prompt: Messages, system_prompt: str | None = None, **kwargs: Any + ) -> AsyncGenerator[dict[str, T | Any], None]: """Get structured output from the model. Args: diff --git a/src/strands/models/llamacpp.py b/src/strands/models/llamacpp.py index 22a3a3873..5dd25729d 100644 --- a/src/strands/models/llamacpp.py +++ b/src/strands/models/llamacpp.py @@ -14,15 +14,10 @@ import logging import mimetypes import time +from collections.abc import AsyncGenerator from typing import ( Any, - AsyncGenerator, - Dict, - Optional, - Type, - TypedDict, TypeVar, - Union, cast, ) @@ -30,12 +25,12 @@ from pydantic import BaseModel from typing_extensions import Unpack, override -from ..types.content import ContentBlock, Messages +from ..types.content import ContentBlock, Messages, SystemContentBlock from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.streaming import StreamEvent from ..types.tools import ToolChoice, ToolSpec -from ._validation import validate_config_keys, warn_on_tool_choice_not_supported -from .model import Model +from ._validation import _has_location_source, validate_config_keys, warn_on_tool_choice_not_supported +from .model import BaseModelConfig, Model logger = logging.getLogger(__name__) @@ -90,7 +85,7 @@ class LlamaCppModel(Model): >>> response = agent(image_content) """ - class LlamaCppConfig(TypedDict, total=False): + class LlamaCppConfig(BaseModelConfig, total=False): """Configuration options for llama.cpp models. Attributes: @@ -130,15 +125,19 @@ class LlamaCppConfig(TypedDict, total=False): - cache_prompt: Cache the prompt for faster generation - slot_id: Slot ID for parallel inference - samplers: Custom sampler order + use_native_token_count: Whether to use the native llama.cpp /tokenize endpoint. + When True, count_tokens() calls the server's tokenize endpoint for accurate counts. + When False (default), skips the API call and uses the local estimator. """ model_id: str - params: Optional[dict[str, Any]] + params: dict[str, Any] | None + use_native_token_count: bool def __init__( self, base_url: str = "http://localhost:8080", - timeout: Optional[Union[float, tuple[float, float]]] = None, + timeout: float | tuple[float, float] | None = None, **model_config: Unpack[LlamaCppConfig], ) -> None: """Initialize llama.cpp provider instance. @@ -196,7 +195,7 @@ def get_config(self) -> LlamaCppConfig: """ return self.config # type: ignore[return-value] - def _format_message_content(self, content: Union[ContentBlock, Dict[str, Any]]) -> dict[str, Any]: + def _format_message_content(self, content: ContentBlock | dict[str, Any]) -> dict[str, Any]: """Format a content block for llama.cpp. Args: @@ -233,7 +232,7 @@ def _format_message_content(self, content: Union[ContentBlock, Dict[str, Any]]) # Handle audio content (not in standard ContentBlock but supported by llama.cpp) if "audio" in content: - audio_content = cast(Dict[str, Any], content) + audio_content = cast(dict[str, Any], content) audio_data = base64.b64encode(audio_content["audio"]["source"]["bytes"]).decode("utf-8") audio_format = audio_content["audio"].get("format", "wav") return { @@ -284,7 +283,7 @@ def _format_tool_message(self, tool_result: dict[str, Any]) -> dict[str, Any]: "content": [self._format_message_content(content) for content in contents], } - def _format_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: + def _format_messages(self, messages: Messages, system_prompt: str | None = None) -> list[dict[str, Any]]: """Format messages for llama.cpp. Args: @@ -303,11 +302,17 @@ def _format_messages(self, messages: Messages, system_prompt: Optional[str] = No for message in messages: contents = message["content"] - formatted_contents = [ - self._format_message_content(content) - for content in contents - if not any(block_type in content for block_type in ["toolResult", "toolUse"]) - ] + # Filter out location sources and unsupported block types + filtered_contents = [] + for content in contents: + if any(block_type in content for block_type in ["toolResult", "toolUse"]): + continue + if _has_location_source(content): + logger.warning("Location sources are not supported by llama.cpp | skipping content block") + continue + filtered_contents.append(content) + + formatted_contents = [self._format_message_content(content) for content in filtered_contents] formatted_tool_calls = [ self._format_tool_call( { @@ -343,8 +348,8 @@ def _format_messages(self, messages: Messages, system_prompt: Optional[str] = No def _format_request( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, ) -> dict[str, Any]: """Format a request for the llama.cpp server. @@ -428,7 +433,7 @@ def _format_request( request[param] = value # Collect llama.cpp-specific parameters for extra_body - extra_body: Dict[str, Any] = {} + extra_body: dict[str, Any] = {} for param, value in params.items(): if param in llamacpp_specific_params: extra_body[param] = value @@ -507,12 +512,69 @@ def _format_chunk(self, event: dict[str, Any]) -> StreamEvent: case _: raise RuntimeError(f"chunk_type=<{event['chunk_type']}> | unknown type") + @override + async def count_tokens( + self, + messages: Messages, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, + system_prompt_content: list[SystemContentBlock] | None = None, + ) -> int: + """Count tokens using llama.cpp's native /tokenize endpoint. + + Sends the formatted prompt to the llama.cpp server's tokenization endpoint + to get an accurate token count. Requires a llama.cpp server version that supports + chat-template-aware tokenization via the ``messages`` field in /tokenize requests. + Older server versions that only accept ``{"content": "string"}`` are not supported + and will fall back to estimation. + + Args: + messages: List of message objects to count tokens for. + tool_specs: List of tool specifications to include in the count. + system_prompt: Plain string system prompt. Ignored if system_prompt_content is provided. + system_prompt_content: Structured system prompt content blocks. + + Returns: + Total input token count. + """ + if self.config.get("use_native_token_count") is not True: + return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content) + + try: + # system_prompt_content is not used; this provider only accepts system_prompt as a plain string, + # matching the behavior of stream(). The caller always provides system_prompt alongside + # system_prompt_content, so the plain string is always available. + request = self._format_request(messages, tool_specs, system_prompt) + payload = { + "messages": request["messages"], + **({"tools": request["tools"]} if request.get("tools") else {}), + } + + response = await self.client.post("/tokenize", json=payload) + response.raise_for_status() + data = response.json() + total_tokens: int = len(data.get("tokens", [])) + + logger.debug( + "model_id=<%s>, total_tokens=<%d> | native token count", + self.config.get("model_id", "default"), + total_tokens, + ) + return total_tokens + except Exception as e: + logger.debug( + "model_id=<%s>, error=<%s> | native token counting failed, falling back to estimation", + self.config.get("model_id", "default"), + e, + ) + return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content) + @override async def stream( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, *, tool_choice: ToolChoice | None = None, **kwargs: Any, @@ -552,7 +614,7 @@ async def stream( yield self._format_chunk({"chunk_type": "message_start"}) yield self._format_chunk({"chunk_type": "content_start", "data_type": "text"}) - tool_calls: Dict[int, list] = {} + tool_calls: dict[int, list] = {} usage_data = None finish_reason = None @@ -706,11 +768,11 @@ async def stream( @override async def structured_output( self, - output_model: Type[T], + output_model: type[T], prompt: Messages, - system_prompt: Optional[str] = None, + system_prompt: str | None = None, **kwargs: Any, - ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + ) -> AsyncGenerator[dict[str, T | Any], None]: """Get structured output using llama.cpp's native JSON schema support. This implementation uses llama.cpp's json_schema parameter to constrain @@ -753,7 +815,7 @@ async def structured_output( if "text" in delta: response_text += delta["text"] # Forward events to caller - yield cast(Dict[str, Union[T, Any]], event) + yield cast(dict[str, T | Any], event) # Parse and validate the JSON response data = json.loads(response_text.strip()) diff --git a/src/strands/models/mistral.py b/src/strands/models/mistral.py index b6459d63f..2ae00cef9 100644 --- a/src/strands/models/mistral.py +++ b/src/strands/models/mistral.py @@ -6,18 +6,20 @@ import base64 import json import logging -from typing import Any, AsyncGenerator, Iterable, Optional, Type, TypeVar, Union +from collections.abc import AsyncGenerator, Iterable +from typing import Any, TypeVar import mistralai from pydantic import BaseModel -from typing_extensions import TypedDict, Unpack, override +from typing_extensions import Unpack, override from ..types.content import ContentBlock, Messages from ..types.exceptions import ModelThrottledException from ..types.streaming import StopReason, StreamEvent from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse -from ._validation import validate_config_keys, warn_on_tool_choice_not_supported -from .model import Model +from ._defaults import resolve_config_metadata +from ._validation import _has_location_source, validate_config_keys, warn_on_tool_choice_not_supported +from .model import BaseModelConfig, Model logger = logging.getLogger(__name__) @@ -35,7 +37,7 @@ class MistralModel(Model): - System prompts """ - class MistralConfig(TypedDict, total=False): + class MistralConfig(BaseModelConfig, total=False): """Configuration parameters for Mistral models. Attributes: @@ -47,16 +49,16 @@ class MistralConfig(TypedDict, total=False): """ model_id: str - max_tokens: Optional[int] - temperature: Optional[float] - top_p: Optional[float] - stream: Optional[bool] + max_tokens: int | None + temperature: float | None + top_p: float | None + stream: bool | None def __init__( self, - api_key: Optional[str] = None, + api_key: str | None = None, *, - client_args: Optional[dict[str, Any]] = None, + client_args: dict[str, Any] | None = None, **model_config: Unpack[MistralConfig], ) -> None: """Initialize provider instance. @@ -113,9 +115,9 @@ def get_config(self) -> MistralConfig: Returns: The Mistral model configuration. """ - return self.config + return resolve_config_metadata(self.config, self.config["model_id"]) - def _format_request_message_content(self, content: ContentBlock) -> Union[str, dict[str, Any]]: + def _format_request_message_content(self, content: ContentBlock) -> str | dict[str, Any]: """Format a Mistral content block. Args: @@ -187,7 +189,7 @@ def _format_request_tool_message(self, tool_result: ToolResult) -> dict[str, Any "tool_call_id": tool_result["toolUseId"], } - def _format_request_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: + def _format_request_messages(self, messages: Messages, system_prompt: str | None = None) -> list[dict[str, Any]]: """Format a Mistral compatible messages array. Args: @@ -211,6 +213,11 @@ def _format_request_messages(self, messages: Messages, system_prompt: Optional[s tool_messages: list[dict[str, Any]] = [] for content in contents: + # Check for location sources and skip with warning + if _has_location_source(content): + logger.warning("Location sources are not supported by Mistral | skipping content block") + continue + if "text" in content: formatted_content = self._format_request_message_content(content) if isinstance(formatted_content, str): @@ -236,7 +243,7 @@ def _format_request_messages(self, messages: Messages, system_prompt: Optional[s return formatted_messages def format_request( - self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None + self, messages: Messages, tool_specs: list[ToolSpec] | None = None, system_prompt: str | None = None ) -> dict[str, Any]: """Format a Mistral chat streaming request. @@ -395,8 +402,8 @@ def _handle_non_streaming_response(self, response: Any) -> Iterable[dict[str, An async def stream( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, *, tool_choice: ToolChoice | None = None, **kwargs: Any, @@ -490,8 +497,8 @@ async def stream( yield self.format_chunk({"chunk_type": "message_stop", "data": choice.finish_reason}) - if hasattr(chunk, "usage"): - yield self.format_chunk({"chunk_type": "metadata", "data": chunk.usage}) + if hasattr(chunk, "data") and hasattr(chunk.data, "usage") and chunk.data.usage: + yield self.format_chunk({"chunk_type": "metadata", "data": chunk.data.usage}) except Exception as e: if "rate" in str(e).lower() or "429" in str(e): @@ -502,8 +509,8 @@ async def stream( @override async def structured_output( - self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any - ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + self, output_model: type[T], prompt: Messages, system_prompt: str | None = None, **kwargs: Any + ) -> AsyncGenerator[dict[str, T | Any], None]: """Get structured output from the model. Args: diff --git a/src/strands/models/model.py b/src/strands/models/model.py index b2fa73802..77ef1df40 100644 --- a/src/strands/models/model.py +++ b/src/strands/models/model.py @@ -1,20 +1,160 @@ """Abstract base class for Agent model providers.""" import abc +import json import logging -from typing import Any, AsyncGenerator, AsyncIterable, Optional, Type, TypeVar, Union +import math +from collections.abc import AsyncGenerator, AsyncIterable, Callable +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Literal, TypedDict, TypeVar from pydantic import BaseModel -from ..types.content import Messages, SystemContentBlock +from ..hooks.events import AfterInvocationEvent +from ..plugins.plugin import Plugin +from ..types.content import ContentBlock, Messages, SystemContentBlock from ..types.streaming import StreamEvent from ..types.tools import ToolChoice, ToolSpec +if TYPE_CHECKING: + from ..agent.agent import Agent + logger = logging.getLogger(__name__) T = TypeVar("T", bound=BaseModel) +def _heuristic_estimate_text(text: str) -> int: + """Estimate token count from text using characters / 4 heuristic.""" + return math.ceil(len(text) / 4) + + +def _heuristic_estimate_json(obj: Any) -> int: + """Estimate token count from a JSON-serializable object using characters / 2 heuristic.""" + try: + return math.ceil(len(json.dumps(obj)) / 2) + except (TypeError, ValueError): + return 0 + + +def _count_content_block_tokens( + block: ContentBlock, count_text: Callable[[str], int], count_json: Callable[[Any], int] +) -> int: + """Count tokens for a single content block. + + Args: + block: The content block to count tokens for. + count_text: Function that returns token count for a text string. + count_json: Function that returns token count for a JSON-serializable object. + """ + total = 0 + + if "text" in block: + total += count_text(block["text"]) + + if "toolUse" in block: + tool_use = block["toolUse"] + total += count_text(tool_use.get("name", "")) + total += count_json(tool_use.get("input", {})) + + if "toolResult" in block: + tool_result = block["toolResult"] + for item in tool_result.get("content", []): + if "text" in item: + total += count_text(item["text"]) + + if "reasoningContent" in block: + reasoning = block["reasoningContent"] + if "reasoningText" in reasoning: + reasoning_text = reasoning["reasoningText"] + if "text" in reasoning_text: + total += count_text(reasoning_text["text"]) + + if "guardContent" in block: + guard = block["guardContent"] + if "text" in guard and "text" in guard["text"]: + total += count_text(guard["text"]["text"]) + + if "citationsContent" in block: + citations = block["citationsContent"] + if "content" in citations: + for citation_item in citations["content"]: + if "text" in citation_item: + total += count_text(citation_item["text"]) + + return total + + +def _estimate_tokens_with_heuristic( + messages: Messages, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, + system_prompt_content: list[SystemContentBlock] | None = None, +) -> int: + """Estimate tokens using character-based heuristics (text: chars/4, JSON: chars/2). + + Dependency-free fallback when tiktoken is not installed. + """ + total = 0 + + if system_prompt_content: + for block in system_prompt_content: + if "text" in block: + total += _heuristic_estimate_text(block["text"]) + elif system_prompt: + total += _heuristic_estimate_text(system_prompt) + + for message in messages: + for block in message["content"]: + total += _count_content_block_tokens(block, _heuristic_estimate_text, _heuristic_estimate_json) + + if tool_specs: + for spec in tool_specs: + total += _heuristic_estimate_json(spec) + + return total + + +class BaseModelConfig(TypedDict, total=False): + """Base configuration shared by all model providers. + + Attributes: + context_window_limit: Maximum context window size in tokens for the model. + This value represents the total token capacity shared between input and output. + """ + + context_window_limit: int | None + + +@dataclass +class CacheConfig: + """Configuration for prompt caching. + + Attributes: + strategy: Caching strategy to use. + - "auto": Automatically detect model support and inject cachePoint to maximize cache coverage + - "anthropic": Inject cachePoint in Anthropic-compatible format without model support check + ttl: Optional TTL duration for cache entries (e.g. "5m", "1h"). + When specified, auto-injected cache points will include this TTL value. + """ + + strategy: Literal["auto", "anthropic"] = "auto" + ttl: str | None = None + + +@dataclass +class CacheToolsConfig: + """Configuration for the toolConfig cache point. + + Attributes: + type: Cache point type (e.g. "default"). + ttl: Optional TTL duration for the cache entry (e.g. "5m", "1h"). + """ + + type: str = "default" + ttl: str | None = None + + class Model(abc.ABC): """Abstract base class for Agent model providers. @@ -22,6 +162,25 @@ class Model(abc.ABC): standardized way to configure and process requests for different AI model providers. """ + @property + def stateful(self) -> bool: + """Whether the model manages conversation state server-side. + + Returns: + False by default. Model providers that support server-side state should override this. + """ + return False + + @property + def context_window_limit(self) -> int | None: + """Maximum context window size in tokens, or None if not configured.""" + config = self.get_config() + return ( + config.get("context_window_limit") + if isinstance(config, dict) + else getattr(config, "context_window_limit", None) + ) + @abc.abstractmethod # pragma: no cover def update_config(self, **model_config: Any) -> None: @@ -45,8 +204,8 @@ def get_config(self) -> Any: @abc.abstractmethod # pragma: no cover def structured_output( - self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any - ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + self, output_model: type[T], prompt: Messages, system_prompt: str | None = None, **kwargs: Any + ) -> AsyncGenerator[dict[str, T | Any], None]: """Get structured output from the model. Args: @@ -68,11 +227,12 @@ def structured_output( def stream( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, *, tool_choice: ToolChoice | None = None, system_prompt_content: list[SystemContentBlock] | None = None, + invocation_state: dict[str, Any] | None = None, **kwargs: Any, ) -> AsyncIterable[StreamEvent]: """Stream conversation with the model. @@ -89,6 +249,7 @@ def stream( system_prompt: System prompt to provide context to the model. tool_choice: Selection strategy for tool invocation. system_prompt_content: System prompt content blocks for advanced features like caching. + invocation_state: Caller-provided state/context that was passed to the agent when it was invoked. **kwargs: Additional keyword arguments for future extensibility. Yields: @@ -98,3 +259,62 @@ def stream( ModelThrottledException: When the model service is throttling requests from the client. """ pass + + async def count_tokens( + self, + messages: Messages, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, + system_prompt_content: list[SystemContentBlock] | None = None, + ) -> int: + """Estimate token count for the given input before sending to the model. + + Used for proactive context management (e.g., triggering compression at a threshold). + Uses tiktoken's cl100k_base encoding when available, otherwise falls back to a + heuristic (characters / 4 for text, characters / 2 for JSON). Accuracy varies by + model provider. Not intended for billing or precise quota calculations. + + Subclasses may override this method to provide model-specific token counting + using native APIs for improved accuracy. + + Args: + messages: List of message objects to estimate tokens for. + tool_specs: List of tool specifications to include in the estimate. + system_prompt: Plain string system prompt. Ignored if system_prompt_content is provided. + system_prompt_content: Structured system prompt content blocks. Takes priority over system_prompt. + + Returns: + Estimated total input tokens. + """ + return _estimate_tokens_with_heuristic(messages, tool_specs, system_prompt, system_prompt_content) + + +class _ModelPlugin(Plugin): + """Plugin that manages model-related lifecycle hooks.""" + + @property + def name(self) -> str: + """A stable string identifier for this plugin.""" + return "strands:model" + + @staticmethod + def _on_after_invocation(event: AfterInvocationEvent) -> None: + """Handle post-invocation model management tasks. + + Performs the following: + - Clears messages when the model is managing conversation state server-side. + """ + if event.agent.model.stateful: + event.agent.messages.clear() + logger.debug( + "response_id=<%s> | cleared messages for server-managed conversation", + event.agent._model_state.get("response_id"), + ) + + def init_agent(self, agent: "Agent") -> None: + """Register model lifecycle hooks with the agent. + + Args: + agent: The agent instance to register hooks with. + """ + agent.add_hook(self._on_after_invocation, AfterInvocationEvent) diff --git a/src/strands/models/ollama.py b/src/strands/models/ollama.py index 574b24200..cf7108c3a 100644 --- a/src/strands/models/ollama.py +++ b/src/strands/models/ollama.py @@ -5,17 +5,19 @@ import json import logging -from typing import Any, AsyncGenerator, Optional, Type, TypeVar, Union, cast +import uuid +from collections.abc import AsyncGenerator +from typing import Any, TypeVar, cast import ollama from pydantic import BaseModel -from typing_extensions import TypedDict, Unpack, override +from typing_extensions import Unpack, override from ..types.content import ContentBlock, Messages from ..types.streaming import StopReason, StreamEvent from ..types.tools import ToolChoice, ToolSpec -from ._validation import validate_config_keys, warn_on_tool_choice_not_supported -from .model import Model +from ._validation import _has_location_source, validate_config_keys, warn_on_tool_choice_not_supported +from .model import BaseModelConfig, Model logger = logging.getLogger(__name__) @@ -32,7 +34,7 @@ class OllamaModel(Model): - Tool/function calling """ - class OllamaConfig(TypedDict, total=False): + class OllamaConfig(BaseModelConfig, total=False): """Configuration parameters for Ollama models. Attributes: @@ -46,20 +48,20 @@ class OllamaConfig(TypedDict, total=False): top_p: Controls diversity via nucleus sampling (alternative to temperature). """ - additional_args: Optional[dict[str, Any]] - keep_alive: Optional[str] - max_tokens: Optional[int] + additional_args: dict[str, Any] | None + keep_alive: str | None + max_tokens: int | None model_id: str - options: Optional[dict[str, Any]] - stop_sequences: Optional[list[str]] - temperature: Optional[float] - top_p: Optional[float] + options: dict[str, Any] | None + stop_sequences: list[str] | None + temperature: float | None + top_p: float | None def __init__( self, - host: Optional[str], + host: str | None, *, - ollama_client_args: Optional[dict[str, Any]] = None, + ollama_client_args: dict[str, Any] | None = None, **model_config: Unpack[OllamaConfig], ) -> None: """Initialize provider instance. @@ -123,7 +125,7 @@ def _format_request_message_contents(self, role: str, content: ContentBlock) -> "tool_calls": [ { "function": { - "name": content["toolUse"]["toolUseId"], + "name": content["toolUse"]["name"], "arguments": content["toolUse"]["input"], } } @@ -147,7 +149,7 @@ def _format_request_message_contents(self, role: str, content: ContentBlock) -> raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") - def _format_request_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: + def _format_request_messages(self, messages: Messages, system_prompt: str | None = None) -> list[dict[str, Any]]: """Format an Ollama compatible messages array. Args: @@ -159,15 +161,19 @@ def _format_request_messages(self, messages: Messages, system_prompt: Optional[s """ system_message = [{"role": "system", "content": system_prompt}] if system_prompt else [] - return system_message + [ - formatted_message - for message in messages - for content in message["content"] - for formatted_message in self._format_request_message_contents(message["role"], content) - ] + formatted_messages = [] + for message in messages: + for content in message["content"]: + # Check for location sources and skip with warning + if _has_location_source(content): + logger.warning("Location sources are not supported by Ollama | skipping content block") + continue + formatted_messages.extend(self._format_request_message_contents(message["role"], content)) + + return system_message + formatted_messages def format_request( - self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None + self, messages: Messages, tool_specs: list[ToolSpec] | None = None, system_prompt: str | None = None ) -> dict[str, Any]: """Format an Ollama chat streaming request. @@ -241,7 +247,8 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent: return {"contentBlockStart": {"start": {}}} tool_name = event["data"].function.name - return {"contentBlockStart": {"start": {"toolUse": {"name": tool_name, "toolUseId": tool_name}}}} + tool_use_id = f"tooluse_{uuid.uuid4().hex[:24]}" + return {"contentBlockStart": {"start": {"toolUse": {"name": tool_name, "toolUseId": tool_use_id}}}} case "content_delta": if event["data_type"] == "text": @@ -268,12 +275,12 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent: return { "metadata": { "usage": { - "inputTokens": event["data"].eval_count, - "outputTokens": event["data"].prompt_eval_count, + "inputTokens": event["data"].prompt_eval_count, + "outputTokens": event["data"].eval_count, "totalTokens": event["data"].eval_count + event["data"].prompt_eval_count, }, "metrics": { - "latencyMs": event["data"].total_duration / 1e6, + "latencyMs": int(event["data"].total_duration / 1e6), }, }, } @@ -285,8 +292,8 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent: async def stream( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, *, tool_choice: ToolChoice | None = None, **kwargs: Any, @@ -339,8 +346,8 @@ async def stream( @override async def structured_output( - self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any - ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + self, output_model: type[T], prompt: Messages, system_prompt: str | None = None, **kwargs: Any + ) -> AsyncGenerator[dict[str, T | Any], None]: """Get structured output from the model. Args: diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py index 435c82cab..94d4b0b90 100644 --- a/src/strands/models/openai.py +++ b/src/strands/models/openai.py @@ -7,7 +7,9 @@ import json import logging import mimetypes -from typing import Any, AsyncGenerator, Optional, Protocol, Type, TypedDict, TypeVar, Union, cast +from collections.abc import AsyncGenerator, AsyncIterator +from contextlib import asynccontextmanager +from typing import Any, Protocol, TypeVar, cast import openai from openai.types.chat.parsed_chat_completion import ParsedChatCompletion @@ -15,16 +17,28 @@ from typing_extensions import Unpack, override from ..types.content import ContentBlock, Messages, SystemContentBlock +from ..types.event_loop import Usage from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.streaming import StreamEvent from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse -from ._validation import validate_config_keys -from .model import Model +from ._defaults import resolve_config_metadata +from ._openai_bedrock import BedrockMantleConfig, resolve_bedrock_client_args +from ._validation import _has_location_source, validate_config_keys +from .model import BaseModelConfig, Model logger = logging.getLogger(__name__) T = TypeVar("T", bound=BaseModel) +# Alternative context overflow error messages +# These are commonly returned by OpenAI-compatible endpoints wrapping other providers +# (e.g., Databricks serving Bedrock models) +_CONTEXT_OVERFLOW_MESSAGES = [ + "Input is too long for requested model", + "input length and `max_tokens` exceed context limit", + "too many total text bytes", +] + class Client(Protocol): """Protocol defining the OpenAI-compatible interface for the underlying provider client.""" @@ -41,7 +55,7 @@ class OpenAIModel(Model): client: Client - class OpenAIConfig(TypedDict, total=False): + class OpenAIConfig(BaseModelConfig, total=False): """Configuration options for OpenAI models. Attributes: @@ -53,22 +67,73 @@ class OpenAIConfig(TypedDict, total=False): """ model_id: str - params: Optional[dict[str, Any]] + params: dict[str, Any] | None - def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[OpenAIConfig]) -> None: + def __init__( + self, + client: Client | None = None, + client_args: dict[str, Any] | None = None, + bedrock_mantle_config: BedrockMantleConfig | None = None, + **model_config: Unpack[OpenAIConfig], + ) -> None: """Initialize provider instance. Args: - client_args: Arguments for the OpenAI client. + client: Pre-configured OpenAI-compatible client to reuse across requests. + When provided, this client will be reused for all requests and will NOT be closed + by the model. The caller is responsible for managing the client lifecycle. + This is useful for: + - Injecting custom client wrappers (e.g., GuardrailsAsyncOpenAI) + - Reusing connection pools within a single event loop/worker + - Centralizing observability, retries, and networking policy + - Pointing to custom model gateways + Note: The client should not be shared across different asyncio event loops. + client_args: Arguments for the OpenAI client (legacy approach). For a complete list of supported arguments, see https://pypi.org/project/openai/. + May be combined with ``bedrock_mantle_config``; when both are set, + ``bedrock_mantle_config`` derives ``base_url`` and ``api_key`` (which must not + appear in ``client_args``). + bedrock_mantle_config: Route requests through Amazon Bedrock's Mantle + (OpenAI-compatible) endpoint. See :class:`BedrockMantleConfig` for accepted + keys. When set, a fresh bearer token is minted on every request. Cannot be + combined with a pre-built ``client``. **model_config: Configuration options for the OpenAI model. + + Raises: + ValueError: If ``client`` is combined with ``client_args`` or ``bedrock_mantle_config``. """ validate_config_keys(model_config, self.OpenAIConfig) self.config = dict(model_config) + + # client_args + bedrock_mantle_config is allowed; the config derives base_url / api_key. + client_args_provided = client_args is not None and len(client_args) > 0 + if client is not None and client_args_provided: + raise ValueError("Only one of 'client' or 'client_args' should be provided, not both.") + if bedrock_mantle_config is not None and client is not None: + raise ValueError("'bedrock_mantle_config' cannot be combined with a pre-built 'client'.") + if bedrock_mantle_config is not None and client_args: + conflicting = [k for k in ("api_key", "base_url") if k in client_args] + if conflicting: + raise ValueError( + f"client_args must not contain {conflicting} when bedrock_mantle_config is set; " + "these are derived from the Mantle config automatically." + ) + + self._custom_client = client self.client_args = client_args or {} + self._bedrock_mantle_config = bedrock_mantle_config logger.debug("config=<%s> | initializing", self.config) + def _resolve_client_args(self) -> dict[str, Any]: + """Return the kwargs to pass to ``openai.AsyncOpenAI`` for the current request. + + Delegates to :func:`resolve_bedrock_client_args` when ``bedrock_mantle_config`` is set. + """ + if self._bedrock_mantle_config is not None: + return resolve_bedrock_client_args(self._bedrock_mantle_config, self.client_args) + return self.client_args + @override def update_config(self, **model_config: Unpack[OpenAIConfig]) -> None: # type: ignore[override] """Update the OpenAI model configuration with the provided arguments. @@ -86,7 +151,9 @@ def get_config(self) -> OpenAIConfig: Returns: The OpenAI model configuration. """ - return cast(OpenAIModel.OpenAIConfig, self.config) + return cast( + OpenAIModel.OpenAIConfig, resolve_config_metadata(self.config, str(self.config.get("model_id", ""))) + ) @classmethod def format_request_message_content(cls, content: ContentBlock, **kwargs: Any) -> dict[str, Any]: @@ -170,12 +237,102 @@ def format_request_tool_message(cls, tool_result: ToolResult, **kwargs: Any) -> ], ) + # Merge adjacent text blocks while preserving the order of non-text + # (image/document) content. When all content is text, join into a + # single string for broad compatibility with OpenAI-compatible + # endpoints (e.g., Kimi K2.5, vLLM, Ollama). + # See https://github.com/strands-agents/sdk-python/issues/1696 + merged: list[dict[str, Any]] = [] + has_non_text = False + for content_block in contents: + if "text" in content_block: + # Merge with the previous entry if it is also text (adjacent) + if merged and merged[-1].get("type") == "text": + merged[-1]["text"] += "\n" + content_block["text"] + else: + merged.append({"type": "text", "text": content_block["text"]}) + elif "image" in content_block or "document" in content_block: + has_non_text = True + merged.append(cls.format_request_message_content(content_block)) + + content: str | list[dict[str, Any]] + if has_non_text: + # Keep array format when images/documents are present so that + # _split_tool_message_images can extract them into a user message. + content = merged + else: + # All text — the loop already merged adjacent blocks with "\n", + # so extract the single resulting entry. + content = merged[0]["text"] if merged else "" + return { "role": "tool", "tool_call_id": tool_result["toolUseId"], - "content": [cls.format_request_message_content(content) for content in contents], + "content": content, } + @classmethod + def _split_tool_message_images(cls, tool_message: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any] | None]: + """Split a tool message into text-only tool message and optional user message with images. + + OpenAI API restricts images to user role messages only. This method extracts any image + content from a tool message and returns it separately as a user message. + + Args: + tool_message: A formatted tool message that may contain images. + + Returns: + A tuple of (tool_message_without_images, user_message_with_images_or_None). + """ + if tool_message.get("role") != "tool": + return tool_message, None + + content = tool_message.get("content", []) + if not isinstance(content, list): + return tool_message, None + + # Separate image and non-image content + text_content = [] + image_content = [] + + for item in content: + if isinstance(item, dict) and item.get("type") == "image_url": + image_content.append(item) + else: + text_content.append(item) + + # If no images found, return original message + if not image_content: + return tool_message, None + + # Let the user know that we are modifying the messages for OpenAI compatibility + logger.warning( + "tool_call_id=<%s> | Moving image from tool message to a new user message for OpenAI compatibility", + tool_message["tool_call_id"], + ) + + # Append a message to the text content to inform the model about the upcoming image + text_content.append( + { + "type": "text", + "text": ( + "Tool successfully returned an image. The image is being provided in the following user message." + ), + } + ) + + # Create the clean tool message with the updated text content + tool_message_clean = { + "role": "tool", + "tool_call_id": tool_message["tool_call_id"], + "content": text_content, + } + + # Create user message with only images + user_message_with_images = {"role": "user", "content": image_content} + + return tool_message_clean, user_message_with_images + @classmethod def _format_request_tool_choice(cls, tool_choice: ToolChoice | None) -> dict[str, Any]: """Format a tool choice for OpenAI compatibility. @@ -203,9 +360,9 @@ def _format_request_tool_choice(cls, tool_choice: ToolChoice | None) -> dict[str @classmethod def _format_system_messages( cls, - system_prompt: Optional[str] = None, + system_prompt: str | None = None, *, - system_prompt_content: Optional[list[SystemContentBlock]] = None, + system_prompt_content: list[SystemContentBlock] | None = None, **kwargs: Any, ) -> list[dict[str, Any]]: """Format system messages for OpenAI-compatible providers. @@ -251,11 +408,17 @@ def _format_regular_messages(cls, messages: Messages, **kwargs: Any) -> list[dic "reasoningContent is not supported in multi-turn conversations with the Chat Completions API." ) - formatted_contents = [ - cls.format_request_message_content(content) - for content in contents - if not any(block_type in content for block_type in ["toolResult", "toolUse", "reasoningContent"]) - ] + # Filter out content blocks that shouldn't be formatted + filtered_contents = [] + for content in contents: + if any(block_type in content for block_type in ["toolResult", "toolUse", "reasoningContent"]): + continue + if _has_location_source(content): + logger.warning("Location sources are not supported by OpenAI | skipping content block") + continue + filtered_contents.append(content) + + formatted_contents = [cls.format_request_message_content(content) for content in filtered_contents] formatted_tool_calls = [ cls.format_request_message_tool_call(content["toolUse"]) for content in contents if "toolUse" in content ] @@ -267,11 +430,21 @@ def _format_regular_messages(cls, messages: Messages, **kwargs: Any) -> list[dic formatted_message = { "role": message["role"], - "content": formatted_contents, + **({"content": formatted_contents} if formatted_contents else {}), **({"tool_calls": formatted_tool_calls} if formatted_tool_calls else {}), } formatted_messages.append(formatted_message) - formatted_messages.extend(formatted_tool_messages) + + # Process tool messages to extract images into separate user messages + # OpenAI API requires images to be in user role messages only + # All tool messages must be grouped together before any user messages with images + user_messages_with_images = [] + for tool_msg in formatted_tool_messages: + tool_msg_clean, user_msg_with_images = cls._split_tool_message_images(tool_msg) + formatted_messages.append(tool_msg_clean) + if user_msg_with_images: + user_messages_with_images.append(user_msg_with_images) + formatted_messages.extend(user_messages_with_images) return formatted_messages @@ -279,9 +452,9 @@ def _format_regular_messages(cls, messages: Messages, **kwargs: Any) -> list[dic def format_request_messages( cls, messages: Messages, - system_prompt: Optional[str] = None, + system_prompt: str | None = None, *, - system_prompt_content: Optional[list[SystemContentBlock]] = None, + system_prompt_content: list[SystemContentBlock] | None = None, **kwargs: Any, ) -> list[dict[str, Any]]: """Format an OpenAI compatible messages array. @@ -298,7 +471,7 @@ def format_request_messages( formatted_messages = cls._format_system_messages(system_prompt, system_prompt_content=system_prompt_content) formatted_messages.extend(cls._format_regular_messages(messages)) - return [message for message in formatted_messages if message["content"] or "tool_calls" in message] + return [message for message in formatted_messages if "content" in message or "tool_calls" in message] def format_request( self, @@ -406,13 +579,19 @@ def format_chunk(self, event: dict[str, Any], **kwargs: Any) -> StreamEvent: return {"messageStop": {"stopReason": "end_turn"}} case "metadata": + usage_data: Usage = { + "inputTokens": event["data"].prompt_tokens, + "outputTokens": event["data"].completion_tokens, + "totalTokens": event["data"].total_tokens, + } + + if tokens_details := getattr(event["data"], "prompt_tokens_details", None): + if cached := getattr(tokens_details, "cached_tokens", None): + usage_data["cacheReadInputTokens"] = cached + return { "metadata": { - "usage": { - "inputTokens": event["data"].prompt_tokens, - "outputTokens": event["data"].completion_tokens, - "totalTokens": event["data"].total_tokens, - }, + "usage": usage_data, "metrics": { "latencyMs": 0, # TODO }, @@ -422,12 +601,39 @@ def format_chunk(self, event: dict[str, Any], **kwargs: Any) -> StreamEvent: case _: raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type") + @asynccontextmanager + async def _get_client(self) -> AsyncIterator[Any]: + """Get an OpenAI client for making requests. + + This context manager handles client lifecycle management: + - If an injected client was provided during initialization, it yields that client + without closing it (caller manages lifecycle). + - Otherwise, creates a new AsyncOpenAI client from client_args and automatically + closes it when the context exits. + + Note: We create a new client per request to avoid connection sharing in the underlying + httpx client, as the asyncio event loop does not allow connections to be shared. + For more details, see https://github.com/encode/httpx/discussions/2959. + + Yields: + Client: An OpenAI-compatible client instance. + """ + if self._custom_client is not None: + # Use the injected client (caller manages lifecycle) + yield self._custom_client + else: + # We initialize an OpenAI context on every request so as to avoid connection sharing in the underlying + # httpx client. The asyncio event loop does not allow connections to be shared. For more details, please + # refer to https://github.com/encode/httpx/discussions/2959. + async with openai.AsyncOpenAI(**self._resolve_client_args()) as client: + yield client + @override async def stream( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, *, tool_choice: ToolChoice | None = None, **kwargs: Any, @@ -457,7 +663,7 @@ async def stream( # We initialize an OpenAI context on every request so as to avoid connection sharing in the underlying httpx # client. The asyncio event loop does not allow connections to be shared. For more details, please refer to # https://github.com/encode/httpx/discussions/2959. - async with openai.AsyncOpenAI(**self.client_args) as client: + async with self._get_client() as client: try: response = await client.chat.completions.create(**request) except openai.BadRequestError as e: @@ -472,6 +678,14 @@ async def stream( # Rate limits (including TPM) require waiting/retrying, not context reduction logger.warning("OpenAI threw rate limit error") raise ModelThrottledException(str(e)) from e + except openai.APIError as e: + # Check for alternative context overflow error messages + error_message = str(e) + if any(overflow_msg in error_message for overflow_msg in _CONTEXT_OVERFLOW_MESSAGES): + logger.warning("context window overflow error detected") + raise ContextWindowOverflowException(error_message) from e + # Re-raise other APIError exceptions + raise logger.debug("got response from model") yield self.format_chunk({"chunk_type": "message_start"}) @@ -556,8 +770,8 @@ def _stream_switch_content(self, data_type: str, prev_data_type: str | None) -> @override async def structured_output( - self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any - ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + self, output_model: type[T], prompt: Messages, system_prompt: str | None = None, **kwargs: Any + ) -> AsyncGenerator[dict[str, T | Any], None]: """Get structured output from the model. Args: @@ -576,7 +790,7 @@ async def structured_output( # We initialize an OpenAI context on every request so as to avoid connection sharing in the underlying httpx # client. The asyncio event loop does not allow connections to be shared. For more details, please refer to # https://github.com/encode/httpx/discussions/2959. - async with openai.AsyncOpenAI(**self.client_args) as client: + async with self._get_client() as client: try: response: ParsedChatCompletion = await client.beta.chat.completions.parse( model=self.get_config()["model_id"], @@ -595,6 +809,14 @@ async def structured_output( # Rate limits (including TPM) require waiting/retrying, not context reduction logger.warning("OpenAI threw rate limit error") raise ModelThrottledException(str(e)) from e + except openai.APIError as e: + # Check for alternative context overflow error messages + error_message = str(e) + if any(overflow_msg in error_message for overflow_msg in _CONTEXT_OVERFLOW_MESSAGES): + logger.warning("context window overflow error detected") + raise ContextWindowOverflowException(error_message) from e + # Re-raise other APIError exceptions + raise parsed: T | None = None # Find the first choice with tool_calls diff --git a/src/strands/models/openai_responses.py b/src/strands/models/openai_responses.py new file mode 100644 index 000000000..8914fb01c --- /dev/null +++ b/src/strands/models/openai_responses.py @@ -0,0 +1,854 @@ +"""OpenAI model provider using the Responses API. + +Built-in tools (e.g. web_search, file_search, code_interpreter) can be passed via the +``params`` configuration and will be merged with any agent function tools in the request. + +All built-in tools produce text responses that stream correctly. Limitations on tool-specific +metadata: + +- web_search (supported): Full support including URL citations. +- file_search (partial): File citation annotations not emitted (no matching CitationLocation variant). +- code_interpreter (partial): Executed code and stdout/stderr not surfaced. +- mcp (partial): Approval flow and ``mcp_list_tools``/``mcp_call`` events not surfaced. +- shell (partial): Local (client-executed) mode not supported. +- tool_search (not supported): Requires ``defer_loading`` on function tools, which is not supported. +- image_generation (not supported): Requires image content block delta support in the event loop. +- computer_use_preview (not supported): Requires a developer-managed screenshot/action loop. + +Docs: https://platform.openai.com/docs/api-reference/responses +""" + +import base64 +import json +import logging +import mimetypes +from collections.abc import AsyncGenerator +from importlib.metadata import version as get_package_version +from types import SimpleNamespace +from typing import Any, Protocol, TypedDict, TypeVar, cast + +from packaging.version import Version +from pydantic import BaseModel +from typing_extensions import Unpack, override + +# Validate OpenAI SDK version at import time - Responses API requires v2.0.0+ +# A major version bump is proposed in https://github.com/strands-agents/sdk-python/pull/1370 +_MIN_OPENAI_VERSION = Version("2.0.0") + +try: + _openai_version = Version(get_package_version("openai")) + if _openai_version < _MIN_OPENAI_VERSION: + raise ImportError( + f"OpenAIResponsesModel requires openai>={_MIN_OPENAI_VERSION} (found {_openai_version}). " + "Install/upgrade with: pip install -U openai. " + "For older SDKs, use OpenAIModel (Chat Completions)." + ) +except ImportError: + # Re-raise ImportError as-is (covers both our explicit raise above and missing openai package) + raise +except Exception as e: + raise ImportError( + f"OpenAIResponsesModel requires openai>={_MIN_OPENAI_VERSION}. Install with: pip install -U openai" + ) from e + +import openai # noqa: E402 - must import after version check + +from ..types.citations import WebLocationDict # noqa: E402 +from ..types.content import ContentBlock, Messages, Role, SystemContentBlock # noqa: E402 +from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException # noqa: E402 +from ..types.streaming import StreamEvent # noqa: E402 +from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse # noqa: E402 +from ._defaults import resolve_config_metadata # noqa: E402 +from ._openai_bedrock import BedrockMantleConfig, resolve_bedrock_client_args # noqa: E402 +from ._validation import validate_config_keys # noqa: E402 +from .model import BaseModelConfig, Model # noqa: E402 + +logger = logging.getLogger(__name__) + +T = TypeVar("T", bound=BaseModel) + +# Maximum file size for media content in tool results (20MB) +_MAX_MEDIA_SIZE_BYTES = 20 * 1024 * 1024 +_MAX_MEDIA_SIZE_LABEL = "20MB" +_DEFAULT_MIME_TYPE = "application/octet-stream" +_CONTEXT_WINDOW_OVERFLOW_MSG = "OpenAI Responses API threw context window overflow error" +_RATE_LIMIT_MSG = "OpenAI Responses API threw rate limit error" + + +def _encode_media_to_data_url(data: bytes, format_ext: str, media_type: str = "image") -> str: + """Encode media bytes to a base64 data URL with size validation. + + Args: + data: Raw bytes of the media content. + format_ext: File format extension (e.g., "png", "pdf"). + media_type: Type of media for error messages ("image" or "document"). + + Returns: + Base64-encoded data URL string. + + Raises: + ValueError: If the media size exceeds the maximum allowed size. + """ + if len(data) > _MAX_MEDIA_SIZE_BYTES: + raise ValueError( + f"{media_type.capitalize()} size {len(data)} bytes exceeds maximum of" + f" {_MAX_MEDIA_SIZE_BYTES} bytes ({_MAX_MEDIA_SIZE_LABEL})" + ) + mime_type = mimetypes.types_map.get(f".{format_ext}", _DEFAULT_MIME_TYPE) + encoded_data = base64.b64encode(data).decode("utf-8") + return f"data:{mime_type};base64,{encoded_data}" + + +class _ToolCallInfo(TypedDict): + """Internal type for tracking tool call information during streaming.""" + + name: str + arguments: str + call_id: str + item_id: str + + +class Client(Protocol): + """Protocol defining the OpenAI Responses API interface for the underlying provider client.""" + + @property + # pragma: no cover + def responses(self) -> Any: + """Responses interface.""" + ... + + +class OpenAIResponsesModel(Model): + """OpenAI Responses API model provider implementation.""" + + client: Client + client_args: dict[str, Any] + + class OpenAIResponsesConfig(BaseModelConfig, total=False): + """Configuration options for OpenAI Responses API models. + + Attributes: + model_id: Model ID (e.g., "gpt-4o"). + For a complete list of supported models, see https://platform.openai.com/docs/models. + params: Model parameters (e.g., max_output_tokens, temperature, etc.). + For a complete list of supported parameters, see + https://platform.openai.com/docs/api-reference/responses/create. + stateful: Whether to enable server-side conversation state management. + When True, the server stores conversation history and the client does not need to + send the full message history with each request. Defaults to False. + use_native_token_count: Whether to use the native OpenAI input_tokens.count API. + When True, count_tokens() calls the OpenAI API for accurate counts. + When False (default), skips the API call and uses the local estimator. + """ + + model_id: str + params: dict[str, Any] | None + stateful: bool + use_native_token_count: bool + + def __init__( + self, + client_args: dict[str, Any] | None = None, + bedrock_mantle_config: BedrockMantleConfig | None = None, + **model_config: Unpack[OpenAIResponsesConfig], + ) -> None: + """Initialize provider instance. + + Args: + client_args: Arguments for the OpenAI client. + For a complete list of supported arguments, see https://pypi.org/project/openai/. + May be combined with ``bedrock_mantle_config``; when both are set, the config + derives ``base_url`` and ``api_key`` (which must not appear in ``client_args``). + bedrock_mantle_config: Route requests through Amazon Bedrock's Mantle + (OpenAI-compatible) endpoint. See :class:`BedrockMantleConfig` for accepted + keys. When set, a fresh bearer token is minted on every request. + **model_config: Configuration options for the OpenAI Responses API model. + """ + validate_config_keys(model_config, self.OpenAIResponsesConfig) + self.config = dict(model_config) + + self.client_args = client_args or {} + self._bedrock_mantle_config = bedrock_mantle_config + + if bedrock_mantle_config is not None and client_args: + conflicting = [k for k in ("api_key", "base_url") if k in client_args] + if conflicting: + raise ValueError( + f"client_args must not contain {conflicting} when bedrock_mantle_config is set; " + "these are derived from the Mantle config automatically." + ) + + logger.debug("config=<%s> | initializing", self.config) + + def _resolve_client_args(self) -> dict[str, Any]: + """Return the kwargs to pass to ``openai.AsyncOpenAI`` for the current request. + + Delegates to :func:`resolve_bedrock_client_args` when ``bedrock_mantle_config`` is set. + """ + if self._bedrock_mantle_config is not None: + return resolve_bedrock_client_args(self._bedrock_mantle_config, self.client_args) + return self.client_args + + @property + @override + def stateful(self) -> bool: + """Whether server-side conversation storage is enabled. + + Derived from the ``stateful`` configuration option. + """ + return bool(self.config.get("stateful")) + + @override + def update_config(self, **model_config: Unpack[OpenAIResponsesConfig]) -> None: # type: ignore[override] + """Update the OpenAI Responses API model configuration with the provided arguments. + + Args: + **model_config: Configuration overrides. + """ + validate_config_keys(model_config, self.OpenAIResponsesConfig) + self.config.update(model_config) + + @override + def get_config(self) -> OpenAIResponsesConfig: + """Get the OpenAI Responses API model configuration. + + Returns: + The OpenAI Responses API model configuration. + """ + return cast( + OpenAIResponsesModel.OpenAIResponsesConfig, + resolve_config_metadata(self.config, str(self.config.get("model_id", ""))), + ) + + @override + async def count_tokens( + self, + messages: Messages, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, + system_prompt_content: list[SystemContentBlock] | None = None, + ) -> int: + """Count tokens using the OpenAI Responses API input_tokens.count endpoint. + + Uses the same message format as the Responses API to get accurate token counts + directly from the OpenAI service. + + Args: + messages: List of message objects to count tokens for. + tool_specs: List of tool specifications to include in the count. + system_prompt: Plain string system prompt. Ignored if system_prompt_content is provided. + system_prompt_content: Structured system prompt content blocks. + + Returns: + Total input token count. + """ + if self.config.get("use_native_token_count") is not True: + return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content) + + try: + # system_prompt_content is not used; this provider only accepts system_prompt as a plain string, + # matching the behavior of stream(). The caller always provides system_prompt alongside + # system_prompt_content, so the plain string is always available. + request = self._format_request(messages, tool_specs, system_prompt) + # Keep only fields accepted by input_tokens.count + count_tokens_fields = {"model", "input", "instructions", "tools"} + request = {k: request[k] for k in request.keys() & count_tokens_fields} + + async with openai.AsyncOpenAI(**self._resolve_client_args()) as client: + response = await client.responses.input_tokens.count(**request) + total_tokens: int = response.input_tokens + + logger.debug( + "model_id=<%s>, total_tokens=<%d> | native token count", + self.config["model_id"], + total_tokens, + ) + return total_tokens + except Exception as e: + logger.debug( + "model_id=<%s>, error=<%s> | native token counting failed, falling back to estimation", + self.config["model_id"], + e, + ) + return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content) + + @override + async def stream( + self, + messages: Messages, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, + *, + tool_choice: ToolChoice | None = None, + model_state: dict[str, Any] | None = None, + **kwargs: Any, + ) -> AsyncGenerator[StreamEvent, None]: + """Stream conversation with the OpenAI Responses API model. + + Args: + messages: List of message objects to be processed by the model. + tool_specs: List of tool specifications to make available to the model. + system_prompt: System prompt to provide context to the model. + tool_choice: Selection strategy for tool invocation. + model_state: Runtime state for model providers (e.g., server-side response ids). + **kwargs: Additional keyword arguments for future extensibility. + + Yields: + Formatted message chunks from the model. + + Raises: + ContextWindowOverflowException: If the input exceeds the model's context window. + ModelThrottledException: If the request is throttled by OpenAI (rate limits). + """ + logger.debug("formatting request for OpenAI Responses API") + request = self._format_request(messages, tool_specs, system_prompt, tool_choice, model_state) + logger.debug("formatted request=<%s>", request) + + logger.debug("invoking OpenAI Responses API model") + + async with openai.AsyncOpenAI(**self._resolve_client_args()) as client: + try: + response = await client.responses.create(**request) + + logger.debug("streaming response from OpenAI Responses API model") + + yield self._format_chunk({"chunk_type": "message_start"}) + + tool_calls: dict[str, _ToolCallInfo] = {} + final_usage = None + data_type: str | None = None + stop_reason: str | None = None + + async for event in response: + if hasattr(event, "type"): + if event.type == "response.created": + # Capture response id for server-side conversation chaining + if hasattr(event, "response"): + response_id = getattr(event.response, "id", None) + if model_state is not None and response_id: + model_state["response_id"] = response_id + + elif event.type in ( + "response.reasoning_text.delta", + "response.reasoning_summary_text.delta", + ): + # Reasoning content streaming: + # - reasoning_text: full chain-of-thought (gpt-oss models) + # - reasoning_summary_text: condensed summary (o-series models) + chunks, data_type = self._stream_switch_content("reasoning_content", data_type) + for chunk in chunks: + yield chunk + if hasattr(event, "delta") and isinstance(event.delta, str): + yield self._format_chunk( + { + "chunk_type": "content_delta", + "data_type": "reasoning_content", + "data": event.delta, + } + ) + + elif event.type == "response.output_text.delta": + # Text content streaming + chunks, data_type = self._stream_switch_content("text", data_type) + for chunk in chunks: + yield chunk + if hasattr(event, "delta") and isinstance(event.delta, str): + yield self._format_chunk( + {"chunk_type": "content_delta", "data_type": "text", "data": event.delta} + ) + + elif event.type == "response.output_text.annotation.added": + if hasattr(event, "annotation"): + if event.annotation.get("type") == "url_citation": + yield self._format_chunk( + { + "chunk_type": "content_delta", + "data_type": "citation", + "data": event.annotation, + } + ) + else: + logger.warning( + "annotation_type=<%s> | unsupported annotation type", + event.annotation.get("type"), + ) + + elif event.type == "response.output_item.added": + # Tool call started + if ( + hasattr(event, "item") + and hasattr(event.item, "type") + and event.item.type == "function_call" + ): + call_id = getattr(event.item, "call_id", "unknown") + tool_calls[call_id] = { + "name": getattr(event.item, "name", ""), + "arguments": "", + "call_id": call_id, + "item_id": getattr(event.item, "id", ""), + } + + elif event.type == "response.function_call_arguments.delta": + # Tool arguments streaming - accumulate deltas by item_id + if hasattr(event, "delta") and hasattr(event, "item_id"): + for _call_id, call_info in tool_calls.items(): + if call_info["item_id"] == event.item_id: + call_info["arguments"] += event.delta + break + + elif event.type == "response.function_call_arguments.done": + # Tool arguments complete - use final arguments as source of truth + if hasattr(event, "arguments") and hasattr(event, "item_id"): + for _call_id, call_info in tool_calls.items(): + if call_info["item_id"] == event.item_id: + call_info["arguments"] = event.arguments + break + + elif event.type == "response.incomplete": + # Response stopped early (e.g., max tokens reached) + if hasattr(event, "response"): + if hasattr(event.response, "usage"): + final_usage = event.response.usage + # Check if stopped due to max_output_tokens + if ( + hasattr(event.response, "incomplete_details") + and event.response.incomplete_details + and getattr(event.response.incomplete_details, "reason", None) + == "max_output_tokens" + ): + stop_reason = "length" + break + + elif event.type == "response.completed": + # Response complete + if hasattr(event, "response") and hasattr(event.response, "usage"): + final_usage = event.response.usage + break + except openai.APIError as e: + if hasattr(e, "code") and e.code == "context_length_exceeded": + logger.warning(_CONTEXT_WINDOW_OVERFLOW_MSG) + raise ContextWindowOverflowException(str(e)) from e + if isinstance(e, openai.RateLimitError): + logger.warning(_RATE_LIMIT_MSG) + raise ModelThrottledException(str(e)) from e + raise + + # Close current content block if we had any + if data_type: + yield self._format_chunk({"chunk_type": "content_stop", "data_type": data_type}) + + # Emit tool calls with complete arguments. + # We emit a single delta per tool containing the full arguments rather than streaming + # incremental argument deltas. The Responses API streams argument chunks via separate + # events (response.function_call_arguments.delta) which we accumulate above, then use + # the final arguments from response.function_call_arguments.done. This approach ensures + # we emit valid, complete JSON arguments rather than partial fragments. + for call_info in tool_calls.values(): + tool_call = SimpleNamespace( + function=SimpleNamespace(name=call_info["name"], arguments=call_info["arguments"]), + id=call_info["call_id"], + ) + + yield self._format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_call}) + yield self._format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": tool_call}) + yield self._format_chunk({"chunk_type": "content_stop", "data_type": "tool"}) + + # Determine finish reason: tool_calls > max_tokens (length) > normal stop + if tool_calls: + finish_reason = "tool_calls" + elif stop_reason == "length": + finish_reason = "length" + else: + finish_reason = "stop" + yield self._format_chunk({"chunk_type": "message_stop", "data": finish_reason}) + + if final_usage: + yield self._format_chunk({"chunk_type": "metadata", "data": final_usage}) + + logger.debug("finished streaming response from OpenAI Responses API model") + + @override + async def structured_output( + self, output_model: type[T], prompt: Messages, system_prompt: str | None = None, **kwargs: Any + ) -> AsyncGenerator[dict[str, T | Any], None]: + """Get structured output from the OpenAI Responses API model. + + Args: + output_model: The output model to use for the agent. + prompt: The prompt messages to use for the agent. + system_prompt: System prompt to provide context to the model. + **kwargs: Additional keyword arguments for future extensibility. + + Yields: + Model events with the last being the structured output. + + Raises: + ContextWindowOverflowException: If the input exceeds the model's context window. + ModelThrottledException: If the request is throttled by OpenAI (rate limits). + """ + async with openai.AsyncOpenAI(**self._resolve_client_args()) as client: + try: + response = await client.responses.parse( + model=self.get_config()["model_id"], + input=self._format_request(prompt, system_prompt=system_prompt)["input"], + text_format=output_model, + ) + except openai.BadRequestError as e: + if hasattr(e, "code") and e.code == "context_length_exceeded": + logger.warning(_CONTEXT_WINDOW_OVERFLOW_MSG) + raise ContextWindowOverflowException(str(e)) from e + raise + except openai.RateLimitError as e: + logger.warning(_RATE_LIMIT_MSG) + raise ModelThrottledException(str(e)) from e + + if response.output_parsed: + yield {"output": response.output_parsed} + else: + raise ValueError("No valid parsed output found in the OpenAI Responses API response.") + + def _format_request( + self, + messages: Messages, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, + tool_choice: ToolChoice | None = None, + model_state: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """Format an OpenAI Responses API compatible response streaming request. + + Args: + messages: List of message objects to be processed by the model. + tool_specs: List of tool specifications to make available to the model. + system_prompt: System prompt to provide context to the model. + tool_choice: Selection strategy for tool invocation. + model_state: Runtime state for model providers (e.g., server-side response ids). + + Returns: + An OpenAI Responses API compatible response streaming request. + + Raises: + TypeError: If a message contains a content block type that cannot be converted to an OpenAI-compatible + format. + """ + input_items = self._format_request_messages(messages) + request: dict[str, Any] = { + "model": self.config["model_id"], + "input": input_items, + "stream": True, + **cast(dict[str, Any], self.config.get("params", {})), + "store": self.stateful, + } + + response_id = model_state.get("response_id") if model_state else None + if response_id and self.stateful: + request["previous_response_id"] = response_id + + if system_prompt: + request["instructions"] = system_prompt + + # Add tools if provided + if tool_specs: + # Merge with any built-in tools (e.g. web_search) already in the request from params + request.setdefault("tools", []).extend( + { + "type": "function", + "name": tool_spec["name"], + "description": tool_spec.get("description", ""), + "parameters": tool_spec["inputSchema"]["json"], + } + for tool_spec in tool_specs + ) + request.update(self._format_request_tool_choice(tool_choice)) + + return request + + @classmethod + def _format_request_tool_choice(cls, tool_choice: ToolChoice | None) -> dict[str, Any]: + """Format a tool choice for OpenAI Responses API compatibility. + + Args: + tool_choice: Tool choice configuration. + + Returns: + OpenAI Responses API compatible tool choice format. + """ + if not tool_choice: + return {} + + match tool_choice: + case {"auto": _}: + return {"tool_choice": "auto"} + case {"any": _}: + return {"tool_choice": "required"} + case {"tool": {"name": tool_name}}: + return {"tool_choice": {"type": "function", "name": tool_name}} + case _: + # Default to auto for unknown formats + return {"tool_choice": "auto"} + + @classmethod + def _format_request_messages(cls, messages: Messages) -> list[dict[str, Any]]: + """Format an OpenAI compatible messages array. + + Args: + messages: List of message objects to be processed by the model. + + Returns: + An OpenAI compatible messages array. + """ + formatted_messages: list[dict[str, Any]] = [] + + for message in messages: + role = message["role"] + contents = message["content"] + + if any("reasoningContent" in content for content in contents): + logger.warning( + "reasoningContent is not yet supported in multi-turn conversations with the Responses API" + ) + + formatted_contents = [ + cls._format_request_message_content(content, role=role) + for content in contents + if not any(block_type in content for block_type in ["toolResult", "toolUse", "reasoningContent"]) + ] + + formatted_tool_calls = [ + cls._format_request_message_tool_call(content["toolUse"]) + for content in contents + if "toolUse" in content + ] + + formatted_tool_messages = [ + cls._format_request_tool_message(content["toolResult"]) + for content in contents + if "toolResult" in content + ] + + if formatted_contents: + formatted_messages.append( + { + "role": role, # "user" | "assistant" + "content": formatted_contents, + } + ) + + formatted_messages.extend(formatted_tool_calls) + formatted_messages.extend(formatted_tool_messages) + + return [ + message + for message in formatted_messages + if message.get("content") or message.get("type") in ["function_call", "function_call_output"] + ] + + @classmethod + def _format_request_message_content(cls, content: ContentBlock, *, role: Role = "user") -> dict[str, Any]: + """Format an OpenAI compatible content block. + + Args: + content: Message content. + role: Message role ("user" or "assistant"). Controls text content + type: "input_text" for user, "output_text" for assistant. + + Returns: + OpenAI compatible content block. + + Raises: + TypeError: If the content block type cannot be converted to an OpenAI-compatible format. + ValueError: If the image or document size exceeds the maximum allowed size (20MB). + """ + if "document" in content: + doc = content["document"] + data_url = _encode_media_to_data_url(doc["source"]["bytes"], doc["format"], "document") + return {"type": "input_file", "file_url": data_url} + + if "image" in content: + img = content["image"] + data_url = _encode_media_to_data_url(img["source"]["bytes"], img["format"], "image") + return {"type": "input_image", "image_url": data_url} + + if "text" in content: + text_type = "output_text" if role == "assistant" else "input_text" + return {"type": text_type, "text": content["text"]} + + if "citationsContent" in content: + text = "".join(c["text"] for c in content["citationsContent"].get("content", []) if "text" in c) + text_type = "output_text" if role == "assistant" else "input_text" + return {"type": text_type, "text": text} + + raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") + + @classmethod + def _format_request_message_tool_call(cls, tool_use: ToolUse) -> dict[str, Any]: + """Format an OpenAI compatible tool call. + + Args: + tool_use: Tool use requested by the model. + + Returns: + OpenAI compatible tool call. + """ + return { + "type": "function_call", + "call_id": tool_use["toolUseId"], + "name": tool_use["name"], + "arguments": json.dumps(tool_use["input"]), + } + + @classmethod + def _format_request_tool_message(cls, tool_result: ToolResult) -> dict[str, Any]: + """Format an OpenAI compatible tool message. + + Args: + tool_result: Tool result collected from a tool execution. + + Returns: + OpenAI compatible tool message. + + Raises: + ValueError: If the image or document size exceeds the maximum allowed size (20MB). + + Note: + The Responses API's function_call_output can be either a string (typically JSON encoded) + or an array of content objects when returning images/files. + See: https://platform.openai.com/docs/guides/function-calling + """ + output_parts: list[dict[str, Any]] = [] + has_media = False + + for content in tool_result["content"]: + if "json" in content: + output_parts.append({"type": "input_text", "text": json.dumps(content["json"])}) + elif "text" in content: + output_parts.append({"type": "input_text", "text": content["text"]}) + elif "image" in content: + has_media = True + img = content["image"] + data_url = _encode_media_to_data_url(img["source"]["bytes"], img["format"], "image") + output_parts.append({"type": "input_image", "image_url": data_url}) + elif "document" in content: + has_media = True + doc = content["document"] + data_url = _encode_media_to_data_url(doc["source"]["bytes"], doc["format"], "document") + output_parts.append({"type": "input_file", "file_url": data_url}) + + # Return array if has media content, otherwise join as string for simpler text-only cases + output: list[dict[str, Any]] | str + if has_media: + output = output_parts + else: + output = "\n".join(part.get("text", "") for part in output_parts) if output_parts else "" + + return { + "type": "function_call_output", + "call_id": tool_result["toolUseId"], + "output": output, + } + + def _stream_switch_content(self, data_type: str, prev_data_type: str | None) -> tuple[list[StreamEvent], str]: + """Handle switching to a new content stream. + + Args: + data_type: The next content data type. + prev_data_type: The previous content data type. + + Returns: + Tuple containing: + - Stop block for previous content and the start block for the next content. + - Next content data type. + """ + chunks: list[StreamEvent] = [] + if data_type != prev_data_type: + if prev_data_type is not None: + chunks.append(self._format_chunk({"chunk_type": "content_stop", "data_type": prev_data_type})) + chunks.append(self._format_chunk({"chunk_type": "content_start", "data_type": data_type})) + + return chunks, data_type + + def _format_chunk(self, event: dict[str, Any]) -> StreamEvent: + """Format an OpenAI response event into a standardized message chunk. + + Args: + event: A response event from the OpenAI compatible model. + + Returns: + The formatted chunk. + + Raises: + RuntimeError: If chunk_type is not recognized. + This error should never be encountered as chunk_type is controlled in the stream method. + """ + match event["chunk_type"]: + case "message_start": + return {"messageStart": {"role": "assistant"}} + + case "content_start": + if event["data_type"] == "tool": + return { + "contentBlockStart": { + "start": { + "toolUse": { + "name": event["data"].function.name, + "toolUseId": event["data"].id, + } + } + } + } + + return {"contentBlockStart": {"start": {}}} + + case "content_delta": + if event["data_type"] == "tool": + return { + "contentBlockDelta": {"delta": {"toolUse": {"input": event["data"].function.arguments or ""}}} + } + + if event["data_type"] == "reasoning_content": + return {"contentBlockDelta": {"delta": {"reasoningContent": {"text": event["data"]}}}} + + if event["data_type"] == "citation": + web_location: WebLocationDict = {"web": {"url": event["data"].get("url", "")}} + return { + "contentBlockDelta": { + "delta": { + "citation": { + "title": event["data"].get("title", ""), + "location": web_location, + } + } + } + } + + return {"contentBlockDelta": {"delta": {"text": event["data"]}}} + + case "content_stop": + return {"contentBlockStop": {}} + + case "message_stop": + match event["data"]: + case "tool_calls": + return {"messageStop": {"stopReason": "tool_use"}} + case "length": + return {"messageStop": {"stopReason": "max_tokens"}} + case _: + return {"messageStop": {"stopReason": "end_turn"}} + + case "metadata": + # Responses API uses input_tokens/output_tokens naming convention + return { + "metadata": { + "usage": { + "inputTokens": getattr(event["data"], "input_tokens", 0), + "outputTokens": getattr(event["data"], "output_tokens", 0), + "totalTokens": getattr(event["data"], "total_tokens", 0), + }, + "metrics": { + "latencyMs": 0, # TODO + }, + }, + } + + case _: + raise RuntimeError(f"chunk_type=<{event['chunk_type']}> | unknown type") diff --git a/src/strands/models/sagemaker.py b/src/strands/models/sagemaker.py index 7f8b8ff51..0d206fd0b 100644 --- a/src/strands/models/sagemaker.py +++ b/src/strands/models/sagemaker.py @@ -3,8 +3,9 @@ import json import logging import os +from collections.abc import AsyncGenerator from dataclasses import dataclass -from typing import Any, AsyncGenerator, Literal, Optional, Type, TypedDict, TypeVar, Union +from typing import Any, Literal, TypedDict, TypeVar import boto3 from botocore.config import Config as BotocoreConfig @@ -16,6 +17,7 @@ from ..types.streaming import StreamEvent from ..types.tools import ToolChoice, ToolResult, ToolSpec from ._validation import validate_config_keys, warn_on_tool_choice_not_supported +from .model import BaseModelConfig from .openai import OpenAIModel T = TypeVar("T", bound=BaseModel) @@ -37,7 +39,7 @@ class UsageMetadata: total_tokens: int completion_tokens: int prompt_tokens: int - prompt_tokens_details: Optional[int] = 0 + prompt_tokens_details: int | None = 0 @dataclass @@ -49,8 +51,8 @@ class FunctionCall: arguments: Arguments to pass to the function """ - name: Union[str, dict[Any, Any]] - arguments: Union[str, dict[Any, Any]] + name: str | dict[Any, Any] + arguments: str | dict[Any, Any] def __init__(self, **kwargs: dict[str, str]): """Initialize function call. @@ -108,14 +110,14 @@ class SageMakerAIPayloadSchema(TypedDict, total=False): max_tokens: int stream: bool - temperature: Optional[float] - top_p: Optional[float] - top_k: Optional[int] - stop: Optional[list[str]] - tool_results_as_user_messages: Optional[bool] - additional_args: Optional[dict[str, Any]] - - class SageMakerAIEndpointConfig(TypedDict, total=False): + temperature: float | None + top_p: float | None + top_k: int | None + stop: list[str] | None + tool_results_as_user_messages: bool | None + additional_args: dict[str, Any] | None + + class SageMakerAIEndpointConfig(BaseModelConfig, total=False): """Configuration options for SageMaker models. Attributes: @@ -127,17 +129,17 @@ class SageMakerAIEndpointConfig(TypedDict, total=False): endpoint_name: str region_name: str - inference_component_name: Union[str, None] - target_model: Union[Optional[str], None] - target_variant: Union[Optional[str], None] - additional_args: Optional[dict[str, Any]] + inference_component_name: str | None + target_model: str | None | None + target_variant: str | None | None + additional_args: dict[str, Any] | None def __init__( self, endpoint_config: SageMakerAIEndpointConfig, payload_config: SageMakerAIPayloadSchema, - boto_session: Optional[boto3.Session] = None, - boto_client_config: Optional[BotocoreConfig] = None, + boto_session: boto3.Session | None = None, + boto_client_config: BotocoreConfig | None = None, ): """Initialize provider instance. @@ -199,8 +201,8 @@ def get_config(self) -> "SageMakerAIModel.SageMakerAIEndpointConfig": # type: i def format_request( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, tool_choice: ToolChoice | None = None, **kwargs: Any, ) -> dict[str, Any]: @@ -300,8 +302,8 @@ def format_request( async def stream( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, *, tool_choice: ToolChoice | None = None, **kwargs: Any, @@ -353,7 +355,7 @@ async def stream( logger.info("choice=<%s>", json.dumps(choice, indent=2)) # Handle text content - if choice["delta"].get("content", None): + if choice["delta"].get("content"): if not text_content_started: yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"}) text_content_started = True @@ -367,7 +369,7 @@ async def stream( ) # Handle reasoning content - if choice["delta"].get("reasoning_content", None): + if choice["delta"].get("reasoning_content"): if not reasoning_content_started: yield self.format_chunk( {"chunk_type": "content_start", "data_type": "reasoning_content"} @@ -392,7 +394,7 @@ async def stream( finish_reason = choice["finish_reason"] break - if choice.get("usage", None): + if choice.get("usage"): yield self.format_chunk( {"chunk_type": "metadata", "data": UsageMetadata(**choice["usage"])} ) @@ -412,7 +414,7 @@ async def stream( # Handle tool calling logger.info("tool_calls=<%s>", json.dumps(tool_calls, indent=2)) for tool_deltas in tool_calls.values(): - if not tool_deltas[0]["function"].get("name", None): + if not tool_deltas[0]["function"].get("name"): raise Exception("The model did not provide a tool name.") yield self.format_chunk( {"chunk_type": "content_start", "data_type": "tool", "data": ToolCall(**tool_deltas[0])} @@ -453,7 +455,7 @@ async def stream( yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"}) # Handle reasoning content - if message.get("reasoning_content", None): + if message.get("reasoning_content"): yield self.format_chunk({"chunk_type": "content_start", "data_type": "reasoning_content"}) yield self.format_chunk( { @@ -465,7 +467,7 @@ async def stream( yield self.format_chunk({"chunk_type": "content_stop", "data_type": "reasoning_content"}) # Handle the tool calling, if any - if message.get("tool_calls", None) or message_stop_reason == "tool_calls": + if message.get("tool_calls") or message_stop_reason == "tool_calls": if not isinstance(message["tool_calls"], list): message["tool_calls"] = [message["tool_calls"]] for tool_call in message["tool_calls"]: @@ -484,9 +486,9 @@ async def stream( # Message close yield self.format_chunk({"chunk_type": "message_stop", "data": message_stop_reason}) # Handle usage metadata - if final_response_json.get("usage", None): + if final_response_json.get("usage"): yield self.format_chunk( - {"chunk_type": "metadata", "data": UsageMetadata(**final_response_json.get("usage", None))} + {"chunk_type": "metadata", "data": UsageMetadata(**final_response_json.get("usage"))} ) except ( self.client.exceptions.InternalFailure, @@ -556,7 +558,7 @@ def format_request_message_content(cls, content: ContentBlock, **kwargs: Any) -> "thinking": content["reasoningContent"].get("reasoningText", {}).get("text", ""), "type": "thinking", } - elif not content.get("reasoningContent", None): + elif not content.get("reasoningContent"): content.pop("reasoningContent", None) if "video" in content: @@ -572,8 +574,8 @@ def format_request_message_content(cls, content: ContentBlock, **kwargs: Any) -> @override async def structured_output( - self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any - ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + self, output_model: type[T], prompt: Messages, system_prompt: str | None = None, **kwargs: Any + ) -> AsyncGenerator[dict[str, T | Any], None]: """Get structured output from the model. Args: diff --git a/src/strands/models/writer.py b/src/strands/models/writer.py index a54fc44c3..3e3276106 100644 --- a/src/strands/models/writer.py +++ b/src/strands/models/writer.py @@ -7,7 +7,8 @@ import json import logging import mimetypes -from typing import Any, AsyncGenerator, Dict, List, Optional, Type, TypedDict, TypeVar, Union, cast +from collections.abc import AsyncGenerator +from typing import Any, TypeVar, cast import writerai from pydantic import BaseModel @@ -17,8 +18,8 @@ from ..types.exceptions import ModelThrottledException from ..types.streaming import StreamEvent from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse -from ._validation import validate_config_keys, warn_on_tool_choice_not_supported -from .model import Model +from ._validation import _has_location_source, validate_config_keys, warn_on_tool_choice_not_supported +from .model import BaseModelConfig, Model logger = logging.getLogger(__name__) @@ -28,7 +29,7 @@ class WriterModel(Model): """Writer API model provider implementation.""" - class WriterConfig(TypedDict, total=False): + class WriterConfig(BaseModelConfig, total=False): """Configuration options for Writer API. Attributes: @@ -41,13 +42,13 @@ class WriterConfig(TypedDict, total=False): """ model_id: str - max_tokens: Optional[int] - stop: Optional[Union[str, List[str]]] - stream_options: Dict[str, Any] - temperature: Optional[float] - top_p: Optional[float] + max_tokens: int | None + stop: str | list[str] | None + stream_options: dict[str, Any] + temperature: float | None + top_p: float | None - def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[WriterConfig]): + def __init__(self, client_args: dict[str, Any] | None = None, **model_config: Unpack[WriterConfig]): """Initialize provider instance. Args: @@ -201,7 +202,7 @@ def _format_request_tool_message(self, tool_result: ToolResult) -> dict[str, Any "content": formatted_contents, } - def _format_request_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: + def _format_request_messages(self, messages: Messages, system_prompt: str | None = None) -> list[dict[str, Any]]: """Format a Writer compatible messages array. Args: @@ -217,11 +218,21 @@ def _format_request_messages(self, messages: Messages, system_prompt: Optional[s for message in messages: contents = message["content"] + # Filter out location sources + filtered_contents = [] + for content in contents: + if _has_location_source(content): + logger.warning("Location sources are not supported by Writer | skipping content block") + continue + filtered_contents.append(content) + # Only palmyra V5 support multiple content. Other models support only '{"content": "text_content"}' if self.get_config().get("model_id", "") == "palmyra-x5": - formatted_contents: str | list[dict[str, Any]] = self._format_request_message_contents_vision(contents) + formatted_contents: str | list[dict[str, Any]] = self._format_request_message_contents_vision( + filtered_contents + ) else: - formatted_contents = self._format_request_message_contents(contents) + formatted_contents = self._format_request_message_contents(filtered_contents) formatted_tool_calls = [ self._format_request_message_tool_call(content["toolUse"]) @@ -245,7 +256,7 @@ def _format_request_messages(self, messages: Messages, system_prompt: Optional[s return [message for message in formatted_messages if message["content"] or "tool_calls" in message] def format_request( - self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None + self, messages: Messages, tool_specs: list[ToolSpec] | None = None, system_prompt: str | None = None ) -> Any: """Format a streaming request to the underlying model. @@ -353,8 +364,8 @@ def format_chunk(self, event: Any) -> StreamEvent: async def stream( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, *, tool_choice: ToolChoice | None = None, **kwargs: Any, @@ -431,8 +442,8 @@ async def stream( @override async def structured_output( - self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any - ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + self, output_model: type[T], prompt: Messages, system_prompt: str | None = None, **kwargs: Any + ) -> AsyncGenerator[dict[str, T | Any], None]: """Get structured output from the model. Args: diff --git a/src/strands/multiagent/__init__.py b/src/strands/multiagent/__init__.py index e251e9318..ad99944a8 100644 --- a/src/strands/multiagent/__init__.py +++ b/src/strands/multiagent/__init__.py @@ -8,7 +8,7 @@ standardized communication between agents. """ -from .base import MultiAgentBase, MultiAgentResult +from .base import MultiAgentBase, MultiAgentResult, Status from .graph import GraphBuilder, GraphResult from .swarm import Swarm, SwarmResult @@ -17,6 +17,7 @@ "GraphResult", "MultiAgentBase", "MultiAgentResult", + "Status", "Swarm", "SwarmResult", ] diff --git a/src/strands/multiagent/a2a/_converters.py b/src/strands/multiagent/a2a/_converters.py new file mode 100644 index 000000000..7808ae325 --- /dev/null +++ b/src/strands/multiagent/a2a/_converters.py @@ -0,0 +1,179 @@ +"""Conversion functions between Strands and A2A types.""" + +from typing import cast +from uuid import uuid4 + +from a2a.types import Message as A2AMessage +from a2a.types import Part, Role, TaskArtifactUpdateEvent, TaskState, TaskStatusUpdateEvent, TextPart + +from ...agent.agent_result import AgentResult +from ...telemetry.metrics import EventLoopMetrics +from ...types.a2a import A2AResponse +from ...types.agent import AgentInput +from ...types.content import ContentBlock, Message +from ...types.event_loop import StopReason + +# Mapping from A2A TaskState to Strands stop_reason +_STATE_TO_STOP_REASON: dict[TaskState, StopReason] = { + TaskState.completed: "end_turn", + TaskState.failed: "end_turn", + TaskState.canceled: "end_turn", + TaskState.rejected: "end_turn", + TaskState.input_required: "interrupt", + TaskState.auth_required: "interrupt", +} + + +def convert_input_to_message(prompt: AgentInput) -> A2AMessage: + """Convert AgentInput to A2A Message. + + Args: + prompt: Input in various formats (string, message list, or content blocks). + + Returns: + A2AMessage ready to send to the remote agent. + + Raises: + ValueError: If prompt format is unsupported. + """ + message_id = uuid4().hex + + if isinstance(prompt, str): + return A2AMessage( + kind="message", + role=Role.user, + parts=[Part(TextPart(kind="text", text=prompt))], + message_id=message_id, + ) + + if isinstance(prompt, list) and prompt and (isinstance(prompt[0], dict)): + # Check for interrupt responses - not supported in A2A + if "interruptResponse" in prompt[0]: + raise ValueError("InterruptResponseContent is not supported for A2AAgent") + + if "role" in prompt[0]: + for msg in reversed(prompt): + if msg.get("role") == "user": + content = cast(list[ContentBlock], msg.get("content", [])) + parts = convert_content_blocks_to_parts(content) + return A2AMessage( + kind="message", + role=Role.user, + parts=parts, + message_id=message_id, + ) + else: + parts = convert_content_blocks_to_parts(cast(list[ContentBlock], prompt)) + return A2AMessage( + kind="message", + role=Role.user, + parts=parts, + message_id=message_id, + ) + + raise ValueError(f"Unsupported input type: {type(prompt)}") + + +def convert_content_blocks_to_parts(content_blocks: list[ContentBlock]) -> list[Part]: + """Convert Strands ContentBlocks to A2A Parts. + + Args: + content_blocks: List of Strands content blocks. + + Returns: + List of A2A Part objects. + """ + parts = [] + for block in content_blocks: + if "text" in block: + parts.append(Part(TextPart(kind="text", text=block["text"]))) + return parts + + +def _extract_task_state(response: A2AResponse) -> TaskState | None: + """Extract the task state from an A2A response. + + Args: + response: A2A response (either A2AMessage or tuple of task and update event). + + Returns: + The TaskState if available, None otherwise. + """ + if isinstance(response, tuple) and len(response) == 2: + _task, update_event = response + if isinstance(update_event, TaskStatusUpdateEvent): + if update_event.status and hasattr(update_event.status, "state"): + return update_event.status.state + return None + + +def convert_response_to_agent_result(response: A2AResponse) -> AgentResult: + """Convert A2A response to AgentResult. + + Maps A2A task lifecycle states to appropriate Strands stop_reasons: + - completed → end_turn + - failed → end_turn (with error content) + - canceled → end_turn (with cancellation info) + - rejected → end_turn (with rejection info) + - input_required → interrupt (agent needs user input) + - auth_required → interrupt (agent needs authentication) + + Args: + response: A2A response (either A2AMessage or tuple of task and update event). + + Returns: + AgentResult with extracted content and metadata. + """ + content: list[ContentBlock] = [] + task_state = _extract_task_state(response) + stop_reason: StopReason = _STATE_TO_STOP_REASON.get(task_state, "end_turn") if task_state else "end_turn" + + if isinstance(response, tuple) and len(response) == 2: + task, update_event = response + + # Handle artifact updates + if isinstance(update_event, TaskArtifactUpdateEvent): + if update_event.artifact and hasattr(update_event.artifact, "parts") and update_event.artifact.parts: + for part in update_event.artifact.parts: + if hasattr(part, "root") and hasattr(part.root, "text"): + content.append({"text": part.root.text}) + # Handle status updates with messages + elif isinstance(update_event, TaskStatusUpdateEvent): + if ( + update_event.status + and hasattr(update_event.status, "message") + and update_event.status.message + and update_event.status.message.parts + ): + for part in update_event.status.message.parts: + if hasattr(part, "root") and hasattr(part.root, "text"): + content.append({"text": part.root.text}) + + # Use task.artifacts when no content was extracted from the event + if not content and task and hasattr(task, "artifacts") and task.artifacts is not None: + for artifact in task.artifacts: + if hasattr(artifact, "parts") and artifact.parts: + for part in artifact.parts: + if hasattr(part, "root") and hasattr(part.root, "text"): + content.append({"text": part.root.text}) + elif isinstance(response, A2AMessage): + for part in response.parts: + if hasattr(part, "root") and hasattr(part.root, "text"): + content.append({"text": part.root.text}) + + message: Message = { + "role": "assistant", + "content": content, + } + + # Build state dict with A2A metadata + state: dict[str, str] = {} + if task_state is not None: + state["a2a_task_state"] = task_state.value + + return AgentResult( + stop_reason=stop_reason, + message=message, + metrics=EventLoopMetrics(), + state=state, + ) diff --git a/src/strands/multiagent/a2a/executor.py b/src/strands/multiagent/a2a/executor.py index 52b6d2ef1..7526386e8 100644 --- a/src/strands/multiagent/a2a/executor.py +++ b/src/strands/multiagent/a2a/executor.py @@ -8,10 +8,13 @@ streamed requests to the A2AServer. """ +import asyncio import base64 import json import logging import mimetypes +import uuid +import warnings from typing import Any, Literal from a2a.server.agent_execution import AgentExecutor, RequestContext @@ -40,7 +43,9 @@ class StrandsA2AExecutor(AgentExecutor): """Executor that adapts a Strands Agent to the A2A protocol. This executor uses streaming mode to handle the execution of agent requests - and converts Strands Agent responses to A2A protocol events. + and converts Strands Agent responses to A2A protocol events. It supports the + full A2A task lifecycle including error handling (failed state), cancellation, + and interrupt-based input_required flows. """ # Default formats for each file type when MIME type is unavailable or unrecognized @@ -49,13 +54,21 @@ class StrandsA2AExecutor(AgentExecutor): # Handle special cases where format differs from extension FORMAT_MAPPINGS = {"jpg": "jpeg", "htm": "html", "3gp": "three_gp", "3gpp": "three_gp", "3g2": "three_gp"} - def __init__(self, agent: SAAgent): + # A2A-compliant streaming mode + _current_artifact_id: str | None + _is_first_chunk: bool + + def __init__(self, agent: SAAgent, *, enable_a2a_compliant_streaming: bool = False): """Initialize a StrandsA2AExecutor. Args: agent: The Strands Agent instance to adapt to the A2A protocol. + enable_a2a_compliant_streaming: If True, uses A2A-compliant streaming with + artifact updates. If False, uses legacy status updates streaming behavior + for backwards compatibility. Defaults to False. """ self.agent = agent + self.enable_a2a_compliant_streaming = enable_a2a_compliant_streaming async def execute( self, @@ -65,14 +78,18 @@ async def execute( """Execute a request using the Strands Agent and send the response as A2A events. This method executes the user's input using the Strands Agent in streaming mode - and converts the agent's response to A2A events. + and converts the agent's response to A2A events. If the agent raises an exception, + the task transitions to the `failed` state. If the agent returns with interrupts, + the task transitions to the `input_required` state. Args: context: The A2A request context, containing the user's input and task metadata. event_queue: The A2A event queue used to send response events back to the client. Raises: - ServerError: If an error occurs during agent execution + ServerError: If an unrecoverable error occurs during agent execution setup + (e.g., missing input). Agent execution errors are handled gracefully + by transitioning the task to the failed state. """ task = context.current_task if not task: @@ -83,8 +100,34 @@ async def execute( try: await self._execute_streaming(context, updater) - except Exception as e: - raise ServerError(error=InternalError()) from e + except ServerError: + # Re-raise ServerErrors (setup failures like missing input) + raise + except asyncio.CancelledError: + # asyncio.CancelledError is a BaseException (not Exception) — raised when + # the asyncio task is cancelled (e.g., HTTP client disconnect, server shutdown). + # We transition to canceled state so the task doesn't remain a zombie in "working". + logger.warning("task_id=<%s> | asyncio task cancelled, transitioning to canceled state", task.id) + try: + await updater.cancel( + message=updater.new_agent_message( + parts=[Part(root=TextPart(text="Task cancelled due to connection termination"))] + ) + ) + except RuntimeError: + # Task already in terminal state + logger.debug("task_id=<%s> | task already in terminal state, cannot transition to canceled", task.id) + raise + except Exception: + # Agent execution failures transition to failed state + logger.exception("task_id=<%s> | agent execution failed, transitioning to failed state", task.id) + try: + await updater.failed( + message=updater.new_agent_message(parts=[Part(root=TextPart(text="Agent execution failed"))]) + ) + except RuntimeError: + # Task already in terminal state (e.g., completed before error in cleanup) + logger.debug("task_id=<%s> | task already in terminal state, cannot transition to failed", task.id) async def _execute_streaming(self, context: RequestContext, updater: TaskUpdater) -> None: """Execute request in streaming mode. @@ -95,21 +138,88 @@ async def _execute_streaming(self, context: RequestContext, updater: TaskUpdater Args: context: The A2A request context, containing the user's input and other metadata. updater: The task updater for managing task state and sending updates. + + Raises: + ServerError: If input conversion fails (missing or empty content). """ # Convert A2A message parts to Strands ContentBlocks if context.message and hasattr(context.message, "parts"): content_blocks = self._convert_a2a_parts_to_content_blocks(context.message.parts) if not content_blocks: - raise ValueError("No content blocks available") + raise ServerError( + error=InternalError(message="No valid content found in request message parts") + ) from None else: - raise ValueError("No content blocks available") + raise ServerError(error=InternalError(message="Request message is missing or has no parts")) from None + + if not self.enable_a2a_compliant_streaming: + warnings.warn( + "The default A2A response stream implemented in the strands sdk does not conform to " + "what is expected in the A2A spec. Please set the `enable_a2a_compliant_streaming` " + "boolean to `True` on your `A2AServer` class to properly conform to the spec. " + "In the next major version release, this will be the default behavior.", + UserWarning, + stacklevel=3, + ) + + if self.enable_a2a_compliant_streaming: + self._current_artifact_id = str(uuid.uuid4()) + self._is_first_chunk = True + + # Pass the A2A RequestContext through invocation state so downstream + # tools and hooks can access request metadata, task info, configuration, etc. + invocation_state: dict[str, Any] = {"a2a_request_context": context} try: - async for event in self.agent.stream_async(content_blocks): - await self._handle_streaming_event(event, updater) + result: SAAgentResult | None = None + async for event in self.agent.stream_async(content_blocks, invocation_state=invocation_state): + if "result" in event: + result = event["result"] + else: + await self._handle_streaming_event(event, updater) + + # Check if agent returned with interrupts (input_required) + # Note: stop_reason="interrupt" is the authoritative signal. Even if interrupts + # list is empty (edge case), the agent still indicated it needs input. + if result is not None and result.stop_reason == "interrupt": + await self._handle_interrupt_result(result, updater) + else: + await self._handle_agent_result(result, updater) except Exception: logger.exception("Error in streaming execution") raise + finally: + if self.enable_a2a_compliant_streaming: + self._current_artifact_id = None + self._is_first_chunk = True + + async def _handle_interrupt_result(self, result: SAAgentResult, updater: TaskUpdater) -> None: + """Handle an agent result that contains interrupts. + + When the Strands Agent returns with stop_reason="interrupt", this maps to + the A2A `input_required` state. The interrupt details are communicated to + the client via the status message. + + Args: + result: The agent result containing interrupts. + updater: The task updater for managing task state. + """ + # Build a descriptive message about what input is needed + interrupt_descriptions = [] + for interrupt in result.interrupts or []: + desc = f"- {interrupt.name}" + if interrupt.reason: + desc += f": {interrupt.reason}" + interrupt_descriptions.append(desc) + + if interrupt_descriptions: + input_message = "Agent requires input:\n" + "\n".join(interrupt_descriptions) + else: + # Edge case: stop_reason="interrupt" but no interrupt details provided. + # Still transition to input_required — the agent signaled it needs input. + input_message = "Agent requires additional input to continue" + + await updater.requires_input(message=updater.new_agent_message(parts=[Part(root=TextPart(text=input_message))])) async def _handle_streaming_event(self, event: dict[str, Any], updater: TaskUpdater) -> None: """Handle a single streaming event from the Strands Agent. @@ -125,28 +235,57 @@ async def _handle_streaming_event(self, event: dict[str, Any], updater: TaskUpda logger.debug("Streaming event: %s", event) if "data" in event: if text_content := event["data"]: - await updater.update_status( - TaskState.working, - new_agent_text_message( - text_content, - updater.context_id, - updater.task_id, - ), - ) - elif "result" in event: - await self._handle_agent_result(event["result"], updater) + if self.enable_a2a_compliant_streaming: + await updater.add_artifact( + [Part(root=TextPart(text=text_content))], + artifact_id=self._current_artifact_id, + name="agent_response", + append=not self._is_first_chunk, + ) + self._is_first_chunk = False + else: + # Legacy use update_status with agent message + await updater.update_status( + TaskState.working, + new_agent_text_message( + text_content, + updater.context_id, + updater.task_id, + ), + ) async def _handle_agent_result(self, result: SAAgentResult | None, updater: TaskUpdater) -> None: """Handle the final result from the Strands Agent. - Processes the agent's final result, extracts text content from the response, - and adds it as an artifact to the task before marking the task as complete. + For A2A-compliant streaming: sends the final artifact chunk marker and marks + the task as complete. If no data chunks were previously sent, includes the + result content. + + For legacy streaming: adds the final result as a simple artifact without + artifact_id tracking. Args: result: The agent result object containing the final response, or None if no result. updater: The task updater for managing task state and adding the final artifact. """ - if final_content := str(result): + if self.enable_a2a_compliant_streaming: + if self._is_first_chunk: + final_content = str(result) if result else "" + await updater.add_artifact( + [Part(root=TextPart(text=final_content))], + artifact_id=self._current_artifact_id, + name="agent_response", + last_chunk=True, + ) + else: + await updater.add_artifact( + [Part(root=TextPart(text=""))], + artifact_id=self._current_artifact_id, + name="agent_response", + append=True, + last_chunk=True, + ) + elif final_content := str(result): await updater.add_artifact( [Part(root=TextPart(text=final_content))], name="agent_response", @@ -156,20 +295,42 @@ async def _handle_agent_result(self, result: SAAgentResult | None, updater: Task async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None: """Cancel an ongoing execution. - This method is called when a request cancellation is requested. Currently, - cancellation is not supported by the Strands Agent executor, so this method - always raises an UnsupportedOperationError. + Transitions the task to the canceled state and attempts to stop the agent. + The agent's cancel() method is called to signal cooperative cancellation + of in-flight execution. + + Note: This transitions the A2A task state. The underlying agent execution + may still complete its current model call before stopping. Args: context: The A2A request context. event_queue: The A2A event queue. Raises: - ServerError: Always raised with an UnsupportedOperationError, as cancellation - is not currently supported. + ServerError: If no current task exists or the task is already in a terminal state. """ - logger.warning("Cancellation requested but not supported") - raise ServerError(error=UnsupportedOperationError()) + task = context.current_task + if not task: + logger.warning("context_id=<%s> | cancel requested but no current task found", context.context_id) + raise ServerError(error=UnsupportedOperationError()) from None + + # Cooperatively cancel the agent's execution (best-effort). + # Agent.cancel() is always available since self.agent is typed as Agent. + try: + self.agent.cancel() + except Exception: + logger.debug("task_id=<%s> | agent cancel signal failed (non-critical)", task.id) + + updater = TaskUpdater(event_queue, task.id, task.context_id) + + try: + await updater.cancel( + message=updater.new_agent_message(parts=[Part(root=TextPart(text="Task cancelled by client request"))]) + ) + except RuntimeError: + # TaskUpdater raises RuntimeError when task is already in a terminal state + logger.warning("task_id=<%s> | cannot cancel, already in terminal state", task.id) + raise ServerError(error=UnsupportedOperationError()) from None def _get_file_type_from_mime_type(self, mime_type: str | None) -> Literal["document", "image", "video", "unknown"]: """Classify file type based on MIME type. @@ -313,15 +474,13 @@ def _convert_a2a_parts_to_content_blocks(self, parts: list[Part]) -> list[Conten elif uri_data: # For URI files, create a text representation since Strands ContentBlocks expect bytes content_blocks.append( - ContentBlock( - text="[File: %s (%s)] - Referenced file at: %s" % (file_name, mime_type, uri_data) - ) + ContentBlock(text=f"[File: {file_name} ({mime_type})] - Referenced file at: {uri_data}") ) elif isinstance(part_root, DataPart): # Handle DataPart - convert structured data to JSON text try: data_text = json.dumps(part_root.data, indent=2) - content_blocks.append(ContentBlock(text="[Structured Data]\n%s" % data_text)) + content_blocks.append(ContentBlock(text=f"[Structured Data]\n{data_text}")) except Exception: logger.exception("Failed to serialize data part") except Exception: diff --git a/src/strands/multiagent/a2a/server.py b/src/strands/multiagent/a2a/server.py index bbfbc824d..fd90e9787 100644 --- a/src/strands/multiagent/a2a/server.py +++ b/src/strands/multiagent/a2a/server.py @@ -42,6 +42,7 @@ def __init__( queue_manager: QueueManager | None = None, push_config_store: PushNotificationConfigStore | None = None, push_sender: PushNotificationSender | None = None, + enable_a2a_compliant_streaming: bool = False, ): """Initialize an A2A-compatible server from a Strands agent. @@ -66,6 +67,9 @@ def __init__( no push notification configuration is used. push_sender: Custom push notification sender implementation. If None, no push notifications are sent. + enable_a2a_compliant_streaming: If True, uses A2A-compliant streaming with + artifact updates. If False, uses legacy status updates streaming behavior + for backwards compatibility. Defaults to False. """ self.host = host self.port = port @@ -75,6 +79,7 @@ def __init__( # Parse the provided URL to extract components for mounting self.public_base_url, self.mount_path = self._parse_public_url(http_url) self.http_url = http_url.rstrip("/") + "/" + self._http_url_explicit = True # Override mount path if serve_at_root is requested if serve_at_root: @@ -84,13 +89,16 @@ def __init__( self.public_base_url = f"http://{host}:{port}" self.http_url = f"{self.public_base_url}/" self.mount_path = "" + self._http_url_explicit = False self.strands_agent = agent self.name = self.strands_agent.name self.description = self.strands_agent.description self.capabilities = AgentCapabilities(streaming=True) self.request_handler = DefaultRequestHandler( - agent_executor=StrandsA2AExecutor(self.strands_agent), + agent_executor=StrandsA2AExecutor( + self.strands_agent, enable_a2a_compliant_streaming=enable_a2a_compliant_streaming + ), task_store=task_store or InMemoryTaskStore(), queue_manager=queue_manager, push_config_store=push_config_store, @@ -176,16 +184,21 @@ def agent_skills(self, skills: list[AgentSkill]) -> None: """ self._agent_skills = skills - def to_starlette_app(self) -> Starlette: + def to_starlette_app(self, *, app_kwargs: dict[str, Any] | None = None) -> Starlette: """Create a Starlette application for serving this agent via HTTP. Automatically handles path-based mounting if a mount path was derived from the http_url parameter. + Args: + app_kwargs: Additional keyword arguments to pass to the Starlette constructor. + Returns: Starlette: A Starlette application configured to serve this agent. """ - a2a_app = A2AStarletteApplication(agent_card=self.public_agent_card, http_handler=self.request_handler).build() + a2a_app = A2AStarletteApplication(agent_card=self.public_agent_card, http_handler=self.request_handler).build( + **app_kwargs or {} + ) if self.mount_path: # Create parent app and mount the A2A app at the specified path @@ -196,16 +209,21 @@ def to_starlette_app(self) -> Starlette: return a2a_app - def to_fastapi_app(self) -> FastAPI: + def to_fastapi_app(self, *, app_kwargs: dict[str, Any] | None = None) -> FastAPI: """Create a FastAPI application for serving this agent via HTTP. Automatically handles path-based mounting if a mount path was derived from the http_url parameter. + Args: + app_kwargs: Additional keyword arguments to pass to the FastAPI constructor. + Returns: FastAPI: A FastAPI application configured to serve this agent. """ - a2a_app = A2AFastAPIApplication(agent_card=self.public_agent_card, http_handler=self.request_handler).build() + a2a_app = A2AFastAPIApplication(agent_card=self.public_agent_card, http_handler=self.request_handler).build( + **app_kwargs or {} + ) if self.mount_path: # Create parent app and mount the A2A app at the specified path @@ -237,12 +255,25 @@ def serve( port: The port number to bind the server to. Defaults to 9000. **kwargs: Additional keyword arguments to pass to uvicorn.run. """ + # Update host/port if overridden, and recalculate URLs if http_url wasn't explicitly set + if host is not None: + self.host = host + if port is not None: + self.port = port + + if host is not None or port is not None: + # Only update the URL if it wasn't explicitly set via http_url parameter + # (i.e., if the URL was auto-generated from host/port in __init__) + if not self._http_url_explicit: + self.public_base_url = f"http://{self.host}:{self.port}" + self.http_url = f"{self.public_base_url}/" + try: logger.info("Starting Strands A2A server...") if app_type == "fastapi": - uvicorn.run(self.to_fastapi_app(), host=host or self.host, port=port or self.port, **kwargs) + uvicorn.run(self.to_fastapi_app(), host=self.host, port=self.port, **kwargs) else: - uvicorn.run(self.to_starlette_app(), host=host or self.host, port=port or self.port, **kwargs) + uvicorn.run(self.to_starlette_app(), host=self.host, port=self.port, **kwargs) except KeyboardInterrupt: logger.warning("Strands A2A server shutdown requested (KeyboardInterrupt).") except Exception: diff --git a/src/strands/multiagent/base.py b/src/strands/multiagent/base.py index f163d05b5..14c4d0d14 100644 --- a/src/strands/multiagent/base.py +++ b/src/strands/multiagent/base.py @@ -6,12 +6,14 @@ import logging import warnings from abc import ABC, abstractmethod +from collections.abc import AsyncIterator, Mapping from dataclasses import dataclass, field from enum import Enum -from typing import Any, AsyncIterator, Mapping, Union +from typing import Any, Union from .._async import run_async from ..agent import AgentResult +from ..hooks.registry import HookCallback from ..interrupt import Interrupt from ..types.event_loop import Metrics, Usage from ..types.multiagent import MultiAgentInput @@ -95,7 +97,7 @@ def from_dict(cls, data: dict[str, Any]) -> "NodeResult": raise TypeError("NodeResult.from_dict: missing 'result'") raw = data["result"] - result: Union[AgentResult, "MultiAgentResult", Exception] + result: AgentResult | MultiAgentResult | Exception if isinstance(raw, dict) and raw.get("type") == "agent_result": result = AgentResult.from_dict(raw) elif isinstance(raw, dict) and raw.get("type") == "exception": @@ -253,6 +255,20 @@ def deserialize_state(self, payload: dict[str, Any]) -> None: """Restore orchestrator state from a session dict.""" raise NotImplementedError + def add_hook(self, callback: HookCallback, event_type: type | list[type] | None = None) -> None: + """Register a hook callback with the orchestrator. + + Subclasses that support hooks should override this method to register + the callback with their hook registry. + + Args: + callback: The callback function to invoke when events of this type occur. + event_type: The class type(s) of events this callback should handle. + Can be a single type, a list of types, or None to infer from + the callback's first parameter type hint. + """ + raise NotImplementedError(f"{type(self).__name__} must implement add_hook() to support plugins") + def _parse_trace_attributes( self, attributes: Mapping[str, AttributeValue] | None = None ) -> dict[str, AttributeValue]: diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 6156d332c..146a31563 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -18,27 +18,33 @@ import copy import logging import time +from collections.abc import AsyncIterator, Callable, Mapping from dataclasses import dataclass, field -from typing import Any, AsyncIterator, Callable, Mapping, Optional, Tuple, cast +from typing import Any, cast from opentelemetry import trace as trace_api from .._async import run_async from ..agent import Agent +from ..agent.base import AgentBase from ..agent.state import AgentState -from ..experimental.hooks.multiagent import ( +from ..hooks.events import ( AfterMultiAgentInvocationEvent, AfterNodeCallEvent, BeforeMultiAgentInvocationEvent, BeforeNodeCallEvent, MultiAgentInitializedEvent, ) -from ..hooks import HookProvider, HookRegistry +from ..hooks.registry import HookCallback, HookProvider, HookRegistry +from ..interrupt import Interrupt, _InterruptState +from ..plugins.multiagent_plugin import MultiAgentPlugin +from ..plugins.multiagent_registry import _MultiAgentPluginRegistry from ..session import SessionManager from ..telemetry import get_tracer from ..types._events import ( MultiAgentHandoffEvent, MultiAgentNodeCancelEvent, + MultiAgentNodeInterruptEvent, MultiAgentNodeStartEvent, MultiAgentNodeStopEvent, MultiAgentNodeStreamEvent, @@ -47,6 +53,7 @@ from ..types.content import ContentBlock, Messages from ..types.event_loop import Metrics, Usage from ..types.multiagent import MultiAgentInput +from ..types.session import decode_bytes_values, encode_bytes_values from ..types.traces import AttributeValue from .base import MultiAgentBase, MultiAgentResult, NodeResult, Status @@ -63,10 +70,15 @@ class GraphState: status: Current execution status of the graph. completed_nodes: Set of nodes that have completed execution. failed_nodes: Set of nodes that failed during execution. + interrupted_nodes: Set of nodes that user interrupted during execution. execution_order: List of nodes in the order they were executed. task: The original input prompt/query provided to the graph execution. This represents the actual work to be performed by the graph as a whole. Entry point nodes receive this task as their input if they have no dependencies. + start_time: Timestamp when the current invocation started. + Resets on each invocation, even when resuming from interrupt. + execution_time: Execution time of current invocation in milliseconds. + Excludes time spent waiting for interrupt responses. """ # Task (with default empty string) @@ -76,6 +88,7 @@ class GraphState: status: Status = Status.PENDING completed_nodes: set["GraphNode"] = field(default_factory=set) failed_nodes: set["GraphNode"] = field(default_factory=set) + interrupted_nodes: set["GraphNode"] = field(default_factory=set) execution_order: list["GraphNode"] = field(default_factory=list) start_time: float = field(default_factory=time.time) @@ -90,14 +103,14 @@ class GraphState: # Graph structure info total_nodes: int = 0 - edges: list[Tuple["GraphNode", "GraphNode"]] = field(default_factory=list) + edges: list[tuple["GraphNode", "GraphNode"]] = field(default_factory=list) entry_points: list["GraphNode"] = field(default_factory=list) def should_continue( self, - max_node_executions: Optional[int], - execution_timeout: Optional[float], - ) -> Tuple[bool, str]: + max_node_executions: int | None, + execution_timeout: float | None, + ) -> tuple[bool, str]: """Check if the graph should continue execution. Returns: (should_continue, reason) @@ -108,7 +121,7 @@ def should_continue( # Check timeout (only if set) if execution_timeout is not None: - elapsed = time.time() - self.start_time + elapsed = self.execution_time / 1000 + time.time() - self.start_time if elapsed > execution_timeout: return False, f"Execution timed out: {execution_timeout}s" @@ -122,8 +135,9 @@ class GraphResult(MultiAgentResult): total_nodes: int = 0 completed_nodes: int = 0 failed_nodes: int = 0 + interrupted_nodes: int = 0 execution_order: list["GraphNode"] = field(default_factory=list) - edges: list[Tuple["GraphNode", "GraphNode"]] = field(default_factory=list) + edges: list[tuple["GraphNode", "GraphNode"]] = field(default_factory=list) entry_points: list["GraphNode"] = field(default_factory=list) @@ -148,22 +162,17 @@ def should_traverse(self, state: GraphState) -> bool: @dataclass class GraphNode: - """Represents a node in the graph. - - The execution_status tracks the node's lifecycle within graph orchestration: - - PENDING: Node hasn't started executing yet - - EXECUTING: Node is currently running - - COMPLETED/FAILED: Node finished executing (regardless of result quality) - """ + """Represents a node in the graph.""" node_id: str - executor: Agent | MultiAgentBase + executor: AgentBase | MultiAgentBase dependencies: set["GraphNode"] = field(default_factory=set) execution_status: Status = Status.PENDING result: NodeResult | None = None execution_time: int = 0 _initial_messages: Messages = field(default_factory=list, init=False) _initial_state: AgentState = field(default_factory=AgentState, init=False) + _initial_model_state: dict[str, Any] = field(default_factory=dict, init=False) def __post_init__(self) -> None: """Capture initial executor state after initialization.""" @@ -174,6 +183,9 @@ def __post_init__(self) -> None: if hasattr(self.executor, "state") and hasattr(self.executor.state, "get"): self._initial_state = AgentState(self.executor.state.get()) + if hasattr(self.executor, "_model_state"): + self._initial_model_state = copy.deepcopy(self.executor._model_state) + def reset_executor_state(self) -> None: """Reset GraphNode executor state to initial state when graph was created. @@ -186,6 +198,9 @@ def reset_executor_state(self) -> None: if hasattr(self.executor, "state"): self.executor.state = AgentState(self._initial_state.get()) + if hasattr(self.executor, "_model_state"): + self.executor._model_state = copy.deepcopy(self._initial_model_state) + # Reset execution status self.execution_status = Status.PENDING self.result = None @@ -202,7 +217,7 @@ def __eq__(self, other: Any) -> bool: def _validate_node_executor( - executor: Agent | MultiAgentBase, existing_nodes: dict[str, GraphNode] | None = None + executor: AgentBase | MultiAgentBase, existing_nodes: dict[str, GraphNode] | None = None ) -> None: """Validate a node executor for graph compatibility. @@ -233,16 +248,17 @@ def __init__(self) -> None: self.entry_points: set[GraphNode] = set() # Configuration options - self._max_node_executions: Optional[int] = None - self._execution_timeout: Optional[float] = None - self._node_timeout: Optional[float] = None + self._max_node_executions: int | None = None + self._execution_timeout: float | None = None + self._node_timeout: float | None = None self._reset_on_revisit: bool = False self._id: str = _DEFAULT_GRAPH_ID - self._session_manager: Optional[SessionManager] = None - self._hooks: Optional[list[HookProvider]] = None + self._session_manager: SessionManager | None = None + self._hooks: list[HookProvider] | None = None + self._plugins: list[MultiAgentPlugin] | None = None - def add_node(self, executor: Agent | MultiAgentBase, node_id: str | None = None) -> GraphNode: - """Add an Agent or MultiAgentBase instance as a node to the graph.""" + def add_node(self, executor: AgentBase | MultiAgentBase, node_id: str | None = None) -> GraphNode: + """Add an AgentBase or MultiAgentBase instance as a node to the graph.""" _validate_node_executor(executor, self.nodes) # Auto-generate node_id if not provided @@ -357,6 +373,15 @@ def set_hook_providers(self, hooks: list[HookProvider]) -> "GraphBuilder": self._hooks = hooks return self + def set_plugins(self, plugins: list[MultiAgentPlugin]) -> "GraphBuilder": + """Set plugins for the graph. + + Args: + plugins: List of multi-agent plugins for extending graph behavior + """ + self._plugins = plugins + return self + def build(self) -> "Graph": """Build and validate the graph with configured settings.""" if not self.nodes: @@ -385,6 +410,7 @@ def build(self) -> "Graph": session_manager=self._session_manager, hooks=self._hooks, id=self._id, + plugins=self._plugins, ) def _validate_graph(self) -> None: @@ -408,14 +434,15 @@ def __init__( nodes: dict[str, GraphNode], edges: set[GraphEdge], entry_points: set[GraphNode], - max_node_executions: Optional[int] = None, - execution_timeout: Optional[float] = None, - node_timeout: Optional[float] = None, + max_node_executions: int | None = None, + execution_timeout: float | None = None, + node_timeout: float | None = None, reset_on_revisit: bool = False, - session_manager: Optional[SessionManager] = None, - hooks: Optional[list[HookProvider]] = None, + session_manager: SessionManager | None = None, + hooks: list[HookProvider] | None = None, id: str = _DEFAULT_GRAPH_ID, - trace_attributes: Optional[Mapping[str, AttributeValue]] = None, + trace_attributes: Mapping[str, AttributeValue] | None = None, + plugins: list[MultiAgentPlugin] | None = None, ) -> None: """Initialize Graph with execution limits and reset behavior. @@ -431,6 +458,7 @@ def __init__( hooks: List of hook providers for monitoring and extending graph execution behavior (default: None) id: Unique graph id (default: None) trace_attributes: Custom trace attributes to apply to the agent's trace span (default: None) + plugins: List of multi-agent plugins for extending graph behavior (default: None) """ super().__init__() @@ -445,6 +473,7 @@ def __init__( self.node_timeout = node_timeout self.reset_on_revisit = reset_on_revisit self.state = GraphState() + self._interrupt_state = _InterruptState() self.tracer = get_tracer() self.trace_attributes: dict[str, AttributeValue] = self._parse_trace_attributes(trace_attributes) self.session_manager = session_manager @@ -455,12 +484,28 @@ def __init__( for hook in hooks: self.hooks.add_hook(hook) + self._plugin_registry = _MultiAgentPluginRegistry(self) + if plugins: + for plugin in plugins: + self._plugin_registry.add_and_init(plugin) + self._resume_next_nodes: list[GraphNode] = [] self._resume_from_session = False self.id = id run_async(lambda: self.hooks.invoke_callbacks_async(MultiAgentInitializedEvent(self))) + def add_hook(self, callback: HookCallback, event_type: type | list[type] | None = None) -> None: + """Register a hook callback with the graph. + + Args: + callback: The callback function to invoke when events of this type occur. + event_type: The class type(s) of events this callback should handle. + Can be a single type, a list of types, or None to infer from + the callback's first parameter type hint. + """ + self.hooks.add_callback(event_type, callback) + def __call__( self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any ) -> GraphResult: @@ -519,6 +564,8 @@ async def stream_async( - multi_agent_node_stop: When a node stops execution - result: Final graph result """ + self._interrupt_state.resume(task) + if invocation_state is None: invocation_state = {} @@ -528,7 +575,7 @@ async def stream_async( # Initialize state start_time = time.time() - if not self._resume_from_session: + if not self._resume_from_session and not self._interrupt_state.activated: # Initialize state self.state = GraphState( status=Status.EXECUTING, @@ -544,6 +591,8 @@ async def stream_async( span = self.tracer.start_multiagent_span(task, "graph", custom_trace_attributes=self.trace_attributes) with trace_api.use_span(span, end_on_exit=True): + interrupts = [] + try: logger.debug( "max_node_executions=<%s>, execution_timeout=<%s>s, node_timeout=<%s>s | graph execution config", @@ -553,6 +602,9 @@ async def stream_async( ) async for event in self._execute_graph(invocation_state): + if isinstance(event, MultiAgentNodeInterruptEvent): + interrupts.extend(event.interrupts) + yield event.as_dict() # Set final status based on execution results @@ -564,7 +616,7 @@ async def stream_async( logger.debug("status=<%s> | graph execution completed", self.state.status) # Yield final result (consistent with Agent's AgentResultEvent format) - result = self._build_result() + result = self._build_result(interrupts) # Use the same event format as Agent for consistency yield MultiAgentResultEvent(result=result).as_dict() @@ -574,7 +626,7 @@ async def stream_async( self.state.status = Status.FAILED raise finally: - self.state.execution_time = round((time.time() - start_time) * 1000) + self.state.execution_time += round((time.time() - start_time) * 1000) await self.hooks.invoke_callbacks_async(AfterMultiAgentInvocationEvent(self)) self._resume_from_session = False self._resume_next_nodes.clear() @@ -591,9 +643,59 @@ def _validate_graph(self, nodes: dict[str, GraphNode]) -> None: # Validate Agent-specific constraints for each node _validate_node_executor(node.executor) + def _activate_interrupt( + self, node: GraphNode, interrupts: list[Interrupt], from_hook: bool = False + ) -> MultiAgentNodeInterruptEvent: + """Activate the interrupt state. + + Args: + node: The interrupted node. + interrupts: The interrupts raised by the user. + from_hook: Whether the interrupt originated from a hook (e.g., BeforeNodeCallEvent). + + Returns: + MultiAgentNodeInterruptEvent + """ + logger.debug("node=<%s>, from_hook=<%s> | node interrupted", node.node_id, from_hook) + + node.execution_status = Status.INTERRUPTED + + self.state.status = Status.INTERRUPTED + self.state.interrupted_nodes.add(node) + + self._interrupt_state.interrupts.update({interrupt.id: interrupt for interrupt in interrupts}) + self._interrupt_state.activate() + + self._interrupt_state.context[node.node_id] = { + "from_hook": from_hook, + "interrupt_ids": [interrupt.id for interrupt in interrupts], + } + + if isinstance(node.executor, Agent): + self._interrupt_state.context[node.node_id].update( + { + "interrupt_state": node.executor._interrupt_state.to_dict(), + "state": node.executor.state.get(), + "messages": node.executor.messages, + "model_state": node.executor._model_state, + } + ) + + return MultiAgentNodeInterruptEvent(node.node_id, interrupts) + async def _execute_graph(self, invocation_state: dict[str, Any]) -> AsyncIterator[Any]: """Execute graph and yield TypedEvent objects.""" - ready_nodes = self._resume_next_nodes if self._resume_from_session else list(self.entry_points) + if self._interrupt_state.activated: + ready_nodes = [self.nodes[node_id] for node_id in self._interrupt_state.context["completed_nodes"]] + ready_nodes.extend(self.state.interrupted_nodes) + + self.state.interrupted_nodes.clear() + + elif self._resume_from_session: + ready_nodes = self._resume_next_nodes + + else: + ready_nodes = list(self.entry_points) while ready_nodes: # Check execution limits before continuing @@ -613,6 +715,14 @@ async def _execute_graph(self, invocation_state: dict[str, Any]) -> AsyncIterato async for event in self._execute_nodes_parallel(current_batch, invocation_state): yield event + if self.state.status == Status.INTERRUPTED: + self._interrupt_state.context["completed_nodes"] = [ + node.node_id for node in current_batch if node.execution_status == Status.COMPLETED + ] + return + + self._interrupt_state.deactivate() + # Find newly ready nodes after batch execution # We add all nodes in current batch as completed batch, # because a failure would throw exception and code would not make it here @@ -641,6 +751,9 @@ async def _execute_nodes_parallel( Uses a shared queue where each node's stream runs independently and pushes events as they occur, enabling true real-time event propagation without round-robin delays. """ + if self._interrupt_state.activated: + nodes = [node for node in nodes if node.execution_status == Status.INTERRUPTED] + event_queue: asyncio.Queue[Any | None | Exception] = asyncio.Queue() # Start all node streams as independent tasks @@ -754,9 +867,16 @@ async def _handle_node_timeout(self, node: GraphNode, event_queue: asyncio.Queue return timeout_exception def _find_newly_ready_nodes(self, completed_batch: list["GraphNode"]) -> list["GraphNode"]: - """Find nodes that became ready after the last execution.""" + """Find nodes that became ready after the last execution. + + Only evaluates destination nodes of outbound edges from the completed batch, + instead of iterating over all nodes in the graph. + """ + # Collect unique candidate nodes reachable from the completed batch + candidates = {edge.to_node for edge in self.edges if edge.from_node in completed_batch} + newly_ready = [] - for _node_id, node in self.nodes.items(): + for node in candidates: if self._is_node_ready_with_conditions(node, completed_batch): newly_ready.append(node) return newly_ready @@ -792,17 +912,20 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) logger.debug("node_id=<%s> | executing node", node.node_id) # Emit node start event - start_event = MultiAgentNodeStartEvent( - node_id=node.node_id, node_type="agent" if isinstance(node.executor, Agent) else "multiagent" - ) + node_type = "multiagent" if isinstance(node.executor, MultiAgentBase) else "agent" + start_event = MultiAgentNodeStartEvent(node_id=node.node_id, node_type=node_type) yield start_event - before_event, _ = await self.hooks.invoke_callbacks_async( + before_event, interrupts = await self.hooks.invoke_callbacks_async( BeforeNodeCallEvent(self, node.node_id, invocation_state) ) start_time = time.time() try: + if interrupts: + yield self._activate_interrupt(node, interrupts, from_hook=True) + return + if before_event.cancel_node: cancel_message = ( before_event.cancel_node if isinstance(before_event.cancel_node, str) else "node cancelled by user" @@ -833,14 +956,15 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) node_result = NodeResult( result=multi_agent_result, execution_time=multi_agent_result.execution_time, - status=Status.COMPLETED, + status=multi_agent_result.status, accumulated_usage=multi_agent_result.accumulated_usage, accumulated_metrics=multi_agent_result.accumulated_metrics, execution_count=multi_agent_result.execution_count, + interrupts=multi_agent_result.interrupts, ) - elif isinstance(node.executor, Agent): - # For agents, stream their events and collect result + elif isinstance(node.executor, AgentBase): + # For AgentBase implementations (Agent, A2AAgent, etc.), stream events and collect result agent_response = None async for event in node.executor.stream_async(node_input, invocation_state=invocation_state): # Forward agent events with node context @@ -854,13 +978,6 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) if agent_response is None: raise ValueError(f"Node '{node.node_id}' did not produce a result event") - # Check for interrupt (from main branch) - if agent_response.stop_reason == "interrupt": - node.executor.messages.pop() # remove interrupted tool use message - node.executor._interrupt_state.deactivate() - - raise RuntimeError("user raised interrupt from agent | interrupts are not yet supported in graphs") - # Extract metrics with defaults response_metrics = getattr(agent_response, "metrics", None) usage = getattr( @@ -868,21 +985,31 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) ) metrics = getattr(response_metrics, "accumulated_metrics", Metrics(latencyMs=0)) + # Handle stop_reason and interrupts (use getattr for AgentBase compatibility) + stop_reason = getattr(agent_response, "stop_reason", "end_turn") + interrupts = getattr(agent_response, "interrupts", None) or [] + node_result = NodeResult( result=agent_response, execution_time=round((time.time() - start_time) * 1000), - status=Status.COMPLETED, + status=Status.INTERRUPTED if stop_reason == "interrupt" else Status.COMPLETED, accumulated_usage=usage, accumulated_metrics=metrics, execution_count=1, + interrupts=interrupts, ) else: raise ValueError(f"Node '{node.node_id}' of type '{type(node.executor)}' is not supported") - # Mark as completed - node.execution_status = Status.COMPLETED node.result = node_result node.execution_time = node_result.execution_time + + if node_result.status == Status.INTERRUPTED: + yield self._activate_interrupt(node, node_result.interrupts) + return + + # Mark as completed + node.execution_status = Status.COMPLETED self.state.completed_nodes.add(node) self.state.results[node.node_id] = node_result self.state.execution_order.append(node) @@ -936,7 +1063,8 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) raise finally: - await self.hooks.invoke_callbacks_async(AfterNodeCallEvent(self, node.node_id, invocation_state)) + if node.execution_status != Status.INTERRUPTED: + await self.hooks.invoke_callbacks_async(AfterNodeCallEvent(self, node.node_id, invocation_state)) def _accumulate_metrics(self, node_result: NodeResult) -> None: """Accumulate metrics from a node result.""" @@ -949,6 +1077,8 @@ def _accumulate_metrics(self, node_result: NodeResult) -> None: def _build_node_input(self, node: GraphNode) -> list[ContentBlock]: """Build input text for a node based on dependency outputs. + If resuming from an interrupt, return user responses. + Example formatted output: ``` Original Task: Analyze the quarterly sales data and create a summary report @@ -963,6 +1093,30 @@ def _build_node_input(self, node: GraphNode) -> list[ContentBlock]: - Agent: Data validation complete. All records verified, no anomalies detected. ``` """ + if self._interrupt_state.activated: + context = self._interrupt_state.context + if node.node_id in context: + node_context = context[node.node_id] + + # Only route responses if the interrupt originated from the node's execution + if not node_context["from_hook"]: + # Filter responses to only those for this node's interrupts + node_responses = [ + response + for response in context["responses"] + if response["interruptResponse"]["interruptId"] in node_context["interrupt_ids"] + ] + + # Restore Agent-specific state for interrupt resumption + # Only Agent (not generic AgentBase) supports interrupt state restoration + if isinstance(node.executor, Agent): + node.executor.messages = node_context["messages"] + node.executor.state = AgentState(node_context["state"]) + node.executor._interrupt_state = _InterruptState.from_dict(node_context["interrupt_state"]) + node.executor._model_state = node_context.get("model_state", {}) + + return node_responses + # Get satisfied dependencies dependency_results = {} for edge in self.edges: @@ -1006,8 +1160,15 @@ def _build_node_input(self, node: GraphNode) -> list[ContentBlock]: return node_input - def _build_result(self) -> GraphResult: - """Build graph result from current state.""" + def _build_result(self, interrupts: list[Interrupt]) -> GraphResult: + """Build graph result from current state. + + Args: + interrupts: List of interrupts collected during execution. + + Returns: + GraphResult with current state. + """ return GraphResult( status=self.state.status, results=self.state.results, @@ -1018,9 +1179,11 @@ def _build_result(self) -> GraphResult: total_nodes=self.state.total_nodes, completed_nodes=len(self.state.completed_nodes), failed_nodes=len(self.state.failed_nodes), + interrupted_nodes=len(self.state.interrupted_nodes), execution_order=self.state.execution_order, edges=self.state.edges, entry_points=self.state.entry_points, + interrupts=interrupts, ) def serialize_state(self) -> dict[str, Any]: @@ -1033,10 +1196,14 @@ def serialize_state(self) -> dict[str, Any]: "status": self.state.status.value, "completed_nodes": [n.node_id for n in self.state.completed_nodes], "failed_nodes": [n.node_id for n in self.state.failed_nodes], + "interrupted_nodes": [n.node_id for n in self.state.interrupted_nodes], "node_results": {k: v.to_dict() for k, v in (self.state.results or {}).items()}, "next_nodes_to_execute": next_nodes, - "current_task": self.state.task, + "current_task": encode_bytes_values(self.state.task), "execution_order": [n.node_id for n in self.state.execution_order], + "_internal_state": { + "interrupt_state": self._interrupt_state.to_dict(), + }, } def deserialize_state(self, payload: dict[str, Any]) -> None: @@ -1052,6 +1219,10 @@ def deserialize_state(self, payload: dict[str, Any]) -> None: payload: Dictionary containing persisted state data including status, completed nodes, results, and next nodes to execute. """ + if "_internal_state" in payload: + internal_state = payload["_internal_state"] + self._interrupt_state = _InterruptState.from_dict(internal_state["interrupt_state"]) + if not payload.get("next_nodes_to_execute"): # Reset all nodes for node in self.nodes.values(): @@ -1098,17 +1269,27 @@ def _from_dict(self, payload: dict[str, Any]) -> None: self.state.failed_nodes = set( self.nodes[node_id] for node_id in (payload.get("failed_nodes") or []) if node_id in self.nodes ) + for node in self.state.failed_nodes: + node.execution_status = Status.FAILED + + self.state.interrupted_nodes = set( + self.nodes[node_id] for node_id in (payload.get("interrupted_nodes") or []) if node_id in self.nodes + ) + for node in self.state.interrupted_nodes: + node.execution_status = Status.INTERRUPTED - # Restore completed nodes from persisted data - completed_node_ids = payload.get("completed_nodes") or [] - self.state.completed_nodes = {self.nodes[node_id] for node_id in completed_node_ids if node_id in self.nodes} + self.state.completed_nodes = set( + self.nodes[node_id] for node_id in (payload.get("completed_nodes") or []) if node_id in self.nodes + ) + for node in self.state.completed_nodes: + node.execution_status = Status.COMPLETED # Execution order (only nodes that still exist) order_node_ids = payload.get("execution_order") or [] self.state.execution_order = [self.nodes[node_id] for node_id in order_node_ids if node_id in self.nodes] # Task - self.state.task = payload.get("current_task", self.state.task) + self.state.task = decode_bytes_values(payload.get("current_task", self.state.task)) # next nodes to execute next_nodes = [self.nodes[nid] for nid in (payload.get("next_nodes_to_execute") or []) if nid in self.nodes] diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index 7eec49649..2eeb38694 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -17,24 +17,28 @@ import copy import json import logging +import sys import time +from collections.abc import AsyncIterator, Callable, Mapping from dataclasses import dataclass, field -from typing import Any, AsyncIterator, Callable, Mapping, Optional, Tuple, cast +from typing import Any, Optional, cast from opentelemetry import trace as trace_api from .._async import run_async from ..agent import Agent from ..agent.state import AgentState -from ..experimental.hooks.multiagent import ( +from ..hooks.events import ( AfterMultiAgentInvocationEvent, AfterNodeCallEvent, BeforeMultiAgentInvocationEvent, BeforeNodeCallEvent, MultiAgentInitializedEvent, ) -from ..hooks import HookProvider, HookRegistry +from ..hooks.registry import HookCallback, HookProvider, HookRegistry from ..interrupt import Interrupt, _InterruptState +from ..plugins.multiagent_plugin import MultiAgentPlugin +from ..plugins.multiagent_registry import _MultiAgentPluginRegistry from ..session import SessionManager from ..telemetry import get_tracer from ..tools.decorator import tool @@ -50,6 +54,7 @@ from ..types.content import ContentBlock, Messages from ..types.event_loop import Metrics, Usage from ..types.multiagent import MultiAgentInput +from ..types.session import decode_bytes_values, encode_bytes_values from ..types.traces import AttributeValue from .base import MultiAgentBase, MultiAgentResult, NodeResult, Status @@ -67,12 +72,14 @@ class SwarmNode: swarm: Optional["Swarm"] = None _initial_messages: Messages = field(default_factory=list, init=False) _initial_state: AgentState = field(default_factory=AgentState, init=False) + _initial_model_state: dict[str, Any] = field(default_factory=dict, init=False) def __post_init__(self) -> None: """Capture initial executor state after initialization.""" # Deep copy the initial messages and state to preserve them self._initial_messages = copy.deepcopy(self.executor.messages) self._initial_state = AgentState(self.executor.state.get()) + self._initial_model_state = copy.deepcopy(self.executor._model_state) def __hash__(self) -> int: """Return hash for SwarmNode based on node_id.""" @@ -102,10 +109,12 @@ def reset_executor_state(self) -> None: self.executor.messages = context["messages"] self.executor.state = AgentState(context["state"]) self.executor._interrupt_state = _InterruptState.from_dict(context["interrupt_state"]) + self.executor._model_state = context.get("model_state", {}) return self.executor.messages = copy.deepcopy(self._initial_messages) self.executor.state = AgentState(self._initial_state.get()) + self.executor._model_state = copy.deepcopy(self._initial_model_state) @dataclass @@ -184,7 +193,7 @@ def should_continue( execution_timeout: float, repetitive_handoff_detection_window: int, repetitive_handoff_min_unique_agents: int, - ) -> Tuple[bool, str]: + ) -> tuple[bool, str]: """Check if the swarm should continue. Returns: (should_continue, reason) @@ -198,7 +207,7 @@ def should_continue( return False, f"Max iterations reached: {max_iterations}" # Check timeout - elapsed = time.time() - self.start_time + elapsed = self.execution_time / 1000 + time.time() - self.start_time if elapsed > execution_timeout: return False, f"Execution timed out: {execution_timeout}s" @@ -239,10 +248,11 @@ def __init__( node_timeout: float = 300.0, repetitive_handoff_detection_window: int = 0, repetitive_handoff_min_unique_agents: int = 0, - session_manager: Optional[SessionManager] = None, - hooks: Optional[list[HookProvider]] = None, + session_manager: SessionManager | None = None, + hooks: list[HookProvider] | None = None, id: str = _DEFAULT_SWARM_ID, - trace_attributes: Optional[Mapping[str, AttributeValue]] = None, + trace_attributes: Mapping[str, AttributeValue] | None = None, + plugins: list[MultiAgentPlugin] | None = None, ) -> None: """Initialize Swarm with agents and configuration. @@ -261,6 +271,7 @@ def __init__( session_manager: Session manager for persisting graph state and execution history (default: None) hooks: List of hook providers for monitoring and extending graph execution behavior (default: None) trace_attributes: Custom trace attributes to apply to the agent's trace span (default: None) + plugins: List of multi-agent plugins for extending swarm behavior (default: None) """ super().__init__() self.id = id @@ -293,12 +304,28 @@ def __init__( if self.session_manager: self.hooks.add_hook(self.session_manager) + self._plugin_registry = _MultiAgentPluginRegistry(self) + if plugins: + for plugin in plugins: + self._plugin_registry.add_and_init(plugin) + self._resume_from_session = False self._setup_swarm(nodes) self._inject_swarm_tools() run_async(lambda: self.hooks.invoke_callbacks_async(MultiAgentInitializedEvent(self))) + def add_hook(self, callback: HookCallback, event_type: type | list[type] | None = None) -> None: + """Register a hook callback with the swarm. + + Args: + callback: The callback function to invoke when events of this type occur. + event_type: The class type(s) of events this callback should handle. + Can be a single type, a list of types, or None to infer from + the callback's first parameter type hint. + """ + self.hooks.add_callback(event_type, callback) + def __call__( self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any ) -> SwarmResult: @@ -405,7 +432,7 @@ async def stream_async( self.state.completion_status = Status.FAILED raise finally: - self.state.execution_time = round((time.time() - self.state.start_time) * 1000) + self.state.execution_time += round((time.time() - self.state.start_time) * 1000) await self.hooks.invoke_callbacks_async(AfterMultiAgentInvocationEvent(self, invocation_state)) self._resume_from_session = False @@ -433,28 +460,25 @@ async def _stream_with_timeout( Exception: If total execution time exceeds timeout """ if timeout is None: - # No timeout - just pass through async for event in async_generator: yield event + elif sys.version_info >= (3, 11): + try: + async with asyncio.timeout(timeout): + async for event in async_generator: + yield event + except asyncio.TimeoutError as err: + raise Exception(timeout_message) from err else: - # Track start time for total timeout - start_time = asyncio.get_event_loop().time() - - while True: - # Calculate remaining time from total timeout budget - elapsed = asyncio.get_event_loop().time() - start_time - remaining = timeout - elapsed - - if remaining <= 0: + # Python 3.10 fallback: timeout is only checked between yielded events. + # A generator that hangs mid-await won't be interrupted until the next event. + # Remove once Python 3.10 support is dropped (Oct 2026). + start_time = asyncio.get_running_loop().time() + async for event in async_generator: + elapsed = asyncio.get_running_loop().time() - start_time + if elapsed > timeout: raise Exception(timeout_message) - - try: - event = await asyncio.wait_for(async_generator.__anext__(), timeout=remaining) - yield event - except StopAsyncIteration: - break - except asyncio.TimeoutError as err: - raise Exception(timeout_message) from err + yield event def _setup_swarm(self, nodes: list[Agent]) -> None: """Initialize swarm configuration.""" @@ -695,6 +719,7 @@ def _activate_interrupt(self, node: SwarmNode, interrupts: list[Interrupt]) -> M "interrupt_state": node.executor._interrupt_state.to_dict(), "state": node.executor.state.get(), "messages": node.executor.messages, + "model_state": node.executor._model_state, } self._interrupt_state.interrupts.update({interrupt.id: interrupt for interrupt in interrupts}) @@ -781,9 +806,10 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato break finally: - await self.hooks.invoke_callbacks_async( - AfterNodeCallEvent(self, current_node.node_id, invocation_state) - ) + if self.state.completion_status != Status.INTERRUPTED: + await self.hooks.invoke_callbacks_async( + AfterNodeCallEvent(self, current_node.node_id, invocation_state) + ) logger.debug("node=<%s> | node execution completed", current_node.node_id) @@ -963,7 +989,7 @@ def serialize_state(self) -> dict[str, Any]: "node_history": [n.node_id for n in self.state.node_history], "node_results": {k: v.to_dict() for k, v in self.state.results.items()}, "next_nodes_to_execute": next_nodes, - "current_task": self.state.task, + "current_task": encode_bytes_values(self.state.task), "context": { "shared_context": getattr(self.state.shared_context, "context", {}) or {}, "handoff_node": self.state.handoff_node.node_id if self.state.handoff_node else None, @@ -1026,7 +1052,7 @@ def _from_dict(self, payload: dict[str, Any]) -> None: logger.exception("Failed to hydrate NodeResult for node_id=%s; skipping.", node_id) raise self.state.results = results - self.state.task = payload.get("current_task", self.state.task) + self.state.task = decode_bytes_values(payload.get("current_task", self.state.task)) next_node_ids = payload.get("next_nodes_to_execute") or [] if next_node_ids: diff --git a/src/strands/plugins/__init__.py b/src/strands/plugins/__init__.py new file mode 100644 index 000000000..7a3d5fa17 --- /dev/null +++ b/src/strands/plugins/__init__.py @@ -0,0 +1,16 @@ +"""Plugin system for extending agent and orchestrator functionality. + +This module provides a composable mechanism for building objects that can +extend agent and multi-agent orchestrator behavior through automatic hook +and tool registration. +""" + +from .decorator import hook +from .multiagent_plugin import MultiAgentPlugin +from .plugin import Plugin + +__all__ = [ + "MultiAgentPlugin", + "Plugin", + "hook", +] diff --git a/src/strands/plugins/_discovery.py b/src/strands/plugins/_discovery.py new file mode 100644 index 000000000..eda955030 --- /dev/null +++ b/src/strands/plugins/_discovery.py @@ -0,0 +1,103 @@ +"""Shared utility for discovering decorated methods on plugin instances. + +This module provides helper functions used by both Plugin and MultiAgentPlugin +to scan for @hook (and optionally @tool) decorated methods, and shared registry +utilities for plugin initialization and hook registration. +""" + +import inspect +import logging +from collections.abc import Awaitable, Callable +from typing import Any, cast + +from .._async import run_async +from ..hooks.registry import HookCallback +from ..tools.decorator import DecoratedFunctionTool + +logger = logging.getLogger(__name__) + + +def _discover_methods(instance: object, plugin_name: str, predicate: Callable[[object], bool], label: str) -> list[Any]: + """Scan an instance's class hierarchy for methods matching a predicate. + + Walks the MRO in reverse so parent class methods come first, but child + overrides win (only the child's version is included). + + Args: + instance: The plugin instance to scan. + plugin_name: The plugin name (used for debug logging). + predicate: Function that returns True for attributes to collect. + label: Label for debug logging (e.g., "hook", "tool"). + + Returns: + List of matching bound methods/descriptors in declaration order. + """ + results: list[Any] = [] + seen: set[str] = set() + + for cls in reversed(type(instance).__mro__): + for attr_name in cls.__dict__: + if attr_name in seen: + continue + seen.add(attr_name) + + try: + bound = getattr(instance, attr_name) + except Exception: + continue + + if predicate(bound): + results.append(bound) + logger.debug("plugin=<%s>, %s=<%s> | discovered", plugin_name, label, attr_name) + + return results + + +def discover_hooks(instance: object, plugin_name: str) -> list[HookCallback]: + """Scan an instance's class hierarchy for @hook decorated methods. + + Args: + instance: The plugin instance to scan. + plugin_name: The plugin name (used for debug logging). + + Returns: + List of bound hook callback methods in declaration order. + """ + return _discover_methods( + instance, + plugin_name, + predicate=lambda bound: hasattr(bound, "_hook_event_types") and callable(bound), + label="hook", + ) + + +def discover_tools(instance: object, plugin_name: str) -> list[DecoratedFunctionTool]: + """Scan an instance's class hierarchy for @tool decorated methods. + + Args: + instance: The plugin instance to scan. + plugin_name: The plugin name (used for debug logging). + + Returns: + List of DecoratedFunctionTool instances in declaration order. + """ + return _discover_methods( + instance, + plugin_name, + predicate=lambda bound: isinstance(bound, DecoratedFunctionTool), + label="tool", + ) + + +def call_init_method(init_method: Callable[..., Any], target: Any) -> None: + """Call a plugin's init method, handling both sync and async implementations. + + Args: + init_method: The init_agent or init_multi_agent method to call. + target: The agent or orchestrator instance to pass to the init method. + """ + if inspect.iscoroutinefunction(init_method): + async_init = cast(Callable[..., Awaitable[None]], init_method) + run_async(lambda: async_init(target)) + else: + init_method(target) diff --git a/src/strands/plugins/decorator.py b/src/strands/plugins/decorator.py new file mode 100644 index 000000000..fc6f75e5b --- /dev/null +++ b/src/strands/plugins/decorator.py @@ -0,0 +1,69 @@ +"""Hook decorator for Plugin methods. + +Marks methods as hook callbacks for automatic registration when the plugin +is attached to an agent. Infers event types from type hints and supports +union types for multiple events. + +Example: + ```python + class MyPlugin(Plugin): + @hook + def on_model_call(self, event: BeforeModelCallEvent): + print(event) + ``` +""" + +from collections.abc import Callable +from typing import Generic, cast, overload + +from ..hooks._type_inference import infer_event_types +from ..hooks.registry import HookCallback, TEvent + + +class _WrappedHookCallable(HookCallback, Generic[TEvent]): + """Wrapped version of HookCallback that includes a `_hook_event_types` attribute.""" + + _hook_event_types: list[type[TEvent]] + + +# Handle @hook +@overload +def hook(__func: HookCallback) -> _WrappedHookCallable: ... + + +# Handle @hook() +@overload +def hook() -> Callable[[HookCallback], _WrappedHookCallable]: ... + + +def hook( + func: HookCallback | None = None, +) -> _WrappedHookCallable | Callable[[HookCallback], _WrappedHookCallable]: + """Mark a method as a hook callback for automatic registration. + + Infers event type from the callback's type hint. Supports union types + for multiple events. Can be used as @hook or @hook(). + + Args: + func: The function to decorate. + + Returns: + The decorated function with hook metadata. + + Raises: + ValueError: If event type cannot be inferred from type hints. + """ + + def decorator(f: HookCallback[TEvent]) -> _WrappedHookCallable[TEvent]: + # Infer event types from type hints + event_types: list[type[TEvent]] = infer_event_types(f) + + # Store hook metadata on the function + f_wrapped = cast(_WrappedHookCallable, f) + f_wrapped._hook_event_types = event_types + + return f_wrapped + + if func is None: + return decorator + return decorator(func) diff --git a/src/strands/plugins/multiagent_plugin.py b/src/strands/plugins/multiagent_plugin.py new file mode 100644 index 000000000..89bd9e0e5 --- /dev/null +++ b/src/strands/plugins/multiagent_plugin.py @@ -0,0 +1,119 @@ +"""MultiAgentPlugin base class for extending multi-agent orchestrator functionality. + +This module defines the MultiAgentPlugin base class, which provides a composable way to +add behavior changes to multi-agent orchestrators (Swarm, Graph) through automatic hook +registration and custom initialization. + +MultiAgentPlugin is the orchestrator-level counterpart to Plugin (which targets individual agents). +A class can implement both Plugin and MultiAgentPlugin to provide functionality at both levels. +""" + +from abc import ABC, abstractmethod +from collections.abc import Awaitable +from typing import TYPE_CHECKING + +from ..hooks.registry import HookCallback +from ._discovery import discover_hooks + +if TYPE_CHECKING: + from ..multiagent.base import MultiAgentBase + + +class MultiAgentPlugin(ABC): + """Base class for objects that extend multi-agent orchestrator functionality. + + MultiAgentPlugins provide a composable way to add behavior changes to orchestrators + (Swarm, Graph). They support automatic discovery and registration of methods decorated + with @hook. + + Unlike agent-level Plugin, MultiAgentPlugin does not support @tool decorated methods + since orchestrators do not have tool registries. + + Attributes: + name: A stable string identifier for the plugin (must be provided by subclass) + hooks: Hooks attached to the orchestrator, auto-discovered from @hook decorated methods + + Example using decorators (recommended): + ```python + from strands.plugins import MultiAgentPlugin, hook + from strands.hooks import BeforeNodeCallEvent, AfterNodeCallEvent + + class MonitoringPlugin(MultiAgentPlugin): + name = "monitoring" + + @hook + def on_before_node(self, event: BeforeNodeCallEvent): + print(f"Node {event.node_id} starting") + + @hook + def on_after_node(self, event: AfterNodeCallEvent): + print(f"Node {event.node_id} completed") + ``` + + Example with custom initialization: + ```python + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + def init_multi_agent(self, orchestrator: MultiAgentBase) -> None: + # Custom initialization logic + pass + ``` + + Dual-use example (both agent and orchestrator): + ```python + from strands.plugins import Plugin, MultiAgentPlugin, hook + from strands.hooks import BeforeInvocationEvent, BeforeNodeCallEvent + + class ObservabilityPlugin(Plugin, MultiAgentPlugin): + name = "observability" + + @hook + def on_agent_invocation(self, event: BeforeInvocationEvent): + print("Agent invocation started") + + @hook + def on_node_call(self, event: BeforeNodeCallEvent): + print(f"Node {event.node_id} starting") + + def init_agent(self, agent): + pass # Agent-level setup + + def init_multi_agent(self, orchestrator): + pass # Orchestrator-level setup + ``` + """ + + @property + @abstractmethod + def name(self) -> str: + """A stable string identifier for the plugin.""" + ... + + def __init__(self) -> None: + """Initialize the plugin and discover decorated hook methods. + + Scans the class for methods decorated with @hook and stores references + for later registration when the plugin is attached to an orchestrator. + + Uses a guard to prevent double-discovery when used with multiple inheritance + (e.g., a class that inherits from both Plugin and MultiAgentPlugin). + """ + if not hasattr(self, "_hooks"): + self._hooks: list[HookCallback] = discover_hooks(self, self.name) + + @property + def hooks(self) -> list[HookCallback]: + """List of hooks the plugin provides, auto-discovered from @hook decorated methods.""" + return self._hooks + + def init_multi_agent(self, orchestrator: "MultiAgentBase") -> None | Awaitable[None]: + """Initialize the plugin with the orchestrator instance. + + Override this method to add custom initialization logic. Decorated + hooks are automatically registered by the plugin registry. + + Args: + orchestrator: The multi-agent orchestrator instance to initialize with. + """ + return None diff --git a/src/strands/plugins/multiagent_registry.py b/src/strands/plugins/multiagent_registry.py new file mode 100644 index 000000000..365c8f9c5 --- /dev/null +++ b/src/strands/plugins/multiagent_registry.py @@ -0,0 +1,113 @@ +"""MultiAgentPlugin registry for managing plugins attached to a multi-agent orchestrator. + +This module provides the _MultiAgentPluginRegistry class for tracking and managing +plugins that have been initialized with an orchestrator instance. +""" + +import logging +import weakref +from typing import TYPE_CHECKING + +from ._discovery import call_init_method +from .multiagent_plugin import MultiAgentPlugin + +if TYPE_CHECKING: + from ..multiagent.base import MultiAgentBase + +logger = logging.getLogger(__name__) + + +class _MultiAgentPluginRegistry: + """Registry for managing plugins attached to a multi-agent orchestrator. + + The _MultiAgentPluginRegistry tracks plugins that have been initialized with an + orchestrator, providing methods to add plugins and invoke their initialization. + + The registry handles: + 1. Calling the plugin's init_multi_agent() method for custom initialization + 2. Auto-registering discovered @hook decorated methods with the orchestrator + + Example: + ```python + registry = _MultiAgentPluginRegistry(orchestrator) + + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + @hook + def on_event(self, event: BeforeNodeCallEvent): + pass # Auto-registered by registry + + def init_multi_agent(self, orchestrator: MultiAgentBase) -> None: + # Custom logic + pass + + plugin = MyPlugin() + registry.add_and_init(plugin) + ``` + """ + + def __init__(self, orchestrator: "MultiAgentBase") -> None: + """Initialize a plugin registry with an orchestrator reference. + + Args: + orchestrator: The orchestrator instance that plugins will be initialized with. + """ + self._orchestrator_ref = weakref.ref(orchestrator) + self._plugins: dict[str, MultiAgentPlugin] = {} + + @property + def _orchestrator(self) -> "MultiAgentBase": + """Return the orchestrator, raising ReferenceError if it has been garbage collected.""" + orchestrator = self._orchestrator_ref() + if orchestrator is None: + raise ReferenceError("Orchestrator has been garbage collected") + return orchestrator + + def add_and_init(self, plugin: MultiAgentPlugin) -> None: + """Add and initialize a plugin with the orchestrator. + + This method: + 1. Registers the plugin in the registry + 2. Calls the plugin's init_multi_agent method for custom initialization + 3. Auto-registers all discovered @hook methods with the orchestrator's hook registry + + Handles both sync and async init_multi_agent implementations automatically. + + Args: + plugin: The plugin to add and initialize. + + Raises: + ValueError: If a plugin with the same name is already registered. + """ + if plugin.name in self._plugins: + raise ValueError(f"plugin_name=<{plugin.name}> | plugin already registered") + + logger.debug("plugin_name=<%s> | registering and initializing multi-agent plugin", plugin.name) + self._plugins[plugin.name] = plugin + + # Call user's init_multi_agent for custom initialization + call_init_method(plugin.init_multi_agent, self._orchestrator) + + # Auto-register discovered hooks with the orchestrator + self._register_hooks(plugin) + + def _register_hooks(self, plugin: MultiAgentPlugin) -> None: + """Register all discovered hooks from the plugin with the orchestrator. + + Uses orchestrator.add_hook() so that the orchestrator can track + registrations through its public API. + + Args: + plugin: The plugin whose hooks should be registered. + """ + for hook_callback in plugin.hooks: + event_types = getattr(hook_callback, "_hook_event_types", []) + for event_type in event_types: + self._orchestrator.add_hook(hook_callback, event_type) + logger.debug( + "plugin=<%s>, hook=<%s>, event_type=<%s> | registered hook", + plugin.name, + getattr(hook_callback, "__name__", repr(hook_callback)), + event_type.__name__, + ) diff --git a/src/strands/plugins/plugin.py b/src/strands/plugins/plugin.py new file mode 100644 index 000000000..35633a30e --- /dev/null +++ b/src/strands/plugins/plugin.py @@ -0,0 +1,108 @@ +"""Plugin base class for extending agent functionality. + +This module defines the Plugin base class, which provides a composable way to +add behavior changes to agents through automatic hook and tool registration. +""" + +from abc import ABC, abstractmethod +from collections.abc import Awaitable +from typing import TYPE_CHECKING + +from ..hooks.registry import HookCallback +from ..tools.decorator import DecoratedFunctionTool +from ._discovery import discover_hooks, discover_tools + +if TYPE_CHECKING: + from ..agent import Agent + + +class Plugin(ABC): + """Base class for objects that extend agent functionality. + + Plugins provide a composable way to add behavior changes to agents. + They support automatic discovery and registration of methods decorated + with @hook and @tool decorators. + + Attributes: + name: A stable string identifier for the plugin (must be provided by subclass) + hooks: Hooks attached to the agent, auto-discovered from @hook decorated methods during __init__ + tools: Tools attached to the agent, auto-discovered from @tool decorated methods during __init__ + + Example using decorators (recommended): + ```python + from strands.plugins import Plugin, hook + from strands.hooks import BeforeModelCallEvent + from strands import tool + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def on_model_call(self, event: BeforeModelCallEvent): + print(f"Model called: {event}") + + @tool + def my_tool(self, param: str) -> str: + '''A tool that does something.''' + return f"Result: {param}" + ``` + + Note: Decorated methods are registered in declaration order, with parent + class methods registered before child class methods. If a child overrides + a parent's decorated method, only the child's version is registered. + + Example with custom initialization: + ```python + class MyPlugin(Plugin): + name = "my-plugin" + + def init_agent(self, agent: Agent) -> None: + # Custom initialization logic - no super() needed + # Decorated hooks/tools are auto-registered by the plugin registry + agent.add_hook(self.custom_hook) + + def custom_hook(self, event: BeforeModelCallEvent): + print(event) + ``` + """ + + @property + @abstractmethod + def name(self) -> str: + """A stable string identifier for the plugin.""" + ... + + def __init__(self) -> None: + """Initialize the plugin and discover decorated methods. + + Scans the class for methods decorated with @hook and @tool and stores + references for later registration when the plugin is attached to an agent. + + Uses a guard to prevent double-discovery when used with multiple inheritance + (e.g., a class that inherits from both Plugin and MultiAgentPlugin). + """ + if not hasattr(self, "_hooks"): + self._hooks: list[HookCallback] = discover_hooks(self, self.name) + if not hasattr(self, "_tools"): + self._tools: list[DecoratedFunctionTool] = discover_tools(self, self.name) + + @property + def hooks(self) -> list[HookCallback]: + """List of hooks the plugin provides, auto-discovered from @hook decorated methods.""" + return self._hooks + + @property + def tools(self) -> list[DecoratedFunctionTool]: + """List of tools the plugin provides, auto-discovered from @tool decorated methods.""" + return self._tools + + def init_agent(self, agent: "Agent") -> None | Awaitable[None]: + """Initialize the agent instance. + + Override this method to add custom initialization logic. Decorated + hooks and tools are automatically registered by the plugin registry. + + Args: + agent: The agent instance to initialize. + """ + return None diff --git a/src/strands/plugins/registry.py b/src/strands/plugins/registry.py new file mode 100644 index 000000000..ca5d654c9 --- /dev/null +++ b/src/strands/plugins/registry.py @@ -0,0 +1,133 @@ +"""Plugin registry for managing plugins attached to an agent. + +This module provides the _PluginRegistry class for tracking and managing +plugins that have been initialized with an agent instance. +""" + +import logging +import weakref +from typing import TYPE_CHECKING + +from ._discovery import call_init_method +from .plugin import Plugin + +if TYPE_CHECKING: + from ..agent import Agent + +logger = logging.getLogger(__name__) + + +class _PluginRegistry: + """Registry for managing plugins attached to an agent. + + The _PluginRegistry tracks plugins that have been initialized with an agent, + providing methods to add plugins and invoke their initialization. + + The registry handles: + 1. Calling the plugin's init_agent() method for custom initialization + 2. Auto-registering discovered @hook decorated methods with the agent + 3. Auto-registering discovered @tool decorated methods with the agent + + Example: + ```python + registry = _PluginRegistry(agent) + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def on_event(self, event: BeforeModelCallEvent): + pass # Auto-registered by registry + + def init_agent(self, agent: Agent) -> None: + # Custom logic only - no super() needed + pass + + plugin = MyPlugin() + registry.add_and_init(plugin) + ``` + """ + + def __init__(self, agent: "Agent") -> None: + """Initialize a plugin registry with an agent reference. + + Args: + agent: The agent instance that plugins will be initialized with. + """ + self._agent_ref = weakref.ref(agent) + self._plugins: dict[str, Plugin] = {} + + @property + def _agent(self) -> "Agent": + """Return the agent, raising ReferenceError if it has been garbage collected.""" + agent = self._agent_ref() + if agent is None: + raise ReferenceError("Agent has been garbage collected") + return agent + + def add_and_init(self, plugin: Plugin) -> None: + """Add and initialize a plugin with the agent. + + This method: + 1. Registers the plugin in the registry + 2. Calls the plugin's init_agent method for custom initialization + 3. Auto-registers all discovered @hook methods with the agent's hook registry + 4. Auto-registers all discovered @tool methods with the agent's tool registry + + Handles both sync and async init_agent implementations automatically. + + Args: + plugin: The plugin to add and initialize. + + Raises: + ValueError: If a plugin with the same name is already registered. + """ + if plugin.name in self._plugins: + raise ValueError(f"plugin_name=<{plugin.name}> | plugin already registered") + + logger.debug("plugin_name=<%s> | registering and initializing plugin", plugin.name) + self._plugins[plugin.name] = plugin + + # Call user's init_agent for custom initialization + call_init_method(plugin.init_agent, self._agent) + + # Auto-register discovered hooks with the agent + self._register_hooks(plugin) + + # Auto-register discovered tools with the agent's tool registry + self._register_tools(plugin) + + def _register_hooks(self, plugin: Plugin) -> None: + """Register all discovered hooks from the plugin with the agent. + + Uses agent.add_hook() rather than the hook registry directly, so that + the agent can track registrations through its public API. + + Args: + plugin: The plugin whose hooks should be registered. + """ + for hook_callback in plugin.hooks: + event_types = getattr(hook_callback, "_hook_event_types", []) + for event_type in event_types: + self._agent.add_hook(hook_callback, event_type) + logger.debug( + "plugin=<%s>, hook=<%s>, event_type=<%s> | registered hook", + plugin.name, + getattr(hook_callback, "__name__", repr(hook_callback)), + event_type.__name__, + ) + + def _register_tools(self, plugin: Plugin) -> None: + """Register all discovered tools from the plugin with the agent. + + Args: + plugin: The plugin whose tools should be registered. + """ + if plugin.tools: + self._agent.tool_registry.process_tools(list(plugin.tools)) + for tool in plugin.tools: + logger.debug( + "plugin=<%s>, tool=<%s> | registered tool", + plugin.name, + tool.tool_name, + ) diff --git a/src/strands/session/file_session_manager.py b/src/strands/session/file_session_manager.py index fc80fc520..0b25d4b5d 100644 --- a/src/strands/session/file_session_manager.py +++ b/src/strands/session/file_session_manager.py @@ -5,7 +5,7 @@ import os import shutil import tempfile -from typing import TYPE_CHECKING, Any, Optional, cast +from typing import TYPE_CHECKING, Any, cast from .. import _identifier from ..types.exceptions import SessionException @@ -44,7 +44,7 @@ class FileSessionManager(RepositorySessionManager, SessionRepository): def __init__( self, session_id: str, - storage_dir: Optional[str] = None, + storage_dir: str | None = None, **kwargs: Any, ): """Initialize FileSession with filesystem storage. @@ -108,7 +108,7 @@ def _get_message_path(self, session_id: str, agent_id: str, message_id: int) -> def _read_file(self, path: str) -> dict[str, Any]: """Read JSON file.""" try: - with open(path, "r", encoding="utf-8") as f: + with open(path, encoding="utf-8") as f: return cast(dict[str, Any], json.load(f)) except json.JSONDecodeError as e: raise SessionException(f"Invalid JSON in file {path}: {str(e)}") from e @@ -140,7 +140,7 @@ def create_session(self, session: Session, **kwargs: Any) -> Session: return session - def read_session(self, session_id: str, **kwargs: Any) -> Optional[Session]: + def read_session(self, session_id: str, **kwargs: Any) -> Session | None: """Read session data.""" session_file = os.path.join(self._get_session_path(session_id), "session.json") if not os.path.exists(session_file): @@ -169,7 +169,7 @@ def create_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: A session_data = session_agent.to_dict() self._write_file(agent_file, session_data) - def read_agent(self, session_id: str, agent_id: str, **kwargs: Any) -> Optional[SessionAgent]: + def read_agent(self, session_id: str, agent_id: str, **kwargs: Any) -> SessionAgent | None: """Read agent data.""" agent_file = os.path.join(self._get_agent_path(session_id, agent_id), "agent.json") if not os.path.exists(agent_file): @@ -199,7 +199,7 @@ def create_message(self, session_id: str, agent_id: str, session_message: Sessio session_dict = session_message.to_dict() self._write_file(message_file, session_dict) - def read_message(self, session_id: str, agent_id: str, message_id: int, **kwargs: Any) -> Optional[SessionMessage]: + def read_message(self, session_id: str, agent_id: str, message_id: int, **kwargs: Any) -> SessionMessage | None: """Read message data.""" message_path = self._get_message_path(session_id, agent_id, message_id) if not os.path.exists(message_path): @@ -220,7 +220,7 @@ def update_message(self, session_id: str, agent_id: str, session_message: Sessio self._write_file(message_file, session_message.to_dict()) def list_messages( - self, session_id: str, agent_id: str, limit: Optional[int] = None, offset: int = 0, **kwargs: Any + self, session_id: str, agent_id: str, limit: int | None = None, offset: int = 0, **kwargs: Any ) -> list[SessionMessage]: """List messages for an agent with pagination.""" messages_dir = os.path.join(self._get_agent_path(session_id, agent_id), "messages") @@ -269,7 +269,7 @@ def create_multi_agent(self, session_id: str, multi_agent: "MultiAgentBase", **k session_data = multi_agent.serialize_state() self._write_file(multi_agent_file, session_data) - def read_multi_agent(self, session_id: str, multi_agent_id: str, **kwargs: Any) -> Optional[dict[str, Any]]: + def read_multi_agent(self, session_id: str, multi_agent_id: str, **kwargs: Any) -> dict[str, Any] | None: """Read multi-agent state from filesystem.""" multi_agent_file = os.path.join(self._get_multi_agent_path(session_id, multi_agent_id), "multi_agent.json") if not os.path.exists(multi_agent_file): diff --git a/src/strands/session/repository_session_manager.py b/src/strands/session/repository_session_manager.py index a8ac099d9..c1032a85e 100644 --- a/src/strands/session/repository_session_manager.py +++ b/src/strands/session/repository_session_manager.py @@ -1,7 +1,8 @@ """Repository session manager implementation.""" +import copy import logging -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from ..agent.state import AgentState from ..tools._tool_helpers import generate_missing_tool_result_content @@ -51,13 +52,19 @@ def __init__( # Create a session if it does not exist yet if session is None: logger.debug("session_id=<%s> | session not found, creating new session", self.session_id) + self._is_new_session = True session = Session(session_id=session_id, session_type=SessionType.AGENT) session_repository.create_session(session) + else: + self._is_new_session = False self.session = session # Keep track of the latest message of each agent in case we need to redact it. - self._latest_agent_message: dict[str, Optional[SessionMessage]] = {} + self._latest_agent_message: dict[str, SessionMessage | None] = {} + + # Track the previously synced internal state for each agent to detect changes. + self._last_synced_internal_state: dict[str, dict[str, Any]] = {} def append_message(self, message: Message, agent: "Agent", **kwargs: Any) -> None: """Append a message to the agent's session. @@ -95,15 +102,70 @@ def redact_latest_message(self, redact_message: Message, agent: "Agent", **kwarg def sync_agent(self, agent: "Agent", **kwargs: Any) -> None: """Serialize and update the agent into the session repository. + Only updates the agent if state has been modified or internal state has changed. + This optimization reduces unnecessary I/O operations when the agent processes + messages without modifying its state. + Args: agent: Agent to sync to the session. **kwargs: Additional keyword arguments for future extensibility. """ + # Get current versions and conversation manager state + current_state_version = agent.state._get_version() + current_interrupt_state_version = agent._interrupt_state._get_version() + current_conversation_manager_state = agent.conversation_manager.get_state() + current_model_state = agent._model_state + + # Check if we have a previous state to compare against + last_synced = self._last_synced_internal_state.get(agent.agent_id) + + # Determine if we need to update by comparing versions + if last_synced is None: + # First sync for this agent - always update + state_changed = True + internal_state_changed = True + conversation_manager_state_changed = True + else: + state_changed = current_state_version != last_synced.get("state_version") + internal_state_changed = current_interrupt_state_version != last_synced.get( + "interrupt_state_version" + ) or current_model_state != last_synced.get("model_state") + conversation_manager_state_changed = current_conversation_manager_state != last_synced.get( + "conversation_manager_state" + ) + + if not state_changed and not internal_state_changed and not conversation_manager_state_changed: + logger.debug( + "agent_id=<%s> | session_id=<%s> | skipping sync, no changes detected", + agent.agent_id, + self.session_id, + ) + return + + logger.debug( + "agent_id=<%s> | session_id=<%s> | state_changed=<%s>, internal_state_changed=<%s>, " + "conversation_manager_state_changed=<%s> | syncing agent", + agent.agent_id, + self.session_id, + state_changed, + internal_state_changed, + conversation_manager_state_changed, + ) + + # Perform the update self.session_repository.update_agent( self.session_id, SessionAgent.from_agent(agent), ) + # Update tracked versions after successful sync + self._last_synced_internal_state[agent.agent_id] = { + "state_version": current_state_version, + "interrupt_state_version": current_interrupt_state_version, + "conversation_manager_state": copy.deepcopy(current_conversation_manager_state), + "model_state": copy.deepcopy(current_model_state), + } + def initialize(self, agent: "Agent", **kwargs: Any) -> None: """Initialize an agent with a session. @@ -115,7 +177,11 @@ def initialize(self, agent: "Agent", **kwargs: Any) -> None: raise SessionException("The `agent_id` of an agent must be unique in a session.") self._latest_agent_message[agent.agent_id] = None - session_agent = self.session_repository.read_agent(self.session_id, agent.agent_id) + # Skip read_agent call for new sessions since no agents can exist yet + if self._is_new_session: + session_agent = None + else: + session_agent = self.session_repository.read_agent(self.session_id, agent.agent_id) if session_agent is None: logger.debug( @@ -158,11 +224,23 @@ def initialize(self, agent: "Agent", **kwargs: Any) -> None: if len(session_messages) > 0: self._latest_agent_message[agent.agent_id] = session_messages[-1] - # Restore the agents messages array including the optional prepend messages - agent.messages = prepend_messages + [session_message.to_message() for session_message in session_messages] + # Skip restoring messages when conversation is managed server-side + if agent.model.stateful: + logger.debug( + "agent_id=<%s> | session_id=<%s> | skipping message restore for server-managed conversation", + agent.agent_id, + self.session_id, + ) + else: + # Restore the agents messages array including the optional prepend messages + agent.messages = prepend_messages + [ + session_message.to_message() for session_message in session_messages + ] - # Fix broken session histories: https://github.com/strands-agents/sdk-python/issues/859 - agent.messages = self._fix_broken_tool_use(agent.messages) + # Fix broken session histories: https://github.com/strands-agents/sdk-python/issues/859 + agent.messages = self._fix_broken_tool_use(agent.messages) + + self._is_new_session = False def _fix_broken_tool_use(self, messages: list[Message]) -> list[Message]: """Fix broken tool use/result pairs in message history. @@ -244,13 +322,20 @@ def initialize_multi_agent(self, source: "MultiAgentBase", **kwargs: Any) -> Non source: Multi-agent source object to restore state into **kwargs: Additional keyword arguments for future extensibility. """ - state = self.session_repository.read_multi_agent(self.session_id, source.id, **kwargs) + # Skip read_multi_agent call for new sessions since no multi-agents can exist yet + if self._is_new_session: + state = None + else: + state = self.session_repository.read_multi_agent(self.session_id, source.id, **kwargs) + if state is None: self.session_repository.create_multi_agent(self.session_id, source, **kwargs) else: logger.debug("session_id=<%s> | restoring multi-agent state", self.session_id) source.deserialize_state(state) + self._is_new_session = False + def initialize_bidi_agent(self, agent: "BidiAgent", **kwargs: Any) -> None: """Initialize a bidirectional agent with a session. @@ -262,7 +347,11 @@ def initialize_bidi_agent(self, agent: "BidiAgent", **kwargs: Any) -> None: raise SessionException("The `agent_id` of an agent must be unique in a session.") self._latest_agent_message[agent.agent_id] = None - session_agent = self.session_repository.read_agent(self.session_id, agent.agent_id) + # Skip read_agent call for new sessions since no agents can exist yet + if self._is_new_session: + session_agent = None + else: + session_agent = self.session_repository.read_agent(self.session_id, agent.agent_id) if session_agent is None: logger.debug( @@ -304,6 +393,8 @@ def initialize_bidi_agent(self, agent: "BidiAgent", **kwargs: Any) -> None: # Fix broken session histories: https://github.com/strands-agents/sdk-python/issues/859 agent.messages = self._fix_broken_tool_use(agent.messages) + self._is_new_session = False + def append_bidi_message(self, message: Message, agent: "BidiAgent", **kwargs: Any) -> None: """Append a message to the bidirectional agent's session. diff --git a/src/strands/session/s3_session_manager.py b/src/strands/session/s3_session_manager.py index 7d081cf09..fad5e4fd0 100644 --- a/src/strands/session/s3_session_manager.py +++ b/src/strands/session/s3_session_manager.py @@ -2,7 +2,8 @@ import json import logging -from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import TYPE_CHECKING, Any, cast import boto3 from botocore.config import Config as BotocoreConfig @@ -47,9 +48,9 @@ def __init__( session_id: str, bucket: str, prefix: str = "", - boto_session: Optional[boto3.Session] = None, - boto_client_config: Optional[BotocoreConfig] = None, - region_name: Optional[str] = None, + boto_session: boto3.Session | None = None, + boto_client_config: BotocoreConfig | None = None, + region_name: str | None = None, **kwargs: Any, ): """Initialize S3SessionManager with S3 storage. @@ -94,7 +95,10 @@ def _get_session_path(self, session_id: str) -> str: ValueError: If session id contains a path separator. """ session_id = _identifier.validate(session_id, _identifier.Identifier.SESSION) - return f"{self.prefix}/{SESSION_PREFIX}{session_id}/" + prefix = self.prefix.strip("/") + if prefix: + return f"{prefix}/{SESSION_PREFIX}{session_id}/" + return f"{SESSION_PREFIX}{session_id}/" def _get_agent_path(self, session_id: str, agent_id: str) -> str: """Get agent S3 prefix. @@ -130,7 +134,7 @@ def _get_message_path(self, session_id: str, agent_id: str, message_id: int) -> agent_path = self._get_agent_path(session_id, agent_id) return f"{agent_path}messages/{MESSAGE_PREFIX}{message_id}.json" - def _read_s3_object(self, key: str) -> Optional[Dict[str, Any]]: + def _read_s3_object(self, key: str) -> dict[str, Any] | None: """Read JSON object from S3.""" try: response = self.client.get_object(Bucket=self.bucket, Key=key) @@ -144,7 +148,7 @@ def _read_s3_object(self, key: str) -> Optional[Dict[str, Any]]: except json.JSONDecodeError as e: raise SessionException(f"Invalid JSON in S3 object {key}: {e}") from e - def _write_s3_object(self, key: str, data: Dict[str, Any]) -> None: + def _write_s3_object(self, key: str, data: dict[str, Any]) -> None: """Write JSON object to S3.""" try: content = json.dumps(data, indent=2, ensure_ascii=False) @@ -171,7 +175,7 @@ def create_session(self, session: Session, **kwargs: Any) -> Session: self._write_s3_object(session_key, session_dict) return session - def read_session(self, session_id: str, **kwargs: Any) -> Optional[Session]: + def read_session(self, session_id: str, **kwargs: Any) -> Session | None: """Read session data from S3.""" session_key = f"{self._get_session_path(session_id)}session.json" session_data = self._read_s3_object(session_key) @@ -209,7 +213,7 @@ def create_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: A agent_key = f"{self._get_agent_path(session_id, agent_id)}agent.json" self._write_s3_object(agent_key, agent_dict) - def read_agent(self, session_id: str, agent_id: str, **kwargs: Any) -> Optional[SessionAgent]: + def read_agent(self, session_id: str, agent_id: str, **kwargs: Any) -> SessionAgent | None: """Read agent data from S3.""" agent_key = f"{self._get_agent_path(session_id, agent_id)}agent.json" agent_data = self._read_s3_object(agent_key) @@ -236,7 +240,7 @@ def create_message(self, session_id: str, agent_id: str, session_message: Sessio message_key = self._get_message_path(session_id, agent_id, message_id) self._write_s3_object(message_key, message_dict) - def read_message(self, session_id: str, agent_id: str, message_id: int, **kwargs: Any) -> Optional[SessionMessage]: + def read_message(self, session_id: str, agent_id: str, message_id: int, **kwargs: Any) -> SessionMessage | None: """Read message data from S3.""" message_key = self._get_message_path(session_id, agent_id, message_id) message_data = self._read_s3_object(message_key) @@ -257,9 +261,23 @@ def update_message(self, session_id: str, agent_id: str, session_message: Sessio self._write_s3_object(message_key, session_message.to_dict()) def list_messages( - self, session_id: str, agent_id: str, limit: Optional[int] = None, offset: int = 0, **kwargs: Any - ) -> List[SessionMessage]: - """List messages for an agent with pagination from S3.""" + self, session_id: str, agent_id: str, limit: int | None = None, offset: int = 0, **kwargs: Any + ) -> list[SessionMessage]: + """List messages for an agent with pagination from S3. + + Args: + session_id: ID of the session + agent_id: ID of the agent + limit: Optional limit on number of messages to return + offset: Optional offset for pagination + **kwargs: Additional keyword arguments + + Returns: + List of SessionMessage objects, sorted by message_id. + + Raises: + SessionException: If S3 error occurs during message retrieval. + """ messages_prefix = f"{self._get_agent_path(session_id, agent_id)}messages/" try: paginator = self.client.get_paginator("list_objects_v2") @@ -287,10 +305,38 @@ def list_messages( else: message_keys = message_keys[offset:] - # Load only the required message objects - messages: List[SessionMessage] = [] - for key in message_keys: - message_data = self._read_s3_object(key) + # Load message objects in parallel for better performance + messages: list[SessionMessage] = [] + if not message_keys: + return messages + + # Optimize for single worker case - avoid thread pool overhead + if len(message_keys) == 1: + for key in message_keys: + message_data = self._read_s3_object(key) + if message_data: + messages.append(SessionMessage.from_dict(message_data)) + return messages + + with ThreadPoolExecutor() as executor: + # Submit all read tasks + future_to_key = {executor.submit(self._read_s3_object, key): key for key in message_keys} + + # Create a mapping from key to index to maintain order + key_to_index = {key: idx for idx, key in enumerate(message_keys)} + + # Initialize results list with None placeholders to maintain order + results: list[dict[str, Any] | None] = [None] * len(message_keys) + + # Process results as they complete + for future in as_completed(future_to_key): + key = future_to_key[future] + message_data = future.result() + # Store result at the correct index to maintain order + results[key_to_index[key]] = message_data + + # Convert results to SessionMessage objects, filtering out None values + for message_data in results: if message_data: messages.append(SessionMessage.from_dict(message_data)) @@ -312,7 +358,7 @@ def create_multi_agent(self, session_id: str, multi_agent: "MultiAgentBase", **k session_data = multi_agent.serialize_state() self._write_s3_object(multi_agent_key, session_data) - def read_multi_agent(self, session_id: str, multi_agent_id: str, **kwargs: Any) -> Optional[dict[str, Any]]: + def read_multi_agent(self, session_id: str, multi_agent_id: str, **kwargs: Any) -> dict[str, Any] | None: """Read multi-agent state from S3.""" multi_agent_key = f"{self._get_multi_agent_path(session_id, multi_agent_id)}multi_agent.json" return self._read_s3_object(multi_agent_key) diff --git a/src/strands/session/session_manager.py b/src/strands/session/session_manager.py index ba4356089..cc954e17d 100644 --- a/src/strands/session/session_manager.py +++ b/src/strands/session/session_manager.py @@ -9,12 +9,14 @@ BidiAgentInitializedEvent, BidiMessageAddedEvent, ) -from ..experimental.hooks.multiagent.events import ( +from ..hooks.events import ( + AfterInvocationEvent, AfterMultiAgentInvocationEvent, AfterNodeCallEvent, + AgentInitializedEvent, + MessageAddedEvent, MultiAgentInitializedEvent, ) -from ..hooks.events import AfterInvocationEvent, AgentInitializedEvent, MessageAddedEvent from ..hooks.registry import HookProvider, HookRegistry from ..types.content import Message diff --git a/src/strands/session/session_repository.py b/src/strands/session/session_repository.py index 3f5476bdf..0b6f2c705 100644 --- a/src/strands/session/session_repository.py +++ b/src/strands/session/session_repository.py @@ -1,7 +1,7 @@ """Session repository interface for agent session management.""" from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from ..types.session import Session, SessionAgent, SessionMessage @@ -17,7 +17,7 @@ def create_session(self, session: Session, **kwargs: Any) -> Session: """Create a new Session.""" @abstractmethod - def read_session(self, session_id: str, **kwargs: Any) -> Optional[Session]: + def read_session(self, session_id: str, **kwargs: Any) -> Session | None: """Read a Session.""" @abstractmethod @@ -25,7 +25,7 @@ def create_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: A """Create a new Agent in a Session.""" @abstractmethod - def read_agent(self, session_id: str, agent_id: str, **kwargs: Any) -> Optional[SessionAgent]: + def read_agent(self, session_id: str, agent_id: str, **kwargs: Any) -> SessionAgent | None: """Read an Agent.""" @abstractmethod @@ -37,7 +37,7 @@ def create_message(self, session_id: str, agent_id: str, session_message: Sessio """Create a new Message for the Agent.""" @abstractmethod - def read_message(self, session_id: str, agent_id: str, message_id: int, **kwargs: Any) -> Optional[SessionMessage]: + def read_message(self, session_id: str, agent_id: str, message_id: int, **kwargs: Any) -> SessionMessage | None: """Read a Message.""" @abstractmethod @@ -49,7 +49,7 @@ def update_message(self, session_id: str, agent_id: str, session_message: Sessio @abstractmethod def list_messages( - self, session_id: str, agent_id: str, limit: Optional[int] = None, offset: int = 0, **kwargs: Any + self, session_id: str, agent_id: str, limit: int | None = None, offset: int = 0, **kwargs: Any ) -> list[SessionMessage]: """List Messages from an Agent with pagination.""" @@ -57,7 +57,7 @@ def create_multi_agent(self, session_id: str, multi_agent: "MultiAgentBase", **k """Create a new MultiAgent state for the Session.""" raise NotImplementedError("MultiAgent is not implemented for this repository") - def read_multi_agent(self, session_id: str, multi_agent_id: str, **kwargs: Any) -> Optional[dict[str, Any]]: + def read_multi_agent(self, session_id: str, multi_agent_id: str, **kwargs: Any) -> dict[str, Any] | None: """Read the MultiAgent state for the Session.""" raise NotImplementedError("MultiAgent is not implemented for this repository") diff --git a/src/strands/telemetry/config.py b/src/strands/telemetry/config.py index 0509c7440..93225335d 100644 --- a/src/strands/telemetry/config.py +++ b/src/strands/telemetry/config.py @@ -5,6 +5,7 @@ """ import logging +import os from importlib.metadata import version from typing import Any @@ -29,9 +30,11 @@ def get_otel_resource() -> Resource: Returns: Resource object with standard service information. """ + service_name = os.getenv("OTEL_SERVICE_NAME", "strands-agents").strip() + resource = Resource.create( { - "service.name": "strands-agents", + "service.name": service_name, "service.version": version("strands-agents"), "telemetry.sdk.name": "opentelemetry", "telemetry.sdk.language": "python", @@ -56,6 +59,7 @@ class StrandsTelemetry: Environment variables are handled by the underlying OpenTelemetry SDK: - OTEL_EXPORTER_OTLP_ENDPOINT: OTLP endpoint URL - OTEL_EXPORTER_OTLP_HEADERS: Headers for OTLP requests + - OTEL_SERVICE_NAME: Overrides resource service name Examples: Quick setup with method chaining: diff --git a/src/strands/telemetry/metrics.py b/src/strands/telemetry/metrics.py index abfbbffae..11690dd44 100644 --- a/src/strands/telemetry/metrics.py +++ b/src/strands/telemetry/metrics.py @@ -3,8 +3,9 @@ import logging import time import uuid +from collections.abc import Iterable from dataclasses import dataclass, field -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple +from typing import Any, Optional import opentelemetry.metrics as metrics_api from opentelemetry.metrics import Counter, Histogram, Meter @@ -23,11 +24,11 @@ class Trace: def __init__( self, name: str, - parent_id: Optional[str] = None, - start_time: Optional[float] = None, - raw_name: Optional[str] = None, - metadata: Optional[Dict[str, Any]] = None, - message: Optional[Message] = None, + parent_id: str | None = None, + start_time: float | None = None, + raw_name: str | None = None, + metadata: dict[str, Any] | None = None, + message: Message | None = None, ) -> None: """Initialize a new trace. @@ -42,15 +43,15 @@ def __init__( """ self.id: str = str(uuid.uuid4()) self.name: str = name - self.raw_name: Optional[str] = raw_name - self.parent_id: Optional[str] = parent_id + self.raw_name: str | None = raw_name + self.parent_id: str | None = parent_id self.start_time: float = start_time if start_time is not None else time.time() - self.end_time: Optional[float] = None - self.children: List["Trace"] = [] - self.metadata: Dict[str, Any] = metadata or {} - self.message: Optional[Message] = message + self.end_time: float | None = None + self.children: list[Trace] = [] + self.metadata: dict[str, Any] = metadata or {} + self.message: Message | None = message - def end(self, end_time: Optional[float] = None) -> None: + def end(self, end_time: float | None = None) -> None: """Mark the trace as complete with the given or current timestamp. Args: @@ -67,7 +68,7 @@ def add_child(self, child: "Trace") -> None: """ self.children.append(child) - def duration(self) -> Optional[float]: + def duration(self) -> float | None: """Calculate the duration of this trace. Returns: @@ -83,7 +84,7 @@ def add_message(self, message: Message) -> None: """ self.message = message - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """Convert the trace to a dictionary representation. Returns: @@ -127,7 +128,7 @@ def add_call( duration: float, success: bool, metrics_client: "MetricsClient", - attributes: Optional[Dict[str, Any]] = None, + attributes: dict[str, Any] | None = None, ) -> None: """Record a new tool call with its outcome. @@ -151,6 +152,34 @@ def add_call( metrics_client.tool_error_count.add(1, attributes=attributes) +@dataclass +class EventLoopCycleMetric: + """Aggregated metrics for a single event loop cycle. + + Attributes: + event_loop_cycle_id: Current eventLoop cycle id. + usage: Total token usage for the entire cycle (succeeded model invocation, excluding tool invocations). + """ + + event_loop_cycle_id: str + usage: Usage + + +@dataclass +class AgentInvocation: + """Metrics for a single agent invocation. + + AgentInvocation contains all the event loop cycles and accumulated token usage for that invocation. + + Attributes: + cycles: List of event loop cycles that occurred during this invocation. + usage: Accumulated token usage for this invocation across all cycles. + """ + + cycles: list[EventLoopCycleMetric] = field(default_factory=list) + usage: Usage = field(default_factory=lambda: Usage(inputTokens=0, outputTokens=0, totalTokens=0)) + + @dataclass class EventLoopMetrics: """Aggregated metrics for an event loop's execution. @@ -159,31 +188,74 @@ class EventLoopMetrics: cycle_count: Number of event loop cycles executed. tool_metrics: Metrics for each tool used, keyed by tool name. cycle_durations: List of durations for each cycle in seconds. + agent_invocations: Agent invocation metrics containing cycles and usage data. traces: List of execution traces. - accumulated_usage: Accumulated token usage across all model invocations. + accumulated_usage: Accumulated token usage across all model invocations (across all requests). accumulated_metrics: Accumulated performance metrics across all model invocations. """ cycle_count: int = 0 - tool_metrics: Dict[str, ToolMetrics] = field(default_factory=dict) - cycle_durations: List[float] = field(default_factory=list) - traces: List[Trace] = field(default_factory=list) + tool_metrics: dict[str, ToolMetrics] = field(default_factory=dict) + cycle_durations: list[float] = field(default_factory=list) + agent_invocations: list[AgentInvocation] = field(default_factory=list) + traces: list[Trace] = field(default_factory=list) accumulated_usage: Usage = field(default_factory=lambda: Usage(inputTokens=0, outputTokens=0, totalTokens=0)) accumulated_metrics: Metrics = field(default_factory=lambda: Metrics(latencyMs=0)) + @property + def latest_context_size(self) -> int | None: + """Most recent context size from the last LLM call. + + This represents the current context size as reported by the model. + + Returns: + The input token count from the most recent cycle, or None if no data is available. + """ + if self.agent_invocations and self.agent_invocations[-1].cycles: + return self.agent_invocations[-1].cycles[-1].usage.get("inputTokens") + return None + + @property + def projected_context_size(self) -> int | None: + """Projected context size for the next model call. + + Computed as inputTokens + outputTokens from the most recent cycle's usage, + representing the approximate input token count for the next model call + (prior input + generated output that is now part of the conversation). + + Returns: + The projected token count, or None if no data is available. + """ + if self.agent_invocations and self.agent_invocations[-1].cycles: + usage = self.agent_invocations[-1].cycles[-1].usage + input_tokens = usage.get("inputTokens") + output_tokens = usage.get("outputTokens") + if input_tokens is not None and output_tokens is not None: + return input_tokens + output_tokens + return None + @property def _metrics_client(self) -> "MetricsClient": """Get the singleton MetricsClient instance.""" return MetricsClient() + @property + def latest_agent_invocation(self) -> AgentInvocation | None: + """Get the most recent agent invocation. + + Returns: + The most recent AgentInvocation, or None if no invocations exist. + """ + return self.agent_invocations[-1] if self.agent_invocations else None + def start_cycle( self, - attributes: Optional[Dict[str, Any]] = None, - ) -> Tuple[float, Trace]: + attributes: dict[str, Any], + ) -> tuple[float, Trace]: """Start a new event loop cycle and create a trace for it. Args: - attributes: attributes of the metrics. + attributes: attributes of the metrics, including event_loop_cycle_id. Returns: A tuple containing the start time and the cycle trace object. @@ -194,9 +266,17 @@ def start_cycle( start_time = time.time() cycle_trace = Trace(f"Cycle {self.cycle_count}", start_time=start_time) self.traces.append(cycle_trace) + + self.agent_invocations[-1].cycles.append( + EventLoopCycleMetric( + event_loop_cycle_id=attributes["event_loop_cycle_id"], + usage=Usage(inputTokens=0, outputTokens=0, totalTokens=0), + ) + ) + return start_time, cycle_trace - def end_cycle(self, start_time: float, cycle_trace: Trace, attributes: Optional[Dict[str, Any]] = None) -> None: + def end_cycle(self, start_time: float, cycle_trace: Trace, attributes: dict[str, Any] | None = None) -> None: """End the current event loop cycle and record its duration. Args: @@ -252,32 +332,53 @@ def add_tool_usage( ) tool_trace.end() + def _accumulate_usage(self, target: Usage, source: Usage) -> None: + """Helper method to accumulate usage from source to target. + + Args: + target: The Usage object to accumulate into. + source: The Usage object to accumulate from. + """ + target["inputTokens"] += source["inputTokens"] + target["outputTokens"] += source["outputTokens"] + target["totalTokens"] += source["totalTokens"] + + if "cacheReadInputTokens" in source: + target["cacheReadInputTokens"] = target.get("cacheReadInputTokens", 0) + source["cacheReadInputTokens"] + + if "cacheWriteInputTokens" in source: + target["cacheWriteInputTokens"] = target.get("cacheWriteInputTokens", 0) + source["cacheWriteInputTokens"] + def update_usage(self, usage: Usage) -> None: """Update the accumulated token usage with new usage data. Args: usage: The usage data to add to the accumulated totals. """ + # Record metrics to OpenTelemetry self._metrics_client.event_loop_input_tokens.record(usage["inputTokens"]) self._metrics_client.event_loop_output_tokens.record(usage["outputTokens"]) - self.accumulated_usage["inputTokens"] += usage["inputTokens"] - self.accumulated_usage["outputTokens"] += usage["outputTokens"] - self.accumulated_usage["totalTokens"] += usage["totalTokens"] - # Handle optional cached token metrics + # Handle optional cached token metrics for OpenTelemetry if "cacheReadInputTokens" in usage: - cache_read_tokens = usage["cacheReadInputTokens"] - self._metrics_client.event_loop_cache_read_input_tokens.record(cache_read_tokens) - self.accumulated_usage["cacheReadInputTokens"] = ( - self.accumulated_usage.get("cacheReadInputTokens", 0) + cache_read_tokens - ) - + self._metrics_client.event_loop_cache_read_input_tokens.record(usage["cacheReadInputTokens"]) if "cacheWriteInputTokens" in usage: - cache_write_tokens = usage["cacheWriteInputTokens"] - self._metrics_client.event_loop_cache_write_input_tokens.record(cache_write_tokens) - self.accumulated_usage["cacheWriteInputTokens"] = ( - self.accumulated_usage.get("cacheWriteInputTokens", 0) + cache_write_tokens - ) + self._metrics_client.event_loop_cache_write_input_tokens.record(usage["cacheWriteInputTokens"]) + + self._accumulate_usage(self.accumulated_usage, usage) + self._accumulate_usage(self.agent_invocations[-1].usage, usage) + + if self.agent_invocations[-1].cycles: + current_cycle = self.agent_invocations[-1].cycles[-1] + self._accumulate_usage(current_cycle.usage, usage) + + def reset_usage_metrics(self) -> None: + """Start a new agent invocation by creating a new AgentInvocation. + + This should be called at the start of a new request to begin tracking + a new agent invocation with fresh usage and cycle data. + """ + self.agent_invocations.append(AgentInvocation()) def update_metrics(self, metrics: Metrics) -> None: """Update the accumulated performance metrics with new metrics data. @@ -290,7 +391,7 @@ def update_metrics(self, metrics: Metrics) -> None: self._metrics_client.model_time_to_first_token.record(metrics["timeToFirstByteMs"]) self.accumulated_metrics["latencyMs"] += metrics["latencyMs"] - def get_summary(self) -> Dict[str, Any]: + def get_summary(self) -> dict[str, Any]: """Generate a comprehensive summary of all collected metrics. Returns: @@ -322,11 +423,21 @@ def get_summary(self) -> Dict[str, Any]: "traces": [trace.to_dict() for trace in self.traces], "accumulated_usage": self.accumulated_usage, "accumulated_metrics": self.accumulated_metrics, + "agent_invocations": [ + { + "usage": invocation.usage, + "cycles": [ + {"event_loop_cycle_id": cycle.event_loop_cycle_id, "usage": cycle.usage} + for cycle in invocation.cycles + ], + } + for invocation in self.agent_invocations + ], } return summary -def _metrics_summary_to_lines(event_loop_metrics: EventLoopMetrics, allowed_names: Set[str]) -> Iterable[str]: +def _metrics_summary_to_lines(event_loop_metrics: EventLoopMetrics, allowed_names: set[str]) -> Iterable[str]: """Convert event loop metrics to a series of formatted text lines. Args: @@ -387,7 +498,7 @@ def _metrics_summary_to_lines(event_loop_metrics: EventLoopMetrics, allowed_name yield from _trace_to_lines(trace.to_dict(), allowed_names=allowed_names, indent=1) -def _trace_to_lines(trace: Dict, allowed_names: Set[str], indent: int) -> Iterable[str]: +def _trace_to_lines(trace: dict, allowed_names: set[str], indent: int) -> Iterable[str]: """Convert a trace to a series of formatted text lines. Args: @@ -419,7 +530,7 @@ def _trace_to_lines(trace: Dict, allowed_names: Set[str], indent: int) -> Iterab yield from _trace_to_lines(child, allowed_names, indent + 1) -def metrics_to_string(event_loop_metrics: EventLoopMetrics, allowed_names: Optional[Set[str]] = None) -> str: +def metrics_to_string(event_loop_metrics: EventLoopMetrics, allowed_names: set[str] | None = None) -> str: """Convert event loop metrics to a human-readable string representation. Args: diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index 2f42d9988..648a65d27 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -7,8 +7,9 @@ import json import logging import os +from collections.abc import Mapping from datetime import date, datetime, timezone -from typing import Any, Dict, Mapping, Optional, cast +from typing import Any, cast import opentelemetry.trace as trace_api from opentelemetry.instrumentation.threading import ThreadingInstrumentor @@ -82,14 +83,14 @@ class Tracer: When the OTEL_EXPORTER_OTLP_ENDPOINT environment variable is set, traces are sent to the OTLP endpoint. - Both attributes are controlled by including "gen_ai_latest_experimental" or "gen_ai_tool_definitions", - respectively, in the OTEL_SEMCONV_STABILITY_OPT_IN environment variable. + Both attributes are controlled by including "gen_ai_latest_experimental", "gen_ai_tool_definitions", + or "gen_ai_use_latest_invocation_tokens", respectively, in the OTEL_SEMCONV_STABILITY_OPT_IN environment variable. """ def __init__(self) -> None: """Initialize the tracer.""" self.service_name = __name__ - self.tracer_provider: Optional[trace_api.TracerProvider] = None + self.tracer_provider: trace_api.TracerProvider | None = None self.tracer_provider = trace_api.get_tracer_provider() self.tracer = self.tracer_provider.get_tracer(self.service_name) ThreadingInstrumentor().instrument() @@ -99,6 +100,7 @@ def __init__(self) -> None: ## To-do: should not set below attributes directly, use env var instead self.use_latest_genai_conventions = "gen_ai_latest_experimental" in opt_in_values self._include_tool_definitions = "gen_ai_tool_definitions" in opt_in_values + self._use_latest_invocation_tokens = "gen_ai_use_latest_invocation_tokens" in opt_in_values def _parse_semconv_opt_in(self) -> set[str]: """Parse the OTEL_SEMCONV_STABILITY_OPT_IN environment variable. @@ -109,11 +111,23 @@ def _parse_semconv_opt_in(self) -> set[str]: opt_in_env = os.getenv("OTEL_SEMCONV_STABILITY_OPT_IN", "") return {value.strip() for value in opt_in_env.split(",")} + @property + def is_langfuse(self) -> bool: + """Check if Langfuse is configured as the OTLP endpoint. + + Returns: + True if Langfuse is the OTLP endpoint, False otherwise. + """ + return any( + "langfuse" in os.getenv(var, "") + for var in ("OTEL_EXPORTER_OTLP_ENDPOINT", "OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", "LANGFUSE_BASE_URL") + ) + def _start_span( self, span_name: str, - parent_span: Optional[Span] = None, - attributes: Optional[Dict[str, AttributeValue]] = None, + parent_span: Span | None = None, + attributes: dict[str, AttributeValue] | None = None, span_kind: trace_api.SpanKind = trace_api.SpanKind.INTERNAL, ) -> Span: """Generic helper method to start a span with common attributes. @@ -141,25 +155,12 @@ def _start_span( # Add all provided attributes if attributes: - self._set_attributes(span, attributes) + span.set_attributes(attributes) return span - def _set_attributes(self, span: Span, attributes: Dict[str, AttributeValue]) -> None: - """Set attributes on a span, handling different value types appropriately. - - Args: - span: The span to set attributes on - attributes: Dictionary of attributes to set - """ - if not span: - return - - for key, value in attributes.items(): - span.set_attribute(key, value) - def _add_optional_usage_and_metrics_attributes( - self, attributes: Dict[str, AttributeValue], usage: Usage, metrics: Metrics + self, attributes: dict[str, AttributeValue], usage: Usage, metrics: Metrics ) -> None: """Add optional usage and metrics attributes if they have values. @@ -183,8 +184,9 @@ def _add_optional_usage_and_metrics_attributes( def _end_span( self, span: Span, - attributes: Optional[Dict[str, AttributeValue]] = None, - error: Optional[Exception] = None, + attributes: dict[str, AttributeValue] | None = None, + error: Exception | None = None, + error_message: str | None = None, ) -> None: """Generic helper method to end a span. @@ -192,8 +194,9 @@ def _end_span( span: The span to end attributes: Optional attributes to set before ending the span error: Optional exception if an error occurred + error_message: Optional error message to set in the span status """ - if not span: + if not span or not span.is_recording(): return try: @@ -202,26 +205,23 @@ def _end_span( # Add any additional attributes if attributes: - self._set_attributes(span, attributes) + span.set_attributes(attributes) # Handle error if present if error: - span.set_status(StatusCode.ERROR, str(error)) + status_description = error_message or str(error) or type(error).__name__ + span.set_status(StatusCode.ERROR, status_description) span.record_exception(error) + elif error_message: + span.set_status(StatusCode.ERROR, error_message) else: span.set_status(StatusCode.OK) except Exception as e: logger.warning("error=<%s> | error while ending span", e, exc_info=True) finally: span.end() - # Force flush to ensure spans are exported - if self.tracer_provider and hasattr(self.tracer_provider, "force_flush"): - try: - self.tracer_provider.force_flush() - except Exception as e: - logger.warning("error=<%s> | failed to force flush tracer provider", e) - - def end_span_with_error(self, span: Span, error_message: str, exception: Optional[Exception] = None) -> None: + + def end_span_with_error(self, span: Span, error_message: str, exception: Exception | None = None) -> None: """End a span with error status. Args: @@ -229,23 +229,30 @@ def end_span_with_error(self, span: Span, error_message: str, exception: Optiona error_message: Error message to set in the span status. exception: Optional exception to record in the span. """ - if not span: + if not span or not span.is_recording(): return error = exception or Exception(error_message) - self._end_span(span, error=error) + self._end_span(span, error=error, error_message=error_message) - def _add_event(self, span: Optional[Span], event_name: str, event_attributes: Attributes) -> None: + def _add_event( + self, span: Span | None, event_name: str, event_attributes: Attributes, to_span_attributes: bool = False + ) -> None: """Add an event with attributes to a span. Args: span: The span to add the event to event_name: Name of the event event_attributes: Dictionary of attributes to set on the event + to_span_attributes: Add the attributes to span attributes """ if not span: return + # Add to span attribute since some backend can't read the events + if to_span_attributes and event_attributes: + span.set_attributes(event_attributes) + span.add_event(event_name, attributes=event_attributes) def _get_event_name_for_message(self, message: Message) -> str: @@ -275,9 +282,11 @@ def _get_event_name_for_message(self, message: Message) -> str: def start_model_invoke_span( self, messages: Messages, - parent_span: Optional[Span] = None, - model_id: Optional[str] = None, - custom_trace_attributes: Optional[Mapping[str, AttributeValue]] = None, + parent_span: Span | None = None, + model_id: str | None = None, + custom_trace_attributes: Mapping[str, AttributeValue] | None = None, + system_prompt: str | None = None, + system_prompt_content: list | None = None, **kwargs: Any, ) -> Span: """Start a new span for a model invocation. @@ -287,12 +296,14 @@ def start_model_invoke_span( parent_span: Optional parent span to link this span to. model_id: Optional identifier for the model being invoked. custom_trace_attributes: Optional mapping of custom trace attributes to include in the span. + system_prompt: Optional system prompt string provided to the model. + system_prompt_content: Optional list of system prompt content blocks. **kwargs: Additional attributes to add to the span. Returns: The created span, or None if tracing is not enabled. """ - attributes: Dict[str, AttributeValue] = self._get_common_attributes(operation_name="chat") + attributes: dict[str, AttributeValue] = self._get_common_attributes(operation_name="chat") if custom_trace_attributes: attributes.update(custom_trace_attributes) @@ -304,6 +315,7 @@ def start_model_invoke_span( attributes.update({k: v for k, v in kwargs.items() if isinstance(v, (str, int, float, bool))}) span = self._start_span("chat", parent_span, attributes=attributes, span_kind=trace_api.SpanKind.INTERNAL) + self._add_system_prompt_event(span, system_prompt, system_prompt_content) self._add_event_messages(span, messages) return span @@ -315,7 +327,6 @@ def end_model_invoke_span( usage: Usage, metrics: Metrics, stop_reason: StopReason, - error: Optional[Exception] = None, ) -> None: """End a model invocation span with results and metrics. @@ -324,10 +335,12 @@ def end_model_invoke_span( message: The message response from the model. usage: Token usage information from the model call. metrics: Metrics from the model call. - stop_reason (StopReason): The reason the model stopped generating. - error: Optional exception if the model call failed. + stop_reason: The reason the model stopped generating. """ - attributes: Dict[str, AttributeValue] = { + if not span or not span.is_recording(): + return + + attributes: dict[str, AttributeValue] = { "gen_ai.usage.prompt_tokens": usage["inputTokens"], "gen_ai.usage.input_tokens": usage["inputTokens"], "gen_ai.usage.completion_tokens": usage["outputTokens"], @@ -353,6 +366,7 @@ def end_model_invoke_span( ] ), }, + to_span_attributes=self.is_langfuse, ) else: self._add_event( @@ -361,13 +375,13 @@ def end_model_invoke_span( event_attributes={"finish_reason": str(stop_reason), "message": serialize(message["content"])}, ) - self._end_span(span, attributes, error) + self._end_span(span, attributes) def start_tool_call_span( self, tool: ToolUse, - parent_span: Optional[Span] = None, - custom_trace_attributes: Optional[Mapping[str, AttributeValue]] = None, + parent_span: Span | None = None, + custom_trace_attributes: Mapping[str, AttributeValue] | None = None, **kwargs: Any, ) -> Span: """Start a new span for a tool call. @@ -381,7 +395,7 @@ def start_tool_call_span( Returns: The created span, or None if tracing is not enabled. """ - attributes: Dict[str, AttributeValue] = self._get_common_attributes(operation_name="execute_tool") + attributes: dict[str, AttributeValue] = self._get_common_attributes(operation_name="execute_tool") attributes.update( { "gen_ai.tool.name": tool["name"], @@ -418,6 +432,7 @@ def start_tool_call_span( ] ) }, + to_span_attributes=self.is_langfuse, ) else: self._add_event( @@ -432,9 +447,7 @@ def start_tool_call_span( return span - def end_tool_call_span( - self, span: Span, tool_result: Optional[ToolResult], error: Optional[Exception] = None - ) -> None: + def end_tool_call_span(self, span: Span, tool_result: ToolResult | None, error: Exception | None = None) -> None: """End a tool call span with results. Args: @@ -442,16 +455,14 @@ def end_tool_call_span( tool_result: The result from the tool execution. error: Optional exception if the tool call failed. """ - attributes: Dict[str, AttributeValue] = {} + attributes: dict[str, AttributeValue] = {} + status: str | None = None + content: list[Any] = [] + if tool_result is not None: status = tool_result.get("status") - status_str = str(status) if status is not None else "" - - attributes.update( - { - "gen_ai.tool.status": status_str, - } - ) + content = tool_result.get("content", []) + attributes["gen_ai.tool.status"] = str(status) if status is not None else "" if self.use_latest_genai_conventions: self._add_event( @@ -466,34 +477,39 @@ def end_tool_call_span( { "type": "tool_call_response", "id": tool_result.get("toolUseId", ""), - "response": tool_result.get("content"), + "response": content, } ], } ] ) }, + to_span_attributes=self.is_langfuse, ) else: self._add_event( span, "gen_ai.choice", event_attributes={ - "message": serialize(tool_result.get("content")), + "message": serialize(content), "id": tool_result.get("toolUseId", ""), }, ) - self._end_span(span, attributes, error) + if error is None and status == "error": + error_message = next((b["text"] for b in content if "text" in b), "tool returned error status") + self._end_span(span, attributes, error_message=error_message) + else: + self._end_span(span, attributes, error) def start_event_loop_cycle_span( self, invocation_state: Any, messages: Messages, - parent_span: Optional[Span] = None, - custom_trace_attributes: Optional[Mapping[str, AttributeValue]] = None, + parent_span: Span | None = None, + custom_trace_attributes: Mapping[str, AttributeValue] | None = None, **kwargs: Any, - ) -> Optional[Span]: + ) -> Span: """Start a new span for an event loop cycle. Args: @@ -509,9 +525,8 @@ def start_event_loop_cycle_span( event_loop_cycle_id = str(invocation_state.get("event_loop_cycle_id")) parent_span = parent_span if parent_span else invocation_state.get("event_loop_parent_span") - attributes: Dict[str, AttributeValue] = { - "event_loop.cycle_id": event_loop_cycle_id, - } + attributes: dict[str, AttributeValue] = self._get_common_attributes(operation_name="execute_event_loop_cycle") + attributes["event_loop.cycle_id"] = event_loop_cycle_id if custom_trace_attributes: attributes.update(custom_trace_attributes) @@ -532,8 +547,7 @@ def end_event_loop_cycle_span( self, span: Span, message: Message, - tool_result_message: Optional[Message] = None, - error: Optional[Exception] = None, + tool_result_message: Message | None = None, ) -> None: """End an event loop cycle span with results. @@ -541,10 +555,11 @@ def end_event_loop_cycle_span( span: The span to end. message: The message response from this cycle. tool_result_message: Optional tool result message if a tool was called. - error: Optional exception if the cycle failed. """ - attributes: Dict[str, AttributeValue] = {} - event_attributes: Dict[str, AttributeValue] = {"message": serialize(message["content"])} + if not span or not span.is_recording(): + return + + event_attributes: dict[str, AttributeValue] = {"message": serialize(message["content"])} if tool_result_message: event_attributes["tool.result"] = serialize(tool_result_message["content"]) @@ -563,19 +578,21 @@ def end_event_loop_cycle_span( ] ) }, + to_span_attributes=self.is_langfuse, ) else: self._add_event(span, "gen_ai.choice", event_attributes=event_attributes) - self._end_span(span, attributes, error) + + self._end_span(span) def start_agent_span( self, messages: Messages, agent_name: str, - model_id: Optional[str] = None, - tools: Optional[list] = None, - custom_trace_attributes: Optional[Mapping[str, AttributeValue]] = None, - tools_config: Optional[dict] = None, + model_id: str | None = None, + tools: list | None = None, + custom_trace_attributes: Mapping[str, AttributeValue] | None = None, + tools_config: dict | None = None, **kwargs: Any, ) -> Span: """Start a new span for an agent invocation. @@ -592,7 +609,7 @@ def start_agent_span( Returns: The created span, or None if tracing is not enabled. """ - attributes: Dict[str, AttributeValue] = self._get_common_attributes(operation_name="invoke_agent") + attributes: dict[str, AttributeValue] = self._get_common_attributes(operation_name="invoke_agent") attributes.update( { "gen_ai.agent.name": agent_name, @@ -630,8 +647,8 @@ def start_agent_span( def end_agent_span( self, span: Span, - response: Optional[AgentResult] = None, - error: Optional[Exception] = None, + response: AgentResult | None = None, + error: Exception | None = None, ) -> None: """End an agent span with results and metrics. @@ -640,7 +657,7 @@ def end_agent_span( response: The response from the agent. error: Any error that occurred. """ - attributes: Dict[str, AttributeValue] = {} + attributes: dict[str, AttributeValue] = {} if response: if self.use_latest_genai_conventions: @@ -658,6 +675,7 @@ def end_agent_span( ] ) }, + to_span_attributes=self.is_langfuse, ) else: self._add_event( @@ -667,16 +685,28 @@ def end_agent_span( ) if hasattr(response, "metrics") and hasattr(response.metrics, "accumulated_usage"): - accumulated_usage = response.metrics.accumulated_usage + if self.is_langfuse: + attributes.update({"langfuse.observation.type": "span"}) + if self._use_latest_invocation_tokens: + latest_invocation = response.metrics.latest_agent_invocation + if latest_invocation is None: + logger.warning( + "latest_agent_invocation is None despite _use_latest_invocation_tokens being set" + ) + usage: Usage = Usage(inputTokens=0, outputTokens=0, totalTokens=0) + else: + usage = latest_invocation.usage + else: + usage = response.metrics.accumulated_usage attributes.update( { - "gen_ai.usage.prompt_tokens": accumulated_usage["inputTokens"], - "gen_ai.usage.completion_tokens": accumulated_usage["outputTokens"], - "gen_ai.usage.input_tokens": accumulated_usage["inputTokens"], - "gen_ai.usage.output_tokens": accumulated_usage["outputTokens"], - "gen_ai.usage.total_tokens": accumulated_usage["totalTokens"], - "gen_ai.usage.cache_read_input_tokens": accumulated_usage.get("cacheReadInputTokens", 0), - "gen_ai.usage.cache_write_input_tokens": accumulated_usage.get("cacheWriteInputTokens", 0), + "gen_ai.usage.prompt_tokens": usage["inputTokens"], + "gen_ai.usage.completion_tokens": usage["outputTokens"], + "gen_ai.usage.input_tokens": usage["inputTokens"], + "gen_ai.usage.output_tokens": usage["outputTokens"], + "gen_ai.usage.total_tokens": usage["totalTokens"], + "gen_ai.usage.cache_read_input_tokens": usage.get("cacheReadInputTokens", 0), + "gen_ai.usage.cache_write_input_tokens": usage.get("cacheWriteInputTokens", 0), } ) @@ -698,11 +728,11 @@ def start_multiagent_span( self, task: MultiAgentInput, instance: str, - custom_trace_attributes: Optional[Mapping[str, AttributeValue]] = None, + custom_trace_attributes: Mapping[str, AttributeValue] | None = None, ) -> Span: """Start a new span for swarm invocation.""" operation = f"invoke_{instance}" - attributes: Dict[str, AttributeValue] = self._get_common_attributes(operation) + attributes: dict[str, AttributeValue] = self._get_common_attributes(operation) attributes.update( { "gen_ai.agent.name": instance, @@ -724,6 +754,7 @@ def start_multiagent_span( span, "gen_ai.client.inference.operation.details", {"gen_ai.input.messages": serialize([{"role": "user", "parts": parts}])}, + to_span_attributes=self.is_langfuse, ) else: self._add_event( @@ -737,7 +768,7 @@ def start_multiagent_span( def end_swarm_span( self, span: Span, - result: Optional[str] = None, + result: str | None = None, ) -> None: """End a swarm span with results.""" if result: @@ -755,6 +786,7 @@ def end_swarm_span( ] ) }, + to_span_attributes=self.is_langfuse, ) else: self._add_event( @@ -766,7 +798,7 @@ def end_swarm_span( def _get_common_attributes( self, operation_name: str, - ) -> Dict[str, AttributeValue]: + ) -> dict[str, AttributeValue]: """Returns a dictionary of common attributes based on the convention version used. Args: @@ -790,6 +822,46 @@ def _get_common_attributes( ) return dict(common_attributes) + def _add_system_prompt_event( + self, + span: Span, + system_prompt: str | None = None, + system_prompt_content: list | None = None, + ) -> None: + """Emit system prompt as a span event per OTel GenAI semantic conventions. + + In legacy mode (v1.36), emits a ``gen_ai.system.message`` event. + In latest experimental mode, emits ``gen_ai.system_instructions`` on the + ``gen_ai.client.inference.operation.details`` event, since Strands passes + system instructions separately from chat history. + + Args: + span: The span to add the event to. + system_prompt: Optional system prompt string. + system_prompt_content: Optional list of system prompt content blocks. + """ + if system_prompt is None and system_prompt_content is None: + return + + content_blocks: list[ContentBlock] = ( + system_prompt_content if system_prompt_content else [{"text": system_prompt or ""}] + ) + + if self.use_latest_genai_conventions: + parts = self._map_content_blocks_to_otel_parts(content_blocks) + self._add_event( + span, + "gen_ai.client.inference.operation.details", + {"gen_ai.system_instructions": serialize(parts)}, + to_span_attributes=self.is_langfuse, + ) + else: + self._add_event( + span, + "gen_ai.system.message", + {"content": serialize(content_blocks)}, + ) + def _add_event_messages(self, span: Span, messages: Messages) -> None: """Adds messages as event to the provided span based on the current GenAI conventions. @@ -804,7 +876,10 @@ def _add_event_messages(self, span: Span, messages: Messages) -> None: {"role": message["role"], "parts": self._map_content_blocks_to_otel_parts(message["content"])} ) self._add_event( - span, "gen_ai.client.inference.operation.details", {"gen_ai.input.messages": serialize(input_messages)} + span, + "gen_ai.client.inference.operation.details", + {"gen_ai.input.messages": serialize(input_messages)}, + to_span_attributes=self.is_langfuse, ) else: for message in messages: diff --git a/src/strands/tools/__init__.py b/src/strands/tools/__init__.py index c61f79748..ada49369d 100644 --- a/src/strands/tools/__init__.py +++ b/src/strands/tools/__init__.py @@ -5,6 +5,7 @@ from .decorator import tool from .structured_output import convert_pydantic_to_tool_spec +from .tool_provider import ToolProvider from .tools import InvalidToolUseNameException, PythonAgentTool, normalize_schema, normalize_tool_spec __all__ = [ @@ -14,4 +15,5 @@ "normalize_schema", "normalize_tool_spec", "convert_pydantic_to_tool_spec", + "ToolProvider", ] diff --git a/src/strands/tools/_caller.py b/src/strands/tools/_caller.py index 4a74dec18..0b5408f35 100644 --- a/src/strands/tools/_caller.py +++ b/src/strands/tools/_caller.py @@ -9,12 +9,15 @@ import json import random -from typing import TYPE_CHECKING, Any, Callable +import weakref +from collections.abc import Callable +from typing import TYPE_CHECKING, Any from .._async import run_async from ..tools.executors._executor import ToolExecutor from ..types._events import ToolInterruptEvent from ..types.content import ContentBlock, Message +from ..types.exceptions import ConcurrencyException from ..types.tools import ToolResult, ToolUse if TYPE_CHECKING: @@ -33,7 +36,15 @@ def __init__(self, agent: "Agent | BidiAgent") -> None: """ # WARNING: Do not add any other member variables or methods as this could result in a name conflict with # agent tools and thus break their execution. - self._agent = agent + self._agent_ref = weakref.ref(agent) + + @property + def _agent(self) -> "Agent | BidiAgent": + """Return the agent, raising ReferenceError if it has been garbage collected.""" + agent = self._agent_ref() + if agent is None: + raise ReferenceError("Agent has been garbage collected") + return agent def __getattr__(self, name: str) -> Callable[..., Any]: """Call tool as a function. @@ -73,44 +84,64 @@ def caller( if self._agent._interrupt_state.activated: raise RuntimeError("cannot directly call tool during interrupt") - normalized_name = self._find_normalized_tool_name(name) + if record_direct_tool_call is not None: + should_record_direct_tool_call = record_direct_tool_call + else: + should_record_direct_tool_call = self._agent.record_direct_tool_call - # Create unique tool ID and set up the tool request - tool_id = f"tooluse_{name}_{random.randint(100000000, 999999999)}" - tool_use: ToolUse = { - "toolUseId": tool_id, - "name": normalized_name, - "input": kwargs.copy(), - } - tool_results: list[ToolResult] = [] - invocation_state = kwargs + should_lock = should_record_direct_tool_call - async def acall() -> ToolResult: - async for event in ToolExecutor._stream(self._agent, tool_use, tool_results, invocation_state): - if isinstance(event, ToolInterruptEvent): - self._agent._interrupt_state.deactivate() - raise RuntimeError("cannot raise interrupt in direct tool call") + from ..agent import Agent # Locally imported to avoid circular reference - tool_result = tool_results[0] + acquired_lock = ( + should_lock + and isinstance(self._agent, Agent) + and self._agent._invocation_lock.acquire_lock(blocking=False) + ) + if should_lock and not acquired_lock: + raise ConcurrencyException( + "Direct tool call cannot be made while the agent is in the middle of an invocation. " + "Set record_direct_tool_call=False to allow direct tool calls during agent invocation." + ) - if record_direct_tool_call is not None: - should_record_direct_tool_call = record_direct_tool_call - else: - should_record_direct_tool_call = self._agent.record_direct_tool_call + try: + normalized_name = self._find_normalized_tool_name(name) - if should_record_direct_tool_call: - # Create a record of this tool execution in the message history - await self._record_tool_execution(tool_use, tool_result, user_message_override) + # Create unique tool ID and set up the tool request + tool_id = f"tooluse_{name}_{random.randint(100000000, 999999999)}" + tool_use: ToolUse = { + "toolUseId": tool_id, + "name": normalized_name, + "input": kwargs.copy(), + } + tool_results: list[ToolResult] = [] + invocation_state = kwargs - return tool_result + async def acall() -> ToolResult: + async for event in ToolExecutor._stream(self._agent, tool_use, tool_results, invocation_state): + if isinstance(event, ToolInterruptEvent): + self._agent._interrupt_state.deactivate() + raise RuntimeError("cannot raise interrupt in direct tool call") + + tool_result = tool_results[0] + + if should_record_direct_tool_call: + # Create a record of this tool execution in the message history + await self._record_tool_execution(tool_use, tool_result, user_message_override) - tool_result = run_async(acall) + return tool_result - # Apply conversation management if agent supports it (traditional agents) - if hasattr(self._agent, "conversation_manager"): - self._agent.conversation_manager.apply_management(self._agent) + tool_result = run_async(acall) + + # TODO: https://github.com/strands-agents/sdk-python/issues/1311 + if isinstance(self._agent, Agent): + self._agent.conversation_manager.apply_management(self._agent) + + return tool_result - return tool_result + finally: + if acquired_lock and isinstance(self._agent, Agent): + self._agent._invocation_lock.release() return caller @@ -118,7 +149,7 @@ def _find_normalized_tool_name(self, name: str) -> str: """Lookup the tool represented by name, replacing characters with underscores as necessary.""" tool_registry = self._agent.tool_registry.registry - if tool_registry.get(name, None): + if tool_registry.get(name): return name # If the desired name contains underscores, it might be a placeholder for characters that can't be diff --git a/src/strands/tools/_tool_helpers.py b/src/strands/tools/_tool_helpers.py index d023caeec..3b62337d3 100644 --- a/src/strands/tools/_tool_helpers.py +++ b/src/strands/tools/_tool_helpers.py @@ -6,14 +6,14 @@ # https://github.com/strands-agents/sdk-python/issues/998 @tool(name="noop", description="This is a fake tool that MUST be completely ignored.") -def noop_tool() -> None: +def noop_tool() -> str: """No-op tool to satisfy tool spec requirement when tool messages are present. Some model providers (e.g., Bedrock) will return an error response if tool uses and tool results are present in messages without any tool specs configured. Consequently, if the summarization agent has no registered tools, summarization will fail. As a workaround, we register the no-op tool. """ - pass + return "You MUST NOT use this tool. Respond DIRECTLY to the user." def generate_missing_tool_result_content(tool_use_ids: list[str]) -> list[ContentBlock]: diff --git a/src/strands/tools/decorator.py b/src/strands/tools/decorator.py index 8dc933f51..9207df9b8 100644 --- a/src/strands/tools/decorator.py +++ b/src/strands/tools/decorator.py @@ -43,17 +43,15 @@ def my_tool(param1: str, param2: int = 42) -> dict: import asyncio import functools import inspect +import json import logging +from collections.abc import Callable from typing import ( Annotated, Any, - Callable, Generic, - Optional, ParamSpec, - Type, TypeVar, - Union, cast, get_args, get_origin, @@ -64,6 +62,7 @@ def my_tool(param1: str, param2: int = 42) -> dict: import docstring_parser from pydantic import BaseModel, Field, create_model from pydantic.fields import FieldInfo +from pydantic_core import PydanticSerializationError from typing_extensions import override from ..interrupt import InterruptException @@ -101,7 +100,7 @@ def __init__(self, func: Callable[..., Any], context_param: str | None = None) - """ self.func = func self.signature = inspect.signature(func) - self.type_hints = get_type_hints(func) + self.type_hints = get_type_hints(func, include_extras=True) self._context_param = context_param self._validate_signature() @@ -183,7 +182,7 @@ def _validate_signature(self) -> None: # Found the parameter, no need to check further break - def _create_input_model(self) -> Type[BaseModel]: + def _create_input_model(self) -> type[BaseModel]: """Create a Pydantic model from function signature for input validation. This method analyzes the function's signature, type hints, and docstring to create a Pydantic model that can @@ -201,9 +200,17 @@ def _create_input_model(self) -> Type[BaseModel]: if self._is_special_parameter(name): continue - # Use param.annotation directly to get the raw type hint. Using get_type_hints() - # can cause inconsistent behavior across Python versions for complex Annotated types. - param_type = param.annotation + # Handle PEP 563 (from __future__ import annotations): + # - When PEP 563 is active, param.annotation is a string literal that needs resolution + # - When PEP 563 is not active, param.annotation is the actual type object (may include Annotated) + # We check if param.annotation is a string to determine if we need type hint resolution. + # This preserves Annotated metadata correctly in both cases and is consistent across Python versions. + if isinstance(param.annotation, str): + # PEP 563 active: resolve string annotation + param_type = self.type_hints.get(name, param.annotation) + else: + # PEP 563 not active: use the actual type object directly + param_type = param.annotation if param_type is inspect.Parameter.empty: param_type = Any default = ... if param.default is inspect.Parameter.empty else param.default @@ -321,13 +328,20 @@ def _clean_pydantic_schema(self, schema: dict[str, Any]) -> None: del schema[key] # Process properties to clean up anyOf and similar structures + required_fields = schema.get("required", []) if "properties" in schema: - for _prop_name, prop_schema in schema["properties"].items(): + for prop_name, prop_schema in schema["properties"].items(): # Handle anyOf constructs (common for Optional types) if "anyOf" in prop_schema: any_of = prop_schema["anyOf"] # Handle Optional[Type] case (represented as anyOf[Type, null]) - if len(any_of) == 2 and any(item.get("type") == "null" for item in any_of): + # Only simplify when the field is not required; required nullable + # fields need anyOf preserved so the model can pass null. + if ( + prop_name not in required_fields + and len(any_of) == 2 + and any(item.get("type") == "null" for item in any_of) + ): # Find the non-null type for item in any_of: if item.get("type") != "null": @@ -463,7 +477,7 @@ def __init__( functools.update_wrapper(wrapper=self, wrapped=self._tool_func) - def __get__(self, instance: Any, obj_type: Optional[Type] = None) -> "DecoratedFunctionTool[P, R]": + def __get__(self, instance: Any, obj_type: type | None = None) -> "DecoratedFunctionTool[P, R]": """Descriptor protocol implementation for proper method binding. This method enables the decorated function to work correctly when used as a class method. @@ -529,6 +543,31 @@ def tool_spec(self) -> ToolSpec: """ return self._tool_spec + @tool_spec.setter + def tool_spec(self, value: ToolSpec) -> None: + """Set the tool specification. + + This allows runtime modification of the tool's schema, enabling dynamic + tool configurations based on feature flags or other runtime conditions. + + Args: + value: The new tool specification. + + Raises: + ValueError: If the spec fails structural validation (wrong name or + missing required field). + """ + if value.get("name") != self._tool_name: + raise ValueError( + f"cannot change tool name via tool_spec (expected '{self._tool_name}', got '{value.get('name')}')" + ) + + for field in ("description", "inputSchema"): + if field not in value: + raise ValueError(f"tool_spec must contain '{field}'") + + self._tool_spec = value + @property def tool_type(self) -> str: """Get the type of the tool. @@ -608,6 +647,7 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw "status": "error", "content": [{"text": f"Error: {error_msg}"}], }, + exception=e, ) except Exception as e: # Return error result with exception details for any other error @@ -620,23 +660,38 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw "status": "error", "content": [{"text": f"Error: {error_type} - {error_msg}"}], }, + exception=e, ) - def _wrap_tool_result(self, tool_use_d: str, result: Any) -> ToolResultEvent: + def _wrap_tool_result(self, tool_use_d: str, result: Any, exception: Exception | None = None) -> ToolResultEvent: # FORMAT THE RESULT for Strands Agent if isinstance(result, dict) and "status" in result and "content" in result: # Result is already in the expected format, just add toolUseId result["toolUseId"] = tool_use_d - return ToolResultEvent(cast(ToolResult, result)) + return ToolResultEvent(cast(ToolResult, result), exception=exception) else: # Wrap any other return value in the standard format - # Always include at least one content item for consistency + # Serialize to JSON for consistent, parseable output (except strings) + if isinstance(result, str): + text = result + elif isinstance(result, BaseModel): + try: + text = result.model_dump_json() + except PydanticSerializationError: + text = str(result) + else: + try: + text = json.dumps(result) + except (TypeError, ValueError): + text = str(result) + return ToolResultEvent( { "toolUseId": tool_use_d, "status": "success", - "content": [{"text": str(result)}], - } + "content": [{"text": text}], + }, + exception=exception, ) @property @@ -666,20 +721,20 @@ def tool(__func: Callable[P, R]) -> DecoratedFunctionTool[P, R]: ... # Handle @decorator() @overload def tool( - description: Optional[str] = None, - inputSchema: Optional[JSONSchema] = None, - name: Optional[str] = None, + description: str | None = None, + inputSchema: JSONSchema | None = None, + name: str | None = None, context: bool | str = False, ) -> Callable[[Callable[P, R]], DecoratedFunctionTool[P, R]]: ... # Suppressing the type error because we want callers to be able to use both `tool` and `tool()` at the # call site, but the actual implementation handles that and it's not representable via the type-system def tool( # type: ignore - func: Optional[Callable[P, R]] = None, - description: Optional[str] = None, - inputSchema: Optional[JSONSchema] = None, - name: Optional[str] = None, + func: Callable[P, R] | None = None, + description: str | None = None, + inputSchema: JSONSchema | None = None, + name: str | None = None, context: bool | str = False, -) -> Union[DecoratedFunctionTool[P, R], Callable[[Callable[P, R]], DecoratedFunctionTool[P, R]]]: +) -> DecoratedFunctionTool[P, R] | Callable[[Callable[P, R]], DecoratedFunctionTool[P, R]]: """Decorator that transforms a Python function into a Strands tool. This decorator seamlessly enables a function to be called both as a regular Python function and as a Strands tool. diff --git a/src/strands/tools/executors/_executor.py b/src/strands/tools/executors/_executor.py index a4f9e7e1f..3993f332b 100644 --- a/src/strands/tools/executors/_executor.py +++ b/src/strands/tools/executors/_executor.py @@ -7,7 +7,8 @@ import abc import logging import time -from typing import TYPE_CHECKING, Any, AsyncGenerator, cast +from collections.abc import AsyncGenerator +from typing import TYPE_CHECKING, Any, cast from opentelemetry import trace as trace_api @@ -49,16 +50,19 @@ async def _invoke_before_tool_call_hook( invocation_state: dict[str, Any], ) -> tuple[BeforeToolCallEvent | BidiBeforeToolCallEvent, list[Interrupt]]: """Invoke the appropriate before tool call hook based on agent type.""" - event_cls = BeforeToolCallEvent if ToolExecutor._is_agent(agent) else BidiBeforeToolCallEvent - return await agent.hooks.invoke_callbacks_async( - event_cls( - agent=agent, - selected_tool=tool_func, - tool_use=tool_use, - invocation_state=invocation_state, - ) + kwargs = { + "selected_tool": tool_func, + "tool_use": tool_use, + "invocation_state": invocation_state, + } + event = ( + BeforeToolCallEvent(agent=cast("Agent", agent), **kwargs) + if ToolExecutor._is_agent(agent) + else BidiBeforeToolCallEvent(agent=cast("BidiAgent", agent), **kwargs) ) + return await agent.hooks.invoke_callbacks_async(event) + @staticmethod async def _invoke_after_tool_call_hook( agent: "Agent | BidiAgent", @@ -70,19 +74,22 @@ async def _invoke_after_tool_call_hook( cancel_message: str | None = None, ) -> tuple[AfterToolCallEvent | BidiAfterToolCallEvent, list[Interrupt]]: """Invoke the appropriate after tool call hook based on agent type.""" - event_cls = AfterToolCallEvent if ToolExecutor._is_agent(agent) else BidiAfterToolCallEvent - return await agent.hooks.invoke_callbacks_async( - event_cls( - agent=agent, - selected_tool=selected_tool, - tool_use=tool_use, - invocation_state=invocation_state, - result=result, - exception=exception, - cancel_message=cancel_message, - ) + kwargs = { + "selected_tool": selected_tool, + "tool_use": tool_use, + "invocation_state": invocation_state, + "result": result, + "exception": exception, + "cancel_message": cancel_message, + } + event = ( + AfterToolCallEvent(agent=cast("Agent", agent), **kwargs) + if ToolExecutor._is_agent(agent) + else BidiAfterToolCallEvent(agent=cast("BidiAgent", agent), **kwargs) ) + return await agent.hooks.invoke_callbacks_async(event) + @staticmethod async def _stream( agent: "Agent | BidiAgent", @@ -141,113 +148,148 @@ async def _stream( } ) - before_event, interrupts = await ToolExecutor._invoke_before_tool_call_hook( - agent, tool_func, tool_use, invocation_state - ) - - if interrupts: - yield ToolInterruptEvent(tool_use, interrupts) - return - - if before_event.cancel_tool: - cancel_message = ( - before_event.cancel_tool if isinstance(before_event.cancel_tool, str) else "tool cancelled by user" + # Retry loop for tool execution - hooks can set after_event.retry = True to retry + while True: + before_event, interrupts = await ToolExecutor._invoke_before_tool_call_hook( + agent, tool_func, tool_use, invocation_state ) - yield ToolCancelEvent(tool_use, cancel_message) - cancel_result: ToolResult = { - "toolUseId": str(tool_use.get("toolUseId")), - "status": "error", - "content": [{"text": cancel_message}], - } + if interrupts: + yield ToolInterruptEvent(tool_use, interrupts) + return - after_event, _ = await ToolExecutor._invoke_after_tool_call_hook( - agent, None, tool_use, invocation_state, cancel_result, cancel_message=cancel_message - ) - yield ToolResultEvent(after_event.result) - tool_results.append(after_event.result) - return - - try: - selected_tool = before_event.selected_tool - tool_use = before_event.tool_use - invocation_state = before_event.invocation_state - - if not selected_tool: - if tool_func == selected_tool: - logger.error( - "tool_name=<%s>, available_tools=<%s> | tool not found in registry", - tool_name, - list(agent.tool_registry.registry.keys()), - ) - else: - logger.debug( - "tool_name=<%s>, tool_use_id=<%s> | a hook resulted in a non-existing tool call", - tool_name, - str(tool_use.get("toolUseId")), - ) + if before_event.cancel_tool: + cancel_message = ( + before_event.cancel_tool if isinstance(before_event.cancel_tool, str) else "tool cancelled by user" + ) + yield ToolCancelEvent(tool_use, cancel_message) - result: ToolResult = { + cancel_result: ToolResult = { "toolUseId": str(tool_use.get("toolUseId")), "status": "error", - "content": [{"text": f"Unknown tool: {tool_name}"}], + "content": [{"text": cancel_message}], } after_event, _ = await ToolExecutor._invoke_after_tool_call_hook( - agent, selected_tool, tool_use, invocation_state, result + agent, + None, + tool_use, + invocation_state, + cancel_result, + cancel_message=cancel_message, ) yield ToolResultEvent(after_event.result) tool_results.append(after_event.result) return - if structured_output_context.is_enabled: - kwargs["structured_output_context"] = structured_output_context - async for event in selected_tool.stream(tool_use, invocation_state, **kwargs): - # Internal optimization; for built-in AgentTools, we yield TypedEvents out of .stream() - # so that we don't needlessly yield ToolStreamEvents for non-generator callbacks. - # In which case, as soon as we get a ToolResultEvent we're done and for ToolStreamEvent - # we yield it directly; all other cases (non-sdk AgentTools), we wrap events in - # ToolStreamEvent and the last event is just the result. - - if isinstance(event, ToolInterruptEvent): - yield event - return - - if isinstance(event, ToolResultEvent): - # below the last "event" must point to the tool_result - event = event.tool_result - break - if isinstance(event, ToolStreamEvent): - yield event - else: - yield ToolStreamEvent(tool_use, event) + try: + selected_tool = before_event.selected_tool + tool_use = before_event.tool_use + invocation_state = before_event.invocation_state + + if not selected_tool: + if tool_func == selected_tool: + logger.error( + "tool_name=<%s>, available_tools=<%s> | tool not found in registry", + tool_name, + list(agent.tool_registry.registry.keys()), + ) + else: + logger.debug( + "tool_name=<%s>, tool_use_id=<%s> | a hook resulted in a non-existing tool call", + tool_name, + str(tool_use.get("toolUseId")), + ) + + result: ToolResult = { + "toolUseId": str(tool_use.get("toolUseId")), + "status": "error", + "content": [{"text": f"Unknown tool: {tool_name}"}], + } + + unknown_tool_error = Exception(f"Unknown tool: {tool_name}") + after_event, _ = await ToolExecutor._invoke_after_tool_call_hook( + agent, selected_tool, tool_use, invocation_state, result, exception=unknown_tool_error + ) + # Check if retry requested for unknown tool error + # Use getattr because BidiAfterToolCallEvent doesn't have retry attribute + if getattr(after_event, "retry", False): + logger.debug("tool_name=<%s> | retry requested, retrying tool call", tool_name) + continue + yield ToolResultEvent(after_event.result, exception=after_event.exception) + tool_results.append(after_event.result) + return + if structured_output_context.is_enabled: + kwargs["structured_output_context"] = structured_output_context + + exception: Exception | None = None + + async for event in selected_tool.stream(tool_use, invocation_state, **kwargs): + # Internal optimization; for built-in AgentTools, we yield TypedEvents out of .stream() + # so that we don't needlessly yield ToolStreamEvents for non-generator callbacks. + # In which case, as soon as we get a ToolResultEvent we're done and for ToolStreamEvent + # we yield it directly; all other cases (non-sdk AgentTools), we wrap events in + # ToolStreamEvent and the last event is just the result. + + if isinstance(event, ToolInterruptEvent): + # Register any interrupts not already in the agent's state. + # For normal hooks this is a no-op (already registered by _Interruptible.interrupt()). + # For sub-agent interrupts propagated via _AgentAsTool, this is where they get + # registered so that _interrupt_state.resume() can locate them by ID. + for interrupt in event.interrupts: + agent._interrupt_state.interrupts.setdefault(interrupt.id, interrupt) + yield event + return + + if isinstance(event, ToolResultEvent): + # Preserve exception from decorated tools before extracting tool_result + exception = event.exception + # below the last "event" must point to the tool_result + event = event.tool_result + break + + if isinstance(event, ToolStreamEvent): + yield event + else: + yield ToolStreamEvent(tool_use, event) + + result = cast(ToolResult, event) - result = cast(ToolResult, event) + after_event, _ = await ToolExecutor._invoke_after_tool_call_hook( + agent, selected_tool, tool_use, invocation_state, result, exception=exception + ) - after_event, _ = await ToolExecutor._invoke_after_tool_call_hook( - agent, selected_tool, tool_use, invocation_state, result - ) + # Check if retry requested (getattr for BidiAfterToolCallEvent compatibility) + if getattr(after_event, "retry", False): + logger.debug("tool_name=<%s> | retry requested, retrying tool call", tool_name) + continue - yield ToolResultEvent(after_event.result) - tool_results.append(after_event.result) + yield ToolResultEvent(after_event.result, exception=after_event.exception) + tool_results.append(after_event.result) + return - except Exception as e: - logger.exception("tool_name=<%s> | failed to process tool", tool_name) - error_result: ToolResult = { - "toolUseId": str(tool_use.get("toolUseId")), - "status": "error", - "content": [{"text": f"Error: {str(e)}"}], - } + except Exception as e: + logger.exception("tool_name=<%s> | failed to process tool", tool_name) + error_result: ToolResult = { + "toolUseId": str(tool_use.get("toolUseId")), + "status": "error", + "content": [{"text": f"Error: {str(e)}"}], + } - after_event, _ = await ToolExecutor._invoke_after_tool_call_hook( - agent, selected_tool, tool_use, invocation_state, error_result, exception=e - ) - yield ToolResultEvent(after_event.result) - tool_results.append(after_event.result) + after_event, _ = await ToolExecutor._invoke_after_tool_call_hook( + agent, selected_tool, tool_use, invocation_state, error_result, exception=e + ) + # Check if retry requested (getattr for BidiAfterToolCallEvent compatibility) + if getattr(after_event, "retry", False): + logger.debug("tool_name=<%s> | retry requested after exception, retrying tool call", tool_name) + continue + yield ToolResultEvent(after_event.result, exception=after_event.exception) + tool_results.append(after_event.result) + return @staticmethod async def _stream_with_trace( - agent: "Agent | BidiAgent", + agent: "Agent", tool_use: ToolUse, tool_results: list[ToolResult], cycle_trace: Trace, @@ -259,7 +301,7 @@ async def _stream_with_trace( """Execute tool with tracing and metrics collection. Args: - agent: The agent (Agent or BidiAgent) for which the tool is being executed. + agent: The agent for which the tool is being executed. tool_use: Metadata and inputs for the tool to be executed. tool_results: List of tool results from each tool execution. cycle_trace: Trace object for the current event loop cycle. @@ -302,13 +344,13 @@ async def _stream_with_trace( agent.event_loop_metrics.add_tool_usage(tool_use, tool_duration, tool_trace, tool_success, message) cycle_trace.add_child(tool_trace) - tracer.end_tool_call_span(tool_call_span, result) + tracer.end_tool_call_span(tool_call_span, result, error=result_event.exception) @abc.abstractmethod # pragma: no cover def _execute( self, - agent: "Agent | BidiAgent", + agent: "Agent", tool_uses: list[ToolUse], tool_results: list[ToolResult], cycle_trace: Trace, @@ -319,7 +361,7 @@ def _execute( """Execute the given tools according to this executor's strategy. Args: - agent: The agent (Agent or BidiAgent) for which tools are being executed. + agent: The agent for which tools are being executed. tool_uses: Metadata and inputs for the tools to be executed. tool_results: List of tool results from each tool execution. cycle_trace: Trace object for the current event loop cycle. diff --git a/src/strands/tools/executors/concurrent.py b/src/strands/tools/executors/concurrent.py index da5c1ff10..835e5abff 100644 --- a/src/strands/tools/executors/concurrent.py +++ b/src/strands/tools/executors/concurrent.py @@ -1,7 +1,8 @@ """Concurrent tool executor implementation.""" import asyncio -from typing import TYPE_CHECKING, Any, AsyncGenerator +from collections.abc import AsyncGenerator +from typing import TYPE_CHECKING, Any from typing_extensions import override @@ -12,7 +13,6 @@ if TYPE_CHECKING: # pragma: no cover from ...agent import Agent - from ...experimental.bidi import BidiAgent from ..structured_output._structured_output_context import StructuredOutputContext @@ -22,7 +22,7 @@ class ConcurrentToolExecutor(ToolExecutor): @override async def _execute( self, - agent: "Agent | BidiAgent", + agent: "Agent", tool_uses: list[ToolUse], tool_results: list[ToolResult], cycle_trace: Trace, @@ -33,7 +33,7 @@ async def _execute( """Execute tools concurrently. Args: - agent: The agent (Agent or BidiAgent) for which tools are being executed. + agent: The agent for which tools are being executed. tool_uses: Metadata and inputs for the tools to be executed. tool_results: List of tool results from each tool execution. cycle_trace: Trace object for the current event loop cycle. @@ -48,38 +48,47 @@ async def _execute( task_events = [asyncio.Event() for _ in tool_uses] stop_event = object() - tasks = [ - asyncio.create_task( - self._task( - agent, - tool_use, - tool_results, - cycle_trace, - cycle_span, - invocation_state, - task_id, - task_queue, - task_events[task_id], - stop_event, - structured_output_context, + tasks = [] + try: + for task_id, tool_use in enumerate(tool_uses): + tasks.append( + asyncio.create_task( + self._task( + agent, + tool_use, + tool_results, + cycle_trace, + cycle_span, + invocation_state, + task_id, + task_queue, + task_events[task_id], + stop_event, + structured_output_context, + ) + ) ) - ) - for task_id, tool_use in enumerate(tool_uses) - ] - task_count = len(tasks) - while task_count: - task_id, event = await task_queue.get() - if event is stop_event: - task_count -= 1 - continue + task_count = len(tasks) + while task_count: + task_id, event = await task_queue.get() + if event is stop_event: + task_count -= 1 + continue - yield event - task_events[task_id].set() + if isinstance(event, Exception): + raise event + + yield event + task_events[task_id].set() + finally: + for task in tasks: + task.cancel() + await asyncio.gather(*tasks, return_exceptions=True) async def _task( self, - agent: "Agent | BidiAgent", + agent: "Agent", tool_use: ToolUse, tool_results: list[ToolResult], cycle_trace: Trace, @@ -94,7 +103,7 @@ async def _task( """Execute a single tool and put results in the task queue. Args: - agent: The agent (Agent or BidiAgent) executing the tool. + agent: The agent executing the tool. tool_use: Tool use metadata and inputs. tool_results: List of tool results from each tool execution. cycle_trace: Trace object for the current event loop cycle. @@ -115,5 +124,8 @@ async def _task( await task_event.wait() task_event.clear() + except Exception as e: + task_queue.put_nowait((task_id, e)) + finally: task_queue.put_nowait((task_id, stop_event)) diff --git a/src/strands/tools/executors/sequential.py b/src/strands/tools/executors/sequential.py index 6163fc195..dc5b9a5d9 100644 --- a/src/strands/tools/executors/sequential.py +++ b/src/strands/tools/executors/sequential.py @@ -1,6 +1,7 @@ """Sequential tool executor implementation.""" -from typing import TYPE_CHECKING, Any, AsyncGenerator +from collections.abc import AsyncGenerator +from typing import TYPE_CHECKING, Any from typing_extensions import override @@ -11,7 +12,6 @@ if TYPE_CHECKING: # pragma: no cover from ...agent import Agent - from ...experimental.bidi import BidiAgent from ..structured_output._structured_output_context import StructuredOutputContext @@ -21,7 +21,7 @@ class SequentialToolExecutor(ToolExecutor): @override async def _execute( self, - agent: "Agent | BidiAgent", + agent: "Agent", tool_uses: list[ToolUse], tool_results: list[ToolResult], cycle_trace: Trace, @@ -34,7 +34,7 @@ async def _execute( Breaks early if an interrupt is raised by the user. Args: - agent: The agent (Agent or BidiAgent) for which tools are being executed. + agent: The agent for which tools are being executed. tool_uses: Metadata and inputs for the tools to be executed. tool_results: List of tool results from each tool execution. cycle_trace: Trace object for the current event loop cycle. diff --git a/src/strands/tools/loader.py b/src/strands/tools/loader.py index 6f745b728..2115cdee8 100644 --- a/src/strands/tools/loader.py +++ b/src/strands/tools/loader.py @@ -9,7 +9,7 @@ from pathlib import Path from posixpath import expanduser from types import ModuleType -from typing import List, cast +from typing import cast from ..types.tools import AgentTool from .decorator import DecoratedFunctionTool @@ -20,7 +20,7 @@ _TOOL_MODULE_PREFIX = "_strands_tool_" -def load_tool_from_string(tool_string: str) -> List[AgentTool]: +def load_tool_from_string(tool_string: str) -> list[AgentTool]: """Load tools follows strands supported input string formats. This function can load a tool based on a string in the following ways: @@ -42,7 +42,7 @@ def load_tool_from_string(tool_string: str) -> List[AgentTool]: return load_tools_from_module_path(tool_string) -def load_tools_from_file_path(tool_path: str) -> List[AgentTool]: +def load_tools_from_file_path(tool_path: str) -> list[AgentTool]: """Load module from specified path, and then load tools from that module. This function attempts to load the passed in path as a python module, and if it succeeds, @@ -116,7 +116,7 @@ def load_tools_from_module(module: ModuleType, module_name: str) -> list[AgentTo # Try and see if any of the attributes in the module are function-based tools decorated with @tool # This means that there may be more than one tool available in this module, so we load them all - function_tools: List[AgentTool] = [] + function_tools: list[AgentTool] = [] # Function tools will appear as attributes in the module for attr_name in dir(module): attr = getattr(module, attr_name) @@ -153,7 +153,7 @@ class ToolLoader: """Handles loading of tools from different sources.""" @staticmethod - def load_python_tools(tool_path: str, tool_name: str) -> List[AgentTool]: + def load_python_tools(tool_path: str, tool_name: str) -> list[AgentTool]: """DEPRECATED: Load a Python tool module and return all discovered function-based tools as a list. This method always returns a list of AgentTool (possibly length 1). It is the @@ -206,7 +206,7 @@ def load_python_tools(tool_path: str, tool_name: str) -> List[AgentTool]: spec.loader.exec_module(module) # Collect function-based tools decorated with @tool - function_tools: List[AgentTool] = [] + function_tools: list[AgentTool] = [] for attr_name in dir(module): attr = getattr(module, attr_name) if isinstance(attr, DecoratedFunctionTool): diff --git a/src/strands/tools/mcp/__init__.py b/src/strands/tools/mcp/__init__.py index cfa841c46..8d2c1daa2 100644 --- a/src/strands/tools/mcp/__init__.py +++ b/src/strands/tools/mcp/__init__.py @@ -8,6 +8,7 @@ from .mcp_agent_tool import MCPAgentTool from .mcp_client import MCPClient, ToolFilters +from .mcp_tasks import TasksConfig from .mcp_types import MCPTransport -__all__ = ["MCPAgentTool", "MCPClient", "MCPTransport", "ToolFilters"] +__all__ = ["MCPAgentTool", "MCPClient", "MCPTransport", "TasksConfig", "ToolFilters"] diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index bb5dca19c..270012fde 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -9,32 +9,49 @@ import asyncio import base64 +import contextvars +import json import logging +import sys import threading import uuid from asyncio import AbstractEventLoop +from collections.abc import Callable, Coroutine, Sequence from concurrent import futures from datetime import timedelta +from re import Pattern from types import TracebackType -from typing import Any, Callable, Coroutine, Dict, Optional, Pattern, Sequence, TypeVar, Union, cast +from typing import Any, TypeVar, cast import anyio from mcp import ClientSession, ListToolsResult from mcp.client.session import ElicitationFnT -from mcp.types import BlobResourceContents, GetPromptResult, ListPromptsResult, TextResourceContents +from mcp.shared.exceptions import McpError +from mcp.types import ( + BlobResourceContents, + ElicitationRequiredErrorData, + GetPromptResult, + ListPromptsResult, + ListResourcesResult, + ListResourceTemplatesResult, + ReadResourceResult, + TextResourceContents, +) from mcp.types import CallToolResult as MCPCallToolResult from mcp.types import EmbeddedResource as MCPEmbeddedResource from mcp.types import ImageContent as MCPImageContent from mcp.types import TextContent as MCPTextContent +from pydantic import AnyUrl from typing_extensions import Protocol, TypedDict -from ...experimental.tools import ToolProvider from ...types import PaginatedList from ...types.exceptions import MCPClientInitializationError, ToolProviderException from ...types.media import ImageFormat from ...types.tools import AgentTool, ToolResultContent, ToolResultStatus +from ..tool_provider import ToolProvider from .mcp_agent_tool import MCPAgentTool from .mcp_instrumentation import mcp_instrumentation +from .mcp_tasks import DEFAULT_TASK_CONFIG, DEFAULT_TASK_POLL_TIMEOUT, DEFAULT_TASK_TTL, TasksConfig from .mcp_types import MCPToolResult, MCPTransport logger = logging.getLogger(__name__) @@ -61,7 +78,7 @@ class ToolFilters(TypedDict, total=False): rejected: list[_ToolMatcher] -MIME_TO_FORMAT: Dict[str, ImageFormat] = { +MIME_TO_FORMAT: dict[str, ImageFormat] = { "image/jpeg": "jpeg", "image/jpg": "jpeg", "image/png": "png", @@ -94,10 +111,6 @@ class MCPClient(ToolProvider): The connection runs in a background thread to avoid blocking the main application thread while maintaining communication with the MCP service. When structured content is available from MCP tools, it will be returned as the last item in the content array of the ToolResult. - - Warning: - This class implements the experimental ToolProvider interface and its methods - are subject to change. """ def __init__( @@ -107,7 +120,8 @@ def __init__( startup_timeout: int = 30, tool_filters: ToolFilters | None = None, prefix: str | None = None, - elicitation_callback: Optional[ElicitationFnT] = None, + elicitation_callback: ElicitationFnT | None = None, + tasks_config: TasksConfig | None = None, ) -> None: """Initialize a new MCP Server connection. @@ -118,6 +132,9 @@ def __init__( tool_filters: Optional filters to apply to tools. prefix: Optional prefix for tool names. elicitation_callback: Optional callback function to handle elicitation requests from the MCP server. + tasks_config: Configuration for MCP task-augmented execution for long-running tools. + If provided (not None), enables task-augmented execution for tools that support it. + See TasksConfig for details. This feature is experimental and subject to change. """ self._startup_timeout = startup_timeout self._tool_filters = tool_filters @@ -140,8 +157,19 @@ def __init__( self._background_thread_event_loop: AbstractEventLoop | None = None self._loaded_tools: list[MCPAgentTool] | None = None self._tool_provider_started = False + self.server_instructions: str | None = None self._consumers: set[Any] = set() + # Task support configuration and caching + self._tasks_config = tasks_config + self._server_task_capable: bool | None = None + + # Conditionally set up the task support cache (old SDK versions don't expose TaskExecutionMode) + if self._is_tasks_enabled(): + from mcp.types import TaskExecutionMode + + self._tool_task_support_cache: dict[str, TaskExecutionMode] = {} + def __enter__(self) -> "MCPClient": """Context manager entry point which initializes the MCP server connection. @@ -150,7 +178,12 @@ def __enter__(self) -> "MCPClient": """ return self.start() - def __exit__(self, exc_type: BaseException, exc_val: BaseException, exc_tb: TracebackType) -> None: + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: """Context manager exit point that cleans up resources.""" self.stop(exc_type, exc_val, exc_tb) @@ -170,7 +203,11 @@ def start(self) -> "MCPClient": raise MCPClientInitializationError("the client session is currently running") self._log_debug_with_thread("entering MCPClient context") - self._background_thread = threading.Thread(target=self._background_task, args=[], daemon=True) + # Copy context vars to propagate to the background thread + # This ensures that context set in the main thread is accessible in the background thread + # See: https://github.com/strands-agents/sdk-python/issues/1440 + ctx = contextvars.copy_context() + self._background_thread = threading.Thread(target=ctx.run, args=(self._background_task,), daemon=True) self._background_thread.start() self._log_debug_with_thread("background thread started, waiting for ready event") try: @@ -188,10 +225,10 @@ def start(self) -> "MCPClient": logger.exception("client failed to initialize") # Pass None for exc_type, exc_val, exc_tb since this isn't a context manager exit self.stop(None, None, None) - raise MCPClientInitializationError("the client initialization failed") from e + raise MCPClientInitializationError(f"the client initialization failed: {e}") from e return self - # ToolProvider interface methods (experimental, as ToolProvider is experimental) + # ToolProvider interface methods async def load_tools(self, **kwargs: Any) -> Sequence[AgentTool]: """Load and return tools from the MCP server. @@ -287,7 +324,10 @@ def remove_consumer(self, consumer_id: Any, **kwargs: Any) -> None: # MCP-specific methods def stop( - self, exc_type: Optional[BaseException], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType] + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, ) -> None: """Signals the background thread to stop and waits for it to complete, ensuring proper cleanup of all resources. @@ -314,6 +354,15 @@ def stop( """ self._log_debug_with_thread("exiting MCPClient context") + # Skip cleanup during interpreter finalization. On Python 3.14+, joining a + # non-daemon thread at shutdown raises PythonFinalizationError; even though + # our background thread is a daemon and will be reclaimed automatically, + # the join call itself produces noisy tracebacks on stderr when the GC + # reaches Agent.__del__ during finalization. See issue #2143. + if sys.is_finalizing(): + self._log_debug_with_thread("interpreter is finalizing, skipping MCPClient cleanup") + return + # Only try to signal close future if we have a background thread if self._background_thread is not None: # Signal close future if event loop exists @@ -330,6 +379,9 @@ async def _set_close_event() -> None: self._log_debug_with_thread("waiting for background thread to join") self._background_thread.join() + if self._background_thread_event_loop is not None: + self._background_thread_event_loop.close() + self._log_debug_with_thread("background thread is closed, MCPClient context exited") # Reset fields to allow instance reuse @@ -341,6 +393,8 @@ async def _set_close_event() -> None: self._loaded_tools = None self._tool_provider_started = False self._consumers = set() + self._server_task_capable = None + self._tool_task_support_cache = {} if self._close_exception: exception = self._close_exception @@ -383,6 +437,13 @@ async def _list_tools_async() -> ListToolsResult: mcp_tools = [] for tool in list_tools_response.tools: + if self._is_tasks_enabled(): + # Cache taskSupport for task-augmented execution decisions + task_support = None + if tool.execution is not None and tool.execution.taskSupport is not None: + task_support = tool.execution.taskSupport + self._tool_task_support_cache[tool.name] = task_support or "forbidden" + # Apply prefix if specified if effective_prefix: prefixed_name = f"{effective_prefix}_{tool.name}" @@ -398,7 +459,7 @@ async def _list_tools_async() -> ListToolsResult: self._log_debug_with_thread("successfully adapted %d MCP tools", len(mcp_tools)) return PaginatedList[MCPAgentTool](mcp_tools, token=list_tools_response.nextCursor) - def list_prompts_sync(self, pagination_token: Optional[str] = None) -> ListPromptsResult: + def list_prompts_sync(self, pagination_token: str | None = None) -> ListPromptsResult: """Synchronously retrieves the list of available prompts from the MCP server. This method calls the asynchronous list_prompts method on the MCP session @@ -446,25 +507,145 @@ async def _get_prompt_async() -> GetPromptResult: return get_prompt_result + def list_resources_sync(self, pagination_token: str | None = None) -> ListResourcesResult: + """Synchronously retrieves the list of available resources from the MCP server. + + This method calls the asynchronous list_resources method on the MCP session + and returns the raw ListResourcesResult with pagination support. + + Args: + pagination_token: Optional token for pagination + + Returns: + ListResourcesResult: The raw MCP response containing resources and pagination info + """ + self._log_debug_with_thread("listing MCP resources synchronously") + if not self._is_session_active(): + raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) + + async def _list_resources_async() -> ListResourcesResult: + return await cast(ClientSession, self._background_thread_session).list_resources(cursor=pagination_token) + + list_resources_result: ListResourcesResult = self._invoke_on_background_thread(_list_resources_async()).result() + self._log_debug_with_thread("received %d resources from MCP server", len(list_resources_result.resources)) + + return list_resources_result + + def read_resource_sync(self, uri: AnyUrl | str) -> ReadResourceResult: + """Synchronously reads a resource from the MCP server. + + Args: + uri: The URI of the resource to read + + Returns: + ReadResourceResult: The resource content from the MCP server + """ + self._log_debug_with_thread("reading MCP resource synchronously: %s", uri) + if not self._is_session_active(): + raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) + + async def _read_resource_async() -> ReadResourceResult: + # Convert string to AnyUrl if needed + resource_uri = AnyUrl(uri) if isinstance(uri, str) else uri + return await cast(ClientSession, self._background_thread_session).read_resource(resource_uri) + + read_resource_result: ReadResourceResult = self._invoke_on_background_thread(_read_resource_async()).result() + self._log_debug_with_thread("received resource content from MCP server") + + return read_resource_result + + def list_resource_templates_sync(self, pagination_token: str | None = None) -> ListResourceTemplatesResult: + """Synchronously retrieves the list of available resource templates from the MCP server. + + Resource templates define URI patterns that can be used to access resources dynamically. + + Args: + pagination_token: Optional token for pagination + + Returns: + ListResourceTemplatesResult: The raw MCP response containing resource templates and pagination info + """ + self._log_debug_with_thread("listing MCP resource templates synchronously") + if not self._is_session_active(): + raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) + + async def _list_resource_templates_async() -> ListResourceTemplatesResult: + return await cast(ClientSession, self._background_thread_session).list_resource_templates( + cursor=pagination_token + ) + + list_resource_templates_result: ListResourceTemplatesResult = self._invoke_on_background_thread( + _list_resource_templates_async() + ).result() + self._log_debug_with_thread( + "received %d resource templates from MCP server", len(list_resource_templates_result.resourceTemplates) + ) + + return list_resource_templates_result + + def _create_call_tool_coroutine( + self, + name: str, + arguments: dict[str, Any] | None, + read_timeout_seconds: timedelta | None, + meta: dict[str, Any] | None = None, + ) -> Coroutine[Any, Any, MCPCallToolResult]: + """Create the appropriate coroutine for calling a tool. + + This method encapsulates the decision logic for whether to use task-augmented + execution or direct call_tool, returning the appropriate coroutine. + + Args: + name: Name of the tool to call. + arguments: Optional arguments to pass to the tool. + read_timeout_seconds: Optional timeout for the tool call. + meta: Optional metadata to pass to the tool call per MCP spec (_meta). + + Returns: + A coroutine that will execute the tool call. + """ + use_task = self._should_use_task(name) + + if use_task: + self._log_debug_with_thread("tool=<%s> | using task-augmented execution", name) + + async def _call_as_task() -> MCPCallToolResult: + # When task-augmented execution is used, use the read_timeout_seconds parameter + # (which is a timedelta) for the polling timeout. + return await self._call_tool_as_task_and_poll_async( + name, arguments, poll_timeout=read_timeout_seconds, meta=meta + ) + + return _call_as_task() + else: + self._log_debug_with_thread("tool=<%s> | using direct call_tool", name) + + async def _call_tool_direct() -> MCPCallToolResult: + return await cast(ClientSession, self._background_thread_session).call_tool( + name, arguments, read_timeout_seconds, meta=meta + ) + + return _call_tool_direct() + def call_tool_sync( self, tool_use_id: str, name: str, arguments: dict[str, Any] | None = None, read_timeout_seconds: timedelta | None = None, + meta: dict[str, Any] | None = None, ) -> MCPToolResult: """Synchronously calls a tool on the MCP server. - This method calls the asynchronous call_tool method on the MCP session - and converts the result to the ToolResult format. If the MCP tool returns - structured content, it will be included as the last item in the content array - of the returned ToolResult. + This method automatically uses task-augmented execution when appropriate, + based on server capabilities and tool-level taskSupport settings. Args: tool_use_id: Unique identifier for this tool use name: Name of the tool to call arguments: Optional arguments to pass to the tool read_timeout_seconds: Optional timeout for the tool call + meta: Optional metadata to pass to the tool call per MCP spec (_meta) Returns: MCPToolResult: The result of the tool call @@ -473,13 +654,9 @@ def call_tool_sync( if not self._is_session_active(): raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) - async def _call_tool_async() -> MCPCallToolResult: - return await cast(ClientSession, self._background_thread_session).call_tool( - name, arguments, read_timeout_seconds - ) - try: - call_tool_result: MCPCallToolResult = self._invoke_on_background_thread(_call_tool_async()).result() + coro = self._create_call_tool_coroutine(name, arguments, read_timeout_seconds, meta=meta) + call_tool_result: MCPCallToolResult = self._invoke_on_background_thread(coro).result() return self._handle_tool_result(tool_use_id, call_tool_result) except Exception as e: logger.exception("tool execution failed") @@ -491,17 +668,19 @@ async def call_tool_async( name: str, arguments: dict[str, Any] | None = None, read_timeout_seconds: timedelta | None = None, + meta: dict[str, Any] | None = None, ) -> MCPToolResult: """Asynchronously calls a tool on the MCP server. - This method calls the asynchronous call_tool method on the MCP session - and converts the result to the MCPToolResult format. + This method automatically uses task-augmented execution when appropriate, + based on server capabilities and tool-level taskSupport settings. Args: tool_use_id: Unique identifier for this tool use name: Name of the tool to call arguments: Optional arguments to pass to the tool read_timeout_seconds: Optional timeout for the tool call + meta: Optional metadata to pass to the tool call per MCP spec (_meta) Returns: MCPToolResult: The result of the tool call @@ -510,13 +689,9 @@ async def call_tool_async( if not self._is_session_active(): raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) - async def _call_tool_async() -> MCPCallToolResult: - return await cast(ClientSession, self._background_thread_session).call_tool( - name, arguments, read_timeout_seconds - ) - try: - future = self._invoke_on_background_thread(_call_tool_async()) + coro = self._create_call_tool_coroutine(name, arguments, read_timeout_seconds, meta=meta) + future = self._invoke_on_background_thread(coro) call_tool_result: MCPCallToolResult = await asyncio.wrap_future(future) return self._handle_tool_result(tool_use_id, call_tool_result) except Exception as e: @@ -524,7 +699,31 @@ async def _call_tool_async() -> MCPCallToolResult: return self._handle_tool_execution_error(tool_use_id, e) def _handle_tool_execution_error(self, tool_use_id: str, exception: Exception) -> MCPToolResult: - """Create error ToolResult with consistent logging.""" + """Create error ToolResult with consistent logging and elicitation callback support. + + Args: + tool_use_id: Unique identifier for this tool use. + exception: The exception that occurred during tool execution. + + Returns: + MCPToolResult: Error result containing either the elicitation data or the + original exception message. + """ + if isinstance(exception, McpError) and exception.error.code == -32042: + try: + error_data = ElicitationRequiredErrorData.model_validate(exception.error.data) + elicitations = [e.model_dump(exclude_none=True) for e in error_data.elicitations] + + return MCPToolResult( + status="error", + toolUseId=tool_use_id, + content=[ + {"text": (f"MCP Elicitation required: [{str(exception)}] with data {json.dumps(elicitations)}")} + ], + ) + except Exception: + logger.debug("Failed to parse ElicitationRequiredErrorData from -32042 error", exc_info=True) + return MCPToolResult( status="error", toolUseId=tool_use_id, @@ -563,6 +762,10 @@ def _handle_tool_result(self, tool_use_id: str, call_tool_result: MCPCallToolRes if call_tool_result.structuredContent: result["structuredContent"] = call_tool_result.structuredContent + if call_tool_result.meta: + result["metadata"] = call_tool_result.meta + if call_tool_result.isError is not None: + result["isError"] = call_tool_result.isError return result @@ -587,11 +790,28 @@ async def _async_background_thread(self) -> None: elicitation_callback=self._elicitation_callback, ) as session: self._log_debug_with_thread("initializing MCP session") - await session.initialize() + init_result = await session.initialize() self._log_debug_with_thread("session initialized successfully") + # Store server instructions from InitializeResult for Host applications + self.server_instructions = init_result.instructions # Store the session for use while we await the close event self._background_thread_session = session + + # Cache server task capability immediately after initialization + # Capabilities are exchanged during session.initialize(), so this is available now + caps = session.get_server_capabilities() + self._server_task_capable = ( + caps is not None + and caps.tasks is not None + and caps.tasks.requests is not None + and caps.tasks.requests.tools is not None + and caps.tasks.requests.tools.call is not None + ) + self._log_debug_with_thread( + "server_task_capable=<%s> | cached server task capability", self._server_task_capable + ) + # Signal that the session has been created and is ready for use self._init_future.set_result(None) @@ -623,7 +843,7 @@ async def _handle_error_message(self, message: Exception | Any) -> None: if isinstance(message, Exception): error_msg = str(message).lower() if any(pattern in error_msg for pattern in _NON_FATAL_ERROR_PATTERNS): - self._log_debug_with_thread("ignoring non-fatal MCP session error", message) + self._log_debug_with_thread("ignoring non-fatal MCP session error: %s", message) else: raise message await anyio.lowlevel.checkpoint() @@ -637,6 +857,9 @@ def _background_task(self) -> None: This allows for a long-running event loop. """ self._log_debug_with_thread("setting up background task event loop") + # Clear any running-loop state leaked by OpenTelemetry's ThreadingInstrumentor, which wraps Thread.run() + # and can propagate the parent thread's event loop reference, causing run_until_complete() to fail. + asyncio._set_running_loop(None) self._background_thread_event_loop = asyncio.new_event_loop() asyncio.set_event_loop(self._background_thread_event_loop) self._background_thread_event_loop.run_until_complete(self._async_background_thread()) @@ -644,7 +867,7 @@ def _background_task(self) -> None: def _map_mcp_content_to_tool_result_content( self, content: MCPTextContent | MCPImageContent | MCPEmbeddedResource | Any, - ) -> Union[ToolResultContent, None]: + ) -> ToolResultContent | None: """Maps MCP content types to tool result content types. This method converts MCP-specific content types to the generic @@ -764,7 +987,7 @@ def _should_include_tool(self, tool: MCPAgentTool) -> bool: """Check if a tool should be included based on constructor filters.""" return self._should_include_tool_with_filters(tool, self._tool_filters) - def _should_include_tool_with_filters(self, tool: MCPAgentTool, filters: Optional[ToolFilters]) -> bool: + def _should_include_tool_with_filters(self, tool: MCPAgentTool, filters: ToolFilters | None) -> bool: """Check if a tool should be included based on provided filters.""" if not filters: return True @@ -796,4 +1019,216 @@ def _matches_patterns(self, tool: MCPAgentTool, patterns: list[_ToolMatcher]) -> return False def _is_session_active(self) -> bool: - return self._background_thread is not None and self._background_thread.is_alive() + if self._background_thread is None or not self._background_thread.is_alive(): + return False + + if self._close_future is not None and self._close_future.done(): + return False + + return True + + def _is_tasks_enabled(self) -> bool: + """Check if tasks feature is enabled. + + Tasks are enabled if tasks config is defined and not None. + + Returns: + True if task-augmented execution is enabled, False otherwise. + """ + return self._tasks_config is not None + + def _get_task_config(self) -> TasksConfig: + """Returns the task execution configuration, configured with defaults if not specified.""" + task_config = self._tasks_config or DEFAULT_TASK_CONFIG + return TasksConfig( + ttl=task_config.get("ttl", DEFAULT_TASK_TTL), + poll_timeout=task_config.get("poll_timeout", DEFAULT_TASK_POLL_TIMEOUT), + ) + + def _has_server_task_support(self) -> bool: + """Check if the MCP server supports task-augmented tool calls. + + Returns the capability value that was cached immediately after session initialization. + Server capabilities are exchanged during the MCP handshake, so this is available + as soon as start() completes. + + Returns: + True if server supports task-augmented tool calls, False otherwise. + """ + return self._server_task_capable or False + + def _should_use_task(self, tool_name: str) -> bool: + """Determine if task-augmented execution should be used for a tool. + + Task-augmented execution requires: + 1. tasks config is enabled (opt-in check) + 2. Server supports tasks (capability check) + 3. Tool taskSupport is 'required' or 'optional' + + Args: + tool_name: Name of the tool to check. + + Returns: + True if task-augmented execution should be used, False otherwise. + """ + # Opt-in check: tasks must be explicitly enabled via tasks config + if not self._is_tasks_enabled(): + return False + + # Local import to avoid errors on old SDK versions that don't support Tasks + from mcp.types import TASK_OPTIONAL, TASK_REQUIRED + + # Server capability check (per MCP spec) + if not self._has_server_task_support(): + return False + + # Tool-level capability check (cached during list_tools_sync) + task_support = self._tool_task_support_cache.get(tool_name) + + # Use tasks for TASK_REQUIRED or TASK_OPTIONAL when server supports + if task_support == TASK_REQUIRED or task_support == TASK_OPTIONAL: + return True + + # Default: 'forbidden', None, or unknown -> don't use tasks + return False + + def _create_task_error_result(self, message: str) -> MCPCallToolResult: + """Create an error MCPCallToolResult with consistent formatting. + + This helper reduces duplication in task error handling paths. + + Args: + message: The error message to include in the result. + + Returns: + MCPCallToolResult with isError=True and the message as text content. + """ + return MCPCallToolResult( + isError=True, + content=[MCPTextContent(type="text", text=message)], + ) + + # ================================================================================== + # Task-Augmented Tool Execution + # ================================================================================== + # + # The MCP spec defines task-augmented execution for long-running tools. The flow is: + # + # 1. Check server capability (tasks.requests.tools.call) and tool setting (taskSupport) + # 2. If using tasks: call_tool_as_task() -> poll_task() -> get_task_result() + # 3. If not using tasks: call_tool() directly + # + # See: https://modelcontextprotocol.io/specification/2025-11-25/basic/utilities/tasks + # ================================================================================== + + async def _call_tool_as_task_and_poll_async( + self, + name: str, + arguments: dict[str, Any] | None = None, + ttl: timedelta | None = None, + poll_timeout: timedelta | None = None, + meta: dict[str, Any] | None = None, + ) -> MCPCallToolResult: + """Call a tool using task-augmented execution and poll until completion. + + This method implements the MCP task workflow: + 1. Creates a task via call_tool_as_task + 2. Polls using poll_task until terminal status (with timeout protection) + 3. Gets the final result using get_task_result + + Args: + name: Name of the tool to call. + arguments: Optional arguments to pass to the tool. + ttl: Task time-to-live. Uses configured value if not specified. + poll_timeout: Timeout for polling. Uses configured value if not specified. + meta: Optional metadata to pass to the tool call per MCP spec (_meta). + + Returns: + MCPCallToolResult: The final tool result after task completion. + """ + # Local import to avoid errors on old SDK versions that don't support Tasks + from mcp.types import TASK_STATUS_CANCELLED, TASK_STATUS_COMPLETED, TASK_STATUS_FAILED, GetTaskResult + + session = cast(ClientSession, self._background_thread_session) + + # Precedence: arg > config > default + timeout = poll_timeout or self._get_task_config().get("poll_timeout", DEFAULT_TASK_POLL_TIMEOUT) + ttl = ttl or self._get_task_config().get("ttl", DEFAULT_TASK_TTL) + ttl_ms = int(ttl.total_seconds() * 1000) + + # Step 1: Create the task + self._log_debug_with_thread("tool=<%s> | calling tool as task with ttl=%d ms", name, ttl_ms) + create_result = await session.experimental.call_tool_as_task( + name=name, + arguments=arguments, + ttl=ttl_ms, + meta=meta, + ) + task_id = create_result.task.taskId + self._log_debug_with_thread("tool=<%s>, task_id=<%s> | task created", name, task_id) + + # Step 2: Poll until terminal status (with timeout protection) + # Note: Using asyncio.wait_for() instead of asyncio.timeout() for Python 3.10 compatibility + async def _poll_until_terminal() -> GetTaskResult | None: + """Inner function to poll task status until terminal state.""" + final = None + async for task in session.experimental.poll_task(task_id): + self._log_debug_with_thread( + "tool=<%s>, task_id=<%s>, status=<%s> | task status update", + name, + task_id, + task.status, + ) + final = task + return final + + try: + final_status = await asyncio.wait_for(_poll_until_terminal(), timeout=timeout.total_seconds()) + except asyncio.TimeoutError: + self._log_debug_with_thread( + "tool=<%s>, task_id=<%s>, timeout_seconds=<%s> | task polling timed out", + name, + task_id, + timeout.total_seconds(), + ) + return self._create_task_error_result( + f"Task {task_id} polling timed out after {timeout.total_seconds()} seconds" + ) + + # Step 3: Handle terminal status + if final_status is None: + self._log_debug_with_thread("tool=<%s>, task_id=<%s> | polling completed without status", name, task_id) + return self._create_task_error_result(f"Task {task_id} polling completed without status") + + if final_status.status == TASK_STATUS_FAILED: + error_msg = final_status.statusMessage or "Task failed" + self._log_debug_with_thread("tool=<%s>, task_id=<%s>, error=<%s> | task failed", name, task_id, error_msg) + return self._create_task_error_result(error_msg) + + if final_status.status == TASK_STATUS_CANCELLED: + self._log_debug_with_thread("tool=<%s>, task_id=<%s> | task was cancelled", name, task_id) + return self._create_task_error_result("Task was cancelled") + + # Step 4: Get the actual result for completed tasks (with error handling for race conditions) + if final_status.status == TASK_STATUS_COMPLETED: + self._log_debug_with_thread("tool=<%s>, task_id=<%s> | task completed, fetching result", name, task_id) + try: + result = await session.experimental.get_task_result(task_id, MCPCallToolResult) + self._log_debug_with_thread("tool=<%s>, task_id=<%s> | task result retrieved", name, task_id) + return result + except Exception as e: + # Handle race condition: task completed but result retrieval failed + # (e.g., result expired, network error, server restarted) + self._log_debug_with_thread( + "tool=<%s>, task_id=<%s>, error=<%s> | failed to retrieve task result", name, task_id, str(e) + ) + return self._create_task_error_result(f"Task completed but result retrieval failed: {str(e)}") + + # Unexpected status - return as error + self._log_debug_with_thread( + "tool=<%s>, task_id=<%s>, status=<%s> | unexpected task status", + name, + task_id, + final_status.status, + ) + return self._create_task_error_result(f"Unexpected task status: {final_status.status}") diff --git a/src/strands/tools/mcp/mcp_instrumentation.py b/src/strands/tools/mcp/mcp_instrumentation.py index f8ab3bc80..5e64cc3d5 100644 --- a/src/strands/tools/mcp/mcp_instrumentation.py +++ b/src/strands/tools/mcp/mcp_instrumentation.py @@ -9,9 +9,10 @@ Related issue: https://github.com/modelcontextprotocol/modelcontextprotocol/issues/246 """ +from collections.abc import AsyncGenerator, Callable from contextlib import _AsyncGeneratorContextManager, asynccontextmanager from dataclasses import dataclass -from typing import Any, AsyncGenerator, Callable, Tuple +from typing import Any from mcp.shared.message import SessionMessage from mcp.types import JSONRPCMessage, JSONRPCRequest @@ -89,9 +90,10 @@ def patch_mcp_client(wrapped: Callable[..., Any], instance: Any, args: Any, kwar if hasattr(request.root, "params") and request.root.params: # Handle Pydantic models if hasattr(request.root.params, "model_dump") and hasattr(request.root.params, "model_validate"): - params_dict = request.root.params.model_dump() + params_dict = request.root.params.model_dump(by_alias=True) # Add _meta with tracing context - meta = params_dict.setdefault("_meta", {}) + meta = params_dict.get("_meta") if params_dict.get("_meta") is not None else {} + params_dict["_meta"] = meta propagate.get_global_textmap().inject(meta) # Recreate the Pydantic model with the updated data @@ -129,7 +131,7 @@ def transport_wrapper() -> Callable[ @asynccontextmanager async def traced_method( wrapped: Callable[..., Any], instance: Any, args: Any, kwargs: Any - ) -> AsyncGenerator[Tuple[Any, Any], None]: + ) -> AsyncGenerator[tuple[Any, Any], None]: async with wrapped(*args, **kwargs) as result: try: read_stream, write_stream = result @@ -139,7 +141,7 @@ async def traced_method( return traced_method - def session_init_wrapper() -> Callable[[Any, Any, Tuple[Any, ...], dict[str, Any]], None]: + def session_init_wrapper() -> Callable[[Any, Any, tuple[Any, ...], dict[str, Any]], None]: """Create a wrapper for MCP session initialization. Wraps session message streams to enable bidirectional context flow. @@ -151,7 +153,7 @@ def session_init_wrapper() -> Callable[[Any, Any, Tuple[Any, ...], dict[str, Any """ def traced_method( - wrapped: Callable[..., Any], instance: Any, args: Tuple[Any, ...], kwargs: dict[str, Any] + wrapped: Callable[..., Any], instance: Any, args: tuple[Any, ...], kwargs: dict[str, Any] ) -> None: wrapped(*args, **kwargs) reader = getattr(instance, "_incoming_message_stream_reader", None) diff --git a/src/strands/tools/mcp/mcp_tasks.py b/src/strands/tools/mcp/mcp_tasks.py new file mode 100644 index 000000000..36537f7df --- /dev/null +++ b/src/strands/tools/mcp/mcp_tasks.py @@ -0,0 +1,33 @@ +"""Task-augmented tool execution configuration for MCP. + +This module provides configuration types and defaults for the experimental MCP Tasks feature. +""" + +from datetime import timedelta + +from typing_extensions import TypedDict + + +class TasksConfig(TypedDict, total=False): + """Configuration for MCP Tasks (task-augmented tool execution). + + When enabled, supported tool calls use the MCP task workflow: + create task -> poll for completion -> get result. + + Warning: + This is an experimental feature in the 2025-11-25 MCP specification and + both the specification and the Strands Agents implementation of this + feature are subject to change. + + Attributes: + ttl: Task time-to-live. Defaults to 1 minute. + poll_timeout: Timeout for polling task completion. Defaults to 5 minutes. + """ + + ttl: timedelta + poll_timeout: timedelta + + +DEFAULT_TASK_TTL = timedelta(minutes=1) +DEFAULT_TASK_POLL_TIMEOUT = timedelta(minutes=5) +DEFAULT_TASK_CONFIG = TasksConfig(ttl=DEFAULT_TASK_TTL, poll_timeout=DEFAULT_TASK_POLL_TIMEOUT) diff --git a/src/strands/tools/mcp/mcp_types.py b/src/strands/tools/mcp/mcp_types.py index 66eda08ae..09feb624f 100644 --- a/src/strands/tools/mcp/mcp_types.py +++ b/src/strands/tools/mcp/mcp_types.py @@ -1,7 +1,7 @@ """Type definitions for MCP integration.""" from contextlib import AbstractAsyncContextManager -from typing import Any, Dict +from typing import Any from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from mcp.client.streamable_http import GetSessionIdCallback @@ -58,6 +58,16 @@ class MCPToolResult(ToolResult): structuredContent: Optional JSON object containing structured data returned by the MCP tool. This allows MCP tools to return complex data structures that can be processed programmatically by agents or other tools. + metadata: Optional arbitrary metadata returned by the MCP tool. This field allows + MCP servers to attach custom metadata to tool results (e.g., token usage, + performance metrics, or business-specific tracking information). + isError: Whether the MCP tool reported an application-level error via + ``CallToolResult.isError``. ``True`` means the tool executed but its logic + returned a failure. Absent when the tool succeeded or when the error was a + protocol/client exception rather than a tool-reported failure, letting + callers distinguish application errors from transport/protocol errors. """ - structuredContent: NotRequired[Dict[str, Any]] + structuredContent: NotRequired[dict[str, Any]] + metadata: NotRequired[dict[str, Any]] + isError: NotRequired[bool] diff --git a/src/strands/tools/registry.py b/src/strands/tools/registry.py index 91f0bf870..9a0f0f722 100644 --- a/src/strands/tools/registry.py +++ b/src/strands/tools/registry.py @@ -10,19 +10,21 @@ import sys import uuid import warnings +from collections.abc import Iterable, Sequence from importlib import import_module, util from os.path import expanduser from pathlib import Path -from typing import Any, Dict, Iterable, List, Optional, Sequence +from typing import Any, cast -from typing_extensions import TypedDict, cast +from typing_extensions import TypedDict from .._async import run_async -from ..experimental.tools import ToolProvider +from ..agent.base import AgentBase from ..tools.decorator import DecoratedFunctionTool from ..types.tools import AgentTool, ToolSpec +from . import ToolProvider from .loader import load_tool_from_string, load_tools_from_module -from .tools import PythonAgentTool, normalize_schema, normalize_tool_spec +from .tools import _COMPOSITION_KEYWORDS, PythonAgentTool, normalize_schema, normalize_tool_spec logger = logging.getLogger(__name__) @@ -35,13 +37,13 @@ class ToolRegistry: def __init__(self) -> None: """Initialize the tool registry.""" - self.registry: Dict[str, AgentTool] = {} - self.dynamic_tools: Dict[str, AgentTool] = {} - self.tool_config: Optional[Dict[str, Any]] = None - self._tool_providers: List[ToolProvider] = [] + self.registry: dict[str, AgentTool] = {} + self.dynamic_tools: dict[str, AgentTool] = {} + self.tool_config: dict[str, Any] | None = None + self._tool_providers: list[ToolProvider] = [] self._registry_id = str(uuid.uuid4()) - def process_tools(self, tools: List[Any]) -> List[str]: + def process_tools(self, tools: list[Any]) -> list[str]: """Process tools list. Process list of tools that can contain local file path string, module import path string, @@ -61,6 +63,7 @@ def process_tools(self, tools: List[Any]) -> List[str]: 3. A module for a module based tool 4. Instances of AgentTool (@tool decorated functions) 5. Dictionaries with name/path keys (deprecated) + 6. Agent instances with an ``as_tool()`` method (auto-wrapped) Returns: @@ -139,6 +142,12 @@ async def get_tools() -> Sequence[AgentTool]: for provider_tool in provider_tools: self.register_tool(provider_tool) tool_names.append(provider_tool.tool_name) + # Agent instances - auto-wrap with .as_tool() for convenience + elif isinstance(tool, AgentBase) and hasattr(tool, "as_tool") and callable(tool.as_tool): + wrapped_tool = tool.as_tool() + self.register_tool(wrapped_tool) + tool_names.append(wrapped_tool.tool_name) + else: logger.warning("tool=<%s> | unrecognized tool specification", tool) @@ -186,7 +195,7 @@ def load_tool_from_filepath(self, tool_name: str, tool_path: str) -> None: logger.exception("tool_name=<%s> | failed to load tool", tool_name) raise ValueError(f"Failed to load tool {tool_name}: {exception_str}") from e - def get_all_tools_config(self) -> Dict[str, Any]: + def get_all_tools_config(self) -> dict[str, Any]: """Dynamically generate tool configuration by combining built-in and dynamic tools. Returns: @@ -279,7 +288,33 @@ def register_tool(self, tool: AgentTool) -> None: list(self.dynamic_tools.keys()), ) - def get_tools_dirs(self) -> List[Path]: + def replace(self, new_tool: AgentTool) -> None: + """Replace an existing tool with a new implementation. + + This performs a swap of the tool implementation in the registry. + The replacement takes effect on the next agent invocation. + + Args: + new_tool: New tool implementation. Its name must match the tool being replaced. + + Raises: + ValueError: If the tool doesn't exist. + """ + tool_name = new_tool.tool_name + + if tool_name not in self.registry: + raise ValueError(f"Cannot replace tool '{tool_name}' - tool does not exist") + + # Update main registry + self.registry[tool_name] = new_tool + + # Update dynamic_tools to match new tool's dynamic status + if new_tool.is_dynamic: + self.dynamic_tools[tool_name] = new_tool + elif tool_name in self.dynamic_tools: + del self.dynamic_tools[tool_name] + + def get_tools_dirs(self) -> list[Path]: """Get all tool directory paths. Returns: @@ -299,7 +334,7 @@ def get_tools_dirs(self) -> List[Path]: return tool_dirs - def discover_tool_modules(self) -> Dict[str, Path]: + def discover_tool_modules(self) -> dict[str, Path]: """Discover available tool modules in all tools directories. Returns: @@ -542,7 +577,7 @@ def get_all_tool_specs(self) -> list[ToolSpec]: A list of ToolSpecs. """ all_tools = self.get_all_tools_config() - tools: List[ToolSpec] = [tool_spec for tool_spec in all_tools.values()] + tools: list[ToolSpec] = [tool_spec for tool_spec in all_tools.values()] return tools def register_dynamic_tool(self, tool: AgentTool) -> None: @@ -604,7 +639,8 @@ def validate_tool_spec(self, tool_spec: ToolSpec) -> None: if "$ref" in prop_def: continue - if "type" not in prop_def: + has_composition = any(kw in prop_def for kw in _COMPOSITION_KEYWORDS) + if "type" not in prop_def and not has_composition: prop_def["type"] = "string" if "description" not in prop_def: prop_def["description"] = f"Property {prop_name}" @@ -618,7 +654,7 @@ class NewToolDict(TypedDict): spec: ToolSpec - def _update_tool_config(self, tool_config: Dict[str, Any], new_tool: NewToolDict) -> None: + def _update_tool_config(self, tool_config: dict[str, Any], new_tool: NewToolDict) -> None: """Update tool configuration with a new tool. Args: @@ -655,7 +691,7 @@ def _update_tool_config(self, tool_config: Dict[str, Any], new_tool: NewToolDict tool_config["tools"].append(new_tool_entry) logger.debug("tool_name=<%s> | added new tool", new_tool_name) - def _scan_module_for_tools(self, module: Any) -> List[AgentTool]: + def _scan_module_for_tools(self, module: Any) -> list[AgentTool]: """Scan a module for function-based tools. Args: @@ -664,7 +700,7 @@ def _scan_module_for_tools(self, module: Any) -> List[AgentTool]: Returns: List of FunctionTool instances found in the module. """ - tools: List[AgentTool] = [] + tools: list[AgentTool] = [] for name, obj in inspect.getmembers(module): if isinstance(obj, DecoratedFunctionTool): diff --git a/src/strands/tools/structured_output/__init__.py b/src/strands/tools/structured_output/__init__.py index 777d5d846..a3a12d000 100644 --- a/src/strands/tools/structured_output/__init__.py +++ b/src/strands/tools/structured_output/__init__.py @@ -1,5 +1,6 @@ """Structured output tools for the Strands Agents framework.""" +from ._structured_output_context import DEFAULT_STRUCTURED_OUTPUT_PROMPT from .structured_output_utils import convert_pydantic_to_tool_spec -__all__ = ["convert_pydantic_to_tool_spec"] +__all__ = ["convert_pydantic_to_tool_spec", "DEFAULT_STRUCTURED_OUTPUT_PROMPT"] diff --git a/src/strands/tools/structured_output/_structured_output_context.py b/src/strands/tools/structured_output/_structured_output_context.py index f33a06915..9a5190d9d 100644 --- a/src/strands/tools/structured_output/_structured_output_context.py +++ b/src/strands/tools/structured_output/_structured_output_context.py @@ -1,7 +1,7 @@ """Context management for structured output in the event loop.""" import logging -from typing import TYPE_CHECKING, Optional, Type +from typing import TYPE_CHECKING from pydantic import BaseModel @@ -13,24 +13,33 @@ logger = logging.getLogger(__name__) +DEFAULT_STRUCTURED_OUTPUT_PROMPT = "You must format the previous response as structured output." + class StructuredOutputContext: """Per-invocation context for structured output execution.""" - def __init__(self, structured_output_model: Type[BaseModel] | None = None): + def __init__( + self, + structured_output_model: type[BaseModel] | None = None, + structured_output_prompt: str | None = None, + ): """Initialize a new structured output context. Args: structured_output_model: Optional Pydantic model type for structured output. + structured_output_prompt: Optional custom prompt message to use when forcing structured output. + Defaults to "You must format the previous response as structured output." """ self.results: dict[str, BaseModel] = {} - self.structured_output_model: Type[BaseModel] | None = structured_output_model + self.structured_output_model: type[BaseModel] | None = structured_output_model self.structured_output_tool: StructuredOutputTool | None = None self.forced_mode: bool = False self.force_attempted: bool = False self.tool_choice: ToolChoice | None = None self.stop_loop: bool = False - self.expected_tool_name: Optional[str] = None + self.expected_tool_name: str | None = None + self.structured_output_prompt: str = structured_output_prompt or DEFAULT_STRUCTURED_OUTPUT_PROMPT if structured_output_model: self.structured_output_tool = StructuredOutputTool(structured_output_model) @@ -91,7 +100,7 @@ def has_structured_output_tool(self, tool_uses: list[ToolUse]) -> bool: return False return any(tool_use.get("name") == self.expected_tool_name for tool_use in tool_uses) - def get_tool_spec(self) -> Optional[ToolSpec]: + def get_tool_spec(self) -> ToolSpec | None: """Get the tool specification for structured output. Returns: diff --git a/src/strands/tools/structured_output/structured_output_tool.py b/src/strands/tools/structured_output/structured_output_tool.py index 25173d048..fa20f526c 100644 --- a/src/strands/tools/structured_output/structured_output_tool.py +++ b/src/strands/tools/structured_output/structured_output_tool.py @@ -6,7 +6,7 @@ import logging from copy import deepcopy -from typing import TYPE_CHECKING, Any, Type +from typing import TYPE_CHECKING, Any from pydantic import BaseModel, ValidationError from typing_extensions import override @@ -17,7 +17,7 @@ logger = logging.getLogger(__name__) -_TOOL_SPEC_CACHE: dict[Type[BaseModel], ToolSpec] = {} +_TOOL_SPEC_CACHE: dict[type[BaseModel], ToolSpec] = {} if TYPE_CHECKING: from ._structured_output_context import StructuredOutputContext @@ -26,7 +26,7 @@ class StructuredOutputTool(AgentTool): """Tool implementation for structured output validation.""" - def __init__(self, structured_output_model: Type[BaseModel]) -> None: + def __init__(self, structured_output_model: type[BaseModel]) -> None: """Initialize a structured output tool. Args: @@ -43,7 +43,7 @@ def __init__(self, structured_output_model: Type[BaseModel]) -> None: self._tool_name = self._tool_spec.get("name", "StructuredOutputTool") @classmethod - def _get_tool_spec(cls, structured_output_model: Type[BaseModel]) -> ToolSpec: + def _get_tool_spec(cls, structured_output_model: type[BaseModel]) -> ToolSpec: """Get a cached tool spec for the given output type. Args: @@ -84,7 +84,7 @@ def tool_type(self) -> str: return "structured_output" @property - def structured_output_model(self) -> Type[BaseModel]: + def structured_output_model(self) -> type[BaseModel]: """Get the Pydantic model type for this tool. Returns: diff --git a/src/strands/tools/structured_output/structured_output_utils.py b/src/strands/tools/structured_output/structured_output_utils.py index 093d67f7c..a78ec6195 100644 --- a/src/strands/tools/structured_output/structured_output_utils.py +++ b/src/strands/tools/structured_output/structured_output_utils.py @@ -1,13 +1,13 @@ """Tools for converting Pydantic models to Bedrock tools.""" -from typing import Any, Dict, Optional, Type, Union +from typing import Any, Union from pydantic import BaseModel from ...types.tools import ToolSpec -def _flatten_schema(schema: Dict[str, Any]) -> Dict[str, Any]: +def _flatten_schema(schema: dict[str, Any]) -> dict[str, Any]: """Flattens a JSON schema by removing $defs and resolving $ref references. Handles required vs optional fields properly. @@ -80,11 +80,11 @@ def _flatten_schema(schema: Dict[str, Any]) -> Dict[str, Any]: def _process_property( - prop: Dict[str, Any], - defs: Dict[str, Any], + prop: dict[str, Any], + defs: dict[str, Any], is_required: bool = False, fully_expand: bool = True, -) -> Dict[str, Any]: +) -> dict[str, Any]: """Process a property in a schema, resolving any references. Args: @@ -174,8 +174,8 @@ def _process_property( def _process_schema_object( - schema_obj: Dict[str, Any], defs: Dict[str, Any], fully_expand: bool = True -) -> Dict[str, Any]: + schema_obj: dict[str, Any], defs: dict[str, Any], fully_expand: bool = True +) -> dict[str, Any]: """Process a schema object, typically from $defs, to resolve all nested properties. Args: @@ -218,7 +218,7 @@ def _process_schema_object( return result -def _process_nested_dict(d: Dict[str, Any], defs: Dict[str, Any]) -> Dict[str, Any]: +def _process_nested_dict(d: dict[str, Any], defs: dict[str, Any]) -> dict[str, Any]: """Recursively processes nested dictionaries and resolves $ref references. Args: @@ -228,7 +228,7 @@ def _process_nested_dict(d: Dict[str, Any], defs: Dict[str, Any]) -> Dict[str, A Returns: Processed dictionary """ - result: Dict[str, Any] = {} + result: dict[str, Any] = {} # Handle direct reference if "$ref" in d: @@ -258,8 +258,8 @@ def _process_nested_dict(d: Dict[str, Any], defs: Dict[str, Any]) -> Dict[str, A def convert_pydantic_to_tool_spec( - model: Type[BaseModel], - description: Optional[str] = None, + model: type[BaseModel], + description: str | None = None, ) -> ToolSpec: """Converts a Pydantic model to a tool description for the Amazon Bedrock Converse API. @@ -302,7 +302,7 @@ def convert_pydantic_to_tool_spec( ) -def _expand_nested_properties(schema: Dict[str, Any], model: Type[BaseModel]) -> None: +def _expand_nested_properties(schema: dict[str, Any], model: type[BaseModel]) -> None: """Expand the properties of nested models in the schema to include their full structure. This updates the schema in place. @@ -348,7 +348,7 @@ def _expand_nested_properties(schema: Dict[str, Any], model: Type[BaseModel]) -> schema["properties"][prop_name] = expanded_object -def _process_referenced_models(schema: Dict[str, Any], model: Type[BaseModel]) -> None: +def _process_referenced_models(schema: dict[str, Any], model: type[BaseModel]) -> None: """Process referenced models to ensure their docstrings are included. This updates the schema in place. @@ -388,7 +388,7 @@ def _process_referenced_models(schema: Dict[str, Any], model: Type[BaseModel]) - _process_properties(ref_def, field_type) -def _process_properties(schema_def: Dict[str, Any], model: Type[BaseModel]) -> None: +def _process_properties(schema_def: dict[str, Any], model: type[BaseModel]) -> None: """Process properties in a schema definition to add descriptions from field metadata. Args: diff --git a/src/strands/experimental/tools/tool_provider.py b/src/strands/tools/tool_provider.py similarity index 92% rename from src/strands/experimental/tools/tool_provider.py rename to src/strands/tools/tool_provider.py index 2c79ceafc..002c57d73 100644 --- a/src/strands/experimental/tools/tool_provider.py +++ b/src/strands/tools/tool_provider.py @@ -1,10 +1,11 @@ """Tool provider interface.""" from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Sequence +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: - from ...types.tools import AgentTool + from ..types.tools import AgentTool class ToolProvider(ABC): diff --git a/src/strands/tools/tools.py b/src/strands/tools/tools.py index 48b969bc3..ccfeac323 100644 --- a/src/strands/tools/tools.py +++ b/src/strands/tools/tools.py @@ -17,6 +17,12 @@ logger = logging.getLogger(__name__) +_COMPOSITION_KEYWORDS = ("anyOf", "oneOf", "allOf", "not") +"""JSON Schema composition keywords that define type constraints. + +Properties with these should not get a default type: "string" added. +""" + class InvalidToolUseNameException(Exception): """Exception raised when a tool use has an invalid name.""" @@ -88,7 +94,9 @@ def _normalize_property(prop_name: str, prop_def: Any) -> dict[str, Any]: if "$ref" in normalized_prop: return normalized_prop - normalized_prop.setdefault("type", "string") + has_composition = any(kw in normalized_prop for kw in _COMPOSITION_KEYWORDS) + if not has_composition: + normalized_prop.setdefault("type", "string") normalized_prop.setdefault("description", f"Property {prop_name}") return normalized_prop @@ -189,6 +197,31 @@ def tool_spec(self) -> ToolSpec: """ return self._tool_spec + @tool_spec.setter + def tool_spec(self, value: ToolSpec) -> None: + """Set the tool specification. + + This allows runtime modification of the tool's schema, enabling dynamic + tool configurations based on feature flags or other runtime conditions. + + Args: + value: The new tool specification. + + Raises: + ValueError: If the spec fails structural validation (wrong name or + missing required field). + """ + if value.get("name") != self._tool_name: + raise ValueError( + f"cannot change tool name via tool_spec (expected '{self._tool_name}', got '{value.get('name')}')" + ) + + for field in ("description", "inputSchema"): + if field not in value: + raise ValueError(f"tool_spec must contain '{field}'") + + self._tool_spec = value + @property def supports_hot_reload(self) -> bool: """Check if this tool supports automatic reloading when modified. diff --git a/src/strands/tools/watcher.py b/src/strands/tools/watcher.py index 44f2ed512..c7f50fccd 100644 --- a/src/strands/tools/watcher.py +++ b/src/strands/tools/watcher.py @@ -6,7 +6,7 @@ import logging from pathlib import Path -from typing import Any, Dict, Set +from typing import Any from watchdog.events import FileSystemEventHandler from watchdog.observers import Observer @@ -25,9 +25,9 @@ class ToolWatcher: # design pattern avoids conflicts when multiple tool registries are watching the same directories. _shared_observer = None - _watched_dirs: Set[str] = set() + _watched_dirs: set[str] = set() _observer_started = False - _registry_handlers: Dict[str, Dict[int, "ToolWatcher.ToolChangeHandler"]] = {} + _registry_handlers: dict[str, dict[int, "ToolWatcher.ToolChangeHandler"]] = {} def __init__(self, tool_registry: ToolRegistry) -> None: """Initialize a tool watcher for the given tool registry. diff --git a/src/strands/types/__init__.py b/src/strands/types/__init__.py index 7eef60cb4..60d6b3a17 100644 --- a/src/strands/types/__init__.py +++ b/src/strands/types/__init__.py @@ -1,5 +1,6 @@ """SDK type definitions.""" +from ._snapshot import Snapshot from .collections import PaginatedList -__all__ = ["PaginatedList"] +__all__ = ["PaginatedList", "Snapshot"] diff --git a/src/strands/types/_events.py b/src/strands/types/_events.py index c3890f428..1d5a5de79 100644 --- a/src/strands/types/_events.py +++ b/src/strands/types/_events.py @@ -5,7 +5,8 @@ agent lifecycle. """ -from typing import TYPE_CHECKING, Any, Sequence, cast +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any, cast from pydantic import BaseModel from typing_extensions import override @@ -20,6 +21,7 @@ if TYPE_CHECKING: from ..agent import AgentResult + from ..agent._agent_as_tool import _AgentAsTool from ..multiagent.base import MultiAgentResult, NodeResult @@ -161,7 +163,7 @@ class CitationStreamEvent(ModelStreamEvent): def __init__(self, delta: ContentBlockDelta, citation: Citation) -> None: """Initialize with delta and citation content.""" - super().__init__({"callback": {"citation": citation, "delta": delta}}) + super().__init__({"citation": citation, "delta": delta}) class ReasoningTextStreamEvent(ModelStreamEvent): @@ -275,13 +277,18 @@ def prepare(self, invocation_state: dict) -> None: class ToolResultEvent(TypedEvent): """Event emitted when a tool execution completes.""" - def __init__(self, tool_result: ToolResult) -> None: - """Initialize with the completed tool result. + def __init__(self, tool_result: ToolResult, exception: Exception | None = None) -> None: + """Initialize tool result event.""" + super().__init__({"type": "tool_result", "tool_result": tool_result}) + self._exception = exception - Args: - tool_result: Final result from the tool execution + @property + def exception(self) -> Exception | None: + """The original exception that occurred, if any. + + Can be used for re-raising or type-based error handling. """ - super().__init__({"type": "tool_result", "tool_result": tool_result}) + return self._exception @property def tool_use_id(self) -> str: @@ -317,6 +324,31 @@ def tool_use_id(self) -> str: return cast(ToolUse, cast(dict, self.get("tool_stream_event")).get("tool_use"))["toolUseId"] +class AgentAsToolStreamEvent(ToolStreamEvent): + """Event emitted when an agent-as-tool yields intermediate events during execution. + + Extends ToolStreamEvent with a reference to the originating _AgentAsTool so callers + can distinguish sub-agent stream events from regular tool stream events and access + the wrapped agent, tool name, description, etc. + """ + + def __init__(self, tool_use: ToolUse, tool_stream_data: Any, agent_as_tool: "_AgentAsTool") -> None: + """Initialize with tool streaming data and agent-tool reference. + + Args: + tool_use: The tool invocation producing the stream. + tool_stream_data: The yielded event from the sub-agent execution. + agent_as_tool: The _AgentAsTool instance that produced this event. + """ + super().__init__(tool_use, tool_stream_data) + self._agent_as_tool = agent_as_tool + + @property + def agent_as_tool(self) -> "_AgentAsTool": + """The _AgentAsTool instance that produced this event.""" + return self._agent_as_tool + + class ToolCancelEvent(TypedEvent): """Event emitted when a user cancels a tool call from their BeforeToolCallEvent hook.""" diff --git a/src/strands/types/_snapshot.py b/src/strands/types/_snapshot.py new file mode 100644 index 000000000..407b811f2 --- /dev/null +++ b/src/strands/types/_snapshot.py @@ -0,0 +1,145 @@ +"""Snapshot types, constants, and helpers for agent state capture.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Any, Literal, TypedDict + +from .exceptions import SnapshotException + +SnapshotField = Literal["messages", "state", "conversation_manager_state", "interrupt_state", "system_prompt"] +SnapshotPreset = Literal["session"] +Scope = Literal["agent"] + +ALL_SNAPSHOT_FIELDS: tuple[SnapshotField, ...] = ( + "messages", + "state", + "conversation_manager_state", + "interrupt_state", + "system_prompt", +) + +VALID_SCOPES: tuple[Scope, ...] = ("agent",) + +SNAPSHOT_SCHEMA_VERSION = "1.0" + +SNAPSHOT_PRESETS: dict[str, tuple[SnapshotField, ...]] = { + "session": ("messages", "state", "conversation_manager_state", "interrupt_state"), +} + + +class TakeSnapshotOptions(TypedDict, total=False): + """Internal options for take_snapshot. Not exported publicly.""" + + preset: SnapshotPreset + include: list[SnapshotField] + exclude: list[SnapshotField] + app_data: dict[str, Any] + + +@dataclass +class Snapshot: + """Point-in-time capture of agent state as a versioned JSON-compatible object.""" + + scope: Scope + schema_version: str + data: dict[str, Any] + app_data: dict[str, Any] + created_at: str = field(default="") # ISO 8601 UTC; auto-filled if empty + + def __post_init__(self) -> None: + if not self.created_at: + self.created_at = _utc_now_iso() + + def validate(self) -> None: + """Validate that this snapshot can be loaded by the current SDK version. + + Raises: + SnapshotException: If schema_version is not "1.0" or scope is invalid. + """ + if self.schema_version != SNAPSHOT_SCHEMA_VERSION: + raise SnapshotException( + f"Unsupported snapshot schema version: {self.schema_version!r}. " + f"Current version: {SNAPSHOT_SCHEMA_VERSION}" + ) + if self.scope not in VALID_SCOPES: + raise SnapshotException(f"Invalid snapshot scope: {self.scope!r}. Valid scopes: {sorted(VALID_SCOPES)}") + + def to_dict(self) -> dict[str, Any]: + """Serialize to a plain JSON-compatible dict.""" + return { + "scope": self.scope, + "schema_version": self.schema_version, + "created_at": self.created_at, + "data": self.data, + "app_data": self.app_data, + } + + @classmethod + def from_dict(cls, d: dict[str, Any]) -> Snapshot: + """Reconstruct a Snapshot from a dict produced by to_dict(). + + Raises: + SnapshotException: If schema_version is not "1.0". + """ + snapshot = cls( + scope=d.get("scope", "agent"), + schema_version=d.get("schema_version", ""), + created_at=d["created_at"], + data=d["data"], + app_data=d.get("app_data", {}), + ) + snapshot.validate() + return snapshot + + +def resolve_snapshot_fields( + *, + preset: SnapshotPreset | None = None, + include: list[SnapshotField] | None = None, + exclude: list[SnapshotField] | None = None, +) -> set[SnapshotField]: + """Resolve the set of fields to capture based on options. + + Applies: preset → include → exclude (in that order). + + Raises: + SnapshotException: If any field name is invalid or the resolved set is empty. + """ + valid = set(ALL_SNAPSHOT_FIELDS) + + # Validate include/exclude field names + for f in include or []: + if f not in valid: + raise SnapshotException(f"Invalid snapshot field: {f!r}. Valid fields: {sorted(valid)}") + for f in exclude or []: + if f not in valid: + raise SnapshotException(f"Invalid snapshot field: {f!r}. Valid fields: {sorted(valid)}") + + # Step 1: start with preset + if preset is not None: + fields: set[SnapshotField] = set(SNAPSHOT_PRESETS[preset]) + else: + fields = set() + + # Step 2: union with include + if include: + fields |= set(include) + + # Step 3: subtract exclude + if exclude: + fields -= set(exclude) + + if not fields: + raise SnapshotException( + "No snapshot fields resolved. Provide a preset or at least one field in 'include'. " + "Note: passing only 'exclude' without a preset or 'include' always results in an empty set." + ) + + return fields + + +def _utc_now_iso() -> str: + """Return the current UTC time as an ISO 8601 string ending in 'Z'.""" + return datetime.now(timezone.utc).isoformat().replace("+00:00", "Z") diff --git a/src/strands/types/a2a.py b/src/strands/types/a2a.py new file mode 100644 index 000000000..2ca444cb0 --- /dev/null +++ b/src/strands/types/a2a.py @@ -0,0 +1,38 @@ +"""Additional A2A types.""" + +from typing import Any, TypeAlias + +from a2a.types import Message, Task, TaskArtifactUpdateEvent, TaskStatusUpdateEvent + +from ._events import TypedEvent + +A2AResponse: TypeAlias = tuple[Task, TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None] | Message | Any + + +class A2AStreamEvent(TypedEvent): + """Event emitted for every update received from the remote A2A server. + + This event wraps all A2A response types during streaming, including: + - Partial task updates (TaskArtifactUpdateEvent) + - Status updates (TaskStatusUpdateEvent) + - Complete messages (Message) + - Final task completions + + The event is emitted for EVERY update from the server, regardless of whether + it represents a complete or partial response. When streaming completes, an + AgentResultEvent containing the final AgentResult is also emitted after all + A2AStreamEvents. + """ + + def __init__(self, a2a_event: A2AResponse) -> None: + """Initialize with A2A event. + + Args: + a2a_event: The original A2A event (Task tuple or Message) + """ + super().__init__( + { + "type": "a2a_stream", + "event": a2a_event, # Nest A2A event to avoid field conflicts + } + ) diff --git a/src/strands/types/agent.py b/src/strands/types/agent.py index aa69149a6..cda01f8aa 100644 --- a/src/strands/types/agent.py +++ b/src/strands/types/agent.py @@ -3,9 +3,26 @@ This module defines the types used for an Agent. """ +from enum import Enum from typing import TypeAlias from .content import ContentBlock, Messages from .interrupt import InterruptResponseContent AgentInput: TypeAlias = str | list[ContentBlock] | list[InterruptResponseContent] | Messages | None + + +class ConcurrentInvocationMode(str, Enum): + """Mode controlling concurrent invocation behavior. + + Values: + THROW: Raises ConcurrencyException if concurrent invocation is attempted (default). + UNSAFE_REENTRANT: Allows concurrent invocations without locking. + + Warning: + The ``UNSAFE_REENTRANT`` mode makes no guarantees about resulting behavior and is + provided only for advanced use cases where the caller understands the risks. + """ + + THROW = "throw" + UNSAFE_REENTRANT = "unsafe_reentrant" diff --git a/src/strands/types/citations.py b/src/strands/types/citations.py index b0e28f655..2b3714ce1 100644 --- a/src/strands/types/citations.py +++ b/src/strands/types/citations.py @@ -3,7 +3,7 @@ These types are modeled after the Bedrock API. """ -from typing import List, Union +from typing import Literal from typing_extensions import TypedDict @@ -77,8 +77,56 @@ class DocumentPageLocation(TypedDict, total=False): end: int -# Union type for citation locations -CitationLocation = Union[DocumentCharLocation, DocumentChunkLocation, DocumentPageLocation] +class SearchResultLocation(TypedDict, total=False): + """Specifies a search result location within the content array. + + Provides positioning information for cited content using search result + index and block positions. + + Attributes: + searchResultIndex: The index of the search result content block where + the cited content is found. Minimum value of 0. + start: The starting position in the content array where the cited + content begins. Minimum value of 0. + end: The ending position in the content array where the cited + content ends. Minimum value of 0. + """ + + searchResultIndex: int + start: int + end: int + + +class WebLocation(TypedDict, total=False): + """Provides the URL and domain information for a cited website. + + Contains information about the website that was cited when performing + a web search. + + Attributes: + url: The URL that was cited when performing a web search. + domain: The domain that was cited when performing a web search. + """ + + url: str + domain: str + + +# Tagged union type aliases following the ToolChoice pattern +DocumentCharLocationDict = dict[Literal["documentChar"], DocumentCharLocation] +DocumentPageLocationDict = dict[Literal["documentPage"], DocumentPageLocation] +DocumentChunkLocationDict = dict[Literal["documentChunk"], DocumentChunkLocation] +SearchResultLocationDict = dict[Literal["searchResultLocation"], SearchResultLocation] +WebLocationDict = dict[Literal["web"], WebLocation] + +# Union type for citation locations - tagged union format matching AWS Bedrock API +CitationLocation = ( + DocumentCharLocationDict + | DocumentPageLocationDict + | DocumentChunkLocationDict + | SearchResultLocationDict + | WebLocationDict +) class CitationSourceContent(TypedDict, total=False): @@ -130,7 +178,7 @@ class Citation(TypedDict, total=False): """ location: CitationLocation - sourceContent: List[CitationSourceContent] + sourceContent: list[CitationSourceContent] title: str @@ -148,5 +196,5 @@ class CitationsContentBlock(TypedDict, total=False): citations. """ - citations: List[Citation] - content: List[CitationGeneratedContent] + citations: list[Citation] + content: list[CitationGeneratedContent] diff --git a/src/strands/types/collections.py b/src/strands/types/collections.py index df857ace0..28b4a1891 100644 --- a/src/strands/types/collections.py +++ b/src/strands/types/collections.py @@ -1,6 +1,6 @@ """Generic collection types for the Strands SDK.""" -from typing import Generic, List, Optional, TypeVar +from typing import Generic, TypeVar T = TypeVar("T") @@ -12,7 +12,7 @@ class PaginatedList(list, Generic[T]): so existing code that expects List[T] will continue to work. """ - def __init__(self, data: List[T], token: Optional[str] = None): + def __init__(self, data: list[T], token: str | None = None): """Initialize a PaginatedList with data and an optional pagination token. Args: diff --git a/src/strands/types/content.py b/src/strands/types/content.py index 4d0bbe412..5f9cc1460 100644 --- a/src/strands/types/content.py +++ b/src/strands/types/content.py @@ -6,11 +6,12 @@ - Bedrock docs: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_Types_Amazon_Bedrock_Runtime.html """ -from typing import Dict, List, Literal, Optional +from typing import Any, Literal -from typing_extensions import TypedDict +from typing_extensions import NotRequired, TypedDict from .citations import CitationsContentBlock +from .event_loop import Metrics, Usage from .media import DocumentContent, ImageContent, VideoContent from .tools import ToolResult, ToolUse @@ -23,7 +24,7 @@ class GuardContentText(TypedDict): text: The input text details to be evaluated by the guardrail. """ - qualifiers: List[Literal["grounding_source", "query", "guard_content"]] + qualifiers: list[Literal["grounding_source", "query", "guard_content"]] text: str @@ -45,7 +46,7 @@ class ReasoningTextBlock(TypedDict, total=False): text: The reasoning that the model used to return the output. """ - signature: Optional[str] + signature: str | None text: str @@ -66,9 +67,12 @@ class CachePoint(TypedDict): Attributes: type: The type of cache point, typically "default". + ttl: Optional cache TTL duration (e.g. "5m", "1h"). Supported by providers + that accept Anthropic-compatible cache_control fields. """ type: str + ttl: NotRequired[str] class ContentBlock(TypedDict, total=False): @@ -120,7 +124,7 @@ class DeltaContent(TypedDict, total=False): """ text: str - toolUse: Dict[Literal["input"], str] + toolUse: dict[Literal["input"], str] class ContentBlockStartToolUse(TypedDict): @@ -129,10 +133,12 @@ class ContentBlockStartToolUse(TypedDict): Attributes: name: The name of the tool that the model is requesting to use. toolUseId: The ID for the tool request. + reasoningSignature: Token that ties the model's reasoning to this tool call. """ name: str toolUseId: str + reasoningSignature: NotRequired[str] class ContentBlockStart(TypedDict, total=False): @@ -142,7 +148,7 @@ class ContentBlockStart(TypedDict, total=False): toolUse: Information about a tool that the model is requesting to use. """ - toolUse: Optional[ContentBlockStartToolUse] + toolUse: ContentBlockStartToolUse | None class ContentBlockDelta(TypedDict): @@ -175,17 +181,44 @@ class ContentBlockStop(TypedDict): """ +class MessageMetadata(TypedDict, total=False): + """Optional metadata attached to a message. + + Not sent to model providers — explicitly stripped before model calls. + Persisted alongside the message in session storage. + + Attributes: + usage: Token usage information from the model response. + metrics: Performance metrics from the model response. + custom: Arbitrary user/framework metadata (e.g. compression provenance). + """ + + usage: Usage + metrics: Metrics + custom: dict[str, Any] + + class Message(TypedDict): """A message in a conversation with the agent. Attributes: content: The message content. role: The role of the message sender. + metadata: Optional metadata, stripped before model calls. """ - content: List[ContentBlock] + content: list[ContentBlock] role: Role + metadata: NotRequired[MessageMetadata] -Messages = List[Message] +Messages = list[Message] """A list of messages representing a conversation.""" + + +def get_message_metadata(message: Message) -> MessageMetadata: + """Get metadata for a message, returning empty dict if not present. + + Individual fields (usage, metrics, custom) may not be present. Use .get() to safely access them. + """ + return message.get("metadata", {}) diff --git a/src/strands/types/event_loop.py b/src/strands/types/event_loop.py index 2a7ad344e..73d4e2bc0 100644 --- a/src/strands/types/event_loop.py +++ b/src/strands/types/event_loop.py @@ -37,6 +37,8 @@ class Metrics(TypedDict, total=False): StopReason = Literal[ + "cancelled", + "checkpoint", "content_filtered", "end_turn", "guardrail_intervened", @@ -47,6 +49,8 @@ class Metrics(TypedDict, total=False): ] """Reason for the model ending its response generation. +- "cancelled": Agent execution was cancelled via agent.cancel() +- "checkpoint": Agent paused for durable checkpoint persistence - "content_filtered": Content was filtered due to policy violation - "end_turn": Normal completion of the response - "guardrail_intervened": Guardrail system intervened diff --git a/src/strands/types/exceptions.py b/src/strands/types/exceptions.py index b9c5bc769..7ad49eb24 100644 --- a/src/strands/types/exceptions.py +++ b/src/strands/types/exceptions.py @@ -77,6 +77,22 @@ class SessionException(Exception): pass +class SnapshotException(Exception): + """Exception raised when snapshot operations fail (e.g., unsupported schema version).""" + + pass + + +class ProviderTokenCountError(Exception): + """Thrown when a model provider's native token counting API fails. + + This error is used as internal control flow within provider ``count_tokens()`` overrides. + When caught, the provider falls back to the base class heuristic estimation. + """ + + pass + + class ToolProviderException(Exception): """Exception raised when a tool provider fails to load or cleanup tools.""" @@ -94,3 +110,14 @@ def __init__(self, message: str): """ self.message = message super().__init__(message) + + +class ConcurrencyException(Exception): + """Exception raised when concurrent invocations are attempted on an agent instance. + + Agent instances maintain internal state that cannot be safely accessed concurrently. + This exception is raised when an invocation is attempted while another invocation + is already in progress on the same agent instance. + """ + + pass diff --git a/src/strands/types/guardrails.py b/src/strands/types/guardrails.py index c15ba1bea..70a7aedd5 100644 --- a/src/strands/types/guardrails.py +++ b/src/strands/types/guardrails.py @@ -5,7 +5,7 @@ - Bedrock docs: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_Types_Amazon_Bedrock_Runtime.html """ -from typing import Dict, List, Literal, Optional +from typing import Literal from typing_extensions import TypedDict @@ -22,7 +22,7 @@ class GuardrailConfig(TypedDict, total=False): guardrailIdentifier: str guardrailVersion: str - streamProcessingMode: Optional[Literal["sync", "async"]] + streamProcessingMode: Literal["sync", "async"] | None trace: Literal["enabled", "disabled"] @@ -47,7 +47,7 @@ class TopicPolicy(TypedDict): topics: The topics in the assessment. """ - topics: List[Topic] + topics: list[Topic] class ContentFilter(TypedDict): @@ -71,7 +71,7 @@ class ContentPolicy(TypedDict): filters: List of content filters to apply. """ - filters: List[ContentFilter] + filters: list[ContentFilter] class CustomWord(TypedDict): @@ -108,8 +108,8 @@ class WordPolicy(TypedDict): managedWordLists: List of managed word lists to filter. """ - customWords: List[CustomWord] - managedWordLists: List[ManagedWord] + customWords: list[CustomWord] + managedWordLists: list[ManagedWord] class PIIEntity(TypedDict): @@ -182,8 +182,8 @@ class SensitiveInformationPolicy(TypedDict): regexes: The regex queries in the assessment. """ - piiEntities: List[PIIEntity] - regexes: List[Regex] + piiEntities: list[PIIEntity] + regexes: list[Regex] class ContextualGroundingFilter(TypedDict): @@ -209,7 +209,7 @@ class ContextualGroundingPolicy(TypedDict): filters: The filter details for the guardrails contextual grounding filter. """ - filters: List[ContextualGroundingFilter] + filters: list[ContextualGroundingFilter] class GuardrailAssessment(TypedDict): @@ -239,9 +239,9 @@ class GuardrailTrace(TypedDict): outputAssessments: Assessments of output content against guardrail policies, keyed by output identifier. """ - inputAssessment: Dict[str, GuardrailAssessment] - modelOutput: List[str] - outputAssessments: Dict[str, List[GuardrailAssessment]] + inputAssessment: dict[str, GuardrailAssessment] + modelOutput: list[str] + outputAssessments: dict[str, list[GuardrailAssessment]] class Trace(TypedDict): diff --git a/src/strands/types/interrupt.py b/src/strands/types/interrupt.py index d67148c5a..f76689762 100644 --- a/src/strands/types/interrupt.py +++ b/src/strands/types/interrupt.py @@ -16,7 +16,7 @@ ``` Example: - ```Python + ```python from typing import Any from strands import Agent, tool diff --git a/src/strands/types/json_dict.py b/src/strands/types/json_dict.py index a8636ab10..dc6ae6565 100644 --- a/src/strands/types/json_dict.py +++ b/src/strands/types/json_dict.py @@ -15,6 +15,7 @@ class JSONSerializableDict: def __init__(self, initial_state: dict[str, Any] | None = None): """Initialize JSONSerializableDict.""" self._data: dict[str, Any] + self._version: int = 0 if initial_state: self._validate_json_serializable(initial_state) self._data = copy.deepcopy(initial_state) @@ -34,6 +35,7 @@ def set(self, key: str, value: Any) -> None: self._validate_key(key) self._validate_json_serializable(value) self._data[key] = copy.deepcopy(value) + self._version += 1 def get(self, key: str | None = None) -> Any: """Get a value or entire data. @@ -57,6 +59,19 @@ def delete(self, key: str) -> None: """ self._validate_key(key) self._data.pop(key, None) + self._version += 1 + + def _get_version(self) -> int: + """Get the current version number of the store. + + The version is incremented each time set() or delete() is called. + Consumers can compare versions to detect changes without requiring + explicit dirty flag clearing. + + Returns: + The current version number. + """ + return self._version def _validate_key(self, key: str) -> None: """Validate that a key is valid. diff --git a/src/strands/types/media.py b/src/strands/types/media.py index 69cd60cf3..b1240dffb 100644 --- a/src/strands/types/media.py +++ b/src/strands/types/media.py @@ -5,9 +5,9 @@ - Bedrock docs: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_Types_Amazon_Bedrock_Runtime.html """ -from typing import Literal, Optional +from typing import Literal, TypeAlias -from typing_extensions import TypedDict +from typing_extensions import Required, TypedDict from .citations import CitationsConfig @@ -15,14 +15,50 @@ """Supported document formats.""" -class DocumentSource(TypedDict): +class Location(TypedDict, total=False): + """A location for a document. + + This type is a generic location for a document. Its usage is determined by the underlying model provider. + """ + + type: Required[str] + + +class S3Location(Location, total=False): + """A storage location in an Amazon S3 bucket. + + Used by Bedrock to reference media files stored in S3 instead of passing raw bytes. + + - Docs: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_S3Location.html + + Attributes: + type: s3 + uri: An object URI starting with `s3://`. Required. + bucketOwner: If the bucket belongs to another AWS account, specify that account's ID. Optional. + """ + + # mypy doesn't like overriding this field since its a subclass, but since its just a literal string, this is fine. + + type: Literal["s3"] # type: ignore[misc] + uri: Required[str] + bucketOwner: str + + +SourceLocation: TypeAlias = Location | S3Location + + +class DocumentSource(TypedDict, total=False): """Contains the content of a document. + Only one of `bytes` or `s3Location` should be specified. + Attributes: bytes: The binary content of the document. + location: Location of the document. """ bytes: bytes + location: SourceLocation class DocumentContent(TypedDict, total=False): @@ -37,22 +73,26 @@ class DocumentContent(TypedDict, total=False): format: Literal["pdf", "csv", "doc", "docx", "xls", "xlsx", "html", "txt", "md"] name: str source: DocumentSource - citations: Optional[CitationsConfig] - context: Optional[str] + citations: CitationsConfig | None + context: str | None ImageFormat = Literal["png", "jpeg", "gif", "webp"] """Supported image formats.""" -class ImageSource(TypedDict): +class ImageSource(TypedDict, total=False): """Contains the content of an image. + Only one of `bytes` or `s3Location` should be specified. + Attributes: bytes: The binary content of the image. + location: Location of the image. """ bytes: bytes + location: SourceLocation class ImageContent(TypedDict): @@ -71,14 +111,18 @@ class ImageContent(TypedDict): """Supported video formats.""" -class VideoSource(TypedDict): +class VideoSource(TypedDict, total=False): """Contains the content of a video. + Only one of `bytes` or `s3Location` should be specified. + Attributes: bytes: The binary content of the video. + location: Location of the video. """ bytes: bytes + location: SourceLocation class VideoContent(TypedDict): diff --git a/src/strands/types/session.py b/src/strands/types/session.py index 5da3dcde8..294c518d7 100644 --- a/src/strands/types/session.py +++ b/src/strands/types/session.py @@ -5,7 +5,7 @@ from dataclasses import asdict, dataclass, field from datetime import datetime, timezone from enum import Enum -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from ..interrupt import _InterruptState from .content import Message @@ -69,7 +69,7 @@ class SessionMessage: message: Message message_id: int - redact_message: Optional[Message] = None + redact_message: Message | None = None created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) updated_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) @@ -134,6 +134,7 @@ def from_agent(cls, agent: "Agent") -> "SessionAgent": state=agent.state.get(), _internal_state={ "interrupt_state": agent._interrupt_state.to_dict(), + "model_state": agent._model_state, }, ) @@ -175,6 +176,8 @@ def initialize_internal_state(self, agent: "Agent") -> None: """Initialize internal state of agent.""" if "interrupt_state" in self._internal_state: agent._interrupt_state = _InterruptState.from_dict(self._internal_state["interrupt_state"]) + if "model_state" in self._internal_state: + agent._model_state = self._internal_state["model_state"] def initialize_bidi_internal_state(self, agent: "BidiAgent") -> None: """Initialize internal state of BidiAgent. diff --git a/src/strands/types/streaming.py b/src/strands/types/streaming.py index dcfd541a8..8ec2e8d7b 100644 --- a/src/strands/types/streaming.py +++ b/src/strands/types/streaming.py @@ -5,8 +5,6 @@ - Bedrock docs: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_Types_Amazon_Bedrock_Runtime.html """ -from typing import Optional, Union - from typing_extensions import TypedDict from .citations import CitationLocation @@ -34,7 +32,7 @@ class ContentBlockStartEvent(TypedDict, total=False): start: Information about the content block being started. """ - contentBlockIndex: Optional[int] + contentBlockIndex: int | None start: ContentBlockStart @@ -102,9 +100,9 @@ class ReasoningContentBlockDelta(TypedDict, total=False): text: The reasoning that the model used to return the output. """ - redactedContent: Optional[bytes] - signature: Optional[str] - text: Optional[str] + redactedContent: bytes | None + signature: str | None + text: str | None class ContentBlockDelta(TypedDict, total=False): @@ -131,7 +129,7 @@ class ContentBlockDeltaEvent(TypedDict, total=False): delta: The incremental content update for the content block. """ - contentBlockIndex: Optional[int] + contentBlockIndex: int | None delta: ContentBlockDelta @@ -143,7 +141,7 @@ class ContentBlockStopEvent(TypedDict, total=False): This is optional to accommodate different model providers. """ - contentBlockIndex: Optional[int] + contentBlockIndex: int | None class MessageStopEvent(TypedDict, total=False): @@ -154,7 +152,7 @@ class MessageStopEvent(TypedDict, total=False): stopReason: The reason why the model stopped generating content. """ - additionalModelResponseFields: Optional[Union[dict, list, int, float, str, bool, None]] + additionalModelResponseFields: dict | list | int | float | str | bool | None | None stopReason: StopReason @@ -168,7 +166,7 @@ class MetadataEvent(TypedDict, total=False): """ metrics: Metrics - trace: Optional[Trace] + trace: Trace | None usage: Usage @@ -203,8 +201,8 @@ class RedactContentEvent(TypedDict, total=False): """ - redactUserContentMessage: Optional[str] - redactAssistantContentMessage: Optional[str] + redactUserContentMessage: str | None + redactAssistantContentMessage: str | None class StreamEvent(TypedDict, total=False): diff --git a/src/strands/types/tools.py b/src/strands/types/tools.py index 8f4dba6b1..088c83bdb 100644 --- a/src/strands/types/tools.py +++ b/src/strands/types/tools.py @@ -7,8 +7,9 @@ import uuid from abc import ABC, abstractmethod +from collections.abc import AsyncGenerator, Awaitable, Callable from dataclasses import dataclass -from typing import Any, AsyncGenerator, Awaitable, Callable, Literal, Protocol, Union +from typing import Any, Literal, Protocol from typing_extensions import NotRequired, TypedDict @@ -57,11 +58,13 @@ class ToolUse(TypedDict): Can be any JSON-serializable type. name: The name of the tool to invoke. toolUseId: A unique identifier for this specific tool use request. + reasoningSignature: Token that ties the model's reasoning to this tool call. """ input: Any name: str toolUseId: str + reasoningSignature: NotRequired[str] class ToolResultContent(TypedDict, total=False): @@ -164,11 +167,7 @@ def _interrupt_id(self, name: str) -> str: ToolChoiceAnyDict = dict[Literal["any"], ToolChoiceAny] ToolChoiceToolDict = dict[Literal["tool"], ToolChoiceTool] -ToolChoice = Union[ - ToolChoiceAutoDict, - ToolChoiceAnyDict, - ToolChoiceToolDict, -] +ToolChoice = ToolChoiceAutoDict | ToolChoiceAnyDict | ToolChoiceToolDict """ Configuration for how the model should choose tools. @@ -201,12 +200,7 @@ class ToolFunc(Protocol): __name__: str - def __call__( - self, *args: Any, **kwargs: Any - ) -> Union[ - ToolResult, - Awaitable[ToolResult], - ]: + def __call__(self, *args: Any, **kwargs: Any) -> ToolResult | Awaitable[ToolResult]: """Function signature for Python decorated and module based tools. Returns: diff --git a/src/strands/types/traces.py b/src/strands/types/traces.py index af6188adb..c5c3aaa64 100644 --- a/src/strands/types/traces.py +++ b/src/strands/types/traces.py @@ -1,20 +1,20 @@ """Tracing type definitions for the SDK.""" -from typing import List, Mapping, Optional, Sequence, Union +from collections.abc import Mapping, Sequence -AttributeValue = Union[ - str, - bool, - float, - int, - List[str], - List[bool], - List[float], - List[int], - Sequence[str], - Sequence[bool], - Sequence[int], - Sequence[float], -] +AttributeValue = ( + str + | bool + | float + | int + | list[str] + | list[bool] + | list[float] + | list[int] + | Sequence[str] + | Sequence[bool] + | Sequence[int] + | Sequence[float] +) -Attributes = Optional[Mapping[str, AttributeValue]] +Attributes = Mapping[str, AttributeValue] | None diff --git a/src/strands/vended_plugins/__init__.py b/src/strands/vended_plugins/__init__.py new file mode 100644 index 000000000..78e8047df --- /dev/null +++ b/src/strands/vended_plugins/__init__.py @@ -0,0 +1 @@ +"""Vended plugins for Strands agents.""" diff --git a/src/strands/vended_plugins/context_offloader/__init__.py b/src/strands/vended_plugins/context_offloader/__init__.py new file mode 100644 index 000000000..01ca6f1fc --- /dev/null +++ b/src/strands/vended_plugins/context_offloader/__init__.py @@ -0,0 +1,46 @@ +"""Context offloader plugin for Strands Agents. + +This module provides the ContextOffloader plugin which intercepts oversized +tool results, persists each content block to a storage backend, and replaces +the in-context result with a truncated preview and per-block references. + +Example Usage: + ```python + from strands import Agent + from strands.vended_plugins.context_offloader import ( + ContextOffloader, + InMemoryStorage, + FileStorage, + ) + + # In-memory storage + agent = Agent(plugins=[ + ContextOffloader(storage=InMemoryStorage()) + ]) + + # File storage with custom thresholds + agent = Agent(plugins=[ + ContextOffloader( + storage=FileStorage("./artifacts"), + max_result_tokens=5_000, + preview_tokens=2_000, + ) + ]) + ``` +""" + +from .plugin import ContextOffloader +from .storage import ( + FileStorage, + InMemoryStorage, + S3Storage, + Storage, +) + +__all__ = [ + "ContextOffloader", + "FileStorage", + "InMemoryStorage", + "S3Storage", + "Storage", +] diff --git a/src/strands/vended_plugins/context_offloader/plugin.py b/src/strands/vended_plugins/context_offloader/plugin.py new file mode 100644 index 000000000..6cb98b31b --- /dev/null +++ b/src/strands/vended_plugins/context_offloader/plugin.py @@ -0,0 +1,328 @@ +"""ContextOffloader plugin for managing large tool outputs. + +This module provides the ContextOffloader plugin that intercepts oversized +tool results, persists each content block to a storage backend, and replaces +the in-context result with a truncated preview and per-block references. + +Example: + ```python + from strands import Agent + from strands.vended_plugins.context_offloader import ( + ContextOffloader, + InMemoryStorage, + FileStorage, + ) + + # In-memory storage + agent = Agent(plugins=[ + ContextOffloader(storage=InMemoryStorage()) + ]) + + # File storage with custom thresholds and retrieval tool enabled + agent = Agent(plugins=[ + ContextOffloader( + storage=FileStorage("./artifacts"), + max_result_tokens=5_000, + preview_tokens=2_000, + include_retrieval_tool=True, + ) + ]) + ``` +""" + +from __future__ import annotations + +import json +import logging +from typing import TYPE_CHECKING + +from ...hooks.events import AfterToolCallEvent +from ...plugins import Plugin, hook +from ...tools.decorator import tool +from ...types.content import Message +from ...types.tools import ToolContext, ToolResult, ToolResultContent +from .storage import Storage + +if TYPE_CHECKING: + from ...agent.agent import Agent + +logger = logging.getLogger(__name__) + +_DEFAULT_MAX_RESULT_TOKENS = 2_500 +"""Default token threshold above which tool results are offloaded.""" + +_DEFAULT_PREVIEW_TOKENS = 1_000 +"""Default number of tokens to keep as a preview in context.""" + +_CHARS_PER_TOKEN = 4 +"""Approximate characters per token, fallback for preview slicing without tiktoken.""" + + +class ContextOffloader(Plugin): + """Plugin that offloads oversized tool results to reduce context consumption. + + When a tool result exceeds the configured token threshold, this plugin + stores each content block individually to a storage backend and replaces + the in-context result with a truncated text preview plus per-block references. + + Token estimation uses the agent's model ``count_tokens`` method, which + leverages tiktoken when available and falls back to character-based heuristics. + + Content type handling: + + - **Text**: stored as ``text/plain``, replaced with a preview + - **JSON**: stored as ``application/json``, replaced with a preview + - **Image**: stored in its native format (e.g., ``image/png``), replaced with a + placeholder showing format and size + - **Document**: stored in its native format (e.g., ``application/pdf``), replaced + with a placeholder showing format, name, and size + - **Unknown types**: passed through unchanged + + This operates proactively at tool execution time via ``AfterToolCallEvent``, + before the result enters the conversation — unlike ``SlidingWindowConversationManager`` + which truncates reactively after context overflow. + + Args: + storage: Backend for storing offloaded content (required). + max_result_tokens: Offload results whose estimated token count exceeds this threshold. + preview_tokens: Number of tokens to keep as a text preview in context. + include_retrieval_tool: Whether to register the ``retrieve_offloaded_content`` tool. + Defaults to True. + + Example: + ```python + from strands import Agent + from strands.vended_plugins.context_offloader import ContextOffloader, InMemoryStorage + + agent = Agent(plugins=[ + ContextOffloader(storage=InMemoryStorage()) + ]) + ``` + """ + + name = "context_offloader" + + def __init__( + self, + storage: Storage, + max_result_tokens: int = _DEFAULT_MAX_RESULT_TOKENS, + preview_tokens: int = _DEFAULT_PREVIEW_TOKENS, + *, + include_retrieval_tool: bool = True, + ) -> None: + """Initialize the ContextOffloader plugin. + + Args: + storage: Backend for storing offloaded content. + max_result_tokens: Offload results whose estimated token count exceeds this + threshold. Defaults to ``_DEFAULT_MAX_RESULT_TOKENS`` (2,500). + preview_tokens: Number of tokens to keep as a text preview in context. + Uses tiktoken for exact slicing when available, falls back to + chars/4 heuristic. Defaults to ``_DEFAULT_PREVIEW_TOKENS`` (1,000). + include_retrieval_tool: Whether to register the ``retrieve_offloaded_content`` + tool so the agent can fetch offloaded content. Defaults to True. + + Raises: + ValueError: If max_result_tokens is not positive, preview_tokens is negative, + or preview_tokens >= max_result_tokens. + """ + if max_result_tokens <= 0: + raise ValueError("max_result_tokens must be positive") + if preview_tokens < 0: + raise ValueError("preview_tokens must be non-negative") + if preview_tokens >= max_result_tokens: + raise ValueError("preview_tokens must be less than max_result_tokens") + + self._storage = storage + self._max_result_tokens = max_result_tokens + self._preview_tokens = preview_tokens + self._include_retrieval_tool = include_retrieval_tool + super().__init__() + + def init_agent(self, agent: Agent) -> None: + """Conditionally register the retrieval tool.""" + if not self._include_retrieval_tool: + # Remove the auto-discovered retrieval tool + self._tools = [t for t in self._tools if t.tool_name != "retrieve_offloaded_content"] + + @tool(context=True) + def retrieve_offloaded_content( + self, + reference: str, + tool_context: ToolContext, + ) -> dict | str: + """Retrieve offloaded content by reference. + + Use this tool when you see a placeholder with a reference (ref: ...) + and need the full content. Only use this as a fallback if the data + cannot be accessed using your existing tools. + + Args: + reference: The reference string from the offload placeholder. + tool_context: Injected by the framework. Not user-facing. + """ + try: + content_bytes, content_type = self._storage.retrieve(reference) + except KeyError: + return f"Error: reference not found: {reference}" + + if content_type.startswith("text/"): + return content_bytes.decode("utf-8") + + if content_type == "application/json": + return {"status": "success", "content": [{"json": json.loads(content_bytes)}]} + + if content_type.startswith("image/"): + img_format = content_type.split("/")[-1] + return { + "status": "success", + "content": [{"image": {"format": img_format, "source": {"bytes": content_bytes}}}], + } + + if content_type.startswith("application/"): + doc_format = content_type.split("/")[-1] + doc_block = {"format": doc_format, "name": reference, "source": {"bytes": content_bytes}} + return {"status": "success", "content": [{"document": doc_block}]} + + return content_bytes.decode("utf-8", errors="replace") + + @hook + async def _handle_tool_result(self, event: AfterToolCallEvent) -> None: + """Intercept oversized tool results, offload per-block, and replace with preview.""" + if event.cancel_message is not None: + return + + if self._include_retrieval_tool and event.tool_use.get("name") == self.retrieve_offloaded_content.tool_name: + return + + result = event.result + content = result["content"] + tool_use_id = event.tool_use["toolUseId"] + + # Estimate token count by wrapping the tool result as a message for count_tokens + tool_result_message: Message = {"role": "user", "content": [{"toolResult": result}]} + token_count = await event.agent.model.count_tokens([tool_result_message]) + + if token_count <= self._max_result_tokens: + return + + # Build text preview from text+JSON blocks. + # Empty text blocks are intentionally excluded — they add no content value. + text_preview_parts: list[str] = [] + for block in content: + if block.get("text"): + text_preview_parts.append(block["text"]) + elif "json" in block: + text_preview_parts.append(json.dumps(block["json"], indent=2)) + + full_text = "\n".join(text_preview_parts) if text_preview_parts else "" + + # Store each content block individually + references: list[tuple[str, str, str]] = [] # (ref, content_type, description) + try: + for i, block in enumerate(content): + key = f"{tool_use_id}_{i}" + if block.get("text"): + ref = self._storage.store(key, block["text"].encode("utf-8"), "text/plain") + references.append((ref, "text/plain", f"text, {len(block['text']):,} chars")) + elif "json" in block: + json_bytes = json.dumps(block["json"], indent=2).encode("utf-8") + ref = self._storage.store(key, json_bytes, "application/json") + references.append((ref, "application/json", f"json, {len(json_bytes):,} bytes")) + elif "image" in block: + image = block["image"] + img_format = image.get("format", "unknown") + img_bytes = image.get("source", {}).get("bytes", b"") + if img_bytes: + ref = self._storage.store(key, img_bytes, f"image/{img_format}") + references.append((ref, f"image/{img_format}", f"image/{img_format}, {len(img_bytes):,} bytes")) + else: + references.append(("", f"image/{img_format}", f"image/{img_format}, 0 bytes")) + elif "document" in block: + doc = block["document"] + doc_format = doc.get("format", "unknown") + doc_name = doc.get("name", "unknown") + doc_bytes = doc.get("source", {}).get("bytes", b"") + if doc_bytes: + ref = self._storage.store(key, doc_bytes, f"application/{doc_format}") + references.append((ref, f"application/{doc_format}", f"{doc_name}, {len(doc_bytes):,} bytes")) + else: + references.append(("", f"application/{doc_format}", f"{doc_name}, 0 bytes")) + except Exception: + logger.warning( + "tool_use_id=<%s> | failed to offload tool result, keeping original", + tool_use_id, + exc_info=True, + ) + return + + logger.debug( + "tool_use_id=<%s>, blocks=<%d>, tokens=<%d> | tool result offloaded", + tool_use_id, + len(references), + token_count, + ) + + # Build preview text — use tiktoken for exact slicing when available + preview = self._slice_preview(full_text) if full_text else "" + ref_lines = "\n".join(f" {ref} ({desc})" for ref, _, desc in references if ref) + + guidance = ( + "Tool result was offloaded to external storage due to size.\n" + "Use the preview below to answer if possible.\n" + "Use your available tools to selectively access the data you need." + ) + if self._include_retrieval_tool: + guidance += "\nYou can also use retrieve_offloaded_content with a reference to get the full content." + + preview_text = ( + f"[Offloaded: {len(content)} blocks, ~{token_count:,} tokens]\n" + f"{guidance}\n\n" + f"{preview}\n\n" + f"[Stored references:]\n{ref_lines}" + ) + + # Build new content with preview + placeholders for non-text blocks + new_content: list[ToolResultContent] = [ToolResultContent(text=preview_text)] + for i, block in enumerate(content): + ref = references[i][0] if i < len(references) else "" + if "text" in block or "json" in block: + continue + elif "image" in block: + image = block["image"] + img_format = image.get("format", "unknown") + img_bytes = image.get("source", {}).get("bytes", b"") + placeholder = f"[image: {img_format}, {len(img_bytes) if img_bytes else 0} bytes" + if ref: + placeholder += f" | ref: {ref}" + placeholder += "]" + new_content.append(ToolResultContent(text=placeholder)) + elif "document" in block: + doc = block["document"] + doc_format = doc.get("format", "unknown") + doc_name = doc.get("name", "unknown") + doc_bytes = doc.get("source", {}).get("bytes", b"") + placeholder = f"[document: {doc_format}, {doc_name}, {len(doc_bytes) if doc_bytes else 0} bytes" + if ref: + placeholder += f" | ref: {ref}" + placeholder += "]" + new_content.append(ToolResultContent(text=placeholder)) + else: + new_content.append(block) + + event.result = ToolResult( + toolUseId=result["toolUseId"], + status=result["status"], + content=new_content, + ) + + def _slice_preview(self, text: str) -> str: + """Slice text to approximately preview_tokens using character-based estimation. + + Args: + text: The full text to slice. + + Returns: + The preview text. + """ + return text[: self._preview_tokens * _CHARS_PER_TOKEN] diff --git a/src/strands/vended_plugins/context_offloader/storage.py b/src/strands/vended_plugins/context_offloader/storage.py new file mode 100644 index 000000000..645d2cb09 --- /dev/null +++ b/src/strands/vended_plugins/context_offloader/storage.py @@ -0,0 +1,394 @@ +"""Storage backends for offloaded tool result content. + +This module defines the Storage protocol and provides three built-in +implementations: file-based, in-memory, and S3 storage. Each content block +from a tool result is stored individually with its content type preserved. + +Example: + ```python + from strands.vended_plugins.context_offloader import ( + FileStorage, + InMemoryStorage, + S3Storage, + ) + + # File-based storage + storage = FileStorage(artifact_dir="./artifacts") + ref = storage.store("tool_123_0", b"large output content...", "text/plain") + content, content_type = storage.retrieve(ref) + + # In-memory storage (useful for testing and serverless) + storage = InMemoryStorage() + + # S3 storage + storage = S3Storage(bucket="my-bucket", prefix="artifacts/") + ``` +""" + +import json +import re +import threading +import time +from pathlib import Path +from typing import Any, Protocol, runtime_checkable + +import boto3 +from botocore.config import Config as BotocoreConfig +from botocore.exceptions import ClientError + + +def _sanitize_id(raw_id: str) -> str: + """Sanitize an ID for safe use in filenames and object keys. + + Replaces path separators, parent directory references, and other + unsafe characters with underscores. + + Args: + raw_id: The raw ID string. + + Returns: + A sanitized string safe for use in filenames. + """ + sanitized = raw_id.replace("..", "_").replace("/", "_").replace("\\", "_") + sanitized = re.sub(r"[^\w\-.]", "_", sanitized) + return sanitized + + +@runtime_checkable +class Storage(Protocol): + """Backend for storing and retrieving offloaded content blocks. + + Each content block from a tool result is stored individually with its + content type preserved. The SDK ships three built-in implementations: + ``InMemoryStorage``, ``FileStorage``, and ``S3Storage``. Implement this + protocol to create custom storage backends (e.g., Redis, DynamoDB). + + Lifecycle: + This protocol intentionally does not include eviction or deletion methods. + Stored content accumulates for the lifetime of the storage instance. For + long-running agents, create a new storage instance per session or use a + backend with built-in lifecycle management (e.g., S3 lifecycle policies). + """ + + def store(self, key: str, content: bytes, content_type: str = "text/plain") -> str: + """Store content and return a reference identifier. + + Args: + key: A unique key for this content block. + content: The raw content bytes to store. + content_type: MIME type of the content (e.g., "text/plain", + "application/json", "image/png", "application/pdf"). + + Returns: + A reference string that can be used to retrieve the content later. + """ + ... + + def retrieve(self, reference: str) -> tuple[bytes, str]: + """Retrieve stored content by reference. + + Args: + reference: The reference returned by a previous store() call. + + Returns: + A tuple of (content bytes, content type). + + Raises: + KeyError: If the reference is not found. + """ + ... + + +class FileStorage: + """Store offloaded content as files on disk. + + Files are written to the configured artifact directory with unique names. + File extensions are derived from the content type. A ``.metadata.json`` + sidecar file tracks content types so they survive process restarts. + + Args: + artifact_dir: Directory path where artifact files will be stored. + """ + + _METADATA_FILE = ".metadata.json" + + def __init__(self, artifact_dir: str = "./artifacts") -> None: + """Initialize file-based storage. + + Args: + artifact_dir: Directory path where artifact files will be stored. + """ + self._artifact_dir = Path(artifact_dir) + self._counter: int = 0 + self._lock = threading.Lock() + self._content_types: dict[str, str] = self._load_metadata() + + @staticmethod + def _extension_for(content_type: str) -> str: + """Return a file extension for the given content type.""" + if content_type == "text/plain": + return ".txt" + return f".{content_type.split('/')[-1]}" + + def store(self, key: str, content: bytes, content_type: str = "text/plain") -> str: + """Store content as a file and return the path as reference. + + The returned path preserves the form of ``artifact_dir`` passed to + the constructor: a relative ``artifact_dir`` yields a relative + reference, an absolute one yields an absolute reference. + + Args: + key: A unique key for this content block. + content: The raw content bytes to store. + content_type: MIME type of the content. + + Returns: + The file path (e.g., ``./artifacts/1234_1_key.txt``). + """ + self._artifact_dir.mkdir(parents=True, exist_ok=True) + + sanitized_key = _sanitize_id(key) + timestamp_ms = int(time.time() * 1000) + ext = self._extension_for(content_type) + with self._lock: + self._counter += 1 + counter = self._counter + filename = f"{timestamp_ms}_{counter}_{sanitized_key}{ext}" + self._content_types[filename] = content_type + self._save_metadata() + + file_path = self._artifact_dir / filename + file_path.write_bytes(content) + + return str(file_path) + + def retrieve(self, reference: str) -> tuple[bytes, str]: + """Retrieve content from a stored file. + + Accepts both full paths (as returned by ``store()``) and bare + filenames for backward compatibility. + + Args: + reference: The file path or filename returned by store(). + + Returns: + A tuple of (content bytes, content type). + + Raises: + KeyError: If the file does not exist. + """ + resolved_dir = self._artifact_dir.resolve() + ref_path = Path(reference) + file_path = ref_path.resolve() if len(ref_path.parts) > 1 else (self._artifact_dir / reference).resolve() + if not file_path.is_relative_to(resolved_dir): + file_path = (self._artifact_dir / reference).resolve() + if not file_path.is_relative_to(resolved_dir): + raise KeyError(f"Reference not found: {reference}") + if not file_path.is_file(): + raise KeyError(f"Reference not found: {reference}") + filename = file_path.name + content_type = self._content_types.get(filename, "application/octet-stream") + return file_path.read_bytes(), content_type + + def _load_metadata(self) -> dict[str, str]: + """Load content type metadata from the sidecar file.""" + metadata_path = self._artifact_dir / self._METADATA_FILE + if metadata_path.is_file(): + try: + result: dict[str, str] = json.loads(metadata_path.read_text(encoding="utf-8")) + return result + except (json.JSONDecodeError, OSError): + return {} + return {} + + def _save_metadata(self) -> None: + """Save content type metadata to the sidecar file.""" + metadata_path = self._artifact_dir / self._METADATA_FILE + metadata_path.write_text(json.dumps(self._content_types), encoding="utf-8") + + +class InMemoryStorage: + """Store offloaded content in memory. + + Useful for testing and serverless environments where disk access + is not available or not desired. Thread-safe. + + Note: + Content accumulates for the lifetime of this instance. For long-running + agents, consider creating a new instance per session or switching to + ``FileStorage`` or ``S3Storage`` for persistent storage with external + lifecycle management. + """ + + def __init__(self) -> None: + """Initialize in-memory storage.""" + self._store: dict[str, tuple[bytes, str]] = {} + self._counter: int = 0 + self._lock = threading.Lock() + + def store(self, key: str, content: bytes, content_type: str = "text/plain") -> str: + """Store content in memory and return a reference. + + Args: + key: A unique key for this content block. + content: The raw content bytes to store. + content_type: MIME type of the content. + + Returns: + A unique reference string. + """ + with self._lock: + self._counter += 1 + reference = f"mem_{self._counter}_{key}" + self._store[reference] = (content, content_type) + return reference + + def retrieve(self, reference: str) -> tuple[bytes, str]: + """Retrieve content from memory. + + Args: + reference: The reference returned by store(). + + Returns: + A tuple of (content bytes, content type). + + Raises: + KeyError: If the reference is not found. + """ + with self._lock: + if reference not in self._store: + raise KeyError(f"Reference not found: {reference}") + return self._store[reference] + + def clear(self) -> None: + """Remove all stored content. + + Call this to free memory when offloaded results are no longer needed, + e.g., between sessions or after an invocation completes. + """ + with self._lock: + self._store.clear() + + +class S3Storage: + """Store offloaded content in Amazon S3. + + Objects are stored with unique keys under the configured prefix. + Content type is preserved as S3 object metadata. + + Args: + bucket: S3 bucket name. + prefix: S3 key prefix for organizing stored artifacts. + boto_session: Optional boto3 session. If not provided, a new session + is created using the given region_name. + boto_client_config: Optional botocore client configuration. + region_name: AWS region. Used only when boto_session is not provided. + + Example: + ```python + from strands.vended_plugins.context_offloader import S3Storage + + storage = S3Storage( + bucket="my-agent-artifacts", + prefix="tool-results/", + ) + ``` + """ + + def __init__( + self, + bucket: str, + prefix: str = "", + boto_session: boto3.Session | None = None, + boto_client_config: BotocoreConfig | None = None, + region_name: str | None = None, + ) -> None: + """Initialize S3-based storage. + + Args: + bucket: S3 bucket name. + prefix: S3 key prefix for organizing stored artifacts. + boto_session: Optional boto3 session. If not provided, a new session + is created using the given region_name. + boto_client_config: Optional botocore client configuration. + region_name: AWS region. Used only when boto_session is not provided. + """ + self._bucket = bucket + self._prefix = prefix.strip("/") + if self._prefix: + self._prefix += "/" + + session = boto_session or boto3.Session(region_name=region_name) + + if boto_client_config: + existing_user_agent = getattr(boto_client_config, "user_agent_extra", None) + new_user_agent = f"{existing_user_agent} strands-agents" if existing_user_agent else "strands-agents" + client_config = boto_client_config.merge(BotocoreConfig(user_agent_extra=new_user_agent)) + else: + client_config = BotocoreConfig(user_agent_extra="strands-agents") + + self._client: Any = session.client(service_name="s3", config=client_config) + self._counter: int = 0 + self._lock = threading.Lock() + + def store(self, key: str, content: bytes, content_type: str = "text/plain") -> str: + """Store content as an S3 object and return an ``s3://`` URI as reference. + + Args: + key: A unique key for this content block. + content: The raw content bytes to store. + content_type: MIME type of the content. + + Returns: + An S3 URI (e.g., ``s3://bucket/prefix/1234_1_key``). + + Raises: + botocore.exceptions.ClientError: If the S3 operation fails (e.g., bucket + does not exist, permission denied). + """ + sanitized_key = _sanitize_id(key) + timestamp_ms = int(time.time() * 1000) + with self._lock: + self._counter += 1 + counter = self._counter + s3_key = f"{self._prefix}{timestamp_ms}_{counter}_{sanitized_key}" + + self._client.put_object( + Bucket=self._bucket, + Key=s3_key, + Body=content, + ContentType=content_type, + ) + + return f"s3://{self._bucket}/{s3_key}" + + def retrieve(self, reference: str) -> tuple[bytes, str]: + """Retrieve content from an S3 object. + + Accepts both ``s3://`` URIs (as returned by ``store()``) and raw + S3 keys for backward compatibility. + + Args: + reference: The S3 URI or object key returned by store(). + + Returns: + A tuple of (content bytes, content type). + + Raises: + KeyError: If the object does not exist. + """ + s3_key = reference + if reference.startswith("s3://"): + expected_prefix = f"s3://{self._bucket}/" + if not reference.startswith(expected_prefix): + raise KeyError(f"Reference not found: {reference}") + s3_key = reference[len(expected_prefix) :] + try: + response = self._client.get_object(Bucket=self._bucket, Key=s3_key) + content: bytes = response["Body"].read() + content_type: str = response.get("ContentType", "application/octet-stream") + return content, content_type + except ClientError as e: + if e.response["Error"]["Code"] == "NoSuchKey": + raise KeyError(f"Reference not found: {reference}") from e + raise diff --git a/src/strands/vended_plugins/skills/__init__.py b/src/strands/vended_plugins/skills/__init__.py new file mode 100644 index 000000000..abd6063b9 --- /dev/null +++ b/src/strands/vended_plugins/skills/__init__.py @@ -0,0 +1,31 @@ +"""AgentSkills.io integration for Strands Agents. + +This module provides the AgentSkills plugin for integrating AgentSkills.io skills +into Strands agents. Skills enable progressive disclosure of instructions: +metadata is injected into the system prompt upfront, and full instructions +are loaded on demand via a tool. + +Example Usage: + ```python + from strands import Agent + from strands.vended_plugins.skills import Skill, AgentSkills + + # Load from filesystem via classmethods + skill = Skill.from_file("./skills/pdf-processing") + skills = Skill.from_directory("./skills/") + + # Or let the plugin resolve paths automatically + plugin = AgentSkills(skills=["./skills/pdf-processing"]) + agent = Agent(plugins=[plugin]) + ``` +""" + +from .agent_skills import AgentSkills, SkillSource, SkillSources +from .skill import Skill + +__all__ = [ + "AgentSkills", + "Skill", + "SkillSource", + "SkillSources", +] diff --git a/src/strands/vended_plugins/skills/agent_skills.py b/src/strands/vended_plugins/skills/agent_skills.py new file mode 100644 index 000000000..ded2afb79 --- /dev/null +++ b/src/strands/vended_plugins/skills/agent_skills.py @@ -0,0 +1,421 @@ +"""AgentSkills plugin for integrating Agent Skills into Strands agents. + +This module provides the AgentSkills class that extends the Plugin base class +to add Agent Skills support. The plugin registers a tool for activating +skills, and injects skill metadata into the system prompt. +""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import TYPE_CHECKING, Any, TypeAlias +from xml.sax.saxutils import escape + +from ...hooks.events import BeforeInvocationEvent +from ...plugins import Plugin, hook +from ...tools.decorator import tool +from ...types.content import SystemContentBlock +from ...types.tools import ToolContext +from .skill import Skill + +if TYPE_CHECKING: + from ...agent.agent import Agent + +logger = logging.getLogger(__name__) + +_DEFAULT_STATE_KEY = "agent_skills" +_RESOURCE_DIRS = ("scripts", "references", "assets") +_DEFAULT_MAX_RESOURCE_FILES = 20 + +SkillSource: TypeAlias = str | Path | Skill +"""A single skill source: path string, Path object, or Skill instance.""" + +SkillSources: TypeAlias = SkillSource | list[SkillSource] +"""One or more skill sources.""" + + +def _normalize_sources(sources: SkillSources) -> list[SkillSource]: + """Normalize a single source or list of sources into a list.""" + if isinstance(sources, list): + return sources + return [sources] + + +class AgentSkills(Plugin): + """Plugin that integrates Agent Skills into a Strands agent. + + The AgentSkills plugin extends the Plugin base class and provides: + + 1. A ``skills`` tool that allows the agent to activate skills on demand + 2. System prompt injection of available skill metadata before each invocation + 3. Session persistence of active skill state via ``agent.state`` + + Skills can be provided as filesystem paths (to individual skill directories or + parent directories containing multiple skills) or as pre-built ``Skill`` instances. + + Example: + ```python + from strands import Agent + from strands.vended_plugins.skills import Skill, AgentSkills + + # Load from filesystem + plugin = AgentSkills(skills=["./skills/pdf-processing", "./skills/"]) + + # Or provide Skill instances directly + skill = Skill(name="my-skill", description="A custom skill", instructions="Do the thing") + plugin = AgentSkills(skills=[skill]) + + agent = Agent(plugins=[plugin]) + ``` + """ + + name = "agent_skills" + + def __init__( + self, + skills: SkillSources, + state_key: str = _DEFAULT_STATE_KEY, + max_resource_files: int = _DEFAULT_MAX_RESOURCE_FILES, + strict: bool = False, + ) -> None: + """Initialize the AgentSkills plugin. + + Args: + skills: One or more skill sources. Can be a single value or a list. Each element can be: + + - A ``str`` or ``Path`` to a skill directory (containing SKILL.md) + - A ``str`` or ``Path`` to a parent directory (containing skill subdirectories) + - A ``Skill`` dataclass instance + - An ``https://`` URL pointing directly to raw SKILL.md content + state_key: Key used to store plugin state in ``agent.state``. + max_resource_files: Maximum number of resource files to list in skill responses. + strict: If True, raise on skill validation issues. If False (default), warn and load anyway. + """ + self._strict = strict + self._skills: dict[str, Skill] = self._resolve_skills(_normalize_sources(skills)) + self._state_key = state_key + self._max_resource_files = max_resource_files + super().__init__() + + def init_agent(self, agent: Agent) -> None: + """Initialize the plugin with an agent instance. + + Decorated hooks and tools are auto-registered by the plugin registry. + + Args: + agent: The agent instance to extend with skills support. + """ + if not self._skills: + logger.warning("no skills were loaded, the agent will have no skills available") + logger.debug("skill_count=<%d> | skills plugin initialized", len(self._skills)) + + @tool(context=True) + def skills(self, skill_name: str, tool_context: ToolContext) -> str: # noqa: D417 + """Activate a skill to load its full instructions. + + Use this tool to load the complete instructions for a skill listed in + the available_skills section of your system prompt. + + Args: + skill_name: Name of the skill to activate. + """ + if not skill_name: + available = ", ".join(self._skills) + return f"Error: skill_name is required. Available skills: {available}" + + found = self._skills.get(skill_name) + if found is None: + available = ", ".join(self._skills) + return f"Skill '{skill_name}' not found. Available skills: {available}" + + logger.debug("skill_name=<%s> | skill activated", skill_name) + self._track_activated_skill(tool_context.agent, skill_name) + return self._format_skill_response(found) + + @hook + def _on_before_invocation(self, event: BeforeInvocationEvent) -> None: + """Inject skill metadata into the system prompt before each invocation. + + Removes the previously injected XML block (if any) via exact match, + then appends a fresh one. Uses agent state to track the injected XML + per-agent, so a single plugin instance can be shared across multiple + agents safely. + + When the agent has a structured system prompt (list of SystemContentBlock), + the injection is done at the block level so that cache points and other + structured blocks are preserved. Otherwise falls back to string manipulation. + + Args: + event: The before-invocation event containing the agent reference. + """ + agent = event.agent + + state_data = agent.state.get(self._state_key) + last_injected_xml = state_data.get("last_injected_xml") if isinstance(state_data, dict) else None + + skills_xml = self._generate_skills_xml() + content = agent.system_prompt_content + + if content is not None: + # Content-block path: preserve cache points and other structured blocks + blocks: list[SystemContentBlock] = list(content) + if last_injected_xml is not None: + injected_block: SystemContentBlock = {"text": last_injected_xml} + if injected_block in blocks: + blocks.remove(injected_block) + else: + logger.warning("unable to find previously injected skills XML in system prompt, re-appending") + blocks.append({"text": skills_xml}) + self._set_state_field(agent, "last_injected_xml", skills_xml) + agent.system_prompt = blocks + else: + # String path: legacy behaviour for plain-string system prompts + current_prompt = agent.system_prompt or "" + if last_injected_xml is not None: + if last_injected_xml in current_prompt: + current_prompt = current_prompt.replace(last_injected_xml, "") + else: + logger.warning("unable to find previously injected skills XML in system prompt, re-appending") + injection = f"\n\n{skills_xml}" + new_prompt = f"{current_prompt}{injection}" if current_prompt else skills_xml + new_injected_xml = injection if current_prompt else skills_xml + self._set_state_field(agent, "last_injected_xml", new_injected_xml) + agent.system_prompt = new_prompt + + def get_available_skills(self) -> list[Skill]: + """Get the list of available skills. + + Returns: + A copy of the current skills list. + """ + return list(self._skills.values()) + + def set_available_skills(self, skills: SkillSources) -> None: + """Set the available skills, replacing any existing ones. + + Each element can be a ``Skill`` instance, a ``str`` or ``Path`` to a + skill directory (containing SKILL.md), a ``str`` or ``Path`` to a + parent directory containing skill subdirectories, or an ``https://`` + URL pointing directly to raw SKILL.md content. + + Note: this does not persist state or deactivate skills on any agent. + Active skill state is managed per-agent and will be reconciled on the + next tool call or invocation. + + Args: + skills: One or more skill sources to resolve and set. + """ + self._skills = self._resolve_skills(_normalize_sources(skills)) + + def _format_skill_response(self, skill: Skill) -> str: + """Format the tool response when a skill is activated. + + Includes the full instructions along with relevant metadata fields + and a listing of available resource files (scripts, references, assets) + for filesystem-based skills. + + Args: + skill: The activated skill. + + Returns: + Formatted string with skill instructions and metadata. + """ + if not skill.instructions: + return f"Skill '{skill.name}' activated (no instructions available)." + + parts: list[str] = [skill.instructions] + + metadata_lines: list[str] = [] + if skill.allowed_tools: + metadata_lines.append(f"Allowed tools: {', '.join(skill.allowed_tools)}") + if skill.compatibility: + metadata_lines.append(f"Compatibility: {skill.compatibility}") + if skill.path is not None: + metadata_lines.append(f"Location: {skill.path / 'SKILL.md'}") + + if metadata_lines: + parts.append("\n---\n" + "\n".join(metadata_lines)) + + if skill.path is not None: + resources = self._list_skill_resources(skill.path) + if resources: + parts.append("\nAvailable resources:\n" + "\n".join(f" {r}" for r in resources)) + + return "\n".join(parts) + + def _list_skill_resources(self, skill_path: Path) -> list[str]: + """List resource files in a skill's optional directories. + + Scans the ``scripts/``, ``references/``, and ``assets/`` subdirectories + for files, returning relative paths. Results are capped at + ``max_resource_files`` to avoid context bloat. + + Args: + skill_path: Path to the skill directory. + + Returns: + List of relative file paths (e.g. ``scripts/extract.py``). + """ + files: list[str] = [] + + for dir_name in _RESOURCE_DIRS: + resource_dir = skill_path / dir_name + if not resource_dir.is_dir(): + continue + + for file_path in sorted(resource_dir.rglob("*")): + if not file_path.is_file(): + continue + files.append(file_path.relative_to(skill_path).as_posix()) + if len(files) >= self._max_resource_files: + files.append(f"... (truncated at {self._max_resource_files} files)") + return files + + return files + + def _generate_skills_xml(self) -> str: + """Generate the XML block listing available skills for the system prompt. + + When no skills are loaded, returns a block indicating no skills are available. + Otherwise includes a ```` element for skills loaded from the filesystem, + following the AgentSkills.io integration spec. + + Returns: + XML-formatted string with skill metadata. + """ + if not self._skills: + return "\nNo skills are currently available.\n" + + lines: list[str] = [""] + + for skill in self._skills.values(): + lines.append("") + lines.append(f"{escape(skill.name)}") + lines.append(f"{escape(skill.description)}") + if skill.path is not None: + lines.append(f"{escape(str(skill.path / 'SKILL.md'))}") + lines.append("") + + lines.append("") + return "\n".join(lines) + + def _resolve_skills(self, sources: list[SkillSource]) -> dict[str, Skill]: + """Resolve a list of skill sources into Skill instances. + + Each source can be a Skill instance, a path to a skill directory, + a path to a parent directory containing multiple skills, or an + HTTPS URL pointing to a SKILL.md file. + + Args: + sources: List of skill sources to resolve. + + Returns: + Dict mapping skill names to Skill instances. + """ + resolved: dict[str, Skill] = {} + + for source in sources: + if isinstance(source, Skill): + if source.name in resolved: + logger.warning("name=<%s> | duplicate skill name, overwriting previous skill", source.name) + resolved[source.name] = source + elif isinstance(source, str) and source.startswith("https://"): + try: + skill = Skill.from_url(source, strict=self._strict) + if skill.name in resolved: + logger.warning("name=<%s> | duplicate skill name, overwriting previous skill", skill.name) + resolved[skill.name] = skill + except (RuntimeError, ValueError) as e: + logger.warning("url=<%s> | failed to load skill from URL: %s", source, e) + else: + path = Path(source).resolve() + if not path.exists(): + logger.warning("path=<%s> | skill source path does not exist, skipping", path) + continue + + if path.is_dir(): + # Check if this directory itself is a skill (has SKILL.md) + has_skill_md = (path / "SKILL.md").is_file() or (path / "skill.md").is_file() + + if has_skill_md: + try: + skill = Skill.from_file(path, strict=self._strict) + if skill.name in resolved: + logger.warning( + "name=<%s> | duplicate skill name, overwriting previous skill", skill.name + ) + resolved[skill.name] = skill + except (ValueError, FileNotFoundError) as e: + logger.warning("path=<%s> | failed to load skill: %s", path, e) + else: + # Treat as parent directory containing skill subdirectories + for skill in Skill.from_directory(path, strict=self._strict): + if skill.name in resolved: + logger.warning( + "name=<%s> | duplicate skill name, overwriting previous skill", skill.name + ) + resolved[skill.name] = skill + elif path.is_file() and path.name.lower() == "skill.md": + try: + skill = Skill.from_file(path, strict=self._strict) + if skill.name in resolved: + logger.warning("name=<%s> | duplicate skill name, overwriting previous skill", skill.name) + resolved[skill.name] = skill + except (ValueError, FileNotFoundError) as e: + logger.warning("path=<%s> | failed to load skill: %s", path, e) + + logger.debug("source_count=<%d>, resolved_count=<%d> | skills resolved", len(sources), len(resolved)) + return resolved + + def _set_state_field(self, agent: Agent, key: str, value: Any) -> None: + """Set a single field in the plugin's agent state dict. + + Args: + agent: The agent whose state to update. + key: The state field key. + value: The value to set. + + Raises: + TypeError: If the existing state value is not a dict. + """ + state_data = agent.state.get(self._state_key) + if state_data is not None and not isinstance(state_data, dict): + raise TypeError(f"expected dict for state key '{self._state_key}', got {type(state_data).__name__}") + if state_data is None: + state_data = {} + state_data[key] = value + agent.state.set(self._state_key, state_data) + + def _track_activated_skill(self, agent: Agent, skill_name: str) -> None: + """Record a skill activation in agent state. + + Maintains an ordered list of activated skill names (most recent last), + without duplicates. + + Args: + agent: The agent whose state to update. + skill_name: Name of the activated skill. + """ + state_data = agent.state.get(self._state_key) + activated: list[str] = state_data.get("activated_skills", []) if isinstance(state_data, dict) else [] + if skill_name in activated: + activated.remove(skill_name) + activated.append(skill_name) + self._set_state_field(agent, "activated_skills", activated) + + def get_activated_skills(self, agent: Agent) -> list[str]: + """Get the list of skills activated by this agent. + + Returns skill names in activation order (most recent last). + + Args: + agent: The agent to query. + + Returns: + List of activated skill names. + """ + state_data = agent.state.get(self._state_key) + if isinstance(state_data, dict): + return list(state_data.get("activated_skills", [])) + return [] diff --git a/src/strands/vended_plugins/skills/skill.py b/src/strands/vended_plugins/skills/skill.py new file mode 100644 index 000000000..a60c1cd6c --- /dev/null +++ b/src/strands/vended_plugins/skills/skill.py @@ -0,0 +1,424 @@ +"""Skill data model and loading utilities for AgentSkills.io skills. + +This module defines the Skill dataclass and provides classmethods for +discovering, parsing, and loading skills from the filesystem, raw content, +or HTTPS URLs. Skills are directories containing a SKILL.md file with YAML +frontmatter metadata and markdown instructions. +""" + +from __future__ import annotations + +import logging +import re +import urllib.error +import urllib.request +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +import yaml + +logger = logging.getLogger(__name__) + +_SKILL_NAME_PATTERN = re.compile(r"^[a-z0-9]([a-z0-9-]*[a-z0-9])?$") +_MAX_SKILL_NAME_LENGTH = 64 + + +def _find_skill_md(skill_dir: Path) -> Path: + """Find the SKILL.md file in a skill directory. + + Searches for SKILL.md (case-sensitive preferred) or skill.md as a fallback. + + Args: + skill_dir: Path to the skill directory. + + Returns: + Path to the SKILL.md file. + + Raises: + FileNotFoundError: If no SKILL.md file is found in the directory. + """ + for name in ("SKILL.md", "skill.md"): + candidate = skill_dir / name + if candidate.is_file(): + return candidate + + raise FileNotFoundError(f"path=<{skill_dir}> | no SKILL.md found in skill directory") + + +def _parse_frontmatter(content: str) -> tuple[dict[str, Any], str]: + """Parse YAML frontmatter and body from SKILL.md content. + + Extracts the YAML frontmatter between ``---`` delimiters at line boundaries + and returns parsed key-value pairs along with the remaining markdown body. + + Args: + content: Full content of a SKILL.md file. + + Returns: + Tuple of (frontmatter_dict, body_string). + + Raises: + ValueError: If the frontmatter is malformed or missing required delimiters. + """ + stripped = content.strip() + if not stripped.startswith("---"): + raise ValueError("SKILL.md must start with --- frontmatter delimiter") + + # Find the closing --- delimiter (first line after the opener that is only dashes) + match = re.search(r"\n^---\s*$", stripped, re.MULTILINE) + if match is None: + raise ValueError("SKILL.md frontmatter missing closing --- delimiter") + + frontmatter_str = stripped[3 : match.start()].strip() + body = stripped[match.end() :].strip() + + try: + result = yaml.safe_load(frontmatter_str) + except yaml.YAMLError: + # AgentSkills spec recommends handling malformed YAML (e.g. unquoted colons in values) + # to improve cross-client compatibility. See: agentskills.io/client-implementation/adding-skills-support + logger.warning("YAML parse failed, retrying with colon-quoting fallback") + fixed = _fix_yaml_colons(frontmatter_str) + result = yaml.safe_load(fixed) + + frontmatter: dict[str, Any] = result if isinstance(result, dict) else {} + return frontmatter, body + + +def _fix_yaml_colons(yaml_str: str) -> str: + """Attempt to fix common YAML issues like unquoted colons in values. + + Wraps values containing colons in double quotes to handle cases like: + ``description: Use this skill when: the user asks about PDFs`` + + Args: + yaml_str: The raw YAML string to fix. + + Returns: + The fixed YAML string. + """ + lines: list[str] = [] + for line in yaml_str.splitlines(): + # Match key: value where value contains another colon + match = re.match(r"^(\s*\w[\w-]*):\s+(.+)$", line) + if match: + key, value = match.group(1), match.group(2) + # If value contains a colon and isn't already quoted + if ":" in value and not (value.startswith('"') or value.startswith("'")): + line = f'{key}: "{value}"' + lines.append(line) + return "\n".join(lines) + + +def _validate_skill_name(name: str, dir_path: Path | None = None, *, strict: bool = False) -> None: + """Validate a skill name per the AgentSkills.io specification. + + In lenient mode (default), logs warnings for cosmetic issues but does not raise. + In strict mode, raises ValueError for any validation failure. + + Rules checked: + - 1-64 characters long + - Lowercase alphanumeric characters and hyphens only + - Cannot start or end with a hyphen + - No consecutive hyphens + - Must match parent directory name (if loaded from disk) + + Args: + name: The skill name to validate. + dir_path: Optional path to the skill directory for name matching. + strict: If True, raise ValueError on any issue. If False (default), log warnings. + + Raises: + ValueError: If the skill name is empty, or if strict=True and any rule is violated. + """ + if not name: + raise ValueError("Skill name cannot be empty") + + if len(name) > _MAX_SKILL_NAME_LENGTH: + msg = "name=<%s> | skill name exceeds %d character limit" + if strict: + raise ValueError(msg % (name, _MAX_SKILL_NAME_LENGTH)) + logger.warning(msg, name, _MAX_SKILL_NAME_LENGTH) + + if not _SKILL_NAME_PATTERN.match(name): + msg = ( + "name=<%s> | skill name should be 1-64 lowercase alphanumeric characters or hyphens, " + "should not start/end with hyphen" + ) + if strict: + raise ValueError(msg % name) + logger.warning(msg, name) + + if "--" in name: + msg = "name=<%s> | skill name contains consecutive hyphens" + if strict: + raise ValueError(msg % name) + logger.warning(msg, name) + + if dir_path is not None and dir_path.name != name: + msg = "name=<%s>, directory=<%s> | skill name does not match parent directory name" + if strict: + raise ValueError(msg % (name, dir_path.name)) + logger.warning(msg, name, dir_path.name) + + +def _build_skill_from_frontmatter( + frontmatter: dict[str, Any], + body: str, +) -> Skill: + """Build a Skill instance from parsed frontmatter and body. + + Args: + frontmatter: Parsed YAML frontmatter dict. + body: Markdown body content. + + Returns: + A populated Skill instance. + """ + # Parse allowed-tools (space-delimited string or YAML list) + allowed_tools_raw = frontmatter.get("allowed-tools") or frontmatter.get("allowed_tools") + allowed_tools: list[str] | None = None + if isinstance(allowed_tools_raw, str) and allowed_tools_raw.strip(): + allowed_tools = allowed_tools_raw.strip().split() + elif isinstance(allowed_tools_raw, list): + allowed_tools = [str(item) for item in allowed_tools_raw if item] + + # Parse metadata (nested mapping) + metadata_raw = frontmatter.get("metadata", {}) + metadata: dict[str, Any] = {} + if isinstance(metadata_raw, dict): + metadata = {str(k): v for k, v in metadata_raw.items()} + + skill_license = frontmatter.get("license") + compatibility = frontmatter.get("compatibility") + + return Skill( + name=frontmatter["name"], + description=frontmatter["description"], + instructions=body, + allowed_tools=allowed_tools, + metadata=metadata, + license=str(skill_license) if skill_license else None, + compatibility=str(compatibility) if compatibility else None, + ) + + +@dataclass +class Skill: + r"""Represents an agent skill with metadata and instructions. + + A skill encapsulates a set of instructions and metadata that can be + dynamically loaded by an agent at runtime. Skills support progressive + disclosure: metadata is shown upfront in the system prompt, and full + instructions are loaded on demand via a tool. + + Skills can be created directly or via convenience classmethods:: + + # From a skill directory on disk + skill = Skill.from_file("./skills/my-skill") + + # From raw SKILL.md content + skill = Skill.from_content("---\nname: my-skill\n...") + + # Load all skills from a parent directory + skills = Skill.from_directory("./skills/") + + # From an HTTPS URL + skill = Skill.from_url("https://example.com/SKILL.md") + + Attributes: + name: Unique identifier for the skill (1-64 chars, lowercase alphanumeric + hyphens). + description: Human-readable description of what the skill does. + instructions: Full markdown instructions from the SKILL.md body. + path: Filesystem path to the skill directory, if loaded from disk. + allowed_tools: List of tool names the skill is allowed to use. (Experimental: not yet enforced) + metadata: Additional key-value metadata from the SKILL.md frontmatter. + license: License identifier (e.g., "Apache-2.0"). + compatibility: Compatibility information string. + """ + + name: str + description: str + instructions: str = "" + path: Path | None = None + allowed_tools: list[str] | None = None + metadata: dict[str, Any] = field(default_factory=dict) + license: str | None = None + compatibility: str | None = None + + @classmethod + def from_file(cls, skill_path: str | Path, *, strict: bool = False) -> Skill: + """Load a single skill from a directory containing SKILL.md. + + Resolves the filesystem path, reads the file content, and delegates + to ``from_content`` for parsing. After loading, sets the skill's + ``path`` and validates the skill name against the parent directory. + + Args: + skill_path: Path to the skill directory or the SKILL.md file itself. + strict: If True, raise on any validation issue. If False (default), warn and load anyway. + + Returns: + A Skill instance populated from the SKILL.md file. + + Raises: + FileNotFoundError: If the path does not exist or SKILL.md is not found. + ValueError: If the skill metadata is invalid. + """ + skill_path = Path(skill_path).resolve() + + if skill_path.is_file() and skill_path.name.lower() == "skill.md": + skill_md_path = skill_path + skill_dir = skill_path.parent + elif skill_path.is_dir(): + skill_dir = skill_path + skill_md_path = _find_skill_md(skill_dir) + else: + raise FileNotFoundError( + f"path=<{skill_path}> | skill path does not exist or is not a valid skill directory" + ) + + logger.debug("path=<%s> | loading skill", skill_md_path) + + content = skill_md_path.read_text(encoding="utf-8") + skill = cls.from_content(content, strict=strict) + + # Set path and check directory name match (from_content already validated the name format) + skill.path = skill_dir + if skill_dir.name != skill.name: + msg = "name=<%s>, directory=<%s> | skill name does not match parent directory name" + if strict: + raise ValueError(msg % (skill.name, skill_dir.name)) + logger.warning(msg, skill.name, skill_dir.name) + + logger.debug("name=<%s>, path=<%s> | skill loaded successfully", skill.name, skill.path) + return skill + + @classmethod + def from_content(cls, content: str, *, strict: bool = False) -> Skill: + """Parse SKILL.md content into a Skill instance. + + This is a convenience method for creating a Skill from raw SKILL.md + content (YAML frontmatter + markdown body) without requiring a file on + disk. + + Example:: + + content = '''--- + name: my-skill + description: Does something useful + --- + # Instructions + Follow these steps... + ''' + skill = Skill.from_content(content) + + Args: + content: Raw SKILL.md content with YAML frontmatter and markdown body. + strict: If True, raise on any validation issue. If False (default), warn and load anyway. + + Returns: + A Skill instance populated from the parsed content. + + Raises: + ValueError: If the content is missing required fields or has invalid frontmatter. + """ + frontmatter, body = _parse_frontmatter(content) + + name = frontmatter.get("name") + if not isinstance(name, str) or not name: + raise ValueError("SKILL.md content must have a 'name' field in frontmatter") + + description = frontmatter.get("description") + if not isinstance(description, str) or not description: + raise ValueError("SKILL.md content must have a 'description' field in frontmatter") + + _validate_skill_name(name, strict=strict) + + return _build_skill_from_frontmatter(frontmatter, body) + + @classmethod + def from_url(cls, url: str, *, strict: bool = False) -> Skill: + """Load a skill by fetching its SKILL.md content from an HTTPS URL. + + Fetches the raw SKILL.md content over HTTPS and parses it using + :meth:`from_content`. The URL must point directly to the raw + file content (not an HTML page). + + Example:: + + skill = Skill.from_url( + "https://raw.githubusercontent.com/org/repo/main/SKILL.md" + ) + + Args: + url: An ``https://`` URL pointing directly to raw SKILL.md content. + strict: If True, raise on any validation issue. If False (default), + warn and load anyway. + + Returns: + A Skill instance populated from the fetched SKILL.md content. + + Raises: + ValueError: If ``url`` is not an ``https://`` URL. + RuntimeError: If the SKILL.md content cannot be fetched. + """ + if not url.startswith("https://"): + raise ValueError(f"url=<{url}> | not a valid HTTPS URL") + + logger.info("url=<%s> | fetching skill content", url) + + try: + req = urllib.request.Request(url, headers={"User-Agent": "strands-agents-sdk"}) # noqa: S310 + with urllib.request.urlopen(req, timeout=30) as response: # noqa: S310 + content: str = response.read().decode("utf-8") + except urllib.error.HTTPError as e: + raise RuntimeError(f"url=<{url}> | HTTP {e.code}: {e.reason}") from e + except urllib.error.URLError as e: + raise RuntimeError(f"url=<{url}> | failed to fetch skill: {e.reason}") from e + + return cls.from_content(content, strict=strict) + + @classmethod + def from_directory(cls, skills_dir: str | Path, *, strict: bool = False) -> list[Skill]: + """Load all skills from a parent directory containing skill subdirectories. + + Each subdirectory containing a SKILL.md file is treated as a skill. + Subdirectories without SKILL.md are silently skipped. + + Args: + skills_dir: Path to the parent directory containing skill subdirectories. + strict: If True, raise on any validation issue. If False (default), warn and load anyway. + + Returns: + List of Skill instances loaded from the directory. + + Raises: + FileNotFoundError: If the skills directory does not exist. + """ + skills_dir = Path(skills_dir).resolve() + + if not skills_dir.is_dir(): + raise FileNotFoundError(f"path=<{skills_dir}> | skills directory does not exist") + + skills: list[Skill] = [] + + for child in sorted(skills_dir.iterdir()): + if not child.is_dir(): + continue + + try: + _find_skill_md(child) + except FileNotFoundError: + logger.debug("path=<%s> | skipping directory without SKILL.md", child) + continue + + try: + skill = cls.from_file(child, strict=strict) + skills.append(skill) + except (ValueError, FileNotFoundError) as e: + logger.warning("path=<%s> | skipping skill due to error: %s", child, e) + + logger.debug("path=<%s>, count=<%d> | loaded skills from directory", skills_dir, len(skills)) + return skills diff --git a/src/strands/vended_plugins/steering/__init__.py b/src/strands/vended_plugins/steering/__init__.py new file mode 100644 index 000000000..c928d0c63 --- /dev/null +++ b/src/strands/vended_plugins/steering/__init__.py @@ -0,0 +1,47 @@ +"""Steering system for Strands agents. + +Provides contextual guidance for agents through modular prompting with progressive disclosure. +Instead of front-loading all instructions, steering handlers provide just-in-time feedback +based on local context data populated by context callbacks. + +Core components: + +- SteeringHandler: Base class for guidance logic with local context +- SteeringContextCallback: Protocol for context update functions +- SteeringContextProvider: Protocol for multi-event context providers +- ToolSteeringAction/ModelSteeringAction: Proceed/Guide/Interrupt decisions + +Usage: + handler = LLMSteeringHandler(system_prompt="...") + agent = Agent(tools=[...], plugins=[handler]) +""" + +# Core primitives +# Context providers +from .context_providers.ledger_provider import ( + LedgerAfterToolCall, + LedgerBeforeToolCall, + LedgerProvider, +) +from .core.action import Guide, Interrupt, ModelSteeringAction, Proceed, ToolSteeringAction +from .core.context import SteeringContextCallback, SteeringContextProvider +from .core.handler import SteeringHandler + +# Handler implementations +from .handlers.llm import LLMPromptMapper, LLMSteeringHandler + +__all__ = [ + "ToolSteeringAction", + "ModelSteeringAction", + "Proceed", + "Guide", + "Interrupt", + "SteeringHandler", + "SteeringContextCallback", + "SteeringContextProvider", + "LedgerBeforeToolCall", + "LedgerAfterToolCall", + "LedgerProvider", + "LLMSteeringHandler", + "LLMPromptMapper", +] diff --git a/src/strands/vended_plugins/steering/context_providers/__init__.py b/src/strands/vended_plugins/steering/context_providers/__init__.py new file mode 100644 index 000000000..242ed9cf1 --- /dev/null +++ b/src/strands/vended_plugins/steering/context_providers/__init__.py @@ -0,0 +1,13 @@ +"""Context providers for steering evaluation.""" + +from .ledger_provider import ( + LedgerAfterToolCall, + LedgerBeforeToolCall, + LedgerProvider, +) + +__all__ = [ + "LedgerAfterToolCall", + "LedgerBeforeToolCall", + "LedgerProvider", +] diff --git a/src/strands/vended_plugins/steering/context_providers/ledger_provider.py b/src/strands/vended_plugins/steering/context_providers/ledger_provider.py new file mode 100644 index 000000000..43f56717a --- /dev/null +++ b/src/strands/vended_plugins/steering/context_providers/ledger_provider.py @@ -0,0 +1,91 @@ +"""Ledger context provider for comprehensive agent activity tracking. + +Tracks complete agent activity ledger including tool calls, conversation history, +and timing information. This comprehensive audit trail enables steering handlers +to make informed guidance decisions based on agent behavior patterns and history. + +Data captured: + + - Tool call history with inputs, outputs, timing, success/failure + - Conversation messages and agent responses + - Session metadata and timing information + - Error patterns and recovery attempts + +Usage: + Use as context provider functions or mix into steering handlers. +""" + +import logging +from datetime import datetime +from typing import Any + +from ....hooks.events import AfterToolCallEvent, BeforeToolCallEvent +from ..core.context import SteeringContext, SteeringContextCallback, SteeringContextProvider + +logger = logging.getLogger(__name__) + + +class LedgerBeforeToolCall(SteeringContextCallback[BeforeToolCallEvent]): + """Context provider for ledger tracking before tool calls.""" + + def __init__(self) -> None: + """Initialize the ledger provider.""" + self.session_start = datetime.now().isoformat() + + def __call__(self, event: BeforeToolCallEvent, steering_context: SteeringContext, **kwargs: Any) -> None: + """Update ledger before tool call.""" + ledger = steering_context.data.get("ledger") or {} + + if not ledger: + ledger = { + "session_start": self.session_start, + "tool_calls": [], + "conversation_history": [], + "session_metadata": {}, + } + + tool_call_entry = { + "timestamp": datetime.now().isoformat(), + "tool_use_id": event.tool_use.get("toolUseId"), + "tool_name": event.tool_use.get("name"), + "tool_args": event.tool_use.get("input", {}), + "status": "pending", + } + ledger["tool_calls"].append(tool_call_entry) + steering_context.data.set("ledger", ledger) + + +class LedgerAfterToolCall(SteeringContextCallback[AfterToolCallEvent]): + """Context provider for ledger tracking after tool calls.""" + + def __call__(self, event: AfterToolCallEvent, steering_context: SteeringContext, **kwargs: Any) -> None: + """Update ledger after tool call.""" + ledger = steering_context.data.get("ledger") or {} + + if ledger.get("tool_calls"): + tool_use_id = event.tool_use.get("toolUseId") + + # Search for the matching tool call in the ledger to update it + for call in reversed(ledger["tool_calls"]): + if call.get("tool_use_id") == tool_use_id and call.get("status") == "pending": + call.update( + { + "completion_timestamp": datetime.now().isoformat(), + "status": event.result["status"], + "result": event.result["content"], + "error": str(event.exception) if event.exception else None, + } + ) + steering_context.data.set("ledger", ledger) + break + + +class LedgerProvider(SteeringContextProvider): + """Combined ledger context provider for both before and after tool calls.""" + + def context_providers(self, **kwargs: Any) -> list[SteeringContextCallback]: + """Return ledger context providers with shared state.""" + return [ + LedgerBeforeToolCall(), + LedgerAfterToolCall(), + ] diff --git a/src/strands/vended_plugins/steering/core/__init__.py b/src/strands/vended_plugins/steering/core/__init__.py new file mode 100644 index 000000000..bb229b175 --- /dev/null +++ b/src/strands/vended_plugins/steering/core/__init__.py @@ -0,0 +1,17 @@ +"""Core steering system interfaces and base classes.""" + +from .action import Guide, Interrupt, ModelSteeringAction, Proceed, ToolSteeringAction +from .context import SteeringContext, SteeringContextCallback, SteeringContextProvider +from .handler import SteeringHandler + +__all__ = [ + "ToolSteeringAction", + "ModelSteeringAction", + "Proceed", + "Guide", + "Interrupt", + "SteeringHandler", + "SteeringContext", + "SteeringContextCallback", + "SteeringContextProvider", +] diff --git a/src/strands/vended_plugins/steering/core/action.py b/src/strands/vended_plugins/steering/core/action.py new file mode 100644 index 000000000..b1f124b40 --- /dev/null +++ b/src/strands/vended_plugins/steering/core/action.py @@ -0,0 +1,76 @@ +"""SteeringAction types for steering evaluation results. + +Defines structured outcomes from steering handlers that determine how agent actions +should be handled. SteeringActions enable modular prompting by providing just-in-time +feedback rather than front-loading all instructions in monolithic prompts. + +Flow: + SteeringHandler.steer_*() → SteeringAction → Event handling + ↓ ↓ ↓ + Evaluate context Action type Execution modified + +SteeringAction types: + Proceed: Allow execution to continue without intervention + Guide: Provide contextual guidance to redirect the agent + Interrupt: Pause execution for human input + +Extensibility: + New action types can be added to the union. Always handle the default + case in pattern matching to maintain backward compatibility. +""" + +from typing import Annotated, Literal + +from pydantic import BaseModel, Field + + +class Proceed(BaseModel): + """Allow execution to continue without intervention. + + The action proceeds as planned. The reason provides context + for logging and debugging purposes. + """ + + type: Literal["proceed"] = "proceed" + reason: str + + +class Guide(BaseModel): + """Provide contextual guidance to redirect the agent. + + The agent receives the reason as contextual feedback to help guide + its behavior. The specific handling depends on the steering context + (e.g., tool call vs. model response). + """ + + type: Literal["guide"] = "guide" + reason: str + + +class Interrupt(BaseModel): + """Pause execution for human input via interrupt system. + + Execution is paused and human input is requested through Strands' + interrupt system. The human can approve or deny the operation, and their + decision determines whether execution continues or is cancelled. + """ + + type: Literal["interrupt"] = "interrupt" + reason: str + + +# Context-specific steering action types +ToolSteeringAction = Annotated[Proceed | Guide | Interrupt, Field(discriminator="type")] +"""Steering actions valid for tool steering (steer_before_tool). + +- Proceed: Allow tool execution to continue +- Guide: Cancel tool and provide feedback for alternative approaches +- Interrupt: Pause for human input before tool execution +""" + +ModelSteeringAction = Annotated[Proceed | Guide, Field(discriminator="type")] +"""Steering actions valid for model steering (steer_after_model). + +- Proceed: Accept model response without modification +- Guide: Discard model response and retry with guidance +""" diff --git a/src/strands/vended_plugins/steering/core/context.py b/src/strands/vended_plugins/steering/core/context.py new file mode 100644 index 000000000..446c4c9f9 --- /dev/null +++ b/src/strands/vended_plugins/steering/core/context.py @@ -0,0 +1,77 @@ +"""Steering context protocols for contextual guidance. + +Defines protocols for context callbacks and providers that populate +steering context data used by handlers to make guidance decisions. + +Architecture: + SteeringContextCallback → Handler.steering_context → SteeringHandler.steer() + ↓ ↓ ↓ + Update local context Store in handler Access via self.steering_context + +Context lifecycle: + 1. Handler registers context callbacks for hook events + 2. Callbacks update handler's local steering_context on events + 3. Handler accesses self.steering_context in steer() method + 4. Context persists across calls within handler instance + +Implementation: + Each handler maintains its own JSONSerializableDict context. + Callbacks are registered per handler instance for isolation. + Providers can supply multiple callbacks for different events. +""" + +import logging +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Any, Generic, TypeVar, cast, get_args, get_origin + +from ....hooks.registry import HookEvent +from ....types.json_dict import JSONSerializableDict + +logger = logging.getLogger(__name__) + + +@dataclass +class SteeringContext: + """Container for steering context data.""" + + """Container for steering context data. + + This class should not be instantiated directly - it is intended for internal use only. + """ + + data: JSONSerializableDict = field(default_factory=JSONSerializableDict) + + +EventType = TypeVar("EventType", bound=HookEvent, contravariant=True) + + +class SteeringContextCallback(ABC, Generic[EventType]): + """Abstract base class for steering context update callbacks.""" + + @property + def event_type(self) -> type[HookEvent]: + """Return the event type this callback handles.""" + for base in getattr(self.__class__, "__orig_bases__", ()): + if get_origin(base) is SteeringContextCallback: + return cast(type[HookEvent], get_args(base)[0]) + raise ValueError("Could not determine event type from generic parameter") + + def __call__(self, event: EventType, steering_context: "SteeringContext", **kwargs: Any) -> None: + """Update steering context based on hook event. + + Args: + event: The hook event that triggered the callback + steering_context: The steering context to update + **kwargs: Additional keyword arguments for context updates + """ + ... + + +class SteeringContextProvider(ABC): + """Abstract base class for context providers that handle multiple event types.""" + + @abstractmethod + def context_providers(self, **kwargs: Any) -> list[SteeringContextCallback]: + """Return list of context callbacks with event types extracted from generics.""" + ... diff --git a/src/strands/vended_plugins/steering/core/handler.py b/src/strands/vended_plugins/steering/core/handler.py new file mode 100644 index 000000000..214118d4f --- /dev/null +++ b/src/strands/vended_plugins/steering/core/handler.py @@ -0,0 +1,218 @@ +"""Steering handler base class for providing contextual guidance to agents. + +Provides modular prompting through contextual guidance that appears when relevant, +rather than front-loading all instructions. Handlers integrate with the Strands hook +system to intercept actions and provide just-in-time feedback based on local context. + +Architecture: + Hook Event → Context Callbacks → Update steering_context → steer_*() → SteeringAction + ↓ ↓ ↓ ↓ ↓ + Hook triggered Populate context Handler evaluates Handler decides Action taken + +Lifecycle: + 1. Context callbacks update handler's steering_context on hook events + 2. BeforeToolCallEvent triggers steer_before_tool() for tool steering + 3. AfterModelCallEvent triggers steer_after_model() for model steering + 4. Handler accesses self.steering_context for guidance decisions + 5. SteeringAction determines execution flow + +Implementation: + Subclass SteeringHandler and override steer_before_tool() and/or steer_after_model(). + Both methods have default implementations that return Proceed, so you only need to + override the methods you want to customize. + Pass context_providers in constructor to register context update functions. + Each handler maintains isolated steering_context that persists across calls. + +SteeringAction handling for steer_before_tool: + Proceed: Tool executes immediately + Guide: Tool cancelled, agent receives contextual feedback to explore alternatives + Interrupt: Tool execution paused for human input via interrupt system + +SteeringAction handling for steer_after_model: + Proceed: Model response accepted without modification + Guide: Discard model response and retry (message is dropped, model is called again) + Interrupt: Model response handling paused for human input via interrupt system +""" + +import logging +from typing import TYPE_CHECKING, Any + +from ....hooks.events import AfterModelCallEvent, BeforeToolCallEvent +from ....plugins import Plugin, hook +from ....types.content import Message +from ....types.streaming import StopReason +from ....types.tools import ToolUse +from .action import Guide, Interrupt, ModelSteeringAction, Proceed, ToolSteeringAction +from .context import SteeringContext, SteeringContextProvider + +if TYPE_CHECKING: + from ....agent import Agent + +logger = logging.getLogger(__name__) + + +class SteeringHandler(Plugin): + """Base class for steering handlers that provide contextual guidance to agents. + + Steering handlers maintain local context and register hook callbacks + to populate context data as needed for guidance decisions. + """ + + name: str = "steering" + + def __init__(self, context_providers: list[SteeringContextProvider] | None = None): + """Initialize the steering handler. + + Args: + context_providers: List of context providers for context updates + """ + super().__init__() + self.steering_context = SteeringContext() + self._context_callbacks = [] + + # Collect callbacks from all providers + for provider in context_providers or []: + self._context_callbacks.extend(provider.context_providers()) + + logger.debug("handler_class=<%s> | initialized", self.__class__.__name__) + + def init_agent(self, agent: "Agent") -> None: + """Initialize the steering handler with an agent. + + Registers hook callbacks for steering guidance and context updates. + + Args: + agent: The agent instance to attach steering to. + """ + # Register context update callbacks + for callback in self._context_callbacks: + agent.add_hook(lambda event, callback=callback: callback(event, self.steering_context), callback.event_type) + + @hook + async def provide_tool_steering_guidance(self, event: BeforeToolCallEvent) -> None: + """Provide steering guidance for tool call.""" + tool_name = event.tool_use["name"] + logger.debug("tool_name=<%s> | providing tool steering guidance", tool_name) + + try: + action = await self.steer_before_tool(agent=event.agent, tool_use=event.tool_use) + except Exception as e: + logger.debug("tool_name=<%s>, error=<%s> | tool steering handler guidance failed", tool_name, e) + return + + self._handle_tool_steering_action(action, event, tool_name) + + def _handle_tool_steering_action( + self, action: ToolSteeringAction, event: BeforeToolCallEvent, tool_name: str + ) -> None: + """Handle the steering action for tool calls by modifying tool execution flow. + + Proceed: Tool executes normally + Guide: Tool cancelled with contextual feedback for agent to consider alternatives + Interrupt: Tool execution paused for human input via interrupt system + """ + if isinstance(action, Proceed): + logger.debug("tool_name=<%s> | tool call proceeding", tool_name) + elif isinstance(action, Guide): + logger.debug("tool_name=<%s> | tool call guided: %s", tool_name, action.reason) + event.cancel_tool = f"Tool call cancelled. {action.reason} You MUST follow this guidance immediately." + elif isinstance(action, Interrupt): + logger.debug("tool_name=<%s> | tool call requires human input: %s", tool_name, action.reason) + can_proceed: bool = event.interrupt(name=f"steering_input_{tool_name}", reason={"message": action.reason}) + logger.debug("tool_name=<%s> | received human input for tool call", tool_name) + + if not can_proceed: + event.cancel_tool = f"Manual approval denied: {action.reason}" + logger.debug("tool_name=<%s> | tool call denied by manual approval", tool_name) + else: + logger.debug("tool_name=<%s> | tool call approved manually", tool_name) + else: + raise ValueError(f"Unknown steering action type for tool call: {action}") + + @hook + async def provide_model_steering_guidance(self, event: AfterModelCallEvent) -> None: + """Provide steering guidance for model response.""" + logger.debug("providing model steering guidance") + + # Only steer on successful model responses + if event.stop_response is None: + logger.debug("no stop response available | skipping model steering") + return + + try: + action = await self.steer_after_model( + agent=event.agent, message=event.stop_response.message, stop_reason=event.stop_response.stop_reason + ) + except Exception as e: + logger.debug("error=<%s> | model steering handler guidance failed", e) + return + + await self._handle_model_steering_action(action, event) + + async def _handle_model_steering_action(self, action: ModelSteeringAction, event: AfterModelCallEvent) -> None: + """Handle the steering action for model responses by modifying response handling flow. + + Proceed: Model response accepted without modification + Guide: Discard model response and retry with guidance message added to conversation + """ + if isinstance(action, Proceed): + logger.debug("model response proceeding") + elif isinstance(action, Guide): + logger.debug("model response guided (retrying): %s", action.reason) + # Set retry flag to discard current response + event.retry = True + # Add guidance message to agent's conversation so model sees it on retry + await event.agent._append_messages({"role": "user", "content": [{"text": action.reason}]}) + logger.debug("added guidance message to conversation for model retry") + else: + raise ValueError(f"Unknown steering action type for model response: {action}") + + async def steer_before_tool(self, *, agent: "Agent", tool_use: ToolUse, **kwargs: Any) -> ToolSteeringAction: + """Provide contextual guidance before tool execution. + + This method is called before a tool is executed, allowing the handler to: + - Proceed: Allow tool execution to continue + - Guide: Cancel tool and provide feedback for alternative approaches + - Interrupt: Pause for human input before tool execution + + Args: + agent: The agent instance + tool_use: The tool use object with name and arguments + **kwargs: Additional keyword arguments for guidance evaluation + + Returns: + ToolSteeringAction indicating how to guide the tool execution + + Note: + Access steering context via self.steering_context + Default implementation returns Proceed (allow tool execution) + Override this method to implement custom tool steering logic + """ + return Proceed(reason="Default implementation: allowing tool execution") + + async def steer_after_model( + self, *, agent: "Agent", message: Message, stop_reason: StopReason, **kwargs: Any + ) -> ModelSteeringAction: + """Provide contextual guidance after model response. + + This method is called after the model generates a response, allowing the handler to: + - Proceed: Accept the model response without modification + - Guide: Discard the response and retry (message is dropped, model is called again) + + Note: Interrupt is not supported for model steering as the model has already responded. + + Args: + agent: The agent instance + message: The model's generated message + stop_reason: The reason the model stopped generating + **kwargs: Additional keyword arguments for guidance evaluation + + Returns: + ModelSteeringAction indicating how to handle the model response + + Note: + Access steering context via self.steering_context + Default implementation returns Proceed (accept response as-is) + Override this method to implement custom model steering logic + """ + return Proceed(reason="Default implementation: accepting model response") diff --git a/src/strands/vended_plugins/steering/handlers/__init__.py b/src/strands/vended_plugins/steering/handlers/__init__.py new file mode 100644 index 000000000..fe364a5a2 --- /dev/null +++ b/src/strands/vended_plugins/steering/handlers/__init__.py @@ -0,0 +1,5 @@ +"""Steering handler implementations.""" + +from collections.abc import Sequence + +__all__: Sequence[str] = [] diff --git a/src/strands/vended_plugins/steering/handlers/llm/__init__.py b/src/strands/vended_plugins/steering/handlers/llm/__init__.py new file mode 100644 index 000000000..4dcccbe80 --- /dev/null +++ b/src/strands/vended_plugins/steering/handlers/llm/__init__.py @@ -0,0 +1,6 @@ +"""LLM steering handler with prompt mapping.""" + +from .llm_handler import LLMSteeringHandler +from .mappers import DefaultPromptMapper, LLMPromptMapper, ToolUse + +__all__ = ["LLMSteeringHandler", "LLMPromptMapper", "DefaultPromptMapper", "ToolUse"] diff --git a/src/strands/vended_plugins/steering/handlers/llm/llm_handler.py b/src/strands/vended_plugins/steering/handlers/llm/llm_handler.py new file mode 100644 index 000000000..6d0a31eeb --- /dev/null +++ b/src/strands/vended_plugins/steering/handlers/llm/llm_handler.py @@ -0,0 +1,99 @@ +"""LLM-based steering handler that uses an LLM to provide contextual guidance.""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any, Literal, cast + +from pydantic import BaseModel, Field + +from .....models import Model +from .....types.tools import ToolUse +from ...context_providers.ledger_provider import LedgerProvider +from ...core.action import Guide, Interrupt, Proceed, ToolSteeringAction +from ...core.context import SteeringContextProvider +from ...core.handler import SteeringHandler +from .mappers import DefaultPromptMapper, LLMPromptMapper + +if TYPE_CHECKING: + from .....agent import Agent + +logger = logging.getLogger(__name__) + + +class _LLMSteering(BaseModel): + """Structured output model for LLM steering decisions.""" + + decision: Literal["proceed", "guide", "interrupt"] = Field( + description="Steering decision: 'proceed' to continue, 'guide' to provide feedback, 'interrupt' for human input" + ) + reason: str = Field(description="Clear explanation of the steering decision and any guidance provided") + + +class LLMSteeringHandler(SteeringHandler): + """Steering handler that uses an LLM to provide contextual guidance. + + Uses natural language prompts to evaluate tool calls and provide + contextual steering guidance to help agents navigate complex workflows. + """ + + def __init__( + self, + system_prompt: str, + prompt_mapper: LLMPromptMapper | None = None, + model: Model | None = None, + context_providers: list[SteeringContextProvider] | None = None, + ): + """Initialize the LLMSteeringHandler. + + Args: + system_prompt: System prompt defining steering guidance rules + prompt_mapper: Custom prompt mapper for evaluation prompts + model: Optional model override for steering evaluation + context_providers: List of context providers for populating steering context. + Defaults to [LedgerProvider()] if None. Pass an empty list to disable + context providers. + """ + providers: list[SteeringContextProvider] = ( + [LedgerProvider()] if context_providers is None else context_providers + ) + super().__init__(context_providers=providers) + self.system_prompt = system_prompt + self.prompt_mapper = prompt_mapper or DefaultPromptMapper() + self.model = model + + async def steer_before_tool(self, *, agent: Agent, tool_use: ToolUse, **kwargs: Any) -> ToolSteeringAction: + """Provide contextual guidance for tool usage. + + Args: + agent: The agent instance + tool_use: The tool use object with name and arguments + **kwargs: Additional keyword arguments for steering evaluation + + Returns: + SteeringAction indicating how to guide the tool execution + """ + # Generate steering prompt + prompt = self.prompt_mapper.create_steering_prompt(self.steering_context, tool_use=tool_use) + + # Create isolated agent for steering evaluation (no shared conversation state) + from .....agent import Agent + + steering_agent = Agent(system_prompt=self.system_prompt, model=self.model or agent.model, callback_handler=None) + + # Get LLM decision + llm_result: _LLMSteering = cast( + _LLMSteering, steering_agent(prompt, structured_output_model=_LLMSteering).structured_output + ) + + # Convert LLM decision to steering action + match llm_result.decision: + case "proceed": + return Proceed(reason=llm_result.reason) + case "guide": + return Guide(reason=llm_result.reason) + case "interrupt": + return Interrupt(reason=llm_result.reason) + case _: + logger.warning("decision=<%s> | unknown llm decision, defaulting to proceed", llm_result.decision) # type: ignore[unreachable] + return Proceed(reason="Unknown LLM decision, defaulting to proceed") diff --git a/src/strands/vended_plugins/steering/handlers/llm/mappers.py b/src/strands/vended_plugins/steering/handlers/llm/mappers.py new file mode 100644 index 000000000..ade018d32 --- /dev/null +++ b/src/strands/vended_plugins/steering/handlers/llm/mappers.py @@ -0,0 +1,130 @@ +"""LLM steering prompt mappers for generating evaluation prompts.""" + +import json +from typing import Any, Protocol + +from .....types.tools import ToolUse +from ...core.context import SteeringContext + +# Agent SOP format - see https://github.com/strands-agents/agent-sop +_STEERING_PROMPT_TEMPLATE = """# Steering Evaluation + +## Overview + +You are a STEERING AGENT that evaluates a {action_type} that ANOTHER AGENT is attempting to make. +Your job is to provide contextual guidance to help the other agent navigate workflows effectively. +You act as a safety net that can intervene when patterns in the context data suggest the agent +should try a different approach or get human input. + +**YOUR ROLE:** +- Analyze context data for concerning patterns (repeated failures, inappropriate timing, etc.) +- Provide just-in-time guidance when the agent is going down an ineffective path +- Allow normal operations to proceed when context shows no issues + +**CRITICAL CONSTRAINTS:** +- Base decisions ONLY on the context data provided below +- Do NOT use external knowledge about domains, URLs, or tool purposes +- Do NOT make assumptions about what tools "should" or "shouldn't" do +- Focus ONLY on patterns in the context data + +## Context + +{context_str} + +### Understanding Ledger Tool States + +If the context includes a ledger with tool_calls, the "status" field indicates: + +- **"pending"**: The tool is CURRENTLY being evaluated by you (the steering agent). +This is NOT a duplicate call - it's the tool you're deciding whether to approve. +The tool has NOT started executing yet. +- **"success"**: The tool completed successfully in a previous turn +- **"error"**: The tool failed or was cancelled in a previous turn + +**IMPORTANT**: When you see a tool with status="pending" that matches the tool you're evaluating, +that IS the current tool being evaluated. +It is NOT already executing or a duplicate. + +## Event to Evaluate + +{event_description} + +## Steps + +### 1. Analyze the {action_type_title} + +Review ONLY the context data above. Look for patterns in the data that indicate: + +- Previous failures or successes with this tool +- Frequency of attempts +- Any relevant tracking information + +**Constraints:** +- You MUST base analysis ONLY on the provided context data +- You MUST NOT use external knowledge about tool purposes or domains +- You SHOULD identify patterns in the context data +- You MAY reference relevant context data to inform your decision + +### 2. Make Steering Decision + +**Constraints:** +- You MUST respond with exactly one of: "proceed", "guide", or "interrupt" +- You MUST base the decision ONLY on context data patterns +- Your reason will be shown to the AGENT as guidance + +**Decision Options:** +- "proceed" if context data shows no concerning patterns +- "guide" if context data shows patterns requiring intervention +- "interrupt" if context data shows patterns requiring human input +""" + + +class LLMPromptMapper(Protocol): + """Protocol for mapping context and events to LLM evaluation prompts.""" + + def create_steering_prompt( + self, steering_context: SteeringContext, tool_use: ToolUse | None = None, **kwargs: Any + ) -> str: + """Create steering prompt for LLM evaluation. + + Args: + steering_context: Steering context with populated data + tool_use: Tool use object for tool call events (None for other events) + **kwargs: Additional event data for other steering events + + Returns: + Formatted prompt string for LLM evaluation + """ + ... + + +class DefaultPromptMapper(LLMPromptMapper): + """Default prompt mapper for steering evaluation.""" + + def create_steering_prompt( + self, steering_context: SteeringContext, tool_use: ToolUse | None = None, **kwargs: Any + ) -> str: + """Create default steering prompt using Agent SOP structure. + + Uses Agent SOP format for structured, constraint-based prompts. + See: https://github.com/strands-agents/agent-sop + """ + context_str = ( + json.dumps(steering_context.data.get(), indent=2) if steering_context.data.get() else "No context available" + ) + + if tool_use: + event_description = ( + f"Tool: {tool_use['name']}\nArguments: {json.dumps(tool_use.get('input', {}), indent=2)}" + ) + action_type = "tool call" + else: + event_description = "General evaluation" + action_type = "action" + + return _STEERING_PROMPT_TEMPLATE.format( + action_type=action_type, + action_type_title=action_type.title(), + context_str=context_str, + event_description=event_description, + ) diff --git a/tests/conftest.py b/tests/conftest.py index f2a8909cb..1c0083e85 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -26,7 +26,9 @@ def moto_env(monkeypatch): monkeypatch.setenv("AWS_SECURITY_TOKEN", "test") monkeypatch.setenv("AWS_DEFAULT_REGION", "us-west-2") monkeypatch.delenv("OTEL_EXPORTER_OTLP_ENDPOINT", raising=False) + monkeypatch.delenv("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", raising=False) monkeypatch.delenv("OTEL_EXPORTER_OTLP_HEADERS", raising=False) + monkeypatch.delenv("LANGFUSE_BASE_URL", raising=False) @pytest.fixture diff --git a/tests/fixtures/mock_hook_provider.py b/tests/fixtures/mock_hook_provider.py index 091f44d06..cf17bb470 100644 --- a/tests/fixtures/mock_hook_provider.py +++ b/tests/fixtures/mock_hook_provider.py @@ -1,4 +1,5 @@ -from typing import Iterator, Literal, Tuple, Type +from collections.abc import Iterator +from typing import Literal from strands import Agent from strands.hooks import ( @@ -17,7 +18,7 @@ class MockHookProvider(HookProvider): - def __init__(self, event_types: list[Type] | Literal["all"]): + def __init__(self, event_types: list[type] | Literal["all"]): if event_types == "all": event_types = [ AgentInitializedEvent, @@ -37,7 +38,7 @@ def __init__(self, event_types: list[Type] | Literal["all"]): def event_types_received(self): return [type(event) for event in self.events_received] - def get_events(self) -> Tuple[int, Iterator[HookEvent]]: + def get_events(self) -> tuple[int, Iterator[HookEvent]]: return len(self.events_received), iter(self.events_received) def register_hooks(self, registry: HookRegistry) -> None: diff --git a/tests/fixtures/mock_multiagent_hook_provider.py b/tests/fixtures/mock_multiagent_hook_provider.py index 727d28a48..a89d5aca8 100644 --- a/tests/fixtures/mock_multiagent_hook_provider.py +++ b/tests/fixtures/mock_multiagent_hook_provider.py @@ -1,20 +1,19 @@ -from typing import Iterator, Literal, Tuple, Type +from collections.abc import Iterator +from typing import Literal -from strands.experimental.hooks.multiagent.events import ( +from strands.hooks import ( AfterMultiAgentInvocationEvent, AfterNodeCallEvent, BeforeNodeCallEvent, - MultiAgentInitializedEvent, -) -from strands.hooks import ( HookEvent, HookProvider, HookRegistry, + MultiAgentInitializedEvent, ) class MockMultiAgentHookProvider(HookProvider): - def __init__(self, event_types: list[Type] | Literal["all"]): + def __init__(self, event_types: list[type] | Literal["all"]): if event_types == "all": event_types = [ MultiAgentInitializedEvent, @@ -30,7 +29,7 @@ def __init__(self, event_types: list[Type] | Literal["all"]): def event_types_received(self): return [type(event) for event in self.events_received] - def get_events(self) -> Tuple[int, Iterator[HookEvent]]: + def get_events(self) -> tuple[int, Iterator[HookEvent]]: return len(self.events_received), iter(self.events_received) def register_hooks(self, registry: HookRegistry) -> None: diff --git a/tests/fixtures/mocked_model_provider.py b/tests/fixtures/mocked_model_provider.py index 24de958bc..f1c5cae77 100644 --- a/tests/fixtures/mocked_model_provider.py +++ b/tests/fixtures/mocked_model_provider.py @@ -1,5 +1,6 @@ import json -from typing import Any, AsyncGenerator, Iterable, Optional, Sequence, Type, TypedDict, TypeVar, Union +from collections.abc import AsyncGenerator, Iterable, Sequence +from typing import Any, TypedDict, TypeVar from pydantic import BaseModel @@ -25,7 +26,7 @@ class MockedModelProvider(Model): to stream mock responses as events. """ - def __init__(self, agent_responses: Sequence[Union[Message, RedactionMessage]]): + def __init__(self, agent_responses: Sequence[Message | RedactionMessage]): self.agent_responses = [*agent_responses] self.index = 0 @@ -33,7 +34,7 @@ def format_chunk(self, event: Any) -> StreamEvent: return event def format_request( - self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None + self, messages: Messages, tool_specs: list[ToolSpec] | None = None, system_prompt: str | None = None ) -> Any: return None @@ -45,9 +46,9 @@ def update_config(self, **model_config: Any) -> None: async def structured_output( self, - output_model: Type[T], + output_model: type[T], prompt: Messages, - system_prompt: Optional[str] = None, + system_prompt: str | None = None, **kwargs: Any, ) -> AsyncGenerator[Any, None]: pass @@ -55,9 +56,9 @@ async def structured_output( async def stream( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, - tool_choice: Optional[Any] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, + tool_choice: Any | None = None, *, system_prompt_content=None, **kwargs: Any, @@ -68,7 +69,7 @@ async def stream( self.index += 1 - def map_agent_message_to_events(self, agent_message: Union[Message, RedactionMessage]) -> Iterable[dict[str, Any]]: + def map_agent_message_to_events(self, agent_message: Message | RedactionMessage) -> Iterable[dict[str, Any]]: stop_reason: StopReason = "end_turn" yield {"messageStart": {"role": "assistant"}} if agent_message.get("redactedAssistantContent"): diff --git a/tests/strands/agent/conftest.py b/tests/strands/agent/conftest.py new file mode 100644 index 000000000..d3af90dc8 --- /dev/null +++ b/tests/strands/agent/conftest.py @@ -0,0 +1,22 @@ +"""Fixtures for agent tests.""" + +import asyncio +from unittest.mock import AsyncMock + +import pytest + + +@pytest.fixture +def mock_sleep(monkeypatch): + """Mock asyncio.sleep to avoid delays in tests and track sleep calls.""" + sleep_calls = [] + + async def _mock_sleep(delay): + sleep_calls.append(delay) + + mock = AsyncMock(side_effect=_mock_sleep) + monkeypatch.setattr(asyncio, "sleep", mock) + + # Return both the mock and the sleep_calls list for verification + mock.sleep_calls = sleep_calls + return mock diff --git a/tests/strands/agent/hooks/test_agent_events.py b/tests/strands/agent/hooks/test_agent_events.py index 7b189a5c6..1f09579b0 100644 --- a/tests/strands/agent/hooks/test_agent_events.py +++ b/tests/strands/agent/hooks/test_agent_events.py @@ -1,6 +1,6 @@ import asyncio import unittest.mock -from unittest.mock import ANY, MagicMock, call +from unittest.mock import ANY, AsyncMock, MagicMock, call, patch import pytest from pydantic import BaseModel @@ -34,9 +34,7 @@ async def streaming_tool(): @pytest.fixture def mock_sleep(): - with unittest.mock.patch.object( - strands.event_loop.event_loop.asyncio, "sleep", new_callable=unittest.mock.AsyncMock - ) as mock: + with patch.object(strands.event_loop._retry.asyncio, "sleep", new_callable=AsyncMock) as mock: yield mock @@ -86,7 +84,7 @@ async def test_stream_e2e_success(alist): mock_callback = unittest.mock.Mock() agent = Agent(model=mock_provider, tools=[async_tool, normal_tool, streaming_tool], callback_handler=mock_callback) - stream = agent.stream_async("Do the stuff", arg1=1013) + stream = agent.stream_async("Do the stuff", invocation_state={"arg1": 1013}) tool_config = { "toolChoice": {"auto": {}}, @@ -149,6 +147,7 @@ async def test_stream_e2e_success(alist): {"toolUse": {"input": {}, "name": "normal_tool", "toolUseId": "123"}}, ], "role": "assistant", + "metadata": ANY, } }, { @@ -207,6 +206,7 @@ async def test_stream_e2e_success(alist): {"toolUse": {"input": {}, "name": "async_tool", "toolUseId": "1234"}}, ], "role": "assistant", + "metadata": ANY, } }, { @@ -265,6 +265,7 @@ async def test_stream_e2e_success(alist): {"toolUse": {"input": {}, "name": "streaming_tool", "toolUseId": "12345"}}, ], "role": "assistant", + "metadata": ANY, } }, { @@ -309,11 +310,11 @@ async def test_stream_e2e_success(alist): }, {"event": {"contentBlockStop": {}}}, {"event": {"messageStop": {"stopReason": "end_turn"}}}, - {"message": {"content": [{"text": "I invoked the tools!"}], "role": "assistant"}}, + {"message": {"content": [{"text": "I invoked the tools!"}], "role": "assistant", "metadata": ANY}}, { "result": AgentResult( stop_reason="end_turn", - message={"content": [{"text": "I invoked the tools!"}], "role": "assistant"}, + message={"content": [{"text": "I invoked the tools!"}], "role": "assistant", "metadata": ANY}, metrics=ANY, state={}, ), @@ -346,7 +347,7 @@ async def test_stream_e2e_throttle_and_redact(alist, mock_sleep): mock_callback = unittest.mock.Mock() agent = Agent(model=model, tools=[normal_tool], callback_handler=mock_callback) - stream = agent.stream_async("Do the stuff", arg1=1013) + stream = agent.stream_async("Do the stuff", invocation_state={"arg1": 1013}) # Base object with common properties throttle_props = { @@ -359,8 +360,8 @@ async def test_stream_e2e_throttle_and_redact(alist, mock_sleep): {"arg1": 1013, "init_event_loop": True}, {"start": True}, {"start_event_loop": True}, + {"event_loop_throttled_delay": 4, **throttle_props}, {"event_loop_throttled_delay": 8, **throttle_props}, - {"event_loop_throttled_delay": 16, **throttle_props}, {"event": {"messageStart": {"role": "assistant"}}}, {"event": {"redactContent": {"redactUserContentMessage": "BLOCKED!"}}}, {"event": {"contentBlockStart": {"start": {}}}}, @@ -373,11 +374,11 @@ async def test_stream_e2e_throttle_and_redact(alist, mock_sleep): }, {"event": {"contentBlockStop": {}}}, {"event": {"messageStop": {"stopReason": "guardrail_intervened"}}}, - {"message": {"content": [{"text": "INPUT BLOCKED!"}], "role": "assistant"}}, + {"message": {"content": [{"text": "INPUT BLOCKED!"}], "role": "assistant", "metadata": ANY}}, { "result": AgentResult( stop_reason="guardrail_intervened", - message={"content": [{"text": "INPUT BLOCKED!"}], "role": "assistant"}, + message={"content": [{"text": "INPUT BLOCKED!"}], "role": "assistant", "metadata": ANY}, metrics=ANY, state={}, ), @@ -444,6 +445,7 @@ async def test_stream_e2e_reasoning_redacted_content(alist): {"text": "Response with redacted reasoning"}, ], "role": "assistant", + "metadata": ANY, } }, { @@ -455,6 +457,7 @@ async def test_stream_e2e_reasoning_redacted_content(alist): {"text": "Response with redacted reasoning"}, ], "role": "assistant", + "metadata": ANY, }, metrics=ANY, state={}, @@ -494,7 +497,7 @@ async def test_event_loop_cycle_text_response_throttling_early_end( # Because we're throwing an exception, we manually collect the items here tru_events = [] - stream = agent.stream_async("Do the stuff", arg1=1013) + stream = agent.stream_async("Do the stuff", invocation_state={"arg1": 1013}) async for event in stream: tru_events.append(event) @@ -508,11 +511,11 @@ async def test_event_loop_cycle_text_response_throttling_early_end( {"init_event_loop": True, "arg1": 1013}, {"start": True}, {"start_event_loop": True}, + {"event_loop_throttled_delay": 4, **common_props}, {"event_loop_throttled_delay": 8, **common_props}, {"event_loop_throttled_delay": 16, **common_props}, {"event_loop_throttled_delay": 32, **common_props}, {"event_loop_throttled_delay": 64, **common_props}, - {"event_loop_throttled_delay": 128, **common_props}, {"force_stop": True, "force_stop_reason": "ThrottlingException | ConverseStream"}, ] @@ -527,6 +530,7 @@ async def test_event_loop_cycle_text_response_throttling_early_end( assert typed_events == [] +@pytest.mark.filterwarnings("ignore:Agent.structured_output_async method is deprecated:DeprecationWarning") @pytest.mark.asyncio async def test_structured_output(agenerator): # we use bedrock here as it uses the tool implementation diff --git a/tests/strands/agent/hooks/test_events.py b/tests/strands/agent/hooks/test_events.py index 8bbd89c17..6771774d3 100644 --- a/tests/strands/agent/hooks/test_events.py +++ b/tests/strands/agent/hooks/test_events.py @@ -2,14 +2,18 @@ import pytest +from strands.agent.agent_result import AgentResult from strands.hooks import ( AfterInvocationEvent, + AfterModelCallEvent, AfterToolCallEvent, AgentInitializedEvent, BeforeInvocationEvent, + BeforeModelCallEvent, BeforeToolCallEvent, MessageAddedEvent, ) +from strands.types.content import Message, Messages from strands.types.tools import ToolResult, ToolUse @@ -18,6 +22,11 @@ def agent(): return Mock() +@pytest.fixture +def sample_messages() -> Messages: + return [{"role": "user", "content": [{"text": "Hello, agent!"}]}] + + @pytest.fixture def tool(): tool = Mock() @@ -50,6 +59,11 @@ def start_request_event(agent): return BeforeInvocationEvent(agent=agent) +@pytest.fixture +def start_request_event_with_messages(agent, sample_messages): + return BeforeInvocationEvent(agent=agent, messages=sample_messages) + + @pytest.fixture def messaged_added_event(agent): return MessageAddedEvent(agent=agent, message=Mock()) @@ -138,3 +152,130 @@ def test_after_tool_invocation_event_cannot_write_properties(after_tool_event): after_tool_event.invocation_state = {} with pytest.raises(AttributeError, match="Property exception is not writable"): after_tool_event.exception = Exception("test") + + +def test_after_invocation_event_properties_not_writable(agent): + """Test that properties are not writable after initialization.""" + mock_message: Message = {"role": "assistant", "content": [{"text": "test"}]} + mock_result = AgentResult( + stop_reason="end_turn", + message=mock_message, + metrics={}, + state={}, + ) + + event = AfterInvocationEvent(agent=agent, result=None) + + with pytest.raises(AttributeError, match="Property result is not writable"): + event.result = mock_result + + with pytest.raises(AttributeError, match="Property agent is not writable"): + event.agent = Mock() + + with pytest.raises(AttributeError, match="Property invocation_state is not writable"): + event.invocation_state = {} + + +def test_invocation_state_is_available_in_invocation_events(agent): + """Test that invocation_state is accessible in BeforeInvocationEvent and AfterInvocationEvent.""" + invocation_state = {"session_id": "test-123", "request_id": "req-456"} + + before_event = BeforeInvocationEvent(agent=agent, invocation_state=invocation_state) + assert before_event.invocation_state == invocation_state + assert before_event.invocation_state["session_id"] == "test-123" + assert before_event.invocation_state["request_id"] == "req-456" + + after_event = AfterInvocationEvent(agent=agent, invocation_state=invocation_state, result=None) + assert after_event.invocation_state == invocation_state + assert after_event.invocation_state["session_id"] == "test-123" + assert after_event.invocation_state["request_id"] == "req-456" + + +def test_invocation_state_is_available_in_model_call_events(agent): + """Test that invocation_state is accessible in BeforeModelCallEvent and AfterModelCallEvent.""" + invocation_state = {"session_id": "test-123", "request_id": "req-456"} + + before_event = BeforeModelCallEvent(agent=agent, invocation_state=invocation_state) + assert before_event.invocation_state == invocation_state + assert before_event.invocation_state["session_id"] == "test-123" + assert before_event.invocation_state["request_id"] == "req-456" + + after_event = AfterModelCallEvent(agent=agent, invocation_state=invocation_state) + assert after_event.invocation_state == invocation_state + assert after_event.invocation_state["session_id"] == "test-123" + assert after_event.invocation_state["request_id"] == "req-456" + + +def test_before_invocation_event_messages_default_none(agent): + """Test that BeforeInvocationEvent.messages defaults to None for backward compatibility.""" + event = BeforeInvocationEvent(agent=agent) + assert event.messages is None + + +def test_before_invocation_event_messages_writable(agent, sample_messages): + """Test that BeforeInvocationEvent.messages can be modified in-place for guardrail redaction.""" + event = BeforeInvocationEvent(agent=agent, messages=sample_messages) + + # Should be able to modify the messages list in-place + event.messages[0]["content"] = [{"text": "[REDACTED]"}] + assert event.messages[0]["content"] == [{"text": "[REDACTED]"}] + + # Should be able to reassign messages entirely + new_messages: Messages = [{"role": "user", "content": [{"text": "Different message"}]}] + event.messages = new_messages + assert event.messages == new_messages + + +def test_before_invocation_event_agent_not_writable(start_request_event_with_messages): + """Test that BeforeInvocationEvent.agent is not writable.""" + with pytest.raises(AttributeError, match="Property agent is not writable"): + start_request_event_with_messages.agent = Mock() + + +def test_after_invocation_event_resume_defaults_to_none(agent): + """Test that AfterInvocationEvent.resume defaults to None.""" + event = AfterInvocationEvent(agent=agent, result=None) + assert event.resume is None + + +def test_after_invocation_event_resume_is_writable(agent): + """Test that AfterInvocationEvent.resume can be set by hooks.""" + event = AfterInvocationEvent(agent=agent, result=None) + event.resume = "continue with this input" + assert event.resume == "continue with this input" + + +def test_after_invocation_event_resume_accepts_various_input_types(agent): + """Test that resume accepts all AgentInput types.""" + event = AfterInvocationEvent(agent=agent, result=None) + + # String input + event.resume = "hello" + assert event.resume == "hello" + + # Content block list + event.resume = [{"text": "hello"}] + assert event.resume == [{"text": "hello"}] + + # None to stop + event.resume = None + assert event.resume is None + + +def test_before_model_call_event_projected_input_tokens_default(agent): + """Test that projected_input_tokens defaults to None.""" + event = BeforeModelCallEvent(agent=agent) + assert event.projected_input_tokens is None + + +def test_before_model_call_event_projected_input_tokens_set(agent): + """Test that projected_input_tokens can be set at construction.""" + event = BeforeModelCallEvent(agent=agent, projected_input_tokens=500) + assert event.projected_input_tokens == 500 + + +def test_before_model_call_event_projected_input_tokens_not_writable(agent): + """Test that projected_input_tokens is not writable after construction.""" + event = BeforeModelCallEvent(agent=agent, projected_input_tokens=500) + with pytest.raises(AttributeError, match="Property projected_input_tokens is not writable"): + event.projected_input_tokens = 1000 diff --git a/tests/strands/agent/hooks/test_hook_registry.py b/tests/strands/agent/hooks/test_hook_registry.py index ad1415f22..12b5af42c 100644 --- a/tests/strands/agent/hooks/test_hook_registry.py +++ b/tests/strands/agent/hooks/test_hook_registry.py @@ -1,6 +1,5 @@ import unittest.mock from dataclasses import dataclass -from typing import List from unittest.mock import MagicMock, Mock import pytest @@ -139,7 +138,7 @@ async def test_invoke_callbacks_async_no_registered_callbacks(hook_registry, nor @pytest.mark.asyncio async def test_invoke_callbacks_async_after_event(hook_registry, after_event): """Test that invoke_callbacks_async calls callbacks in reverse order for after events.""" - call_order: List[str] = [] + call_order: list[str] = [] def callback1(_event): call_order.append("callback1") diff --git a/tests/strands/agent/test_a2a_agent.py b/tests/strands/agent/test_a2a_agent.py new file mode 100644 index 000000000..9c3be7917 --- /dev/null +++ b/tests/strands/agent/test_a2a_agent.py @@ -0,0 +1,876 @@ +"""Tests for A2AAgent class.""" + +import warnings +from contextlib import asynccontextmanager +from unittest.mock import AsyncMock, MagicMock, patch +from uuid import uuid4 + +import pytest +from a2a.client import ClientConfig +from a2a.types import AgentCard, Message, Part, Role, TaskState, TextPart + +from strands.agent.a2a_agent import A2AAgent +from strands.agent.agent_result import AgentResult + + +@pytest.fixture +def mock_agent_card(): + """Mock AgentCard for testing.""" + return AgentCard( + name="test-agent", + description="Test agent", + url="http://localhost:8000", + version="1.0.0", + capabilities={}, + default_input_modes=["text/plain"], + default_output_modes=["text/plain"], + skills=[], + ) + + +@pytest.fixture +def a2a_agent(): + """Create A2AAgent instance for testing.""" + return A2AAgent(endpoint="http://localhost:8000") + + +@pytest.fixture +def mock_httpx_client(): + """Create a mock httpx.AsyncClient that works as async context manager.""" + mock_client = AsyncMock() + mock_client.__aenter__.return_value = mock_client + mock_client.__aexit__.return_value = None + return mock_client + + +@asynccontextmanager +async def mock_a2a_client_context(send_message_func): + """Helper to create mock A2A client setup for _send_message tests.""" + mock_client = MagicMock() + mock_client.send_message = send_message_func + with patch("strands.agent.a2a_agent.httpx.AsyncClient") as mock_httpx_class: + mock_httpx = AsyncMock() + mock_httpx.__aenter__.return_value = mock_httpx + mock_httpx.__aexit__.return_value = None + mock_httpx_class.return_value = mock_httpx + with patch("strands.agent.a2a_agent.ClientFactory") as mock_factory_class: + mock_factory = MagicMock() + mock_factory.create.return_value = mock_client + mock_factory_class.return_value = mock_factory + yield mock_httpx_class, mock_factory_class + + +# === Init Tests === + + +def test_init_with_defaults(): + """Test initialization with default parameters.""" + agent = A2AAgent(endpoint="http://localhost:8000") + assert agent.endpoint == "http://localhost:8000" + assert agent.timeout == 300 + assert agent._agent_card is None + assert agent.name is None + assert agent.description is None + + +def test_init_with_name_and_description(): + """Test initialization with custom name and description.""" + agent = A2AAgent(endpoint="http://localhost:8000", name="my-agent", description="My custom agent") + assert agent.name == "my-agent" + assert agent.description == "My custom agent" + + +def test_init_with_custom_timeout(): + """Test initialization with custom timeout.""" + agent = A2AAgent(endpoint="http://localhost:8000", timeout=600) + assert agent.timeout == 600 + + +def test_init_with_client_config(): + """Test initialization with client_config.""" + config = ClientConfig() + agent = A2AAgent(endpoint="http://localhost:8000", client_config=config) + assert agent._client_config is config + + +def test_init_with_external_a2a_client_factory(): + """Test initialization with external A2A client factory emits deprecation warning.""" + external_factory = MagicMock() + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + agent = A2AAgent(endpoint="http://localhost:8000", a2a_client_factory=external_factory) + assert agent._a2a_client_factory is external_factory + assert len(w) == 1 + assert issubclass(w[0].category, DeprecationWarning) + assert "a2a_client_factory is deprecated" in str(w[0].message) + assert "client_config" in str(w[0].message) + + +def test_init_with_both_client_config_and_factory_raises(): + """Test that providing both client_config and factory raises ValueError.""" + config = ClientConfig() + factory = MagicMock() + with pytest.raises(ValueError, match="Cannot provide both client_config and a2a_client_factory"): + A2AAgent(endpoint="http://localhost:8000", client_config=config, a2a_client_factory=factory) + + +def test_init_no_asyncio_lock(): + """Test that A2AAgent does not create an asyncio.Lock in __init__.""" + agent = A2AAgent(endpoint="http://localhost:8000") + assert not hasattr(agent, "_card_lock") + + +# === Card Resolution Tests === + + +@pytest.mark.asyncio +async def test_get_agent_card(a2a_agent, mock_agent_card, mock_httpx_client): + """Test agent card discovery.""" + with patch("strands.agent.a2a_agent.httpx.AsyncClient", return_value=mock_httpx_client): + with patch("strands.agent.a2a_agent.A2ACardResolver") as mock_resolver_class: + mock_resolver = AsyncMock() + mock_resolver.get_agent_card = AsyncMock(return_value=mock_agent_card) + mock_resolver_class.return_value = mock_resolver + + card = await a2a_agent.get_agent_card() + + assert card == mock_agent_card + assert a2a_agent._agent_card == mock_agent_card + + +@pytest.mark.asyncio +async def test_get_agent_card_cached(a2a_agent, mock_agent_card): + """Test that agent card is cached after first discovery.""" + a2a_agent._agent_card = mock_agent_card + + card = await a2a_agent.get_agent_card() + + assert card == mock_agent_card + + +@pytest.mark.asyncio +async def test_get_agent_card_populates_name_and_description(mock_agent_card, mock_httpx_client): + """Test that agent card populates name and description if not set.""" + agent = A2AAgent(endpoint="http://localhost:8000") + + with patch("strands.agent.a2a_agent.httpx.AsyncClient", return_value=mock_httpx_client): + with patch("strands.agent.a2a_agent.A2ACardResolver") as mock_resolver_class: + mock_resolver = AsyncMock() + mock_resolver.get_agent_card = AsyncMock(return_value=mock_agent_card) + mock_resolver_class.return_value = mock_resolver + + await agent.get_agent_card() + + assert agent.name == mock_agent_card.name + assert agent.description == mock_agent_card.description + + +@pytest.mark.asyncio +async def test_get_agent_card_preserves_custom_name_and_description(mock_agent_card, mock_httpx_client): + """Test that custom name and description are not overridden by agent card.""" + agent = A2AAgent(endpoint="http://localhost:8000", name="custom-name", description="Custom description") + + with patch("strands.agent.a2a_agent.httpx.AsyncClient", return_value=mock_httpx_client): + with patch("strands.agent.a2a_agent.A2ACardResolver") as mock_resolver_class: + mock_resolver = AsyncMock() + mock_resolver.get_agent_card = AsyncMock(return_value=mock_agent_card) + mock_resolver_class.return_value = mock_resolver + + await agent.get_agent_card() + + assert agent.name == "custom-name" + assert agent.description == "Custom description" + + +@pytest.mark.asyncio +async def test_get_agent_card_handles_empty_string_name_and_description(mock_httpx_client): + """Test that empty string name/description from card are preserved (not treated as None).""" + mock_card = MagicMock(spec=AgentCard) + mock_card.name = "" + mock_card.description = "" + + agent = A2AAgent(endpoint="http://localhost:8000") + + with patch("strands.agent.a2a_agent.httpx.AsyncClient", return_value=mock_httpx_client): + with patch("strands.agent.a2a_agent.A2ACardResolver") as mock_resolver_class: + mock_resolver = AsyncMock() + mock_resolver.get_agent_card = AsyncMock(return_value=mock_card) + mock_resolver_class.return_value = mock_resolver + + await agent.get_agent_card() + + # Empty strings should be set (not treated as falsy/None) + assert agent.name == "" + assert agent.description == "" + + +@pytest.mark.asyncio +async def test_get_agent_card_with_client_config_uses_auth_client(): + """Test that client_config's httpx_client is used for card resolution (fixes auth bug).""" + mock_auth_client = MagicMock() + config = ClientConfig(httpx_client=mock_auth_client) + + mock_card = MagicMock(spec=AgentCard) + mock_card.name = "test" + mock_card.description = "test" + + agent = A2AAgent(endpoint="http://localhost:8000", client_config=config) + + resolver_httpx_client = None + + def track_resolver_init(*, httpx_client, base_url): + nonlocal resolver_httpx_client + resolver_httpx_client = httpx_client + mock_resolver = AsyncMock() + mock_resolver.get_agent_card = AsyncMock(return_value=mock_card) + return mock_resolver + + with patch("strands.agent.a2a_agent.A2ACardResolver", side_effect=track_resolver_init): + await agent.get_agent_card() + + # CRITICAL: Verify the authenticated client was used for card resolution + assert resolver_httpx_client is mock_auth_client, ( + "Bug not fixed: authenticated httpx client was not used for card resolution" + ) + + +@pytest.mark.asyncio +async def test_get_agent_card_without_client_config_uses_default_httpx(mock_httpx_client): + """Test that card resolution uses bare httpx when no client_config is provided.""" + mock_card = MagicMock(spec=AgentCard) + mock_card.name = "test" + mock_card.description = "test" + + agent = A2AAgent(endpoint="http://localhost:8000") + + with patch("strands.agent.a2a_agent.httpx.AsyncClient", return_value=mock_httpx_client) as mock_httpx_class: + with patch("strands.agent.a2a_agent.A2ACardResolver") as mock_resolver_class: + mock_resolver = AsyncMock() + mock_resolver.get_agent_card = AsyncMock(return_value=mock_card) + mock_resolver_class.return_value = mock_resolver + + await agent.get_agent_card() + + # Should use bare httpx with timeout + mock_httpx_class.assert_called_once_with(timeout=300) + + +@pytest.mark.asyncio +async def test_get_agent_card_factory_only_uses_default_httpx(mock_httpx_client): + """Test that deprecated factory without client_config still uses bare httpx for card resolution.""" + mock_card = MagicMock(spec=AgentCard) + mock_card.name = "test" + mock_card.description = "test" + + mock_factory = MagicMock() + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + agent = A2AAgent(endpoint="http://localhost:8000", a2a_client_factory=mock_factory) + + with patch("strands.agent.a2a_agent.httpx.AsyncClient", return_value=mock_httpx_client) as mock_httpx_class: + with patch("strands.agent.a2a_agent.A2ACardResolver") as mock_resolver_class: + mock_resolver = AsyncMock() + mock_resolver.get_agent_card = AsyncMock(return_value=mock_card) + mock_resolver_class.return_value = mock_resolver + + await agent.get_agent_card() + + # Factory alone does NOT provide auth for card resolution — uses bare httpx + mock_httpx_class.assert_called_once_with(timeout=300) + + +@pytest.mark.asyncio +async def test_get_agent_card_client_config_without_httpx_uses_default(mock_httpx_client): + """Test that client_config without httpx_client falls through to managed httpx (same as no config).""" + mock_card = MagicMock(spec=AgentCard) + mock_card.name = "test" + mock_card.description = "test" + + config = ClientConfig(polling=True) # No httpx_client + agent = A2AAgent(endpoint="http://localhost:8000", client_config=config) + + with patch("strands.agent.a2a_agent.httpx.AsyncClient", return_value=mock_httpx_client) as mock_httpx_class: + with patch("strands.agent.a2a_agent.A2ACardResolver") as mock_resolver_class: + mock_resolver = AsyncMock() + mock_resolver.get_agent_card = AsyncMock(return_value=mock_card) + mock_resolver_class.return_value = mock_resolver + + await agent.get_agent_card() + + # Should use managed httpx with timeout (same as no config path) + mock_httpx_class.assert_called_once_with(timeout=300) + + +# === Client Creation Tests === + + +@pytest.mark.asyncio +async def test_get_a2a_client_with_client_config_preserves_user_settings(mock_agent_card): + """Test that _get_a2a_client preserves all user ClientConfig settings via dataclasses.replace.""" + mock_auth_client = MagicMock() + config = ClientConfig( + httpx_client=mock_auth_client, + streaming=False, # user set this to False + polling=True, + supported_transports=["jsonrpc"], + ) + + agent = A2AAgent(endpoint="http://localhost:8000", client_config=config) + + with patch.object(agent, "get_agent_card", return_value=mock_agent_card): + with patch("strands.agent.a2a_agent.ClientFactory") as mock_factory_class: + mock_factory = MagicMock() + mock_factory.create.return_value = MagicMock() + mock_factory_class.return_value = mock_factory + + async with agent._get_a2a_client(): + pass + + # Verify factory was created with a config that preserves user settings + mock_factory_class.assert_called_once() + created_config = mock_factory_class.call_args[0][0] + assert created_config.httpx_client is mock_auth_client + assert created_config.streaming is True # overridden to True + assert created_config.polling is True # preserved + assert created_config.supported_transports == ["jsonrpc"] # preserved + + +@pytest.mark.asyncio +async def test_get_a2a_client_with_client_config_does_not_mutate_original(mock_agent_card): + """Test that _get_a2a_client does not mutate the original client_config.""" + config = ClientConfig(streaming=False) + agent = A2AAgent(endpoint="http://localhost:8000", client_config=config) + + with patch.object(agent, "get_agent_card", return_value=mock_agent_card): + with patch("strands.agent.a2a_agent.ClientFactory") as mock_factory_class: + mock_factory = MagicMock() + mock_factory.create.return_value = MagicMock() + mock_factory_class.return_value = mock_factory + + async with agent._get_a2a_client(): + pass + + # Original config should NOT be mutated + assert config.streaming is False + + +@pytest.mark.asyncio +async def test_get_a2a_client_config_without_httpx_delegates_to_factory(mock_agent_card): + """Test that _get_a2a_client delegates to ClientFactory when config has no httpx_client. + + ClientFactory handles creating a default httpx client internally. We just pass + the config with streaming=True and let the factory do its job. + """ + config = ClientConfig(polling=True, supported_transports=["jsonrpc"]) + agent = A2AAgent(endpoint="http://localhost:8000", client_config=config, timeout=600) + + with patch.object(agent, "get_agent_card", return_value=mock_agent_card): + with patch("strands.agent.a2a_agent.ClientFactory") as mock_factory_class: + mock_factory = MagicMock() + mock_factory.create.return_value = MagicMock() + mock_factory_class.return_value = mock_factory + + async with agent._get_a2a_client(): + pass + + # Should pass config directly to ClientFactory — factory handles httpx defaults + created_config = mock_factory_class.call_args[0][0] + assert created_config.streaming is True + assert created_config.polling is True + assert created_config.supported_transports == ["jsonrpc"] + assert created_config.httpx_client is None # factory handles default + + +@pytest.mark.asyncio +async def test_send_message_uses_provided_factory(mock_agent_card): + """Test _send_message uses provided factory instead of creating per-call client.""" + external_factory = MagicMock() + mock_a2a_client = MagicMock() + + async def mock_send_message(*args, **kwargs): + yield MagicMock() + + mock_a2a_client.send_message = mock_send_message + external_factory.create.return_value = mock_a2a_client + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + agent = A2AAgent(endpoint="http://localhost:8000", a2a_client_factory=external_factory) + + with patch.object(agent, "get_agent_card", return_value=mock_agent_card): + # Consume the async iterator + async for _ in agent._send_message("Hello"): + pass + + external_factory.create.assert_called_once_with(mock_agent_card) + + +@pytest.mark.asyncio +async def test_send_message_uses_client_config_httpx_client(mock_agent_card): + """Test _send_message uses client_config's httpx_client for client creation.""" + mock_auth_client = MagicMock() + config = ClientConfig(httpx_client=mock_auth_client) + + agent = A2AAgent(endpoint="http://localhost:8000", client_config=config) + + mock_a2a_client = MagicMock() + + async def mock_send(*args, **kwargs): + yield MagicMock() + + mock_a2a_client.send_message = mock_send + + with patch.object(agent, "get_agent_card", return_value=mock_agent_card): + with patch("strands.agent.a2a_agent.ClientFactory") as mock_factory_class: + mock_factory = MagicMock() + mock_factory.create.return_value = mock_a2a_client + mock_factory_class.return_value = mock_factory + + async for _ in agent._send_message("Hello"): + pass + + # Verify ClientFactory was created with config containing the auth client + mock_factory_class.assert_called_once() + call_args = mock_factory_class.call_args + created_config = call_args[0][0] + assert created_config.httpx_client is mock_auth_client + + +@pytest.mark.asyncio +async def test_send_message_creates_per_call_client(a2a_agent, mock_agent_card): + """Test _send_message creates a fresh httpx client for each call when no factory provided.""" + mock_response = Message( + message_id=uuid4().hex, + role=Role.agent, + parts=[Part(TextPart(kind="text", text="Response"))], + ) + + async def mock_send_message(*args, **kwargs): + yield mock_response + + with patch.object(a2a_agent, "get_agent_card", return_value=mock_agent_card): + async with mock_a2a_client_context(mock_send_message) as (mock_httpx_class, _): + # Consume the async iterator + async for _ in a2a_agent._send_message("Hello"): + pass + + # Verify httpx client was created with timeout + mock_httpx_class.assert_called_once_with(timeout=300) + + +@pytest.mark.asyncio +async def test_get_a2a_client_no_config_creates_managed_httpx(): + """Test that _get_a2a_client creates a managed httpx client when no config provided.""" + mock_card = MagicMock(spec=AgentCard) + agent = A2AAgent(endpoint="http://localhost:8000", timeout=600) + + with patch.object(agent, "get_agent_card", return_value=mock_card): + with patch("strands.agent.a2a_agent.httpx.AsyncClient") as mock_httpx_class: + mock_httpx = AsyncMock() + mock_httpx.__aenter__.return_value = mock_httpx + mock_httpx.__aexit__.return_value = None + mock_httpx_class.return_value = mock_httpx + + with patch("strands.agent.a2a_agent.ClientFactory") as mock_factory_class: + mock_factory = MagicMock() + mock_factory.create.return_value = MagicMock() + mock_factory_class.return_value = mock_factory + + async with agent._get_a2a_client(): + pass + + # Verify httpx client was created with agent timeout + mock_httpx_class.assert_called_once_with(timeout=600) + # Verify ClientFactory was called with streaming=True + created_config = mock_factory_class.call_args[0][0] + assert created_config.streaming is True + + +# === Invoke/Stream Tests === + + +@pytest.mark.asyncio +async def test_invoke_async_success(a2a_agent, mock_agent_card): + """Test successful async invocation.""" + mock_response = Message( + message_id=uuid4().hex, + role=Role.agent, + parts=[Part(TextPart(kind="text", text="Response"))], + ) + + async def mock_send_message(*args, **kwargs): + yield mock_response + + with patch.object(a2a_agent, "get_agent_card", return_value=mock_agent_card): + async with mock_a2a_client_context(mock_send_message): + result = await a2a_agent.invoke_async("Hello") + + assert isinstance(result, AgentResult) + assert result.message["content"][0]["text"] == "Response" + + +@pytest.mark.asyncio +async def test_invoke_async_no_prompt(a2a_agent): + """Test that invoke_async raises ValueError when prompt is None.""" + with pytest.raises(ValueError, match="prompt is required"): + await a2a_agent.invoke_async(None) + + +@pytest.mark.asyncio +async def test_invoke_async_no_response(a2a_agent, mock_agent_card): + """Test that invoke_async raises RuntimeError when no response received.""" + + async def mock_send_message(*args, **kwargs): + return + yield # Make it an async generator + + with patch.object(a2a_agent, "get_agent_card", return_value=mock_agent_card): + async with mock_a2a_client_context(mock_send_message): + with pytest.raises(RuntimeError, match="No response received"): + await a2a_agent.invoke_async("Hello") + + +def test_call_sync(a2a_agent): + """Test synchronous call method.""" + mock_result = AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "Response"}]}, + metrics=MagicMock(), + state={}, + ) + + with patch("strands.agent.a2a_agent.run_async") as mock_run_async: + mock_run_async.return_value = mock_result + + result = a2a_agent("Hello") + + assert result == mock_result + mock_run_async.assert_called_once() + + +@pytest.mark.asyncio +async def test_stream_async_success(a2a_agent, mock_agent_card): + """Test successful async streaming.""" + mock_response = Message( + message_id=uuid4().hex, + role=Role.agent, + parts=[Part(TextPart(kind="text", text="Response"))], + ) + + async def mock_send_message(*args, **kwargs): + yield mock_response + + with patch.object(a2a_agent, "get_agent_card", return_value=mock_agent_card): + async with mock_a2a_client_context(mock_send_message): + events = [] + async for event in a2a_agent.stream_async("Hello"): + events.append(event) + + assert len(events) == 2 + # First event is A2A stream event + assert events[0]["type"] == "a2a_stream" + assert events[0]["event"] == mock_response + # Final event is AgentResult + assert "result" in events[1] + assert isinstance(events[1]["result"], AgentResult) + assert events[1]["result"].message["content"][0]["text"] == "Response" + + +@pytest.mark.asyncio +async def test_stream_async_no_prompt(a2a_agent): + """Test that stream_async raises ValueError when prompt is None.""" + with pytest.raises(ValueError, match="prompt is required"): + async for _ in a2a_agent.stream_async(None): + pass + + +# === Complete Event Tests === + + +def test_is_complete_event_message(a2a_agent): + """Test _is_complete_event returns True for Message.""" + mock_message = MagicMock(spec=Message) + + assert a2a_agent._is_complete_event(mock_message) is True + + +def test_is_complete_event_tuple_with_none_update(a2a_agent): + """Test _is_complete_event returns True for tuple with None update event.""" + mock_task = MagicMock() + + assert a2a_agent._is_complete_event((mock_task, None)) is True + + +def test_is_complete_event_artifact_last_chunk(a2a_agent): + """Test _is_complete_event handles TaskArtifactUpdateEvent last_chunk flag.""" + from a2a.types import TaskArtifactUpdateEvent + + mock_task = MagicMock() + + # last_chunk=True -> complete + event_complete = MagicMock(spec=TaskArtifactUpdateEvent) + event_complete.last_chunk = True + assert a2a_agent._is_complete_event((mock_task, event_complete)) is True + + # last_chunk=False -> not complete + event_incomplete = MagicMock(spec=TaskArtifactUpdateEvent) + event_incomplete.last_chunk = False + assert a2a_agent._is_complete_event((mock_task, event_incomplete)) is False + + # last_chunk=None -> not complete + event_none = MagicMock(spec=TaskArtifactUpdateEvent) + event_none.last_chunk = None + assert a2a_agent._is_complete_event((mock_task, event_none)) is False + + +def test_is_complete_event_status_update(a2a_agent): + """Test _is_complete_event handles TaskStatusUpdateEvent state.""" + from a2a.types import TaskState, TaskStatusUpdateEvent + + mock_task = MagicMock() + + # completed state -> complete + event_completed = MagicMock(spec=TaskStatusUpdateEvent) + event_completed.status = MagicMock() + event_completed.status.state = TaskState.completed + assert a2a_agent._is_complete_event((mock_task, event_completed)) is True + + # working state -> not complete + event_working = MagicMock(spec=TaskStatusUpdateEvent) + event_working.status = MagicMock() + event_working.status.state = TaskState.working + assert a2a_agent._is_complete_event((mock_task, event_working)) is False + + # no status -> not complete + event_no_status = MagicMock(spec=TaskStatusUpdateEvent) + event_no_status.status = None + assert a2a_agent._is_complete_event((mock_task, event_no_status)) is False + + +def test_is_complete_event_unknown_type(a2a_agent): + """Test _is_complete_event returns False for unknown event types.""" + assert a2a_agent._is_complete_event("unknown") is False + + +@pytest.mark.asyncio +async def test_stream_async_tracks_complete_events(a2a_agent, mock_agent_card): + """Test stream_async uses last complete event for final result.""" + from a2a.types import TaskState, TaskStatusUpdateEvent + + mock_task = MagicMock() + mock_task.artifacts = None + + # First event: incomplete + incomplete_event = MagicMock(spec=TaskStatusUpdateEvent) + incomplete_event.status = MagicMock() + incomplete_event.status.state = TaskState.working + incomplete_event.status.message = None + + # Second event: complete + complete_event = MagicMock(spec=TaskStatusUpdateEvent) + complete_event.status = MagicMock() + complete_event.status.state = TaskState.completed + complete_event.status.message = MagicMock() + complete_event.status.message.parts = [] + + async def mock_send_message(*args, **kwargs): + yield (mock_task, incomplete_event) + yield (mock_task, complete_event) + + with patch.object(a2a_agent, "get_agent_card", return_value=mock_agent_card): + async with mock_a2a_client_context(mock_send_message): + events = [] + async for event in a2a_agent.stream_async("Hello"): + events.append(event) + + # Should have 2 stream events + 1 result event + assert len(events) == 3 + assert "result" in events[2] + + +@pytest.mark.asyncio +async def test_stream_async_falls_back_to_last_event(a2a_agent, mock_agent_card): + """Test stream_async falls back to last event when no complete event.""" + from a2a.types import TaskState, TaskStatusUpdateEvent + + mock_task = MagicMock() + mock_task.artifacts = None + + incomplete_event = MagicMock(spec=TaskStatusUpdateEvent) + incomplete_event.status = MagicMock() + incomplete_event.status.state = TaskState.working + incomplete_event.status.message = None + + async def mock_send_message(*args, **kwargs): + yield (mock_task, incomplete_event) + + with patch.object(a2a_agent, "get_agent_card", return_value=mock_agent_card): + async with mock_a2a_client_context(mock_send_message): + events = [] + async for event in a2a_agent.stream_async("Hello"): + events.append(event) + + # Should have 1 stream event + 1 result event (falls back to last) + assert len(events) == 2 + assert "result" in events[1] + + +# ========================================================================= +# NEW TESTS: Client-side lifecycle state handling +# ========================================================================= + + +def test_is_complete_event_failed_state(a2a_agent): + """Test that failed state is recognized as complete.""" + from unittest.mock import MagicMock + + from a2a.types import TaskState, TaskStatusUpdateEvent + + task = MagicMock() + status = MagicMock() + status.state = TaskState.failed + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + assert a2a_agent._is_complete_event((task, update_event)) is True + + +def test_is_complete_event_canceled_state(a2a_agent): + """Test that canceled state is recognized as complete.""" + from unittest.mock import MagicMock + + from a2a.types import TaskState, TaskStatusUpdateEvent + + task = MagicMock() + status = MagicMock() + status.state = TaskState.canceled + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + assert a2a_agent._is_complete_event((task, update_event)) is True + + +def test_is_complete_event_rejected_state(a2a_agent): + """Test that rejected state is recognized as complete.""" + from unittest.mock import MagicMock + + from a2a.types import TaskState, TaskStatusUpdateEvent + + task = MagicMock() + status = MagicMock() + status.state = TaskState.rejected + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + assert a2a_agent._is_complete_event((task, update_event)) is True + + +def test_is_complete_event_input_required_state(a2a_agent): + """Test that input_required state is recognized as complete (pausing).""" + from unittest.mock import MagicMock + + from a2a.types import TaskState, TaskStatusUpdateEvent + + task = MagicMock() + status = MagicMock() + status.state = TaskState.input_required + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + assert a2a_agent._is_complete_event((task, update_event)) is True + + +def test_is_complete_event_auth_required_state(a2a_agent): + """Test that auth_required state is recognized as complete (pausing).""" + from unittest.mock import MagicMock + + from a2a.types import TaskState, TaskStatusUpdateEvent + + task = MagicMock() + status = MagicMock() + status.state = TaskState.auth_required + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + assert a2a_agent._is_complete_event((task, update_event)) is True + + +def test_is_complete_event_working_state_not_complete(a2a_agent): + """Test that working state is NOT recognized as complete.""" + from unittest.mock import MagicMock + + from a2a.types import TaskState, TaskStatusUpdateEvent + + task = MagicMock() + status = MagicMock() + status.state = TaskState.working + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + assert a2a_agent._is_complete_event((task, update_event)) is False + + +def test_is_complete_event_submitted_state_not_complete(a2a_agent): + """Test that submitted state is NOT recognized as complete.""" + from unittest.mock import MagicMock + + from a2a.types import TaskState, TaskStatusUpdateEvent + + task = MagicMock() + status = MagicMock() + status.state = TaskState.submitted + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + assert a2a_agent._is_complete_event((task, update_event)) is False + + +# ========================================================================= +# DEVIL'S ADVOCATE FINDINGS — Tests addressing review gaps +# ========================================================================= + + +@pytest.mark.parametrize( + "state,expected_complete", + [ + (TaskState.completed, True), + (TaskState.failed, True), + (TaskState.canceled, True), + (TaskState.rejected, True), + (TaskState.input_required, True), + (TaskState.auth_required, True), + (TaskState.working, False), + (TaskState.submitted, False), + (TaskState.unknown, False), + ], + ids=[ + "completed-is-complete", + "failed-is-complete", + "canceled-is-complete", + "rejected-is-complete", + "input_required-is-complete", + "auth_required-is-complete", + "working-not-complete", + "submitted-not-complete", + "unknown-not-complete", + ], +) +def test_is_complete_event_all_states_parametrized(a2a_agent, state, expected_complete): + """Minor Finding 7: Parametrized test covering ALL TaskState values. + + This replaces verbose individual tests with a single parameterized test that + covers all 9 TaskState values. When a2a-sdk adds new states, adding a row here + is trivial. + """ + from unittest.mock import MagicMock + + from a2a.types import TaskStatusUpdateEvent + + task = MagicMock() + status = MagicMock() + status.state = state + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + assert a2a_agent._is_complete_event((task, update_event)) is expected_complete diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index f133400a8..680a1d23c 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -3,40 +3,49 @@ import json import os import textwrap +import threading import unittest.mock import warnings +from collections.abc import AsyncGenerator +from typing import Any from uuid import uuid4 import pytest from pydantic import BaseModel import strands -from strands import Agent +from strands import Agent, Plugin, ToolContext from strands.agent import AgentResult +from strands.agent._agent_as_tool import _AgentAsTool from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager from strands.agent.state import AgentState from strands.handlers.callback_handler import PrintingCallbackHandler, null_callback_handler -from strands.hooks import BeforeToolCallEvent +from strands.hooks import BeforeInvocationEvent, BeforeModelCallEvent, BeforeToolCallEvent from strands.interrupt import Interrupt from strands.models.bedrock import DEFAULT_BEDROCK_MODEL_ID, BedrockModel from strands.session.repository_session_manager import RepositorySessionManager from strands.telemetry.tracer import serialize from strands.types._events import EventLoopStopEvent, ModelStreamEvent +from strands.types.agent import ConcurrentInvocationMode from strands.types.content import Messages -from strands.types.exceptions import ContextWindowOverflowException, EventLoopException +from strands.types.exceptions import ConcurrencyException, ContextWindowOverflowException, EventLoopException from strands.types.session import Session, SessionAgent, SessionMessage, SessionType from tests.fixtures.mock_session_repository import MockedSessionRepository from tests.fixtures.mocked_model_provider import MockedModelProvider # For unit testing we will use the the us inference -FORMATTED_DEFAULT_MODEL_ID = DEFAULT_BEDROCK_MODEL_ID.format("us") +FORMATTED_DEFAULT_MODEL_ID = DEFAULT_BEDROCK_MODEL_ID @pytest.fixture def mock_model(request): async def stream(*args, **kwargs): - result = mock.mock_stream(*copy.deepcopy(args), **copy.deepcopy(kwargs)) + # Skip deep copy of invocation_state which contains non-serializable objects (agent, spans, etc.) + copied_kwargs = { + key: value if key == "invocation_state" else copy.deepcopy(value) for key, value in kwargs.items() + } + result = mock.mock_stream(*copy.deepcopy(args), **copied_kwargs) # If result is already an async generator, yield from it if hasattr(result, "__aiter__"): async for item in result: @@ -49,6 +58,7 @@ async def stream(*args, **kwargs): mock = unittest.mock.Mock(spec=getattr(request, "param", None)) mock.configure_mock(mock_stream=unittest.mock.MagicMock()) mock.stream.side_effect = stream + mock.stateful = False return mock @@ -185,6 +195,29 @@ class User(BaseModel): return User(name="Jane Doe", age=30, email="jane@doe.com") +class SyncEventMockedModel(MockedModelProvider): + """A mock model that uses events to synchronize concurrent threads. + + This model signals when it starts streaming and waits for a proceed signal, + allowing deterministic testing of concurrent behavior without relying on sleeps. + """ + + def __init__(self, agent_responses): + super().__init__(agent_responses) + self.started_event = threading.Event() + self.proceed_event = threading.Event() + + async def stream( + self, messages, tool_specs=None, system_prompt=None, tool_choice=None, **kwargs + ) -> AsyncGenerator[Any, None]: + # Signal that streaming has started + self.started_event.set() + # Wait for signal to proceed + self.proceed_event.wait() + async for event in super().stream(messages, tool_specs, system_prompt, tool_choice, **kwargs): + yield event + + def test_agent__init__tool_loader_format(tool_decorated, tool_module, tool_imported, tool_registry): _ = tool_registry @@ -303,7 +336,7 @@ def test_agent__call__( "stop_reason": result.stop_reason, } exp_result = { - "message": {"content": [{"text": "test text"}], "role": "assistant"}, + "message": {"content": [{"text": "test text"}], "role": "assistant", "metadata": unittest.mock.ANY}, "state": {}, "stop_reason": "end_turn", } @@ -325,6 +358,8 @@ def test_agent__call__( system_prompt, tool_choice=None, system_prompt_content=[{"text": system_prompt}], + invocation_state=unittest.mock.ANY, + model_state=unittest.mock.ANY, ), unittest.mock.call( [ @@ -363,6 +398,8 @@ def test_agent__call__( system_prompt, tool_choice=None, system_prompt_content=[{"text": system_prompt}], + invocation_state=unittest.mock.ANY, + model_state=unittest.mock.ANY, ), ], ) @@ -484,6 +521,8 @@ def test_agent__call__retry_with_reduced_context(mock_model, agent, tool, agener unittest.mock.ANY, tool_choice=None, system_prompt_content=unittest.mock.ANY, + invocation_state=unittest.mock.ANY, + model_state=unittest.mock.ANY, ) conversation_manager_spy.reduce_context.assert_called_once() @@ -587,7 +626,7 @@ def test_agent__call__retry_with_overwritten_tool(mock_model, agent, tool, agene }, }, }, - {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"random_string": "abcdEfghI123"}'}}}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"random_string": "' + "X" * 500 + '"}'}}}}, {"contentBlockStop": {}}, {"messageStop": {"stopReason": "tool_use"}}, ] @@ -601,12 +640,14 @@ def test_agent__call__retry_with_overwritten_tool(mock_model, agent, tool, agene agent("test message") + large_input = "X" * 500 + truncated_text = large_input[:200] + "...\n\n... [truncated: 100 chars removed] ...\n\n..." + large_input[-200:] expected_messages = [ {"role": "user", "content": [{"text": "test message"}]}, { "role": "assistant", "content": [ - {"toolUse": {"toolUseId": "t1", "name": "tool_decorated", "input": {"random_string": "abcdEfghI123"}}} + {"toolUse": {"toolUseId": "t1", "name": "tool_decorated", "input": {"random_string": large_input}}} ], }, { @@ -615,8 +656,8 @@ def test_agent__call__retry_with_overwritten_tool(mock_model, agent, tool, agene { "toolResult": { "toolUseId": "t1", - "status": "error", - "content": [{"text": "The tool result was too large!"}], + "status": "success", + "content": [{"text": truncated_text}], } } ], @@ -629,6 +670,8 @@ def test_agent__call__retry_with_overwritten_tool(mock_model, agent, tool, agene unittest.mock.ANY, tool_choice=None, system_prompt_content=unittest.mock.ANY, + invocation_state=unittest.mock.ANY, + model_state=unittest.mock.ANY, ) assert conversation_manager_spy.reduce_context.call_count == 2 @@ -738,6 +781,7 @@ def test_agent__call__callback(mock_model, agent, callback_handler, agenerator): {"reasoningContent": {"reasoningText": {"text": "value", "signature": "value"}}}, {"text": "value"}, ], + "metadata": unittest.mock.ANY, }, ), unittest.mock.call( @@ -750,6 +794,7 @@ def test_agent__call__callback(mock_model, agent, callback_handler, agenerator): {"reasoningContent": {"reasoningText": {"text": "value", "signature": "value"}}}, {"text": "value"}, ], + "metadata": unittest.mock.ANY, }, metrics=unittest.mock.ANY, state={}, @@ -774,7 +819,7 @@ async def test_agent__call__in_async_context(mock_model, agent, agenerator): result = agent("test") tru_message = result.message - exp_message = {"content": [{"text": "abc"}], "role": "assistant"} + exp_message = {"content": [{"text": "abc"}], "role": "assistant", "metadata": unittest.mock.ANY} assert tru_message == exp_message @@ -794,7 +839,7 @@ async def test_agent_invoke_async(mock_model, agent, agenerator): result = await agent.invoke_async("test") tru_message = result.message - exp_message = {"content": [{"text": "abc"}], "role": "assistant"} + exp_message = {"content": [{"text": "abc"}], "role": "assistant", "metadata": unittest.mock.ANY} assert tru_message == exp_message @@ -1085,7 +1130,7 @@ async def test_stream_async_multi_modal_input(mock_model, agent, agenerator, ali tru_message = agent.messages exp_message = [ {"content": prompt, "role": "user"}, - {"content": [{"text": "I see text and an image"}], "role": "assistant"}, + {"content": [{"text": "I see text and an image"}], "role": "assistant", "metadata": unittest.mock.ANY}, ] assert tru_message == exp_message @@ -1121,6 +1166,33 @@ def test_system_prompt_setter_none(): assert agent._system_prompt_content is None +def test_system_prompt_content_string(): + """Test that system_prompt_content returns content blocks for string prompt.""" + agent = Agent(system_prompt="hello") + assert agent.system_prompt_content == [{"text": "hello"}] + + +def test_system_prompt_content_structured(): + """Test that system_prompt_content returns structured blocks with cache points.""" + blocks = [{"text": "You are helpful"}, {"cachePoint": {"type": "default"}}] + agent = Agent(system_prompt=blocks) + assert agent.system_prompt_content == blocks + + +def test_system_prompt_content_none(): + """Test that system_prompt_content returns None when no prompt is set.""" + agent = Agent(system_prompt=None) + assert agent.system_prompt_content is None + + +def test_system_prompt_content_returns_copy(): + """Test that system_prompt_content returns a defensive copy.""" + agent = Agent(system_prompt="hello") + content = agent.system_prompt_content + content.append({"text": "injected"}) + assert agent.system_prompt_content == [{"text": "hello"}] + + @pytest.mark.asyncio async def test_stream_async_passes_invocation_state(agent, mock_model, mock_event_loop_cycle, agenerator, alist): mock_model.mock_stream.side_effect = [ @@ -1269,7 +1341,9 @@ def test_agent_call_creates_and_ends_span_on_success(mock_get_tracer, mock_model @pytest.mark.asyncio @unittest.mock.patch("strands.agent.agent.get_tracer") -async def test_agent_stream_async_creates_and_ends_span_on_success(mock_get_tracer, mock_event_loop_cycle, alist): +async def test_agent_stream_async_creates_and_ends_span_on_success( + mock_get_tracer, mock_event_loop_cycle, mock_model, alist +): """Test that stream_async creates and ends a span when the call succeeds.""" # Setup mock tracer and span mock_tracer = unittest.mock.MagicMock() @@ -1570,10 +1644,15 @@ def test_agent_restored_from_session_management_with_correct_index(): def test_agent_with_session_and_conversation_manager(): - mock_model = MockedModelProvider([{"role": "assistant", "content": [{"text": "hello!"}]}]) + mock_model = MockedModelProvider( + [ + {"role": "assistant", "content": [{"text": "first"}]}, + {"role": "assistant", "content": [{"text": "second"}]}, + ] + ) mock_session_repository = MockedSessionRepository() session_manager = RepositorySessionManager(session_id="123", session_repository=mock_session_repository) - conversation_manager = SlidingWindowConversationManager(window_size=1) + conversation_manager = SlidingWindowConversationManager(window_size=2) # Create an agent with a mocked model and session repository agent = Agent( session_manager=session_manager, @@ -1588,14 +1667,20 @@ def test_agent_with_session_and_conversation_manager(): agent("Hello!") - # After invoking, assert that the messages were persisted + # After first invocation: [user, assistant] — fits in window, no trimming assert len(mock_session_repository.list_messages("123", agent.agent_id)) == 2 - # Assert conversation manager reduced the messages - assert len(agent.messages) == 1 + assert len(agent.messages) == 2 + + agent("Second question") + + # After second invocation: [user, assistant, user, assistant] exceeds window_size=2 + # Conversation manager trims to 2 messages starting with a user message + assert len(agent.messages) == 2 + assert agent.messages[0]["role"] == "user" # Initialize another agent using the same session session_manager_2 = RepositorySessionManager(session_id="123", session_repository=mock_session_repository) - conversation_manager_2 = SlidingWindowConversationManager(window_size=1) + conversation_manager_2 = SlidingWindowConversationManager(window_size=2) agent_2 = Agent( session_manager=session_manager_2, conversation_manager=conversation_manager_2, @@ -1603,7 +1688,7 @@ def test_agent_with_session_and_conversation_manager(): ) # Assert that the second agent was initialized properly, and that the messages of both agents are equal assert agent.messages == agent_2.messages - # Asser the conversation manager was initialized properly + # Assert the conversation manager was initialized properly assert agent.conversation_manager.removed_message_count == agent_2.conversation_manager.removed_message_count @@ -1910,7 +1995,11 @@ def shell(command: str): } # And that it continued to the LLM call - assert agent.messages[-1] == {"content": [{"text": "I invoked a tool!"}], "role": "assistant"} + assert agent.messages[-1] == { + "content": [{"text": "I invoked a tool!"}], + "role": "assistant", + "metadata": unittest.mock.ANY, + } def test_agent_string_system_prompt(): @@ -2182,3 +2271,532 @@ def test_agent_skips_fix_for_valid_conversation(mock_model, agenerator): # Should not have added any toolResult messages # Only the new user message and assistant response should be added assert len(agent.messages) == original_length + 2 + + +# ============================================================================ +# Concurrency Exception Tests +# ============================================================================ + + +def test_agent_concurrent_call_raises_exception(): + """Test that concurrent __call__() calls raise ConcurrencyException.""" + model = SyncEventMockedModel( + [ + {"role": "assistant", "content": [{"text": "hello"}]}, + {"role": "assistant", "content": [{"text": "world"}]}, + ] + ) + agent = Agent(model=model, concurrent_invocation_mode="throw") + + results = [] + errors = [] + + def invoke(): + try: + result = agent("test") + results.append(result) + except ConcurrencyException as e: + errors.append(e) + + # Start first thread and wait for it to begin streaming + t1 = threading.Thread(target=invoke) + t1.start() + model.started_event.wait() # Wait until first thread is in the model.stream() + + # Start second thread while first is still running + t2 = threading.Thread(target=invoke) + t2.start() + + # Give second thread time to attempt invocation and fail + t2.join(timeout=1.0) + + # Now let first thread complete + model.proceed_event.set() + t1.join() + t2.join() + + # One should succeed, one should raise ConcurrencyException + assert len(results) == 1, f"Expected 1 success, got {len(results)}" + assert len(errors) == 1, f"Expected 1 error, got {len(errors)}" + assert "concurrent" in str(errors[0]).lower() and "invocation" in str(errors[0]).lower() + + +def test_agent_concurrent_structured_output_raises_exception(): + """Test that concurrent structured_output() calls raise ConcurrencyException. + + Note: This test validates that the sync invocation path is protected. + The concurrent __call__() test already validates the core functionality. + """ + # Events for synchronization + model = SyncEventMockedModel( + [ + {"role": "assistant", "content": [{"text": "response1"}]}, + {"role": "assistant", "content": [{"text": "response2"}]}, + ], + ) + agent = Agent(model=model, concurrent_invocation_mode="throw") + + results = [] + errors = [] + lock = threading.Lock() + + def invoke(): + try: + result = agent("test") + with lock: + results.append(result) + except ConcurrencyException as e: + with lock: + errors.append(e) + + # Start first thread and wait for it to begin streaming + t1 = threading.Thread(target=invoke) + t1.start() + model.started_event.wait() # Wait until first thread is in the model.stream() + + # Start second thread while first is still running + t2 = threading.Thread(target=invoke) + t2.start() + + # Give second thread time to attempt invocation and fail + t2.join(timeout=1.0) + + # Now let first thread complete + model.proceed_event.set() + t1.join() + t2.join() + + # One should succeed, one should raise ConcurrencyException + assert len(results) == 1, f"Expected 1 success, got {len(results)}" + assert len(errors) == 1, f"Expected 1 error, got {len(errors)}" + assert "concurrent" in str(errors[0]).lower() and "invocation" in str(errors[0]).lower() + + +def test_agent_concurrent_call_succeeds_with_unsafe_reentrant_mode(): + """Test that concurrent __call__() calls succeed when concurrent_invocation_mode is 'unsafe_reentrant'.""" + model = SyncEventMockedModel( + [ + {"role": "assistant", "content": [{"text": "hello"}]}, + {"role": "assistant", "content": [{"text": "world"}]}, + ] + ) + agent = Agent(model=model, concurrent_invocation_mode="unsafe_reentrant") + + results = [] + errors = [] + lock = threading.Lock() + + def invoke(): + try: + result = agent("test") + with lock: + results.append(result) + except ConcurrencyException as e: + with lock: + errors.append(e) + + # Start first thread and wait for it to begin streaming + t1 = threading.Thread(target=invoke) + t1.start() + model.started_event.wait() # Wait until first thread is in the model.stream() + + # Start second thread while first is still running + t2 = threading.Thread(target=invoke) + t2.start() + + # Let both threads proceed + model.proceed_event.set() + t1.join() + t2.join() + + # Both should succeed, no ConcurrencyException raised + assert len(errors) == 0, f"Expected 0 errors, got {len(errors)}: {errors}" + assert len(results) == 2, f"Expected 2 successes, got {len(results)}" + + +def test_agent_concurrent_invocation_mode_default_is_throw(): + """Test that the default concurrent_invocation_mode is 'throw'.""" + model = MockedModelProvider([{"role": "assistant", "content": [{"text": "hello"}]}]) + agent = Agent(model=model) + + # Verify the default mode + assert agent._concurrent_invocation_mode == "throw" + + +def test_agent_concurrent_invocation_mode_stores_value(): + """Test that concurrent_invocation_mode is stored correctly as instance variable.""" + model = MockedModelProvider([{"role": "assistant", "content": [{"text": "hello"}]}]) + + agent_throw = Agent(model=model, concurrent_invocation_mode="throw") + assert agent_throw._concurrent_invocation_mode == "throw" + + agent_reentrant = Agent(model=model, concurrent_invocation_mode="unsafe_reentrant") + assert agent_reentrant._concurrent_invocation_mode == "unsafe_reentrant" + + +def test_agent_concurrent_invocation_mode_accepts_enum(): + """Test that concurrent_invocation_mode accepts enum values as well as strings.""" + + model = MockedModelProvider([{"role": "assistant", "content": [{"text": "hello"}]}]) + + # Using enum values + agent_throw = Agent(model=model, concurrent_invocation_mode=ConcurrentInvocationMode.THROW) + assert agent_throw._concurrent_invocation_mode == "throw" + assert agent_throw._concurrent_invocation_mode == ConcurrentInvocationMode.THROW + + agent_reentrant = Agent(model=model, concurrent_invocation_mode=ConcurrentInvocationMode.UNSAFE_REENTRANT) + assert agent_reentrant._concurrent_invocation_mode == "unsafe_reentrant" + assert agent_reentrant._concurrent_invocation_mode == ConcurrentInvocationMode.UNSAFE_REENTRANT + + +@pytest.mark.asyncio +async def test_agent_sequential_invocations_work(): + """Test that sequential invocations work correctly after lock is released.""" + model = MockedModelProvider( + [ + {"role": "assistant", "content": [{"text": "response1"}]}, + {"role": "assistant", "content": [{"text": "response2"}]}, + {"role": "assistant", "content": [{"text": "response3"}]}, + ] + ) + agent = Agent(model=model) + + # All sequential calls should succeed + result1 = await agent.invoke_async("test1") + assert result1.message["content"][0]["text"] == "response1" + + result2 = await agent.invoke_async("test2") + assert result2.message["content"][0]["text"] == "response2" + + result3 = await agent.invoke_async("test3") + assert result3.message["content"][0]["text"] == "response3" + + +@pytest.mark.asyncio +async def test_agent_lock_released_on_exception(): + """Test that lock is released when an exception occurs during invocation.""" + + # Create a mock model that raises an explicit error + mock_model = unittest.mock.Mock() + + async def failing_stream(*args, **kwargs): + raise RuntimeError("Simulated model failure") + yield # Make this an async generator + + mock_model.stream = failing_stream + + agent = Agent(model=mock_model) + + # First call will fail due to the simulated error + with pytest.raises(RuntimeError, match="Simulated model failure"): + await agent.invoke_async("test") + + # Lock should be released, so this should not raise ConcurrencyException + # It will still raise RuntimeError, but that's expected + with pytest.raises(RuntimeError, match="Simulated model failure"): + await agent.invoke_async("test") + + +def test_agent_direct_tool_call_during_invocation_raises_exception(tool_decorated): + """Test that direct tool call during agent invocation raises ConcurrencyException.""" + + tool_calls = [] + + @strands.tool + def tool_to_invoke(): + tool_calls.append("tool_to_invoke") + return "called" + + @strands.tool(context=True) + def agent_tool(tool_context: ToolContext) -> str: + tool_context.agent.tool.tool_to_invoke(record_direct_tool_call=True) + return "tool result" + + model = MockedModelProvider( + [ + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "test-123", + "name": "agent_tool", + "input": {}, + } + } + ], + }, + {"role": "assistant", "content": [{"text": "Done"}]}, + ] + ) + agent = Agent(model=model, tools=[agent_tool, tool_to_invoke]) + agent("Hi") + + # Tool call should have not succeeded + assert len(tool_calls) == 0 + + assert agent.messages[-2] == { + "content": [ + { + "toolResult": { + "content": [ + { + "text": "Error: ConcurrencyException - Direct tool call cannot be made while the agent is " + "in the middle of an invocation. Set record_direct_tool_call=False to allow direct tool " + "calls during agent invocation." + } + ], + "status": "error", + "toolUseId": "test-123", + } + } + ], + "role": "user", + } + + +def test_agent_direct_tool_call_during_invocation_succeeds_with_record_false(tool_decorated): + """Test that direct tool call during agent invocation succeeds when record_direct_tool_call=False.""" + tool_calls = [] + + @strands.tool + def tool_to_invoke(): + tool_calls.append("tool_to_invoke") + return "called" + + @strands.tool(context=True) + def agent_tool(tool_context: ToolContext) -> str: + tool_context.agent.tool.tool_to_invoke(record_direct_tool_call=False) + return "tool result" + + model = MockedModelProvider( + [ + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "test-123", + "name": "agent_tool", + "input": {}, + } + } + ], + }, + {"role": "assistant", "content": [{"text": "Done"}]}, + ] + ) + agent = Agent(model=model, tools=[agent_tool, tool_to_invoke]) + agent("Hi") + + # Tool call should have succeeded + assert len(tool_calls) == 1 + + assert agent.messages[-2] == { + "content": [ + { + "toolResult": { + "content": [{"text": "tool result"}], + "status": "success", + "toolUseId": "test-123", + } + } + ], + "role": "user", + } + + +def test_agent_add_hook_registers_callback(): + """Test that add_hook registers a callback with the hooks registry.""" + agent = Agent(model=MockedModelProvider([{"role": "assistant", "content": [{"text": "response"}]}])) + callback = unittest.mock.Mock() + + agent.add_hook(callback, BeforeModelCallEvent) + + # Verify callback was registered by checking it gets invoked + agent("test prompt") + callback.assert_called_once() + # Verify it was called with the correct event type + call_args = callback.call_args[0] + assert isinstance(call_args[0], BeforeModelCallEvent) + + +def test_agent_add_hook_delegates_to_hooks_add_callback(): + """Test that add_hook delegates to self.hooks.add_callback.""" + agent = Agent(model=MockedModelProvider([{"role": "assistant", "content": [{"text": "response"}]}])) + callback = unittest.mock.Mock() + + # Spy on the hooks.add_callback method + with unittest.mock.patch.object(agent.hooks, "add_callback") as mock_add_callback: + agent.add_hook(callback, BeforeInvocationEvent) + mock_add_callback.assert_called_once_with(BeforeInvocationEvent, callback) + + +@pytest.mark.asyncio +async def test_agent_add_hook_works_with_async_callback(): + """Test that add_hook works with async callbacks.""" + + agent = Agent(model=MockedModelProvider([{"role": "assistant", "content": [{"text": "response"}]}])) + async_callback = unittest.mock.AsyncMock() + + agent.add_hook(async_callback, BeforeModelCallEvent) + + # Use stream_async to invoke the agent with async support + _ = [event async for event in agent.stream_async("test prompt")] + async_callback.assert_called_once() + # Verify it was called with the correct event type + call_args = async_callback.call_args[0] + assert isinstance(call_args[0], BeforeModelCallEvent) + + +def test_agent_add_hook_infers_event_type_from_callback(): + """Test that add_hook infers event type from callback type hint.""" + agent = Agent(model=MockedModelProvider([{"role": "assistant", "content": [{"text": "response"}]}])) + callback_invoked = [] + + def typed_callback(event: BeforeModelCallEvent) -> None: + callback_invoked.append(event) + + agent.add_hook(typed_callback) + agent("test prompt") + + assert len(callback_invoked) == 1 + assert isinstance(callback_invoked[0], BeforeModelCallEvent) + + +def test_agent_add_hook_raises_error_when_no_type_hint(): + """Test that add_hook raises error when event type cannot be inferred.""" + agent = Agent(model=MockedModelProvider([{"role": "assistant", "content": [{"text": "response"}]}])) + + def untyped_callback(event): + pass + + with pytest.raises(ValueError, match="cannot infer event type"): + agent.add_hook(untyped_callback) + + +def test_agent_plugins_sync_initialization(): + """Test that plugins with sync init_agent are initialized correctly.""" + plugin_mock = unittest.mock.Mock() + plugin_mock.name = "test-plugin" + plugin_mock.hooks = [] + plugin_mock.tools = [] + plugin_mock.init_agent = unittest.mock.Mock() + + agent = Agent( + model=MockedModelProvider([{"role": "assistant", "content": [{"text": "response"}]}]), + plugins=[plugin_mock], + ) + + plugin_mock.init_agent.assert_called_once_with(agent) + + +def test_agent_plugins_async_initialization(): + """Test that plugins with async init_agent are initialized correctly.""" + plugin_mock = unittest.mock.Mock() + plugin_mock.name = "async-plugin" + plugin_mock.hooks = [] + plugin_mock.tools = [] + plugin_mock.init_agent = unittest.mock.AsyncMock() + + agent = Agent( + model=MockedModelProvider([{"role": "assistant", "content": [{"text": "response"}]}]), + plugins=[plugin_mock], + ) + + plugin_mock.init_agent.assert_called_once_with(agent) + + +def test_agent_plugins_multiple_in_order(): + """Test that multiple plugins are initialized in order.""" + call_order = [] + + plugin1 = unittest.mock.Mock() + plugin1.name = "plugin1" + plugin1.hooks = [] + plugin1.tools = [] + plugin1.init_agent = unittest.mock.Mock(side_effect=lambda agent: call_order.append("plugin1")) + + plugin2 = unittest.mock.Mock() + plugin2.name = "plugin2" + plugin2.hooks = [] + plugin2.tools = [] + plugin2.init_agent = unittest.mock.Mock(side_effect=lambda agent: call_order.append("plugin2")) + + Agent( + model=MockedModelProvider([{"role": "assistant", "content": [{"text": "response"}]}]), + plugins=[plugin1, plugin2], + ) + + assert call_order == ["plugin1", "plugin2"] + + +def test_agent_plugins_can_register_hooks(): + """Test that plugins can register hooks during initialization.""" + hook_called = [] + + class TestPlugin(Plugin): + name = "hook-plugin" + + def init_agent(self, agent): + def hook_callback(event: BeforeModelCallEvent): + hook_called.append(True) + + agent.add_hook(hook_callback) + + agent = Agent( + model=MockedModelProvider([{"role": "assistant", "content": [{"text": "response"}]}]), + plugins=[TestPlugin()], + ) + + agent("test") + assert len(hook_called) == 1 + + +def test_as_tool_returns_agent_tool(): + """Test that as_tool returns an _AgentAsTool wrapping the agent.""" + agent = Agent(name="researcher", description="Finds information") + tool = agent.as_tool() + + assert isinstance(tool, _AgentAsTool) + assert tool.agent is agent + + +def test_as_tool_defaults_name_from_agent(): + """Test that as_tool defaults the tool name to the agent's name.""" + agent = Agent(name="researcher") + tool = agent.as_tool() + + assert tool.tool_name == "researcher" + + +def test_as_tool_defaults_description_from_agent(): + """Test that as_tool defaults the description to the agent's description.""" + agent = Agent(name="researcher", description="Finds information") + tool = agent.as_tool() + + assert tool.tool_spec["description"] == "Finds information" + + +def test_as_tool_custom_name(): + """Test that as_tool accepts a custom name.""" + agent = Agent(name="researcher") + tool = agent.as_tool(name="custom_name") + + assert tool.tool_name == "custom_name" + + +def test_as_tool_custom_description(): + """Test that as_tool accepts a custom description.""" + agent = Agent(name="researcher", description="Original") + tool = agent.as_tool(description="Custom description") + + assert tool.tool_spec["description"] == "Custom description" + + +def test_as_tool_defaults_description_when_agent_has_none(): + """Test that as_tool generates a default description when agent has none.""" + agent = Agent(name="researcher") + tool = agent.as_tool() + + assert tool.tool_spec["description"] == "Use the researcher agent as a tool by providing a natural language input" diff --git a/tests/strands/agent/test_agent_as_tool.py b/tests/strands/agent/test_agent_as_tool.py new file mode 100644 index 000000000..5a8399830 --- /dev/null +++ b/tests/strands/agent/test_agent_as_tool.py @@ -0,0 +1,722 @@ +"""Tests for _AgentAsTool - the agent-as-tool adapter.""" + +from unittest.mock import MagicMock + +import pytest + +from strands.agent._agent_as_tool import _AgentAsTool +from strands.agent.agent_result import AgentResult +from strands.interrupt import Interrupt, _InterruptState +from strands.telemetry.metrics import EventLoopMetrics +from strands.types._events import AgentAsToolStreamEvent, ToolInterruptEvent, ToolResultEvent, ToolStreamEvent + + +async def _mock_stream_async(result, intermediate_events=None): + """Helper that yields intermediate events then the final result event.""" + for event in intermediate_events or []: + yield event + yield {"result": result} + + +@pytest.fixture +def mock_agent(): + agent = MagicMock() + agent.name = "test_agent" + agent.description = "A test agent" + agent._interrupt_state = _InterruptState() + return agent + + +@pytest.fixture +def fake_agent(): + """A real Agent instance for tests that need Agent-specific features.""" + from strands.agent.agent import Agent + + return Agent(name="fake_agent", callback_handler=None) + + +@pytest.fixture +def tool(mock_agent): + return _AgentAsTool(mock_agent, name="test_agent", description="A test agent", preserve_context=True) + + +@pytest.fixture +def tool_use(): + return { + "toolUseId": "tool-123", + "name": "test_agent", + "input": {"input": "hello"}, + } + + +@pytest.fixture +def agent_result(): + return AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "response text"}]}, + metrics=EventLoopMetrics(), + state={}, + ) + + +# --- init --- + + +def test_init(mock_agent): + tool = _AgentAsTool(mock_agent, name="my_tool", description="custom desc", preserve_context=True) + assert tool.tool_name == "my_tool" + assert tool._description == "custom desc" + assert tool.agent is mock_agent + + +def test_init_description_defaults_to_agent_description(fake_agent): + fake_agent.description = "Agent that researches topics" + tool = _AgentAsTool(fake_agent, name="researcher", preserve_context=True) + assert tool._description == "Agent that researches topics" + + +def test_init_description_defaults_to_generic_when_agent_has_none(fake_agent): + tool = _AgentAsTool(fake_agent, name="researcher", preserve_context=True) + assert tool._description == "Use the researcher agent as a tool by providing a natural language input" + + +def test_init_description_explicit_overrides_agent_description(fake_agent): + fake_agent.description = "Agent that researches topics" + tool = _AgentAsTool(fake_agent, name="researcher", description="custom", preserve_context=True) + assert tool._description == "custom" + + +def test_init_preserve_context_defaults_false(fake_agent): + tool = _AgentAsTool(fake_agent, name="t", description="d") + assert tool._preserve_context is False + + +def test_init_preserve_context_true(mock_agent): + tool = _AgentAsTool(mock_agent, name="t", description="d", preserve_context=True) + assert tool._preserve_context is True + + +# --- properties --- + + +def test_tool_properties(tool): + assert tool.tool_name == "test_agent" + assert tool.tool_type == "agent" + + spec = tool.tool_spec + assert spec["name"] == "test_agent" + assert spec["description"] == "A test agent" + + schema = spec["inputSchema"]["json"] + assert schema["type"] == "object" + assert "input" in schema["properties"] + assert schema["properties"]["input"]["type"] == "string" + assert schema["required"] == ["input"] + + props = tool.get_display_properties() + assert props["Agent"] == "test_agent" + assert props["Type"] == "agent" + + +# --- stream --- + + +@pytest.mark.asyncio +async def test_stream_success(tool, mock_agent, tool_use, agent_result): + mock_agent.stream_async.return_value = _mock_stream_async(agent_result) + + events = [event async for event in tool.stream(tool_use, {})] + + result_events = [e for e in events if isinstance(e, ToolResultEvent)] + assert len(result_events) == 1 + assert result_events[0]["tool_result"]["status"] == "success" + assert result_events[0]["tool_result"]["content"][0]["text"] == "response text\n" + + +@pytest.mark.asyncio +async def test_stream_passes_input_to_agent(tool, mock_agent, tool_use, agent_result): + mock_agent.stream_async.return_value = _mock_stream_async(agent_result) + + async for _ in tool.stream(tool_use, {}): + pass + + mock_agent.stream_async.assert_called_once_with("hello") + + +@pytest.mark.asyncio +async def test_stream_empty_input(tool, mock_agent, agent_result): + empty_tool_use = { + "toolUseId": "tool-123", + "name": "test_agent", + "input": {}, + } + mock_agent.stream_async.return_value = _mock_stream_async(agent_result) + + async for _ in tool.stream(empty_tool_use, {}): + pass + + mock_agent.stream_async.assert_called_once_with("") + + +@pytest.mark.asyncio +async def test_stream_string_input(tool, mock_agent, agent_result): + tool_use = { + "toolUseId": "tool-123", + "name": "test_agent", + "input": "direct string", + } + mock_agent.stream_async.return_value = _mock_stream_async(agent_result) + + async for _ in tool.stream(tool_use, {}): + pass + + mock_agent.stream_async.assert_called_once_with("direct string") + + +@pytest.mark.asyncio +async def test_stream_error(tool, mock_agent, tool_use): + mock_agent.stream_async.side_effect = RuntimeError("boom") + + events = [event async for event in tool.stream(tool_use, {})] + + assert len(events) == 1 + assert events[0]["tool_result"]["status"] == "error" + assert "boom" in events[0]["tool_result"]["content"][0]["text"] + + +@pytest.mark.asyncio +async def test_stream_propagates_tool_use_id(tool, mock_agent, tool_use, agent_result): + mock_agent.stream_async.return_value = _mock_stream_async(agent_result) + + events = [event async for event in tool.stream(tool_use, {})] + + result_events = [e for e in events if isinstance(e, ToolResultEvent)] + assert result_events[0]["tool_result"]["toolUseId"] == "tool-123" + + +@pytest.mark.asyncio +async def test_stream_forwards_intermediate_events(tool, mock_agent, tool_use, agent_result): + intermediate = [{"data": "partial"}, {"data": "more"}] + mock_agent.stream_async.return_value = _mock_stream_async(agent_result, intermediate) + + events = [event async for event in tool.stream(tool_use, {})] + + stream_events = [e for e in events if isinstance(e, AgentAsToolStreamEvent)] + assert len(stream_events) == 2 + assert stream_events[0]["tool_stream_event"]["data"]["data"] == "partial" + assert stream_events[1]["tool_stream_event"]["data"]["data"] == "more" + assert stream_events[0].agent_as_tool is tool + assert stream_events[0].tool_use_id == "tool-123" + + +@pytest.mark.asyncio +async def test_stream_events_not_double_wrapped_by_executor(tool, mock_agent, tool_use, agent_result): + """AgentAsToolStreamEvent is a ToolStreamEvent subclass, so the executor should pass it through directly.""" + intermediate = [{"data": "chunk"}] + mock_agent.stream_async.return_value = _mock_stream_async(agent_result, intermediate) + + events = [event async for event in tool.stream(tool_use, {})] + + stream_events = [e for e in events if isinstance(e, AgentAsToolStreamEvent)] + assert len(stream_events) == 1 + + event = stream_events[0] + # It's a ToolStreamEvent (so the executor yields it directly) + assert isinstance(event, ToolStreamEvent) + # But it's specifically an AgentAsToolStreamEvent (not re-wrapped) + assert type(event) is AgentAsToolStreamEvent + # And it references the originating _AgentAsTool + assert event.agent_as_tool is tool + + +@pytest.mark.asyncio +async def test_stream_no_result_yields_error(tool, mock_agent, tool_use): + async def _empty_stream(): + return + yield # noqa: RET504 - make it an async generator + + mock_agent.stream_async.return_value = _empty_stream() + + events = [event async for event in tool.stream(tool_use, {})] + + assert len(events) == 1 + assert events[0]["tool_result"]["status"] == "error" + assert "did not produce a result" in events[0]["tool_result"]["content"][0]["text"] + + +@pytest.mark.asyncio +async def test_stream_structured_output(tool, mock_agent, tool_use): + from pydantic import BaseModel + + class MyOutput(BaseModel): + answer: str + + structured = MyOutput(answer="42") + result = AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "ignored"}]}, + metrics=EventLoopMetrics(), + state={}, + structured_output=structured, + ) + mock_agent.stream_async.return_value = _mock_stream_async(result) + + events = [event async for event in tool.stream(tool_use, {})] + + result_events = [e for e in events if isinstance(e, ToolResultEvent)] + assert result_events[0]["tool_result"]["status"] == "success" + assert result_events[0]["tool_result"]["content"][0]["json"] == {"answer": "42"} + + +# --- preserve_context --- + + +@pytest.mark.asyncio +async def test_stream_resets_to_initial_state_when_preserve_context_false(fake_agent): + fake_agent.messages = [{"role": "user", "content": [{"text": "initial"}]}] + fake_agent.state.set("counter", 0) + + tool = _AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=False) + + # Mutate agent state as if a previous invocation happened + fake_agent.messages.append({"role": "assistant", "content": [{"text": "reply"}]}) + fake_agent.state.set("counter", 5) + + # Mock stream_async so we don't need a real model + fake_agent.stream_async = lambda prompt, **kw: _mock_stream_async( + AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "ok"}]}, + metrics=EventLoopMetrics(), + state={}, + ) + ) + + tool_use = { + "toolUseId": "tool-123", + "name": "fake_agent", + "input": {"input": "hello"}, + } + + async for _ in tool.stream(tool_use, {}): + pass + + assert fake_agent.messages == [{"role": "user", "content": [{"text": "initial"}]}] + assert fake_agent.state.get("counter") == 0 + + +@pytest.mark.asyncio +async def test_stream_resets_on_every_invocation(fake_agent): + """Each call should reset to the same initial snapshot, not to the previous call's state.""" + fake_agent.messages = [{"role": "user", "content": [{"text": "seed"}]}] + fake_agent.state.set("count", 1) + + tool = _AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=False) + + fake_agent.stream_async = lambda prompt, **kw: _mock_stream_async( + AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "ok"}]}, + metrics=EventLoopMetrics(), + state={}, + ) + ) + + tool_use = { + "toolUseId": "tool-1", + "name": "fake_agent", + "input": {"input": "first"}, + } + + async for _ in tool.stream(tool_use, {}): + pass + fake_agent.messages.append({"role": "assistant", "content": [{"text": "added"}]}) + fake_agent.state.set("count", 99) + + tool_use["toolUseId"] = "tool-2" + async for _ in tool.stream(tool_use, {}): + pass + + assert fake_agent.messages == [{"role": "user", "content": [{"text": "seed"}]}] + assert fake_agent.state.get("count") == 1 + + +@pytest.mark.asyncio +async def test_stream_initial_snapshot_is_deep_copy(fake_agent): + """Mutating the agent's messages after construction should not affect the snapshot.""" + fake_agent.messages = [{"role": "user", "content": [{"text": "original"}]}] + + tool = _AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=False) + + fake_agent.messages[0]["content"][0]["text"] = "mutated" + fake_agent.messages.append({"role": "assistant", "content": [{"text": "extra"}]}) + + fake_agent.stream_async = lambda prompt, **kw: _mock_stream_async( + AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "ok"}]}, + metrics=EventLoopMetrics(), + state={}, + ) + ) + + tool_use = { + "toolUseId": "tool-123", + "name": "fake_agent", + "input": {"input": "hello"}, + } + + async for _ in tool.stream(tool_use, {}): + pass + + assert fake_agent.messages == [{"role": "user", "content": [{"text": "original"}]}] + + +@pytest.mark.asyncio +async def test_stream_resets_empty_initial_state_when_preserve_context_false(fake_agent): + tool = _AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=False) + + fake_agent.messages = [{"role": "user", "content": [{"text": "old"}]}] + fake_agent.state.set("key", "value") + + fake_agent.stream_async = lambda prompt, **kw: _mock_stream_async( + AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "ok"}]}, + metrics=EventLoopMetrics(), + state={}, + ) + ) + + tool_use = { + "toolUseId": "tool-123", + "name": "fake_agent", + "input": {"input": "hello"}, + } + + async for _ in tool.stream(tool_use, {}): + pass + + assert fake_agent.messages == [] + assert fake_agent.state.get() == {} + + +@pytest.mark.asyncio +async def test_stream_resets_context_by_default(fake_agent): + """Default preserve_context=False means each invocation starts fresh.""" + fake_agent.messages = [{"role": "user", "content": [{"text": "old"}]}] + fake_agent.state.set("key", "value") + tool = _AgentAsTool(fake_agent, name="fake_agent", description="desc") + + # Mutate after construction + fake_agent.messages.append({"role": "assistant", "content": [{"text": "extra"}]}) + fake_agent.state.set("key", "changed") + + fake_agent.stream_async = lambda prompt, **kw: _mock_stream_async( + AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "ok"}]}, + metrics=EventLoopMetrics(), + state={}, + ) + ) + + tool_use = { + "toolUseId": "tool-123", + "name": "fake_agent", + "input": {"input": "hello"}, + } + + async for _ in tool.stream(tool_use, {}): + pass + + # Should reset to construction-time snapshot + assert fake_agent.messages == [{"role": "user", "content": [{"text": "old"}]}] + assert fake_agent.state.get("key") == "value" + + +@pytest.mark.asyncio +async def test_stream_preserves_context_when_explicitly_true(fake_agent): + fake_agent.messages = [{"role": "user", "content": [{"text": "old"}]}] + fake_agent.state.set("key", "value") + tool = _AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=True) + + fake_agent.stream_async = lambda prompt, **kw: _mock_stream_async( + AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "ok"}]}, + metrics=EventLoopMetrics(), + state={}, + ) + ) + + tool_use = { + "toolUseId": "tool-123", + "name": "fake_agent", + "input": {"input": "hello"}, + } + + async for _ in tool.stream(tool_use, {}): + pass + + assert len(fake_agent.messages) >= 1 + assert fake_agent.state.get("key") == "value" + + +def test_preserve_context_false_rejects_session_manager(fake_agent): + """preserve_context=False should raise ValueError when agent has a session manager.""" + fake_agent._session_manager = MagicMock() + + with pytest.raises(ValueError, match="cannot be used with an agent that has a session manager"): + _AgentAsTool(fake_agent, name="t", description="d", preserve_context=False) + + +# --- interrupt propagation --- + + +@pytest.fixture +def interrupt_result(): + interrupt = Interrupt(id="interrupt-1", name="approval", reason="need approval") + return AgentResult( + stop_reason="interrupt", + message={"role": "assistant", "content": [{"text": "pending"}]}, + metrics=EventLoopMetrics(), + state={}, + interrupts=[interrupt], + ) + + +@pytest.mark.asyncio +async def test_stream_interrupt_yields_tool_interrupt_event(tool, mock_agent, tool_use, interrupt_result): + """When the sub-agent returns an interrupt result, _AgentAsTool should yield ToolInterruptEvent.""" + mock_agent.stream_async.return_value = _mock_stream_async(interrupt_result) + + events = [event async for event in tool.stream(tool_use, {})] + + assert len(events) == 1 + assert isinstance(events[0], ToolInterruptEvent) + assert events[0].interrupts == interrupt_result.interrupts + assert events[0].tool_use_id == "tool-123" + + +@pytest.mark.asyncio +async def test_stream_interrupt_no_tool_result_appended(tool, mock_agent, tool_use, interrupt_result): + """ToolInterruptEvent should not produce a ToolResultEvent.""" + mock_agent.stream_async.return_value = _mock_stream_async(interrupt_result) + + events = [event async for event in tool.stream(tool_use, {})] + + result_events = [e for e in events if isinstance(e, ToolResultEvent)] + assert result_events == [] + + +@pytest.mark.asyncio +async def test_stream_interrupt_forwards_intermediate_events(tool, mock_agent, tool_use, interrupt_result): + """Intermediate events should still be yielded before the interrupt.""" + intermediate = [{"data": "partial"}] + mock_agent.stream_async.return_value = _mock_stream_async(interrupt_result, intermediate) + + events = [event async for event in tool.stream(tool_use, {})] + + stream_events = [e for e in events if isinstance(e, AgentAsToolStreamEvent)] + interrupt_events = [e for e in events if isinstance(e, ToolInterruptEvent)] + assert len(stream_events) == 1 + assert len(interrupt_events) == 1 + + +@pytest.mark.asyncio +async def test_stream_interrupt_resume_forwards_responses(fake_agent): + """On resume, _AgentAsTool should forward interrupt responses to the sub-agent.""" + interrupt = Interrupt(id="interrupt-1", name="approval", reason="need approval", response="APPROVE") + + # Put the sub-agent in an activated interrupt state with the response already set + fake_agent._interrupt_state.interrupts["interrupt-1"] = interrupt + fake_agent._interrupt_state.activate() + + normal_result = AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "approved"}]}, + metrics=EventLoopMetrics(), + state={}, + ) + fake_agent.stream_async = MagicMock(return_value=_mock_stream_async(normal_result)) + + tool = _AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=True) + tool_use = {"toolUseId": "tool-123", "name": "fake_agent", "input": {"input": "do something"}} + + events = [event async for event in tool.stream(tool_use, {})] + + # Should have called stream_async with interrupt responses, not the original prompt + call_args = fake_agent.stream_async.call_args + agent_input = call_args[0][0] + assert isinstance(agent_input, list) + assert len(agent_input) == 1 + assert agent_input[0]["interruptResponse"]["interruptId"] == "interrupt-1" + assert agent_input[0]["interruptResponse"]["response"] == "APPROVE" + + # Should produce a normal result + result_events = [e for e in events if isinstance(e, ToolResultEvent)] + assert len(result_events) == 1 + assert result_events[0]["tool_result"]["status"] == "success" + + +@pytest.mark.asyncio +async def test_stream_interrupt_resume_skips_state_reset(fake_agent): + """When resuming from interrupt with preserve_context=False, state reset should be skipped.""" + fake_agent.messages = [{"role": "user", "content": [{"text": "initial"}]}] + fake_agent.state.set("key", "value") + + tool = _AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=False) + + # Simulate the sub-agent being in interrupt state after a previous invocation + interrupt = Interrupt(id="interrupt-1", name="approval", reason="need approval", response="APPROVE") + fake_agent._interrupt_state.interrupts["interrupt-1"] = interrupt + fake_agent._interrupt_state.activate() + + # Mutate messages to simulate sub-agent progress before interrupt + fake_agent.messages.append({"role": "assistant", "content": [{"text": "working on it"}]}) + + normal_result = AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "done"}]}, + metrics=EventLoopMetrics(), + state={}, + ) + fake_agent.stream_async = MagicMock(return_value=_mock_stream_async(normal_result)) + + tool_use = {"toolUseId": "tool-123", "name": "fake_agent", "input": {"input": "do something"}} + async for _ in tool.stream(tool_use, {}): + pass + + # Messages should NOT have been reset — the sub-agent needs its conversation history intact + assert len(fake_agent.messages) == 2 + + +@pytest.mark.asyncio +async def test_is_sub_agent_interrupted_false_by_default(tool): + """_is_sub_agent_interrupted returns False when no interrupts are active.""" + assert tool._is_sub_agent_interrupted() is False + + +@pytest.mark.asyncio +async def test_is_sub_agent_interrupted_true_when_activated(fake_agent): + """_is_sub_agent_interrupted returns True when the sub-agent's interrupt state is activated.""" + tool = _AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=True) + assert tool._is_sub_agent_interrupted() is False + + fake_agent._interrupt_state.activate() + assert tool._is_sub_agent_interrupted() is True + + +@pytest.mark.asyncio +async def test_build_interrupt_responses(fake_agent): + """_build_interrupt_responses packages sub-agent interrupts into response content blocks.""" + tool = _AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=True) + + interrupt_a = Interrupt(id="id-a", name="a", reason="r", response="yes") + interrupt_b = Interrupt(id="id-b", name="b", reason="r", response=None) + fake_agent._interrupt_state.interrupts = {"id-a": interrupt_a, "id-b": interrupt_b} + + responses = tool._build_interrupt_responses() + + # Only interrupt_a has a response + assert len(responses) == 1 + assert responses[0] == {"interruptResponse": {"interruptId": "id-a", "response": "yes"}} + + +# --- concurrency --- + + +@pytest.mark.asyncio +async def test_stream_rejects_concurrent_call(tool, mock_agent, tool_use, agent_result): + """A second concurrent call should get an error ToolResultEvent.""" + mock_agent.stream_async.return_value = _mock_stream_async(agent_result) + + # Simulate the lock already being held by another invocation + tool._lock.acquire() + try: + events = [event async for event in tool.stream(tool_use, {})] + + assert len(events) == 1 + assert isinstance(events[0], ToolResultEvent) + assert events[0]["tool_result"]["status"] == "error" + assert "already processing" in events[0]["tool_result"]["content"][0]["text"] + mock_agent.stream_async.assert_not_called() + finally: + tool._lock.release() + + +@pytest.mark.asyncio +async def test_stream_releases_lock_after_completion(tool, mock_agent, tool_use, agent_result): + """Lock should be released after stream completes, allowing subsequent calls.""" + mock_agent.stream_async.return_value = _mock_stream_async(agent_result) + + async for _ in tool.stream(tool_use, {}): + pass + + assert not tool._lock.locked() + + # A second call should succeed + mock_agent.stream_async.return_value = _mock_stream_async(agent_result) + events = [event async for event in tool.stream(tool_use, {})] + + result_events = [e for e in events if isinstance(e, ToolResultEvent)] + assert len(result_events) == 1 + assert result_events[0]["tool_result"]["status"] == "success" + + +@pytest.mark.asyncio +async def test_stream_releases_lock_after_error(tool, mock_agent, tool_use): + """Lock should be released even when the agent raises an exception.""" + mock_agent.stream_async.side_effect = RuntimeError("boom") + + async for _ in tool.stream(tool_use, {}): + pass + + assert not tool._lock.locked() + + +# --- Agent-as-tool sugar (passing agents directly in tools list) --- + + +def test_agent_passed_directly_in_tools_list(): + """Test that an Agent can be passed directly in another Agent's tools list.""" + from strands.agent.agent import Agent + + sub_agent = Agent(name="research_agent", description="Does research", callback_handler=None) + + # This should work without calling .as_tool() explicitly + parent_agent = Agent(name="orchestrator", tools=[sub_agent], callback_handler=None) + + assert "research_agent" in parent_agent.tool_names + + +def test_multiple_agents_passed_directly_in_tools_list(): + """Test that multiple Agents can be passed directly in another Agent's tools list.""" + from strands.agent.agent import Agent + + agent_a = Agent(name="agent_a", callback_handler=None) + agent_b = Agent(name="agent_b", callback_handler=None) + + parent = Agent(name="parent", tools=[agent_a, agent_b], callback_handler=None) + + assert "agent_a" in parent.tool_names + assert "agent_b" in parent.tool_names + + +def test_agent_mixed_with_regular_tools_in_tools_list(): + """Test that Agents can be mixed with regular tools in the tools list.""" + from strands import tool as tool_decorator + from strands.agent.agent import Agent + + @tool_decorator + def my_tool(x: str) -> str: + """A regular tool.""" + return x + + sub_agent = Agent(name="helper_agent", callback_handler=None) + + parent = Agent(name="parent", tools=[my_tool, sub_agent], callback_handler=None) + + assert "my_tool" in parent.tool_names + assert "helper_agent" in parent.tool_names diff --git a/tests/strands/agent/test_agent_cancellation.py b/tests/strands/agent/test_agent_cancellation.py new file mode 100644 index 000000000..756e96485 --- /dev/null +++ b/tests/strands/agent/test_agent_cancellation.py @@ -0,0 +1,290 @@ +"""Tests for agent cancellation functionality using agent.cancel() API.""" + +import asyncio +import threading +from unittest.mock import ANY + +import pytest + +from strands import Agent, tool +from strands.hooks import AfterModelCallEvent +from tests.fixtures.mocked_model_provider import MockedModelProvider + +# Default agent response for simple tests +DEFAULT_RESPONSE = { + "role": "assistant", + "content": [{"text": "Hello! How can I help you?"}], +} + + +@pytest.mark.asyncio +async def test_agent_cancel_before_invocation(): + """Test agent.cancel() before invocation starts. + + Verifies that calling cancel() before invoke_async() results in + immediate cancellation without any model calls. + """ + agent = Agent(model=MockedModelProvider([DEFAULT_RESPONSE])) + + # Cancel before invocation + agent.cancel() + + result = await agent.invoke_async("Hello") + + assert result.stop_reason == "cancelled" + assert result.message == {"role": "assistant", "content": [{"text": "Cancelled by user"}], "metadata": ANY} + + +@pytest.mark.asyncio +async def test_agent_cancel_during_execution(): + """Test agent.cancel() during execution. + + Verifies that calling cancel() while the agent is running + stops execution at the next checkpoint. + """ + streaming_started = asyncio.Event() + cancel_ready = asyncio.Event() + + class DelayedModelProvider(MockedModelProvider): + async def stream(self, *args, **kwargs): + streaming_started.set() + # Block until cancel has been called + await cancel_ready.wait() + async for event in super().stream(*args, **kwargs): + yield event + + agent = Agent(model=DelayedModelProvider([DEFAULT_RESPONSE])) + + async def cancel_when_ready(): + await streaming_started.wait() + agent.cancel() + cancel_ready.set() + + cancel_task = asyncio.create_task(cancel_when_ready()) + result = await agent.invoke_async("Hello") + await cancel_task + + assert result.stop_reason == "cancelled" + + +@pytest.mark.asyncio +async def test_agent_cancel_with_tools(): + """Test agent.cancel() during tool execution. + + Verifies that cancellation works correctly when tools are being executed. + Uses AfterModelCallEvent hook to cancel deterministically after model returns tool_use. + """ + tool_executed = [] + + @tool + def slow_tool(x: int) -> int: + """A tool for testing.""" + tool_executed.append(x) + return x * 2 + + tool_use_response = { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "tool_1", + "name": "slow_tool", + "input": {"x": 5}, + } + } + ], + } + + agent = Agent( + model=MockedModelProvider([tool_use_response, DEFAULT_RESPONSE]), + tools=[slow_tool], + ) + + # Cancel deterministically after model returns tool_use + async def cancel_after_model(event: AfterModelCallEvent): + if event.stop_response and event.stop_response.stop_reason == "tool_use": + agent.cancel() + + agent.add_hook(cancel_after_model, AfterModelCallEvent) + + result = await agent.invoke_async("Use the tool") + + assert result.stop_reason == "cancelled" + + +@pytest.mark.asyncio +async def test_agent_cancel_idempotent(): + """Test that calling cancel() multiple times is safe. + + Verifies that multiple cancel() calls are idempotent and don't + cause any issues. + """ + agent = Agent(model=MockedModelProvider([DEFAULT_RESPONSE])) + + # Cancel multiple times + agent.cancel() + agent.cancel() + agent.cancel() + + result = await agent.invoke_async("Hello") + + assert result.stop_reason == "cancelled" + + +@pytest.mark.asyncio +async def test_agent_cancel_from_thread(): + """Test agent.cancel() from another thread. + + Verifies thread-safety of the cancel() method when called + from a background thread. + """ + streaming_started = asyncio.Event() + cancel_ready = asyncio.Event() + loop = asyncio.get_running_loop() + + class DelayedModelProvider(MockedModelProvider): + async def stream(self, *args, **kwargs): + streaming_started.set() + await cancel_ready.wait() + async for event in super().stream(*args, **kwargs): + yield event + + agent = Agent(model=DelayedModelProvider([DEFAULT_RESPONSE])) + + def cancel_from_thread(): + # Wait for streaming to start before cancelling + asyncio.run_coroutine_threadsafe(streaming_started.wait(), loop).result() + agent.cancel() + loop.call_soon_threadsafe(cancel_ready.set) + + thread = threading.Thread(target=cancel_from_thread) + thread.start() + + result = await agent.invoke_async("Hello") + thread.join() + + assert result.stop_reason == "cancelled" + + +@pytest.mark.asyncio +async def test_agent_cancel_streaming(): + """Test cancellation during streaming response. + + Verifies that cancellation works correctly when using + the streaming API (stream_async). + """ + chunks_yielded = asyncio.Event() + cancel_done = asyncio.Event() + + class SlowStreamingModelProvider(MockedModelProvider): + async def stream(self, *args, **kwargs): + yield {"messageStart": {"role": "assistant"}} + yield {"contentBlockStart": {"start": {}}} + + for i in range(10): + yield {"contentBlockDelta": {"delta": {"text": f"chunk {i} "}}} + if i == 2: + # Signal after a few chunks so cancel can fire + chunks_yielded.set() + # Wait for cancel to complete before continuing + await cancel_done.wait() + + yield {"contentBlockStop": {}} + yield {"messageStop": {"stopReason": "end_turn"}} + + agent = Agent(model=SlowStreamingModelProvider([DEFAULT_RESPONSE])) + + async def cancel_after_chunks(): + await chunks_yielded.wait() + agent.cancel() + cancel_done.set() + + cancel_task = asyncio.create_task(cancel_after_chunks()) + + events = [] + async for event in agent.stream_async("Hello"): + events.append(event) + if event.get("result"): + break + + await cancel_task + + result_event = next((e for e in events if e.get("result")), None) + assert result_event is not None + assert result_event["result"].stop_reason == "cancelled" + + +@pytest.mark.asyncio +async def test_agent_cancel_before_tool_execution_adds_tool_results(): + """Test that cancelling before tool execution adds tool_result messages. + + Verifies that when cancellation occurs after model returns tool_use but before + tools execute, proper tool_result messages are added to maintain valid conversation state. + This prevents the "tool_use without tool_result" error on next invocation. + """ + + @tool + def calculator(x: int, y: int) -> int: + """Add two numbers.""" + return x + y + + tool_use_response = { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "tool_1", + "name": "calculator", + "input": {"x": 5, "y": 3}, + } + } + ], + } + + agent = Agent( + model=MockedModelProvider([tool_use_response, DEFAULT_RESPONSE]), + tools=[calculator], + ) + + async def cancel_after_model(event: AfterModelCallEvent): + if event.stop_response and event.stop_response.stop_reason == "tool_use": + agent.cancel() + + agent.add_hook(cancel_after_model, AfterModelCallEvent) + + result = await agent.invoke_async("Calculate 5 + 3") + + assert result.stop_reason == "cancelled" + + # Should have: user message, assistant message with tool_use, user message with tool_result + assert len(agent.messages) == 3 + assert agent.messages[0]["role"] == "user" + assert agent.messages[1]["role"] == "assistant" + assert agent.messages[2]["role"] == "user" + + tool_result_content = agent.messages[2]["content"] + assert len(tool_result_content) == 1 + assert "toolResult" in tool_result_content[0] + + tool_result = tool_result_content[0]["toolResult"] + assert tool_result["toolUseId"] == "tool_1" + assert tool_result["status"] == "error" + assert "cancelled" in tool_result["content"][0]["text"].lower() + + +@pytest.mark.asyncio +async def test_agent_cancel_continue_after(): + """Test that agent is reusable after cancellation. + + Verifies that the cancel signal is cleared after an invocation completes, + allowing subsequent invocations to run normally. + """ + agent = Agent(model=MockedModelProvider([DEFAULT_RESPONSE, DEFAULT_RESPONSE])) + + agent.cancel() + result1 = await agent.invoke_async("Hello") + assert result1.stop_reason == "cancelled" + + # Second invocation should work normally + result2 = await agent.invoke_async("Hello again") + assert result2.stop_reason == "end_turn" diff --git a/tests/strands/agent/test_agent_hooks.py b/tests/strands/agent/test_agent_hooks.py index 32266c3eb..bc2c376c2 100644 --- a/tests/strands/agent/test_agent_hooks.py +++ b/tests/strands/agent/test_agent_hooks.py @@ -1,4 +1,4 @@ -from unittest.mock import ANY, Mock +from unittest.mock import ANY, AsyncMock, MagicMock, Mock, patch import pytest from pydantic import BaseModel @@ -16,6 +16,7 @@ MessageAddedEvent, ) from strands.types.content import Messages +from strands.types.exceptions import ModelThrottledException from strands.types.tools import ToolResult, ToolUse from tests.fixtures.mock_hook_provider import MockHookProvider from tests.fixtures.mocked_model_provider import MockedModelProvider @@ -101,6 +102,12 @@ class User(BaseModel): return User(name="Jane Doe", age=30) +@pytest.fixture +def mock_sleep(): + with patch.object(strands.event_loop._retry.asyncio, "sleep", new_callable=AsyncMock) as mock: + yield mock + + def test_agent__init__hooks(): """Verify that the AgentInitializedEvent is emitted on Agent construction.""" hook_provider = MockHookProvider(event_types=[AgentInitializedEvent]) @@ -147,24 +154,26 @@ def test_agent_tool_call(agent, hook_provider, agent_tool): def test_agent__call__hooks(agent, hook_provider, agent_tool, mock_model, tool_use): """Verify that the correct hook events are emitted as part of __call__.""" - agent("test message") + result = agent("test message") length, events = hook_provider.get_events() assert length == 12 - assert next(events) == BeforeInvocationEvent(agent=agent) + assert next(events) == BeforeInvocationEvent(agent=agent, invocation_state=ANY, messages=agent.messages[0:1]) assert next(events) == MessageAddedEvent( agent=agent, message=agent.messages[0], ) - assert next(events) == BeforeModelCallEvent(agent=agent) + assert next(events) == BeforeModelCallEvent(agent=agent, invocation_state=ANY, projected_input_tokens=ANY) assert next(events) == AfterModelCallEvent( agent=agent, + invocation_state=ANY, stop_response=AfterModelCallEvent.ModelStopResponse( message={ "content": [{"toolUse": tool_use}], "role": "assistant", + "metadata": ANY, }, stop_reason="tool_use", ), @@ -186,18 +195,19 @@ def test_agent__call__hooks(agent, hook_provider, agent_tool, mock_model, tool_u result={"content": [{"text": "!loot a dekovni I"}], "status": "success", "toolUseId": "123"}, ) assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[2]) - assert next(events) == BeforeModelCallEvent(agent=agent) + assert next(events) == BeforeModelCallEvent(agent=agent, invocation_state=ANY, projected_input_tokens=ANY) assert next(events) == AfterModelCallEvent( agent=agent, + invocation_state=ANY, stop_response=AfterModelCallEvent.ModelStopResponse( - message=mock_model.agent_responses[1], + message={"role": "assistant", "content": [{"text": "I invoked a tool!"}], "metadata": ANY}, stop_reason="end_turn", ), exception=None, ) assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[3]) - assert next(events) == AfterInvocationEvent(agent=agent) + assert next(events) == AfterInvocationEvent(agent=agent, invocation_state=ANY, result=result) assert len(agent.messages) == 4 @@ -207,28 +217,37 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, mock_m """Verify that the correct hook events are emitted as part of stream_async.""" iterator = agent.stream_async("test message") await anext(iterator) - assert hook_provider.events_received == [BeforeInvocationEvent(agent=agent)] + + # Verify first event is BeforeInvocationEvent with invocation_state and messages + assert len(hook_provider.events_received) == 1 + assert hook_provider.events_received[0].invocation_state is not None + assert hook_provider.events_received[0].messages is not None + assert hook_provider.events_received[0].messages[0]["role"] == "user" # iterate the rest - async for _ in iterator: - pass + result = None + async for item in iterator: + if "result" in item: + result = item["result"] length, events = hook_provider.get_events() assert length == 12 - assert next(events) == BeforeInvocationEvent(agent=agent) + assert next(events) == BeforeInvocationEvent(agent=agent, invocation_state=ANY, messages=agent.messages[0:1]) assert next(events) == MessageAddedEvent( agent=agent, message=agent.messages[0], ) - assert next(events) == BeforeModelCallEvent(agent=agent) + assert next(events) == BeforeModelCallEvent(agent=agent, invocation_state=ANY, projected_input_tokens=ANY) assert next(events) == AfterModelCallEvent( agent=agent, + invocation_state=ANY, stop_response=AfterModelCallEvent.ModelStopResponse( message={ "content": [{"toolUse": tool_use}], "role": "assistant", + "metadata": ANY, }, stop_reason="tool_use", ), @@ -250,22 +269,24 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, mock_m result={"content": [{"text": "!loot a dekovni I"}], "status": "success", "toolUseId": "123"}, ) assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[2]) - assert next(events) == BeforeModelCallEvent(agent=agent) + assert next(events) == BeforeModelCallEvent(agent=agent, invocation_state=ANY, projected_input_tokens=ANY) assert next(events) == AfterModelCallEvent( agent=agent, + invocation_state=ANY, stop_response=AfterModelCallEvent.ModelStopResponse( - message=mock_model.agent_responses[1], + message={"role": "assistant", "content": [{"text": "I invoked a tool!"}], "metadata": ANY}, stop_reason="end_turn", ), exception=None, ) assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[3]) - assert next(events) == AfterInvocationEvent(agent=agent) + assert next(events) == AfterInvocationEvent(agent=agent, invocation_state=ANY, result=result) assert len(agent.messages) == 4 +@pytest.mark.filterwarnings("ignore:Agent.structured_output method is deprecated:DeprecationWarning") def test_agent_structured_output_hooks(agent, hook_provider, user, agenerator): """Verify that the correct hook events are emitted as part of structured_output.""" @@ -276,12 +297,13 @@ def test_agent_structured_output_hooks(agent, hook_provider, user, agenerator): assert length == 2 - assert next(events) == BeforeInvocationEvent(agent=agent) - assert next(events) == AfterInvocationEvent(agent=agent) + assert next(events) == BeforeInvocationEvent(agent=agent, invocation_state=ANY) + assert next(events) == AfterInvocationEvent(agent=agent, invocation_state=ANY) assert len(agent.messages) == 0 # no new messages added +@pytest.mark.filterwarnings("ignore:Agent.structured_output_async method is deprecated:DeprecationWarning") @pytest.mark.asyncio async def test_agent_structured_async_output_hooks(agent, hook_provider, user, agenerator): """Verify that the correct hook events are emitted as part of structured_output_async.""" @@ -293,7 +315,763 @@ async def test_agent_structured_async_output_hooks(agent, hook_provider, user, a assert length == 2 - assert next(events) == BeforeInvocationEvent(agent=agent) - assert next(events) == AfterInvocationEvent(agent=agent) + assert next(events) == BeforeInvocationEvent(agent=agent, invocation_state=ANY) + assert next(events) == AfterInvocationEvent(agent=agent, invocation_state=ANY) assert len(agent.messages) == 0 # no new messages added + + +@pytest.mark.asyncio +async def test_hook_retry_on_successful_call(): + """Test that hooks can retry even on successful model calls based on response content.""" + + mock_provider = MockedModelProvider( + [ + { + "role": "assistant", + "content": [{"text": "Short"}], + }, + { + "role": "assistant", + "content": [{"text": "This is a much longer and more detailed response"}], + }, + ] + ) + + # Hook that retries if response is too short + class MinLengthRetryHook: + def __init__(self, min_length=10): + self.min_length = min_length + self.call_count = 0 + + def register_hooks(self, registry): + registry.add_callback(strands.hooks.AfterModelCallEvent, self.handle_after_model_call) + + async def handle_after_model_call(self, event): + self.call_count += 1 + + # Check successful responses for minimum length + if event.stop_response: + message = event.stop_response.message + text_content = "".join(block.get("text", "") for block in message.get("content", [])) + + if len(text_content) < self.min_length: + event.retry = True + + retry_hook = MinLengthRetryHook(min_length=10) + agent = Agent(model=mock_provider, hooks=[retry_hook]) + + result = agent("Generate a response") + + # Verify hook was called twice (once for short response, once for long) + assert retry_hook.call_count == 2 + + # Verify final result is the longer response + assert result.message["content"][0]["text"] == "This is a much longer and more detailed response" + + +@pytest.mark.asyncio +async def test_hook_retry_on_exception_basic(alist, mock_sleep): + """Test that hooks can retry model calls on exceptions.""" + + class CustomException(Exception): + pass + + model = MagicMock() + model.stream.side_effect = [ + CustomException("First attempt fails"), + MockedModelProvider( + [ + { + "role": "assistant", + "content": [{"text": "Success after retry"}], + }, + ] + ).stream([]), + ] + + # Hook that enables retry on CustomException + class RetryHook: + def __init__(self): + self.after_model_call_count = 0 + + def register_hooks(self, registry): + registry.add_callback(strands.hooks.AfterModelCallEvent, self.handle_after_model_call) + + async def handle_after_model_call(self, event): + self.after_model_call_count += 1 + if event.exception and isinstance(event.exception, CustomException): + event.retry = True + + retry_hook = RetryHook() + agent = Agent(model=model, hooks=[retry_hook]) + + result = agent("Test retry") + + # Verify the hook was called twice (once for failure, once for success) + assert retry_hook.after_model_call_count == 2 + assert result.stop_reason == "end_turn" + assert result.message["content"][0]["text"] == "Success after retry" + + +@pytest.mark.asyncio +async def test_hook_retry_not_set_on_success(alist): + """Test that model is not retried when hook doesn't set retry_model on success.""" + mock_provider = MockedModelProvider( + [ + { + "role": "assistant", + "content": [{"text": "First successful response"}], + }, + ] + ) + + # Hook that tries to set retry_model=True even on success + class NoRetryHook: + def __init__(self): + self.call_count = 0 + + def register_hooks(self, registry): + registry.add_callback(strands.hooks.AfterModelCallEvent, self.handle_after_model_call) + + async def handle_after_model_call(self, event): + self.call_count += 1 + # Try to set retry even on success + # Don't set retry_model (leave it as False) + + retry_hook = NoRetryHook() + agent = Agent(model=mock_provider, hooks=[retry_hook]) + + result = agent("Test no retry when not set") + + # Should only be called once since retry_model was not set + assert retry_hook.call_count == 1 + assert result.message["content"][0]["text"] == "First successful response" + + +@pytest.mark.asyncio +async def test_hook_retry_with_limit(alist, mock_sleep): + """Test that hooks can control retry limits.""" + + class CustomException(Exception): + pass + + model = MagicMock() + model.stream.side_effect = [ + CustomException("Attempt 1 fails"), + CustomException("Attempt 2 fails"), + CustomException("Attempt 3 fails"), + ] + + # Hook that allows max 2 retries + class LimitedRetryHook: + def __init__(self, max_retries=2): + self.max_retries = max_retries + self.retry_count = 0 + self.call_count = 0 + + def register_hooks(self, registry): + registry.add_callback(strands.hooks.AfterModelCallEvent, self.handle_after_model_call) + + async def handle_after_model_call(self, event): + self.call_count += 1 + if event.exception and isinstance(event.exception, CustomException): + if self.retry_count < self.max_retries: + self.retry_count += 1 + event.retry = True + # else: let exception propagate + + retry_hook = LimitedRetryHook(max_retries=2) + agent = Agent(model=model, hooks=[retry_hook]) + + with pytest.raises(CustomException, match="Attempt 3 fails"): + await agent("Test limited retries") + + # Should be called 3 times: initial + 2 retries + assert retry_hook.call_count == 3 + assert retry_hook.retry_count == 2 + + +@pytest.mark.asyncio +async def test_hook_retry_multiple_hooks(alist, mock_sleep): + """Test that multiple hooks can modify retry_model and last one wins.""" + + class CustomException(Exception): + pass + + model = MagicMock() + model.stream.side_effect = [ + CustomException("First attempt fails"), + MockedModelProvider( + [ + { + "role": "assistant", + "content": [{"text": "Success"}], + }, + ] + ).stream([]), + ] + + async def retry_enabler(event: AfterModelCallEvent): + if event.exception: + event.retry = True + + async def another_retry_enabler(event: AfterModelCallEvent): + if event.exception: + event.retry = True + + agent = Agent(model=model) + agent.hooks.add_callback(AfterModelCallEvent, retry_enabler) + agent.hooks.add_callback(AfterModelCallEvent, another_retry_enabler) + + result = agent("Test multiple hooks") + + assert result.stop_reason == "end_turn" + assert result.message["content"][0]["text"] == "Success" + + +@pytest.mark.asyncio +async def test_hook_retry_last_hook_wins(alist, mock_sleep): + """Test that when multiple hooks set retry_model, the last-called hook wins. + + Note: AfterModelCallEvent callbacks are invoked in reverse order, so the first + registered hook is called last. + """ + + class CustomException(Exception): + pass + + call_count = [0] + + def mock_stream(*args, **kwargs): + call_count[0] += 1 + if call_count[0] == 1: + raise CustomException("First attempt fails") + else: + raise CustomException(f"Should not be called (call {call_count[0]})") + + model = MagicMock() + model.stream = mock_stream + + async def retry_enabler(event: AfterModelCallEvent): + """Called first due to reverse order.""" + if event.exception: + event.retry = True + + async def retry_disabler(event: AfterModelCallEvent): + """Called last, so it wins.""" + if event.exception: + event.retry = False + + agent = Agent(model=model) + agent.hooks.add_callback(AfterModelCallEvent, retry_disabler) # Registered first, called last + agent.hooks.add_callback(AfterModelCallEvent, retry_enabler) # Registered second, called first + + # Should raise exception since last-called hook disabled retry + with pytest.raises(CustomException, match="First attempt fails"): + agent("Test last hook wins") + + # Verify stream was only called once + assert call_count[0] == 1 + + +@pytest.mark.asyncio +async def test_hook_retry_with_throttle_exception(alist, mock_sleep): + """Test that hook retry works alongside existing throttle retry.""" + + class CustomException(Exception): + pass + + model = MagicMock() + model.stream.side_effect = [ + CustomException("Custom error"), + ModelThrottledException("ThrottlingException"), + ModelThrottledException("ThrottlingException"), + MockedModelProvider( + [ + { + "role": "assistant", + "content": [{"text": "Success after mixed retries"}], + }, + ] + ).stream([]), + ] + + async def handle_after_model_call(event: AfterModelCallEvent): + if event.exception and isinstance(event.exception, CustomException): + event.retry = True + + agent = Agent(model=model) + agent.hooks.add_callback(AfterModelCallEvent, handle_after_model_call) + + result = agent("Test mixed retries") + + # Should succeed after: custom retry + 2 throttle retries + assert result.stop_reason == "end_turn" + assert result.message["content"][0]["text"] == "Success after mixed retries" + + +def test_before_invocation_event_message_modification(): + """Test that hooks can modify messages in BeforeInvocationEvent for input guardrails.""" + mock_provider = MockedModelProvider( + [ + { + "role": "assistant", + "content": [{"text": "I received your redacted message"}], + }, + ] + ) + + modified_content = None + + async def input_guardrail_hook(event: BeforeInvocationEvent): + """Simulates a guardrail that redacts sensitive content.""" + nonlocal modified_content + if event.messages is not None: + for message in event.messages: + if message.get("role") == "user": + content = message.get("content", []) + for block in content: + if "text" in block and "SECRET" in block["text"]: + # Redact sensitive content in-place + block["text"] = block["text"].replace("SECRET", "[REDACTED]") + modified_content = event.messages[0]["content"][0]["text"] + + agent = Agent(model=mock_provider) + agent.hooks.add_callback(BeforeInvocationEvent, input_guardrail_hook) + + agent("My password is SECRET123") + + # Verify the message was modified before being processed + assert modified_content == "My password is [REDACTED]123" + # Verify the modified message was added to agent's conversation history + assert agent.messages[0]["content"][0]["text"] == "My password is [REDACTED]123" + + +def test_before_invocation_event_message_overwrite(): + """Test that hooks can overwrite messages in BeforeInvocationEvent.""" + mock_provider = MockedModelProvider( + [ + { + "role": "assistant", + "content": [{"text": "I received your message message"}], + }, + ] + ) + + async def overwrite_input_hook(event: BeforeInvocationEvent): + event.messages = [{"role": "user", "content": [{"text": "GOODBYE"}]}] + + agent = Agent(model=mock_provider) + agent.hooks.add_callback(BeforeInvocationEvent, overwrite_input_hook) + + agent("HELLO") + + # Verify the message was overwritten to agent's conversation history + assert agent.messages[0]["content"][0]["text"] == "GOODBYE" + + +@pytest.mark.filterwarnings("ignore:Agent.structured_output_async method is deprecated:DeprecationWarning") +@pytest.mark.asyncio +async def test_before_invocation_event_messages_none_in_structured_output(agenerator): + """Test that BeforeInvocationEvent.messages is None when called from deprecated structured_output.""" + + class Person(BaseModel): + name: str + age: int + + mock_provider = MockedModelProvider([]) + mock_provider.structured_output = Mock(return_value=agenerator([{"output": Person(name="Test", age=30)}])) + + received_messages = "not_set" + + async def capture_messages_hook(event: BeforeInvocationEvent): + nonlocal received_messages + received_messages = event.messages + + agent = Agent(model=mock_provider) + agent.hooks.add_callback(BeforeInvocationEvent, capture_messages_hook) + + await agent.structured_output_async(Person, "Test prompt") + + # structured_output_async uses deprecated path that doesn't pass messages + assert received_messages is None + + +def test_after_invocation_resume_triggers_new_invocation(): + """Test that setting resume on AfterInvocationEvent re-invokes the agent.""" + mock_provider = MockedModelProvider( + [ + {"role": "assistant", "content": [{"text": "First response"}]}, + {"role": "assistant", "content": [{"text": "Second response"}]}, + ] + ) + + resume_count = 0 + + async def resume_once(event: AfterInvocationEvent): + nonlocal resume_count + if resume_count == 0: + resume_count += 1 + event.resume = "continue" + + agent = Agent(model=mock_provider) + agent.hooks.add_callback(AfterInvocationEvent, resume_once) + + result = agent("start") + + # Agent should have been invoked twice + assert resume_count == 1 + assert result.message["content"][0]["text"] == "Second response" + # 4 messages: user1, assistant1, user2 (resume), assistant2 + assert len(agent.messages) == 4 + assert agent.messages[0]["content"][0]["text"] == "start" + assert agent.messages[2]["content"][0]["text"] == "continue" + + +def test_after_invocation_resume_none_does_not_loop(): + """Test that resume=None (default) does not re-invoke the agent.""" + mock_provider = MockedModelProvider( + [ + {"role": "assistant", "content": [{"text": "Only response"}]}, + ] + ) + + call_count = 0 + + async def no_resume(event: AfterInvocationEvent): + nonlocal call_count + call_count += 1 + # Don't set resume - should remain None + + agent = Agent(model=mock_provider) + agent.hooks.add_callback(AfterInvocationEvent, no_resume) + + result = agent("hello") + + assert call_count == 1 + assert result.message["content"][0]["text"] == "Only response" + + +def test_after_invocation_resume_fires_before_invocation_event(): + """Test that resume triggers BeforeInvocationEvent on each iteration.""" + mock_provider = MockedModelProvider( + [ + {"role": "assistant", "content": [{"text": "First"}]}, + {"role": "assistant", "content": [{"text": "Second"}]}, + ] + ) + + before_invocation_count = 0 + after_invocation_count = 0 + + async def count_before(event: BeforeInvocationEvent): + nonlocal before_invocation_count + before_invocation_count += 1 + + async def resume_once(event: AfterInvocationEvent): + nonlocal after_invocation_count + after_invocation_count += 1 + if after_invocation_count == 1: + event.resume = "next" + + agent = Agent(model=mock_provider) + agent.hooks.add_callback(BeforeInvocationEvent, count_before) + agent.hooks.add_callback(AfterInvocationEvent, resume_once) + + agent("start") + + # BeforeInvocationEvent should fire for both the initial and resumed invocation + assert before_invocation_count == 2 + assert after_invocation_count == 2 + + +def test_after_invocation_resume_multiple_times(): + """Test that resume can chain multiple re-invocations.""" + mock_provider = MockedModelProvider( + [ + {"role": "assistant", "content": [{"text": "Response 1"}]}, + {"role": "assistant", "content": [{"text": "Response 2"}]}, + {"role": "assistant", "content": [{"text": "Response 3"}]}, + ] + ) + + resume_count = 0 + + async def resume_twice(event: AfterInvocationEvent): + nonlocal resume_count + if resume_count < 2: + resume_count += 1 + event.resume = f"iteration {resume_count + 1}" + + agent = Agent(model=mock_provider) + agent.hooks.add_callback(AfterInvocationEvent, resume_twice) + + result = agent("iteration 1") + + assert resume_count == 2 + assert result.message["content"][0]["text"] == "Response 3" + # 6 messages: 3 user + 3 assistant + assert len(agent.messages) == 6 + + +def test_after_invocation_resume_handles_interrupt_with_responses(): + """Test that a hook can handle an interrupt by resuming with interrupt responses.""" + + @strands.tools.tool(name="interruptable_tool") + def interruptable_tool(value: str) -> str: + return value + + tool_use_id = "tool-1" + mock_provider = MockedModelProvider( + [ + # First invocation: model calls the tool, which will be interrupted + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": tool_use_id, + "name": "interruptable_tool", + "input": {"value": "test"}, + } + } + ], + }, + # Second invocation (after interrupt resume): model gives final response + {"role": "assistant", "content": [{"text": "Completed after interrupt"}]}, + ] + ) + + def interrupt_tool(event: BeforeToolCallEvent): + """Interrupt before tool execution; returns stored response on second call.""" + if event.tool_use["name"] == "interruptable_tool": + event.interrupt("approval_needed", reason="Need human approval") + + async def handle_interrupt_via_resume(event: AfterInvocationEvent): + """Hook that automatically handles interrupts by resuming with responses.""" + if event.result and event.result.stop_reason == "interrupt": + responses = [] + for interrupt in event.result.interrupts: + responses.append({"interruptResponse": {"interruptId": interrupt.id, "response": "approved"}}) + event.resume = responses + + agent = Agent(model=mock_provider, tools=[interruptable_tool], callback_handler=None) + agent.hooks.add_callback(BeforeToolCallEvent, interrupt_tool) + agent.hooks.add_callback(AfterInvocationEvent, handle_interrupt_via_resume) + + result = agent("do something") + + # The hook handled the interrupt automatically — agent completed normally + assert result.stop_reason == "end_turn" + assert result.message["content"][0]["text"] == "Completed after interrupt" + # Interrupt state should be cleared after successful resume + assert agent._interrupt_state.activated is False + + +def test_after_invocation_resume_with_invalid_input_during_interrupt(): + """Test that resuming with non-interrupt input while interrupt is active raises TypeError.""" + + @strands.tools.tool(name="interruptable_tool") + def interruptable_tool(value: str) -> str: + return value + + tool_use_id = "tool-1" + mock_provider = MockedModelProvider( + [ + # First invocation: model calls the tool, which will be interrupted + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": tool_use_id, + "name": "interruptable_tool", + "input": {"value": "test"}, + } + } + ], + }, + ] + ) + + def interrupt_tool(event: BeforeToolCallEvent): + if event.tool_use["name"] == "interruptable_tool": + event.interrupt("approval_needed", reason="Need approval") + + async def resume_with_bad_input(event: AfterInvocationEvent): + """Hook that incorrectly tries to resume with a plain string during interrupt.""" + if event.result and event.result.stop_reason == "interrupt": + event.resume = "this is wrong" + + agent = Agent(model=mock_provider, tools=[interruptable_tool], callback_handler=None) + agent.hooks.add_callback(BeforeToolCallEvent, interrupt_tool) + agent.hooks.add_callback(AfterInvocationEvent, resume_with_bad_input) + + with pytest.raises(TypeError, match="must resume from interrupt with list of interruptResponse's"): + agent("do something") + + +def test_after_invocation_resume_interrupt_without_resume_returns_to_caller(): + """Test that an interrupt without resume set returns the interrupt to the caller.""" + + @strands.tools.tool(name="interruptable_tool") + def interruptable_tool(value: str) -> str: + return value + + tool_use_id = "tool-1" + mock_provider = MockedModelProvider( + [ + # First invocation: model calls the tool, which will be interrupted + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": tool_use_id, + "name": "interruptable_tool", + "input": {"value": "test"}, + } + } + ], + }, + # Second invocation (caller resumes manually): final response + {"role": "assistant", "content": [{"text": "Done after manual resume"}]}, + ] + ) + + def interrupt_tool(event: BeforeToolCallEvent): + if event.tool_use["name"] == "interruptable_tool": + event.interrupt("approval_needed", reason="Need approval") + + agent = Agent(model=mock_provider, tools=[interruptable_tool], callback_handler=None) + agent.hooks.add_callback(BeforeToolCallEvent, interrupt_tool) + + # First call: hits interrupt, no hook handles it, returns to caller + result = agent("do something") + assert result.stop_reason == "interrupt" + assert len(result.interrupts) == 1 + assert result.interrupts[0].name == "approval_needed" + assert agent._interrupt_state.activated is True + + # Caller manually resumes with interrupt responses + interrupt_id = result.interrupts[0].id + result = agent([{"interruptResponse": {"interruptId": interrupt_id, "response": "yes"}}]) + assert result.stop_reason == "end_turn" + assert result.message["content"][0]["text"] == "Done after manual resume" + assert agent._interrupt_state.activated is False + + +def test_after_invocation_resume_interrupt_during_resumed_invocation(): + """Test that an interrupt during a resumed invocation can be handled by the hook.""" + + @strands.tools.tool(name="interruptable_tool") + def interruptable_tool(value: str) -> str: + return value + + tool_use_id = "tool-1" + mock_provider = MockedModelProvider( + [ + # First invocation: simple text response (no tool call) + {"role": "assistant", "content": [{"text": "First response"}]}, + # Second invocation (resumed): triggers a tool call which will be interrupted + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": tool_use_id, + "name": "interruptable_tool", + "input": {"value": "test"}, + } + } + ], + }, + # Third invocation (after interrupt handled via resume): final response + {"role": "assistant", "content": [{"text": "Final response"}]}, + ] + ) + + invocation_count = 0 + + async def resume_hook(event: AfterInvocationEvent): + """Resume with new input on first call, handle interrupt on second.""" + nonlocal invocation_count + invocation_count += 1 + if invocation_count == 1: + # First invocation done, resume with new input + event.resume = "continue" + elif event.result and event.result.stop_reason == "interrupt": + # Second invocation hit interrupt, handle it + responses = [] + for interrupt in event.result.interrupts: + responses.append({"interruptResponse": {"interruptId": interrupt.id, "response": "approved"}}) + event.resume = responses + + def interrupt_tool(event: BeforeToolCallEvent): + if event.tool_use["name"] == "interruptable_tool": + event.interrupt("approval_needed", reason="Need approval") + + agent = Agent(model=mock_provider, tools=[interruptable_tool], callback_handler=None) + agent.hooks.add_callback(AfterInvocationEvent, resume_hook) + agent.hooks.add_callback(BeforeToolCallEvent, interrupt_tool) + + result = agent("start") + + # All three invocations happened within a single agent call + assert invocation_count == 3 + assert result.stop_reason == "end_turn" + assert result.message["content"][0]["text"] == "Final response" + assert agent._interrupt_state.activated is False + + +def test_hooks_param_accepts_callable(): + """Verify that a plain callable can be passed via hooks parameter.""" + events_received = [] + + def my_callback(event: AgentInitializedEvent) -> None: + events_received.append(event) + + agent = Agent(hooks=[my_callback], callback_handler=None) + + assert len(events_received) == 1 + assert isinstance(events_received[0], AgentInitializedEvent) + assert events_received[0].agent is agent + + +def test_hooks_param_accepts_mixed_list(): + """Verify that a mix of HookProviders and callables can be passed.""" + callback_events = [] + + def my_callback(event: AgentInitializedEvent) -> None: + callback_events.append(event) + + provider = MockHookProvider(event_types=[AgentInitializedEvent]) + + agent = Agent(hooks=[provider, my_callback], callback_handler=None) + + assert len(callback_events) == 1 + assert callback_events[0].agent is agent + length, _ = provider.get_events() + assert length == 1 + + +def test_hooks_param_invalid_hook_raises_error(): + """Verify that passing an invalid hook raises ValueError.""" + with pytest.raises(ValueError, match="Invalid hook"): + Agent(hooks=["not_a_hook"], callback_handler=None) # type: ignore + + +def test_hooks_param_callable_invoked_during_lifecycle(): + """Verify callable hooks fire during agent lifecycle.""" + before_events = [] + + def on_before(event: BeforeInvocationEvent) -> None: + before_events.append(event) + + mock_model = MockedModelProvider([{"role": "assistant", "content": [{"text": "Hello!"}]}]) + agent = Agent(model=mock_model, hooks=[on_before], callback_handler=None) + agent("test") + + assert len(before_events) == 1 + assert isinstance(before_events[0], BeforeInvocationEvent) diff --git a/tests/strands/agent/test_agent_model_state.py b/tests/strands/agent/test_agent_model_state.py new file mode 100644 index 000000000..7e751d334 --- /dev/null +++ b/tests/strands/agent/test_agent_model_state.py @@ -0,0 +1,69 @@ +"""Tests for agent model state with server-side conversation management.""" + +import unittest.mock + +import pytest + +from strands.agent.agent import Agent +from strands.agent.conversation_manager import NullConversationManager, SlidingWindowConversationManager + + +@pytest.fixture +def mock_model(): + """Create a mock model that writes response_id to model_state.""" + model = unittest.mock.MagicMock() + model.config = {"model_id": "test-model"} + model.get_config.return_value = {"model_id": "test-model"} + + call_count = 0 + + async def mock_stream(messages, tool_specs=None, system_prompt=None, **kwargs): + nonlocal call_count + call_count += 1 + resp_id = "resp_abc123" if call_count == 1 else "resp_def456" + + model_state = kwargs.get("model_state") + if model_state is not None: + model_state["response_id"] = resp_id + + yield {"messageStart": {"role": "assistant"}} + yield {"contentBlockStart": {"start": {}}} + yield {"contentBlockDelta": {"delta": {"text": "Hello"}}} + yield {"contentBlockStop": {}} + yield {"messageStop": {"stopReason": "end_turn"}} + yield { + "metadata": { + "usage": {"inputTokens": 10, "outputTokens": 5, "totalTokens": 15}, + "metrics": {"latencyMs": 100}, + } + } + + model.stream = unittest.mock.MagicMock(side_effect=mock_stream) + model.stateful = True + return model + + +def test_agent_model_state(mock_model): + """Verify model_state is populated, messages are cleared, and model_state is passed on subsequent calls.""" + agent = Agent(model=mock_model, callback_handler=None) + assert isinstance(agent.conversation_manager, NullConversationManager) + + agent("Turn 1") + assert agent._model_state.get("response_id") == "resp_abc123" + assert len(agent.messages) == 0 + + agent("Turn 2") + assert agent._model_state.get("response_id") == "resp_def456" + assert len(agent.messages) == 0 + + second_call_kwargs = mock_model.stream.call_args_list[1][1] + assert second_call_kwargs.get("model_state") is agent._model_state + + +def test_agent_model_state_raises_with_conversation_manager(): + """Passing a conversation_manager with a stateful model raises ValueError.""" + model = unittest.mock.MagicMock() + model.stateful = True + + with pytest.raises(ValueError, match="conversation_manager cannot be used with a stateful model"): + Agent(model=model, conversation_manager=SlidingWindowConversationManager()) diff --git a/tests/strands/agent/test_agent_result.py b/tests/strands/agent/test_agent_result.py index 5d1f02089..7cb106182 100644 --- a/tests/strands/agent/test_agent_result.py +++ b/tests/strands/agent/test_agent_result.py @@ -1,10 +1,11 @@ import unittest.mock -from typing import Optional, cast +from typing import cast import pytest from pydantic import BaseModel from strands.agent.agent_result import AgentResult +from strands.interrupt import Interrupt from strands.telemetry.metrics import EventLoopMetrics from strands.types.content import Message from strands.types.streaming import StopReason @@ -150,7 +151,7 @@ class StructuredOutputModel(BaseModel): name: str value: int - optional_field: Optional[str] = None + optional_field: str | None = None def test__init__with_structured_output(mock_metrics, simple_message: Message): @@ -185,7 +186,7 @@ def test__init__structured_output_defaults_to_none(mock_metrics, simple_message: def test__str__with_structured_output(mock_metrics, simple_message: Message): - """Test that str() is not affected by structured_output.""" + """Test that str() returns structured output JSON when structured_output is present.""" structured_output = StructuredOutputModel(name="test", value=42) result = AgentResult( @@ -196,11 +197,11 @@ def test__str__with_structured_output(mock_metrics, simple_message: Message): structured_output=structured_output, ) - # The string representation should only include the message text, not structured output + # When structured_output is present, it takes priority over message text message_string = str(result) - assert message_string == "Hello world!\n" - assert "test" not in message_string - assert "42" not in message_string + assert message_string == structured_output.model_dump_json() + assert "test" in message_string + assert "42" in message_string def test__str__empty_message_with_structured_output(mock_metrics, empty_message: Message): @@ -225,3 +226,175 @@ def test__str__empty_message_with_structured_output(mock_metrics, empty_message: assert "example" in message_string assert "123" in message_string assert "optional" in message_string + + +@pytest.fixture +def citations_message(): + """Message with citationsContent block.""" + return { + "role": "assistant", + "content": [ + { + "citationsContent": { + "citations": [ + { + "title": "Source Document", + "location": {"document": {"pageNumber": 1}}, + "sourceContent": [{"text": "source text"}], + } + ], + "content": [{"text": "This is cited text from the document."}], + } + } + ], + } + + +@pytest.fixture +def mixed_text_and_citations_message(): + """Message with both plain text and citationsContent blocks.""" + return { + "role": "assistant", + "content": [ + {"text": "Introduction paragraph"}, + { + "citationsContent": { + "citations": [{"title": "Doc", "location": {}, "sourceContent": []}], + "content": [{"text": "Cited content here."}], + } + }, + {"text": "Conclusion paragraph"}, + ], + } + + +def test__str__with_citations_content(mock_metrics, citations_message: Message): + """Test that str() extracts text from citationsContent blocks.""" + result = AgentResult(stop_reason="end_turn", message=citations_message, metrics=mock_metrics, state={}) + + message_string = str(result) + assert message_string == "This is cited text from the document.\n" + + +def test__str__mixed_text_and_citations_content(mock_metrics, mixed_text_and_citations_message: Message): + """Test that str() works with both plain text and citationsContent blocks.""" + result = AgentResult( + stop_reason="end_turn", message=mixed_text_and_citations_message, metrics=mock_metrics, state={} + ) + + message_string = str(result) + assert message_string == "Introduction paragraph\nCited content here.\nConclusion paragraph\n" + + +def test__str__with_interrupts(mock_metrics, simple_message: Message): + """Test that str() returns stringified interrupts when present.""" + interrupts = [ + Interrupt(id="int-1", name="approval", reason="Need user approval"), + Interrupt(id="int-2", name="input", reason="Need more info"), + ] + + result = AgentResult( + stop_reason="end_turn", + message=simple_message, + metrics=mock_metrics, + state={}, + interrupts=interrupts, + ) + + message_string = str(result) + + # Should contain stringified interrupt dicts + assert "int-1" in message_string + assert "approval" in message_string + assert "Need user approval" in message_string + assert "int-2" in message_string + assert "input" in message_string + assert "Need more info" in message_string + + +def test__str__interrupts_priority_over_structured_output(mock_metrics, simple_message: Message): + """Test that interrupts take priority over structured_output in str().""" + interrupts = [Interrupt(id="int-1", name="approval", reason="Needs approval")] + structured_output = StructuredOutputModel(name="test", value=42) + + result = AgentResult( + stop_reason="end_turn", + message=simple_message, + metrics=mock_metrics, + state={}, + interrupts=interrupts, + structured_output=structured_output, + ) + + message_string = str(result) + + # Should return interrupts, not structured output + assert "int-1" in message_string + assert "approval" in message_string + # Should NOT contain structured output + assert "test" not in message_string or "approval" in message_string # "test" might appear but not from structured + assert '"value": 42' not in message_string + + +def test__str__interrupts_priority_over_text_content(mock_metrics, simple_message: Message): + """Test that interrupts take priority over message text content in str().""" + interrupts = [Interrupt(id="int-1", name="confirm", reason="Please confirm")] + + result = AgentResult( + stop_reason="end_turn", + message=simple_message, + metrics=mock_metrics, + state={}, + interrupts=interrupts, + ) + + message_string = str(result) + + # Should return interrupts, not message text + assert "int-1" in message_string + assert "confirm" in message_string + assert "Hello world!" not in message_string + + +def test__str__empty_interrupts_returns_agent_message(mock_metrics, simple_message: Message): + """Test that empty interrupts list falls through to other content.""" + result = AgentResult( + stop_reason="end_turn", + message=simple_message, + metrics=mock_metrics, + state={}, + interrupts=[], + ) + + message_string = str(result) + + # Empty list is falsy, should fall through to text content + assert message_string == "Hello world!\n" + + +def test_context_size_delegates_to_metrics(mock_metrics, simple_message: Message): + """Test that context_size delegates to metrics.latest_context_size.""" + mock_metrics.latest_context_size = 12345 + result = AgentResult(stop_reason="end_turn", message=simple_message, metrics=mock_metrics, state={}) + assert result.context_size == 12345 + + +def test_context_size_none_when_no_data(mock_metrics, simple_message: Message): + """Test that context_size returns None when metrics has no data.""" + mock_metrics.latest_context_size = None + result = AgentResult(stop_reason="end_turn", message=simple_message, metrics=mock_metrics, state={}) + assert result.context_size is None + + +def test_projected_context_size_delegates_to_metrics(mock_metrics, simple_message: Message): + """Test that projected_context_size delegates to metrics.projected_context_size.""" + mock_metrics.projected_context_size = 15000 + result = AgentResult(stop_reason="end_turn", message=simple_message, metrics=mock_metrics, state={}) + assert result.projected_context_size == 15000 + + +def test_projected_context_size_none_when_no_data(mock_metrics, simple_message: Message): + """Test that projected_context_size returns None when metrics has no data.""" + mock_metrics.projected_context_size = None + result = AgentResult(stop_reason="end_turn", message=simple_message, metrics=mock_metrics, state={}) + assert result.projected_context_size is None diff --git a/tests/strands/agent/test_agent_retry.py b/tests/strands/agent/test_agent_retry.py new file mode 100644 index 000000000..15757865a --- /dev/null +++ b/tests/strands/agent/test_agent_retry.py @@ -0,0 +1,189 @@ +"""Integration tests for Agent retry_strategy parameter.""" + +from unittest.mock import Mock + +import pytest + +from strands import Agent, ModelRetryStrategy +from strands.event_loop.event_loop import INITIAL_DELAY, MAX_ATTEMPTS, MAX_DELAY +from strands.hooks import AfterModelCallEvent +from strands.types.exceptions import ModelThrottledException +from tests.fixtures.mocked_model_provider import MockedModelProvider + +# Agent Retry Strategy Initialization Tests + + +def test_agent_with_default_retry_strategy(): + """Test that Agent uses ModelRetryStrategy by default when retry_strategy is not provided.""" + agent = Agent() + + # Should have a retry_strategy + assert agent._retry_strategy is not None + + # Should be ModelRetryStrategy with default parameters + assert isinstance(agent._retry_strategy, ModelRetryStrategy) + assert agent._retry_strategy._max_attempts == 6 + assert agent._retry_strategy._initial_delay == 4 + assert agent._retry_strategy._max_delay == 240 + + +def test_agent_with_retry_strategy_none_disables_retries(): + """Test that Agent disables retries when retry_strategy=None is explicitly passed.""" + agent = Agent(retry_strategy=None) + + # Should have a retry_strategy with max_attempts=1 (no retries) + assert agent._retry_strategy is not None + assert isinstance(agent._retry_strategy, ModelRetryStrategy) + assert agent._retry_strategy._max_attempts == 1 + + +def test_agent_with_custom_model_retry_strategy(): + """Test Agent initialization with custom ModelRetryStrategy parameters.""" + custom_strategy = ModelRetryStrategy(max_attempts=3, initial_delay=2, max_delay=60) + agent = Agent(retry_strategy=custom_strategy) + + assert agent._retry_strategy is custom_strategy + assert agent._retry_strategy._max_attempts == 3 + assert agent._retry_strategy._initial_delay == 2 + assert agent._retry_strategy._max_delay == 60 + + +def test_agent_rejects_invalid_retry_strategy_type(): + """Test that Agent raises ValueError for non-ModelRetryStrategy retry_strategy.""" + + class FakeRetryStrategy: + pass + + with pytest.raises(ValueError, match="retry_strategy must be an instance of ModelRetryStrategy"): + Agent(retry_strategy=FakeRetryStrategy()) + + +def test_agent_rejects_subclass_of_model_retry_strategy(): + """Test that Agent rejects subclasses of ModelRetryStrategy (strict type check).""" + + class CustomRetryStrategy(ModelRetryStrategy): + pass + + with pytest.raises(ValueError, match="retry_strategy must be an instance of ModelRetryStrategy"): + Agent(retry_strategy=CustomRetryStrategy()) + + +def test_agent_default_retry_strategy_uses_event_loop_constants(): + """Test that default retry strategy uses constants from event_loop module.""" + agent = Agent() + + assert agent._retry_strategy._max_attempts == MAX_ATTEMPTS + assert agent._retry_strategy._initial_delay == INITIAL_DELAY + assert agent._retry_strategy._max_delay == MAX_DELAY + + +def test_retry_strategy_registered_as_hook(): + """Test that retry_strategy is registered with the hook system.""" + custom_strategy = ModelRetryStrategy(max_attempts=3) + agent = Agent(retry_strategy=custom_strategy) + + # Verify retry strategy callback is registered + callbacks = list(agent.hooks.get_callbacks_for(AfterModelCallEvent(agent=agent, exception=None))) + + # Should have at least one callback (from retry strategy) + assert len(callbacks) > 0 + + # Verify one of the callbacks is from the retry strategy + assert any( + callback.__self__ is custom_strategy if hasattr(callback, "__self__") else False for callback in callbacks + ) + + +# Agent Retry Behavior Tests + + +@pytest.mark.asyncio +async def test_agent_retries_with_default_strategy(mock_sleep): + """Test that Agent retries on throttling with default ModelRetryStrategy.""" + # Create a model that fails twice with throttling, then succeeds + model = Mock() + model.stream.side_effect = [ + ModelThrottledException("ThrottlingException"), + ModelThrottledException("ThrottlingException"), + MockedModelProvider([{"role": "assistant", "content": [{"text": "Success after retries"}]}]).stream([]), + ] + + agent = Agent(model=model) + + result = agent.stream_async("test prompt") + events = [event async for event in result] + + # Should have succeeded after retries - just check we got events + assert len(events) > 0 + + # Should have slept twice (for two retries) + assert len(mock_sleep.sleep_calls) == 2 + # First retry: 4 seconds + assert mock_sleep.sleep_calls[0] == 4 + # Second retry: 8 seconds (exponential backoff) + assert mock_sleep.sleep_calls[1] == 8 + + +@pytest.mark.asyncio +async def test_agent_respects_max_attempts(mock_sleep): + """Test that Agent respects max_attempts in retry strategy.""" + # Create a model that always fails + model = Mock() + model.stream.side_effect = ModelThrottledException("ThrottlingException") + + # Use custom strategy with max 2 attempts + custom_strategy = ModelRetryStrategy(max_attempts=2, initial_delay=1, max_delay=60) + agent = Agent(model=model, retry_strategy=custom_strategy) + + with pytest.raises(ModelThrottledException): + result = agent.stream_async("test prompt") + _ = [event async for event in result] + + # Should have attempted max_attempts times, which means (max_attempts - 1) sleeps + # Attempt 0: fail, sleep + # Attempt 1: fail, no more attempts + assert len(mock_sleep.sleep_calls) == 1 + + +# Backwards Compatibility Tests + + +@pytest.mark.asyncio +async def test_event_loop_throttle_event_emitted(mock_sleep): + """Test that EventLoopThrottleEvent is still emitted for backwards compatibility.""" + # Create a model that fails once with throttling, then succeeds + model = Mock() + model.stream.side_effect = [ + ModelThrottledException("ThrottlingException"), + MockedModelProvider([{"role": "assistant", "content": [{"text": "Success"}]}]).stream([]), + ] + + agent = Agent(model=model) + + result = agent.stream_async("test prompt") + events = [event async for event in result] + + # Should have EventLoopThrottleEvent in the stream + throttle_events = [e for e in events if "event_loop_throttled_delay" in e] + assert len(throttle_events) > 0 + + # Should have the correct delay value + assert throttle_events[0]["event_loop_throttled_delay"] > 0 + + +@pytest.mark.asyncio +async def test_agent_no_retry_when_retry_strategy_none(mock_sleep): + """Test that Agent does not retry when retry_strategy=None.""" + # Create a model that fails with throttling + model = Mock() + model.stream.side_effect = ModelThrottledException("ThrottlingException") + + # Explicitly disable retries + agent = Agent(model=model, retry_strategy=None) + + with pytest.raises(ModelThrottledException): + result = agent.stream_async("test prompt") + _ = [event async for event in result] + + # Should not have slept at all (no retries) + assert len(mock_sleep.sleep_calls) == 0 diff --git a/tests/strands/agent/test_agent_structured_output.py b/tests/strands/agent/test_agent_structured_output.py index b679faed0..6ab112048 100644 --- a/tests/strands/agent/test_agent_structured_output.py +++ b/tests/strands/agent/test_agent_structured_output.py @@ -1,6 +1,5 @@ """Tests for Agent structured output functionality.""" -from typing import Optional from unittest import mock from unittest.mock import Mock, patch @@ -28,7 +27,7 @@ class ProductModel(BaseModel): title: str price: float - description: Optional[str] = None + description: str | None = None @pytest.fixture @@ -412,3 +411,160 @@ async def mock_product_cycle(*args, **kwargs): mock_event_loop.side_effect = mock_product_cycle result2 = agent("Get product", structured_output_model=product_model) assert result2.structured_output is pm + + +class TestAgentStructuredOutputPrompt: + """Test Agent structured_output_prompt functionality.""" + + def test_agent_init_with_structured_output_prompt(self, user_model): + """Test that Agent can be initialized with a structured_output_prompt.""" + custom_prompt = "Please format your response using the schema." + agent = Agent(structured_output_model=user_model, structured_output_prompt=custom_prompt) + + assert agent._structured_output_prompt == custom_prompt + + def test_agent_init_without_structured_output_prompt(self): + """Test that Agent can be initialized without structured_output_prompt.""" + agent = Agent() + + assert agent._structured_output_prompt is None + + @patch("strands.agent.agent.event_loop_cycle") + def test_agent_call_with_default_structured_output_prompt( + self, mock_event_loop, user_model, mock_model, mock_metrics + ): + """Test Agent.__call__ uses default structured_output_prompt when not specified.""" + custom_prompt = "Use the output schema to format your response." + + async def mock_cycle(*args, **kwargs): + structured_output_context = kwargs.get("structured_output_context") + assert structured_output_context is not None + assert structured_output_context.structured_output_prompt == custom_prompt + + yield EventLoopStopEvent( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "Response"}]}, + metrics=mock_metrics, + request_state={}, + ) + + mock_event_loop.side_effect = mock_cycle + + # Create agent with default structured_output_prompt + agent = Agent( + model=mock_model, + structured_output_model=user_model, + structured_output_prompt=custom_prompt, + ) + agent("Get user info") + + mock_event_loop.assert_called_once() + + @patch("strands.agent.agent.event_loop_cycle") + def test_agent_call_override_default_structured_output_prompt( + self, mock_event_loop, user_model, mock_model, mock_metrics + ): + """Test that invocation-level structured_output_prompt overrides default.""" + default_prompt = "Default prompt for structured output." + override_prompt = "Override prompt for this specific call." + + async def mock_cycle(*args, **kwargs): + structured_output_context = kwargs.get("structured_output_context") + # Should use override_prompt, not the default + assert structured_output_context.structured_output_prompt == override_prompt + + yield EventLoopStopEvent( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "Response"}]}, + metrics=mock_metrics, + request_state={}, + ) + + mock_event_loop.side_effect = mock_cycle + + # Create agent with default prompt, but override at call time + agent = Agent( + model=mock_model, + structured_output_model=user_model, + structured_output_prompt=default_prompt, + ) + agent("Get user info", structured_output_prompt=override_prompt) + + mock_event_loop.assert_called_once() + + @patch("strands.agent.agent.event_loop_cycle") + def test_agent_call_with_invocation_prompt_no_default(self, mock_event_loop, user_model, mock_model, mock_metrics): + """Test that invocation-level prompt works when no default is set.""" + invocation_prompt = "Format as structured output now." + + async def mock_cycle(*args, **kwargs): + structured_output_context = kwargs.get("structured_output_context") + assert structured_output_context.structured_output_prompt == invocation_prompt + + yield EventLoopStopEvent( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "Response"}]}, + metrics=mock_metrics, + request_state={}, + ) + + mock_event_loop.side_effect = mock_cycle + + # Create agent without default prompt + agent = Agent(model=mock_model, structured_output_model=user_model) + agent("Get user info", structured_output_prompt=invocation_prompt) + + mock_event_loop.assert_called_once() + + @pytest.mark.asyncio + @patch("strands.agent.agent.event_loop_cycle") + async def test_agent_invoke_async_with_structured_output_prompt( + self, mock_event_loop, user_model, mock_model, mock_metrics + ): + """Test Agent.invoke_async with structured_output_prompt.""" + custom_prompt = "Async prompt for structured output." + + async def mock_cycle(*args, **kwargs): + structured_output_context = kwargs.get("structured_output_context") + assert structured_output_context.structured_output_prompt == custom_prompt + + yield EventLoopStopEvent( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "Response"}]}, + metrics=mock_metrics, + request_state={}, + ) + + mock_event_loop.side_effect = mock_cycle + + agent = Agent(model=mock_model, structured_output_model=user_model) + await agent.invoke_async("Get user", structured_output_prompt=custom_prompt) + + mock_event_loop.assert_called_once() + + @pytest.mark.asyncio + @patch("strands.agent.agent.event_loop_cycle") + async def test_agent_stream_async_with_structured_output_prompt( + self, mock_event_loop, user_model, mock_model, mock_metrics + ): + """Test Agent.stream_async with structured_output_prompt.""" + custom_prompt = "Stream async prompt for structured output." + + async def mock_cycle(*args, **kwargs): + structured_output_context = kwargs.get("structured_output_context") + assert structured_output_context.structured_output_prompt == custom_prompt + + yield EventLoopStopEvent( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "Response"}]}, + metrics=mock_metrics, + request_state={}, + ) + + mock_event_loop.side_effect = mock_cycle + + agent = Agent(model=mock_model, structured_output_model=user_model) + async for _ in agent.stream_async("Get user", structured_output_prompt=custom_prompt): + pass + + mock_event_loop.assert_called_once() diff --git a/tests/strands/agent/test_conversation_manager.py b/tests/strands/agent/test_conversation_manager.py index 77d7dcce8..df748241e 100644 --- a/tests/strands/agent/test_conversation_manager.py +++ b/tests/strands/agent/test_conversation_manager.py @@ -1,9 +1,16 @@ +from unittest.mock import MagicMock, patch + import pytest +from strands import tool from strands.agent.agent import Agent +from strands.agent.conversation_manager.conversation_manager import ConversationManager from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager +from strands.hooks.events import BeforeModelCallEvent +from strands.hooks.registry import HookProvider, HookRegistry from strands.types.exceptions import ContextWindowOverflowException +from tests.fixtures.mocked_model_provider import MockedModelProvider @pytest.fixture @@ -72,6 +79,7 @@ def conversation_manager(request): ], ), # 5 - Remove dangling assistant message with tool use and user message without tool result + # Must start with a user message, so we skip the assistant message ( {"window_size": 3}, [ @@ -81,7 +89,6 @@ def conversation_manager(request): {"role": "assistant", "content": [{"toolUse": {"toolUseId": "123", "name": "tool1", "input": {}}}]}, ], [ - {"role": "assistant", "content": [{"text": "First response"}]}, {"role": "user", "content": [{"text": "Use a tool"}]}, {"role": "assistant", "content": [{"toolUse": {"toolUseId": "123", "name": "tool1", "input": {}}}]}, ], @@ -101,19 +108,22 @@ def conversation_manager(request): ], ), # 7 - Message count above max window size - Preserve tool use/tool result pairs + # Cannot start with assistant or orphaned toolResult, so trim advances to next plain user message ( {"window_size": 2}, [ - {"role": "user", "content": [{"toolResult": {"toolUseId": "123", "content": [], "status": "success"}}]}, + {"role": "user", "content": [{"text": "Hello"}]}, {"role": "assistant", "content": [{"toolUse": {"toolUseId": "123", "name": "tool1", "input": {}}}]}, - {"role": "user", "content": [{"toolResult": {"toolUseId": "456", "content": [], "status": "success"}}]}, + {"role": "user", "content": [{"toolResult": {"toolUseId": "123", "content": [], "status": "success"}}]}, + {"role": "assistant", "content": [{"text": "Done"}]}, + {"role": "user", "content": [{"text": "Next"}]}, ], [ - {"role": "assistant", "content": [{"toolUse": {"toolUseId": "123", "name": "tool1", "input": {}}}]}, - {"role": "user", "content": [{"toolResult": {"toolUseId": "456", "content": [], "status": "success"}}]}, + {"role": "user", "content": [{"text": "Next"}]}, ], ), # 8 - Test sliding window behavior - preserve tool use/result pairs across cut boundary + # Must start with user message (not assistant, not orphaned toolResult), so trim advances to plain user msg ( {"window_size": 3}, [ @@ -121,14 +131,14 @@ def conversation_manager(request): {"role": "assistant", "content": [{"toolUse": {"toolUseId": "123", "name": "tool1", "input": {}}}]}, {"role": "user", "content": [{"toolResult": {"toolUseId": "123", "content": [], "status": "success"}}]}, {"role": "assistant", "content": [{"text": "Response after tool use"}]}, + {"role": "user", "content": [{"text": "Follow up"}]}, ], [ - {"role": "assistant", "content": [{"toolUse": {"toolUseId": "123", "name": "tool1", "input": {}}}]}, - {"role": "user", "content": [{"toolResult": {"toolUseId": "123", "content": [], "status": "success"}}]}, - {"role": "assistant", "content": [{"text": "Response after tool use"}]}, + {"role": "user", "content": [{"text": "Follow up"}]}, ], ), # 9 - Test sliding window with multiple tool pairs that need preservation + # Must start with user message; orphaned toolResult is skipped, lands on plain user text ( {"window_size": 4}, [ @@ -138,11 +148,10 @@ def conversation_manager(request): {"role": "assistant", "content": [{"toolUse": {"toolUseId": "456", "name": "tool2", "input": {}}}]}, {"role": "user", "content": [{"toolResult": {"toolUseId": "456", "content": [], "status": "success"}}]}, {"role": "assistant", "content": [{"text": "Final response"}]}, + {"role": "user", "content": [{"text": "Another question"}]}, ], [ - {"role": "assistant", "content": [{"toolUse": {"toolUseId": "456", "name": "tool2", "input": {}}}]}, - {"role": "user", "content": [{"toolResult": {"toolUseId": "456", "content": [], "status": "success"}}]}, - {"role": "assistant", "content": [{"text": "Final response"}]}, + {"role": "user", "content": [{"text": "Another question"}]}, ], ), ], @@ -155,6 +164,43 @@ def test_apply_management(conversation_manager, messages, expected_messages): assert messages == expected_messages +def test_sliding_window_forces_user_message_start(): + """Test that trimmed conversation always starts with a user message (GitHub #2085).""" + manager = SlidingWindowConversationManager(window_size=3, should_truncate_results=False) + messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi"}]}, + {"role": "user", "content": [{"text": "How are you?"}]}, + {"role": "assistant", "content": [{"text": "Good"}]}, + {"role": "user", "content": [{"text": "Great"}]}, + ] + test_agent = Agent(messages=messages) + manager.apply_management(test_agent) + + assert len(messages) == 3 + assert messages[0]["role"] == "user" + assert messages[0]["content"] == [{"text": "How are you?"}] + + +def test_sliding_window_happy_path_preserves_window_size(): + """In a typical user/assistant conversation, trimming preserves close to window_size messages.""" + manager = SlidingWindowConversationManager(window_size=4, should_truncate_results=False) + messages = [ + {"role": "user", "content": [{"text": "First"}]}, + {"role": "assistant", "content": [{"text": "First response"}]}, + {"role": "user", "content": [{"text": "Second"}]}, + {"role": "assistant", "content": [{"text": "Second response"}]}, + {"role": "user", "content": [{"text": "Third"}]}, + {"role": "assistant", "content": [{"text": "Third response"}]}, + ] + test_agent = Agent(messages=messages) + manager.apply_management(test_agent) + + assert len(messages) == 4 + assert messages[0]["role"] == "user" + assert messages[0]["content"] == [{"text": "Second"}] + + def test_sliding_window_conversation_manager_with_untrimmable_history_raises_context_window_overflow_exception(): manager = SlidingWindowConversationManager(1, False) messages = [ @@ -165,47 +211,135 @@ def test_sliding_window_conversation_manager_with_untrimmable_history_raises_con test_agent = Agent(messages=messages) with pytest.raises(ContextWindowOverflowException): - manager.apply_management(test_agent) + manager.reduce_context(test_agent, e=RuntimeError("context overflow")) + + assert messages == original_messages + + +def test_sliding_window_no_valid_trim_point_without_error_does_not_raise(): + """When no valid trim point exists during routine management (no error), messages are left unchanged.""" + manager = SlidingWindowConversationManager(1, False) + messages = [ + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "456", "name": "tool1", "input": {}}}]}, + {"role": "user", "content": [{"toolResult": {"toolUseId": "789", "content": [], "status": "success"}}]}, + ] + original_messages = messages.copy() + test_agent = Agent(messages=messages) + + manager.apply_management(test_agent) assert messages == original_messages +def test_sliding_window_tool_heavy_conversation_falls_back_to_tool_pair_boundary(): + """Tool-heavy conversations trim to assistant(toolUse) + user(toolResult) boundary.""" + manager = SlidingWindowConversationManager(window_size=4, should_truncate_results=False) + messages = [ + {"role": "user", "content": [{"text": "Review this PR"}]}, + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "1", "name": "get_diff", "input": {}}}]}, + { + "role": "user", + "content": [{"toolResult": {"toolUseId": "1", "content": [{"text": "diff"}], "status": "success"}}], + }, + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "2", "name": "get_file", "input": {}}}]}, + { + "role": "user", + "content": [{"toolResult": {"toolUseId": "2", "content": [{"text": "file"}], "status": "success"}}], + }, + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "3", "name": "get_tree", "input": {}}}]}, + { + "role": "user", + "content": [{"toolResult": {"toolUseId": "3", "content": [{"text": "tree"}], "status": "success"}}], + }, + {"role": "assistant", "content": [{"text": "Here is my review"}]}, + ] + test_agent = Agent(messages=messages) + + manager.reduce_context(test_agent, e=Exception("context window overflow")) + + # Should trim to first assistant(toolUse) + user(toolResult) pair after trim_index + # With 8 messages and window_size=4, trim_index starts at 4. First fallback at index 5 (toolUseId "3"). + assert len(messages) == 3 + assert messages[0]["role"] == "assistant" + assert messages[0]["content"][0]["toolUse"]["toolUseId"] == "3" + assert messages[1]["role"] == "user" + assert any("toolResult" in content for content in messages[1]["content"]) + + +def test_sliding_window_prefers_plain_user_message_over_tool_pair_fallback(): + """Plain user messages are preferred over assistant+toolResult fallback when both exist.""" + manager = SlidingWindowConversationManager(window_size=2, should_truncate_results=False) + messages = [ + {"role": "user", "content": [{"text": "First"}]}, + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "1", "name": "tool1", "input": {}}}]}, + { + "role": "user", + "content": [{"toolResult": {"toolUseId": "1", "content": [{"text": "result"}], "status": "success"}}], + }, + {"role": "assistant", "content": [{"text": "Response"}]}, + {"role": "user", "content": [{"text": "Plain user message"}]}, + {"role": "assistant", "content": [{"text": "Final response"}]}, + ] + test_agent = Agent(messages=messages) + + manager.apply_management(test_agent) + + # Should prefer the plain user message, not the assistant+toolResult fallback + assert messages[0]["role"] == "user" + assert messages[0]["content"] == [{"text": "Plain user message"}] + + def test_sliding_window_conversation_manager_with_tool_results_truncated(): + large_text = "A" * 300 + "B" * 300 + "C" * 300 manager = SlidingWindowConversationManager(1) messages = [ {"role": "assistant", "content": [{"toolUse": {"toolUseId": "456", "name": "tool1", "input": {}}}]}, { "role": "user", - "content": [ - {"toolResult": {"toolUseId": "789", "content": [{"text": "large input"}], "status": "success"}} - ], + "content": [{"toolResult": {"toolUseId": "789", "content": [{"text": large_text}], "status": "success"}}], }, ] test_agent = Agent(messages=messages) - manager.reduce_context(test_agent) + manager.reduce_context(test_agent, e=RuntimeError("context overflow")) - expected_messages = [ + result_text = messages[1]["content"][0]["toolResult"]["content"][0]["text"] + assert result_text.startswith("A" * 200) + assert result_text.endswith("C" * 200) + assert "... [truncated:" in result_text + # Status must NOT be changed to error + assert messages[1]["content"][0]["toolResult"]["status"] == "success" + + +def test_sliding_window_proactive_compression_skips_tool_result_truncation(): + """Proactive compression (e=None) should only trim messages, not truncate tool results.""" + large_text = "A" * 300 + "B" * 300 + "C" * 300 + manager = SlidingWindowConversationManager(window_size=2) + messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, {"role": "assistant", "content": [{"toolUse": {"toolUseId": "456", "name": "tool1", "input": {}}}]}, { "role": "user", - "content": [ - { - "toolResult": { - "toolUseId": "789", - "content": [{"text": "The tool result was too large!"}], - "status": "error", - } - } - ], + "content": [{"toolResult": {"toolUseId": "789", "content": [{"text": large_text}], "status": "success"}}], }, + {"role": "assistant", "content": [{"text": "Done"}]}, + {"role": "user", "content": [{"text": "Next question"}]}, ] + test_agent = Agent(messages=messages) - assert messages == expected_messages + manager.reduce_context(test_agent) # e=None (proactive) + # Tool results should NOT be truncated during proactive compression + for msg in messages: + for content in msg.get("content", []): + if "toolResult" in content: + for item in content["toolResult"].get("content", []): + if "text" in item: + assert "... [truncated:" not in item["text"] -def test_null_conversation_manager_reduce_context_raises_context_window_overflow_exception(): - """Test that NullConversationManager doesn't modify messages.""" + +def test_null_conversation_manager_reduce_context_proactive_returns_silently(): + """Proactive compression (e=None) returns silently without raising.""" manager = NullConversationManager() messages = [ {"role": "user", "content": [{"text": "Hello"}]}, @@ -216,12 +350,25 @@ def test_null_conversation_manager_reduce_context_raises_context_window_overflow manager.apply_management(test_agent) - with pytest.raises(ContextWindowOverflowException): - manager.reduce_context(messages) + # Proactive call (e=None) should not raise + manager.reduce_context(test_agent) assert messages == original_messages +def test_null_conversation_manager_reduce_context_reactive_raises_overflow(): + """Reactive overflow (e is not None) re-raises the exception.""" + manager = NullConversationManager() + messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there"}]}, + ] + test_agent = Agent(messages=messages) + + with pytest.raises(ContextWindowOverflowException): + manager.reduce_context(test_agent, e=ContextWindowOverflowException("overflow")) + + def test_null_conversation_manager_reduce_context_with_exception_raises_same_exception(): """Test that NullConversationManager doesn't modify messages.""" manager = NullConversationManager() @@ -246,3 +393,646 @@ def test_null_conversation_does_not_restore_with_incorrect_state(): with pytest.raises(ValueError): manager.restore_from_session({}) + + +# ============================================================================== +# Per-Turn Management Tests +# ============================================================================== + + +def test_per_turn_parameter_validation(): + """Test per_turn parameter validation.""" + # Valid values + assert SlidingWindowConversationManager(per_turn=False).per_turn is False + assert SlidingWindowConversationManager(per_turn=True).per_turn is True + assert SlidingWindowConversationManager(per_turn=3).per_turn == 3 + + +def test_per_turn_zero_raises_value_error(): + with pytest.raises(ValueError, match="per_turn"): + SlidingWindowConversationManager(per_turn=0) + + +def test_per_turn_negative_raises_value_error(): + with pytest.raises(ValueError, match="per_turn"): + SlidingWindowConversationManager(per_turn=-5) + + +def test_conversation_manager_is_hook_provider(): + """Test that ConversationManager implements HookProvider protocol.""" + manager = NullConversationManager() + assert isinstance(manager, HookProvider) + + +def test_derived_class_does_not_need_to_implement_register_hooks(): + """Test that derived classes don't need to override register_hooks for backwards compatibility.""" + from strands.agent.conversation_manager.conversation_manager import ConversationManager + + class MinimalConversationManager(ConversationManager): + """A minimal implementation that only implements abstract methods.""" + + def apply_management(self, agent, **kwargs): + pass + + def reduce_context(self, agent, e=None, **kwargs): + pass + + # Should be able to instantiate without implementing register_hooks + manager = MinimalConversationManager() + registry = HookRegistry() + + # Should work without error — the base class always registers the hook + manager.register_hooks(registry) + # Base class always registers the proactive compression hook + assert registry.has_callbacks() + + +def test_per_turn_hooks_registration(): + """Test that hooks are registered when conversation_manager implements HookProvider.""" + manager = SlidingWindowConversationManager(per_turn=True) + assert isinstance(manager, HookProvider) + + registry = HookRegistry() + manager.register_hooks(registry) + assert registry.has_callbacks() + + +def test_per_turn_false_no_management_during_loop(): + """Test that per_turn=False only manages in finally block.""" + manager = SlidingWindowConversationManager(per_turn=False, window_size=100) + responses = [{"role": "assistant", "content": [{"text": "Response"}]}] * 3 + model = MockedModelProvider(responses) + agent = Agent(model=model, conversation_manager=manager) + + with patch.object(manager, "apply_management", wraps=manager.apply_management) as mock: + agent("Test") + # Should only be called once in finally block (per_turn disabled) + assert mock.call_count == 1 + + +def test_per_turn_true_manages_each_model_call(): + """Test that per_turn=True applies management before each model call.""" + manager = SlidingWindowConversationManager(per_turn=True, window_size=100) + responses = [{"role": "assistant", "content": [{"text": "Response"}]}] * 3 + model = MockedModelProvider(responses) + agent = Agent(model=model, conversation_manager=manager) + + with patch.object(manager, "apply_management", wraps=manager.apply_management) as mock: + agent("Test") + # Should be called for each model call + finally block + # With simple text responses, agent makes 1 model call then stops + assert mock.call_count >= 1 + + +def test_per_turn_integer_manages_every_n_calls(): + """Test that per_turn=N applies management every N model calls.""" + manager = SlidingWindowConversationManager(per_turn=2, window_size=100) + # Create responses that trigger multiple model calls + responses = [ + {"role": "assistant", "content": [{"toolUse": {"toolUseId": f"{i}", "name": "test", "input": {}}}]} + for i in range(5) + ] + [{"role": "assistant", "content": [{"text": "Done"}]}] + model = MockedModelProvider(responses) + + @tool(name="test") + def test_tool(query: str = "") -> str: + return "result" + + agent = Agent(model=model, conversation_manager=manager, tools=[test_tool]) + + with patch.object(manager, "apply_management", wraps=manager.apply_management) as mock: + agent("Test") + # With 6 model calls and per_turn=2: called on 2nd, 4th, 6th + finally + assert mock.call_count == 4 + + +def test_per_turn_dynamic_change(): + """Test that per_turn can be changed dynamically.""" + manager = SlidingWindowConversationManager(per_turn=False) + registry = HookRegistry() + manager.register_hooks(registry) + + mock_agent = MagicMock() + mock_agent.messages = [] + event = BeforeModelCallEvent(agent=mock_agent, invocation_state={}) + + # Initially disabled + with patch.object(manager, "apply_management") as mock_apply: + registry.invoke_callbacks(event) + assert mock_apply.call_count == 0 + + # Enable dynamically + manager.per_turn = True + with patch.object(manager, "apply_management") as mock_apply: + registry.invoke_callbacks(event) + assert mock_apply.call_count == 1 + + +def test_per_turn_reduces_message_count(): + """Test that per_turn actually reduces message count during execution.""" + manager = SlidingWindowConversationManager(per_turn=1, window_size=4) + responses = [{"role": "assistant", "content": [{"text": f"Response {i}"}]} for i in range(10)] + model = MockedModelProvider(responses) + agent = Agent(model=model, conversation_manager=manager) + + message_counts = [] + original_apply = manager.apply_management + + def track_apply(agent_instance): + message_counts.append(len(agent_instance.messages)) + return original_apply(agent_instance) + + with patch.object(manager, "apply_management", side_effect=track_apply): + agent("Test") + + # Verify message count stayed around window_size + assert any(count <= manager.window_size for count in message_counts) + + +def test_per_turn_state_persistence(): + """Test that model_call_count is persisted in state.""" + manager = SlidingWindowConversationManager(per_turn=3) + manager._model_call_count = 7 + + state = manager.get_state() + assert state["model_call_count"] == 7 + + new_manager = SlidingWindowConversationManager(per_turn=3) + new_manager.restore_from_session(state) + assert new_manager._model_call_count == 7 + + +def test_per_turn_backward_compatibility(): + """Test that existing code without per_turn still works.""" + manager = SlidingWindowConversationManager(window_size=40) + assert manager.per_turn is False + + responses = [{"role": "assistant", "content": [{"text": "Hello"}]}] + model = MockedModelProvider(responses) + agent = Agent(model=model, conversation_manager=manager) + result = agent("Hello") + assert result is not None + + +# ============================================================================== +# Improved Truncation Strategy Tests +# ============================================================================== + + +def test_truncation_targets_oldest_message_first(): + """Oldest message with tool results is truncated before newer ones.""" + large_text = "X" * 20000 + manager = SlidingWindowConversationManager(window_size=10) + messages = [ + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "1", "name": "tool1", "input": {}}}]}, + { + "role": "user", + "content": [{"toolResult": {"toolUseId": "1", "content": [{"text": large_text}], "status": "success"}}], + }, + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "2", "name": "tool2", "input": {}}}]}, + { + "role": "user", + "content": [{"toolResult": {"toolUseId": "2", "content": [{"text": large_text}], "status": "success"}}], + }, + ] + test_agent = Agent(messages=messages) + + manager.reduce_context(test_agent, e=RuntimeError("context overflow")) + + # The oldest tool result (index 1) must be truncated + oldest_text = messages[1]["content"][0]["toolResult"]["content"][0]["text"] + assert "... [truncated:" in oldest_text + + # The newest tool result (index 3) must remain untouched after the first reduce_context call + newest_text = messages[3]["content"][0]["toolResult"]["content"][0]["text"] + assert "... [truncated:" not in newest_text + + +def test_large_tool_result_partially_truncated_with_context_preserved(): + """Large tool results are truncated in the middle while the beginning and end are preserved.""" + preserve = 200 # matches _PRESERVE_CHARS + # Build text with distinct prefix, middle, and suffix + prefix_text = "P" * preserve + middle_text = "M" * 500 + suffix_text = "S" * preserve + large_text = prefix_text + middle_text + suffix_text + + manager = SlidingWindowConversationManager(window_size=10) + messages = [ + { + "role": "user", + "content": [{"toolResult": {"toolUseId": "1", "content": [{"text": large_text}], "status": "success"}}], + } + ] + + truncated = manager._truncate_tool_results(messages, 0) + + assert truncated + result_text = messages[0]["content"][0]["toolResult"]["content"][0]["text"] + assert result_text.startswith(prefix_text) + assert result_text.endswith(suffix_text) + assert "... [truncated:" in result_text + removed = len(large_text) - 2 * preserve + assert f"... [truncated: {removed} chars removed] ..." in result_text + + +def test_truncation_does_not_change_status_to_error(): + """Partial truncation must not change the tool result status.""" + large_text = "Z" * 15000 + manager = SlidingWindowConversationManager(window_size=10) + messages = [ + { + "role": "user", + "content": [{"toolResult": {"toolUseId": "1", "content": [{"text": large_text}], "status": "success"}}], + } + ] + + manager._truncate_tool_results(messages, 0) + + assert messages[0]["content"][0]["toolResult"]["status"] == "success" + + +def test_image_blocks_inside_tool_result_replaced_with_placeholder(): + """Image blocks nested inside toolResult content are replaced with a text placeholder.""" + manager = SlidingWindowConversationManager(window_size=10) + image_data = b"base64encodeddata" + messages = [ + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "1", + "content": [ + {"text": "some text"}, + { + "image": { + "format": "jpeg", + "source": {"bytes": image_data}, + } + }, + ], + "status": "success", + } + } + ], + } + ] + + changed = manager._truncate_tool_results(messages, 0) + + assert changed + tool_result_items = messages[0]["content"][0]["toolResult"]["content"] + assert not any(isinstance(item, dict) and "image" in item for item in tool_result_items) + expected_placeholder = f"[image: jpeg, {len(image_data)} bytes]" + assert any(isinstance(item, dict) and item.get("text") == expected_placeholder for item in tool_result_items) + + +def test_already_truncated_text_not_truncated_again(): + """A text block that already contains the truncation marker is not truncated a second time.""" + manager = SlidingWindowConversationManager(window_size=10) + already_truncated = "A" * 200 + "...\n\n... [truncated: 990 chars removed] ...\n\n..." + "Z" * 200 + messages = [ + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "1", + "content": [{"text": already_truncated}], + "status": "success", + } + } + ], + } + ] + + changed = manager._truncate_tool_results(messages, 0) + + assert not changed + assert messages[0]["content"][0]["toolResult"]["content"][0]["text"] == already_truncated + + +def test_short_text_in_tool_result_not_truncated(): + """Text no longer than 2 * _PRESERVE_CHARS must not be modified.""" + manager = SlidingWindowConversationManager(window_size=10) + short_text = "X" * 100 # 100 < 2 * 200 + messages = [ + { + "role": "user", + "content": [{"toolResult": {"toolUseId": "1", "content": [{"text": short_text}], "status": "success"}}], + } + ] + + changed = manager._truncate_tool_results(messages, 0) + + assert not changed + assert messages[0]["content"][0]["toolResult"]["content"][0]["text"] == short_text + + +def test_boundary_text_in_tool_result_not_truncated(): + """Text of exactly 2 * _PRESERVE_CHARS must not be truncated.""" + manager = SlidingWindowConversationManager(window_size=10) + boundary_text = "X" * 400 # exactly 2 * 200 + messages = [ + { + "role": "user", + "content": [{"toolResult": {"toolUseId": "1", "content": [{"text": boundary_text}], "status": "success"}}], + } + ] + + changed = manager._truncate_tool_results(messages, 0) + + assert not changed + assert messages[0]["content"][0]["toolResult"]["content"][0]["text"] == boundary_text + + +# ============================================================================== +# window_size=0 and negative window_size validation tests +# ============================================================================== + + +def test_window_size_negative_raises_value_error(): + with pytest.raises(ValueError, match="window_size"): + SlidingWindowConversationManager(window_size=-1) + + +def test_window_size_zero_clears_all_messages_on_apply_management(): + """window_size=0 should remove all messages, matching TypeScript SDK behaviour (issue #2205).""" + manager = SlidingWindowConversationManager(window_size=0, should_truncate_results=False) + messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there"}]}, + ] + test_agent = Agent(messages=messages) + manager.apply_management(test_agent) + + assert messages == [] + assert manager.removed_message_count == 2 + + +def test_window_size_zero_clears_all_messages_on_reduce_context(): + """reduce_context with window_size=0 should clear all messages even without overflow.""" + manager = SlidingWindowConversationManager(window_size=0, should_truncate_results=False) + messages = [ + {"role": "user", "content": [{"text": "First"}]}, + {"role": "assistant", "content": [{"text": "Second"}]}, + {"role": "user", "content": [{"text": "Third"}]}, + ] + test_agent = Agent(messages=messages) + manager.reduce_context(test_agent) + + assert messages == [] + assert manager.removed_message_count == 3 + + +def test_window_size_zero_clears_on_overflow(): + """reduce_context with window_size=0 should clear messages even when called with an overflow exception.""" + manager = SlidingWindowConversationManager(window_size=0, should_truncate_results=False) + messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi"}]}, + ] + test_agent = Agent(messages=messages) + manager.reduce_context(test_agent, e=Exception("overflow")) + + assert messages == [] + + +# ============================================================================== +# Proactive Compression Tests (proactive_compression parameter) +# ============================================================================== + + +class _MinimalManager(ConversationManager): + """Manager that only implements abstract methods.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.reduce_context_call_count = 0 + + def apply_management(self, agent, **kwargs): + pass + + def reduce_context(self, agent, e=None, **kwargs): + self.reduce_context_call_count += 1 + if agent.messages: + agent.messages.pop(0) + + +def _make_mock_agent(messages=None, context_window_limit=1000): + agent = MagicMock() + agent.messages = messages if messages is not None else [] + agent.model = MagicMock() + agent.model.context_window_limit = context_window_limit + return agent + + +def _make_threshold_event(agent, projected_input_tokens=None): + return BeforeModelCallEvent( + agent=agent, + invocation_state={}, + projected_input_tokens=projected_input_tokens, + ) + + +def test_proactive_compression_rejects_zero(): + with pytest.raises(ValueError, match="compression_threshold must be between 0"): + _MinimalManager(proactive_compression={"compression_threshold": 0}) + + +def test_proactive_compression_rejects_negative(): + with pytest.raises(ValueError, match="compression_threshold must be between 0"): + _MinimalManager(proactive_compression={"compression_threshold": -0.5}) + + +def test_proactive_compression_rejects_greater_than_one(): + with pytest.raises(ValueError, match="compression_threshold must be between 0"): + _MinimalManager(proactive_compression={"compression_threshold": 1.5}) + + +def test_proactive_compression_accepts_exactly_one(): + manager = _MinimalManager(proactive_compression={"compression_threshold": 1.0}) + assert manager._compression_threshold == 1.0 + + +def test_proactive_compression_none_by_default(): + manager = _MinimalManager() + assert manager._compression_threshold is None + + +def test_proactive_compression_true_uses_default_threshold(): + """proactive_compression=True uses default threshold of 0.7.""" + manager = _MinimalManager(proactive_compression=True) + assert manager._compression_threshold == 0.7 + + +def test_proactive_compression_false_disables(): + """proactive_compression=False means no compression.""" + manager = _MinimalManager(proactive_compression=False) + assert manager._compression_threshold is None + + +def test_proactive_compression_always_registers_hook(): + """Hook is always registered regardless of proactive_compression setting.""" + manager = _MinimalManager() + registry = HookRegistry() + manager.register_hooks(registry) + # Always registers the hook + assert registry.has_callbacks() + + +def test_proactive_compression_hook_is_noop_when_not_configured(): + """BeforeModelCallEvent handler is a no-op when proactive_compression is not set.""" + manager = _MinimalManager() + agent = _make_mock_agent(context_window_limit=1000) + registry = HookRegistry() + manager.register_hooks(registry) + + event = _make_threshold_event(agent, projected_input_tokens=900) + registry.invoke_callbacks(event) + + assert manager.reduce_context_call_count == 0 + + +def test_proactive_compression_calls_reduce_context_when_exceeded(): + manager = _MinimalManager(proactive_compression={"compression_threshold": 0.7}) + agent = _make_mock_agent(messages=[{"role": "user", "content": [{"text": "msg"}]}], context_window_limit=1000) + registry = HookRegistry() + manager.register_hooks(registry) + + event = _make_threshold_event(agent, projected_input_tokens=800) + registry.invoke_callbacks(event) + + assert manager.reduce_context_call_count == 1 + + +def test_proactive_compression_no_call_when_below(): + manager = _MinimalManager(proactive_compression={"compression_threshold": 0.7}) + agent = _make_mock_agent(context_window_limit=1000) + registry = HookRegistry() + manager.register_hooks(registry) + + event = _make_threshold_event(agent, projected_input_tokens=500) + registry.invoke_callbacks(event) + + assert manager.reduce_context_call_count == 0 + + +def test_proactive_compression_no_call_when_projected_tokens_none(): + manager = _MinimalManager(proactive_compression=True) + agent = _make_mock_agent(context_window_limit=1000) + registry = HookRegistry() + manager.register_hooks(registry) + + event = _make_threshold_event(agent, projected_input_tokens=None) + registry.invoke_callbacks(event) + + assert manager.reduce_context_call_count == 0 + + +def test_proactive_compression_uses_default_when_context_window_limit_not_set(): + manager = _MinimalManager(proactive_compression={"compression_threshold": 0.7}) + agent = _make_mock_agent(context_window_limit=None) + registry = HookRegistry() + manager.register_hooks(registry) + + # projected_input_tokens=150_000 is 75% of the 200k default, exceeding 0.7 threshold + event = _make_threshold_event(agent, projected_input_tokens=150_000) + with patch("strands.agent.conversation_manager.conversation_manager.logger") as mock_logger: + registry.invoke_callbacks(event) + mock_logger.warning.assert_called_once() + assert "using default" in mock_logger.warning.call_args[0][0] + + assert manager.reduce_context_call_count == 1 + + +def test_proactive_compression_warns_only_once_per_instance(): + """Second invocation on the same manager instance suppresses the context_window_limit warning.""" + manager = _MinimalManager(proactive_compression={"compression_threshold": 0.7}) + agent = _make_mock_agent(context_window_limit=None) + registry = HookRegistry() + manager.register_hooks(registry) + + event = _make_threshold_event(agent, projected_input_tokens=150_000) + with patch("strands.agent.conversation_manager.conversation_manager.logger") as mock_logger: + registry.invoke_callbacks(event) + registry.invoke_callbacks(event) + assert mock_logger.warning.call_count == 1 + + +def test_proactive_compression_exception_swallowed(): + """Exceptions in reduce_context during proactive compression should not propagate.""" + + class _FailingManager(ConversationManager): + def apply_management(self, agent, **kwargs): + pass + + def reduce_context(self, agent, e=None, **kwargs): + raise RuntimeError("boom") + + manager = _FailingManager(proactive_compression={"compression_threshold": 0.7}) + agent = _make_mock_agent(context_window_limit=1000) + registry = HookRegistry() + manager.register_hooks(registry) + + event = _make_threshold_event(agent, projected_input_tokens=800) + registry.invoke_callbacks(event) + + +def test_proactive_compression_true_default_threshold_behavior(): + """proactive_compression=True uses 0.7 — triggered at 0.7+ but not below.""" + manager = _MinimalManager(proactive_compression=True) + agent = _make_mock_agent( + messages=[{"role": "user", "content": [{"text": "msg"}]}], context_window_limit=1000 + ) + registry = HookRegistry() + manager.register_hooks(registry) + + # 650/1000 = 0.65 < 0.7 — should NOT trigger + event = _make_threshold_event(agent, projected_input_tokens=650) + registry.invoke_callbacks(event) + assert manager.reduce_context_call_count == 0 + + # 800/1000 = 0.8 >= 0.7 — should trigger + event2 = _make_threshold_event(agent, projected_input_tokens=800) + registry.invoke_callbacks(event2) + assert manager.reduce_context_call_count == 1 + + +def test_sliding_window_proactive_compression_trims(): + manager = SlidingWindowConversationManager( + window_size=4, should_truncate_results=False, proactive_compression={"compression_threshold": 0.7} + ) + messages = [ + {"role": "user", "content": [{"text": f"Message {i}"}]} + if i % 2 == 0 + else {"role": "assistant", "content": [{"text": f"Response {i}"}]} + for i in range(6) + ] + agent = _make_mock_agent(messages=messages, context_window_limit=1000) + registry = HookRegistry() + manager.register_hooks(registry) + + event = _make_threshold_event(agent, projected_input_tokens=800) + registry.invoke_callbacks(event) + + assert len(agent.messages) == 4 + + +def test_sliding_window_proactive_compression_no_trim_below(): + manager = SlidingWindowConversationManager( + window_size=4, should_truncate_results=False, proactive_compression={"compression_threshold": 0.7} + ) + messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi"}]}, + ] + agent = _make_mock_agent(messages=messages, context_window_limit=1000) + registry = HookRegistry() + manager.register_hooks(registry) + + event = _make_threshold_event(agent, projected_input_tokens=500) + registry.invoke_callbacks(event) + + assert len(agent.messages) == 2 diff --git a/tests/strands/agent/test_retry.py b/tests/strands/agent/test_retry.py new file mode 100644 index 000000000..830c1b5b8 --- /dev/null +++ b/tests/strands/agent/test_retry.py @@ -0,0 +1,328 @@ +"""Unit tests for retry strategy implementations.""" + +from unittest.mock import Mock + +import pytest + +from strands import ModelRetryStrategy +from strands.hooks import AfterInvocationEvent, AfterModelCallEvent, HookRegistry +from strands.types._events import EventLoopThrottleEvent +from strands.types.exceptions import ModelThrottledException + +# ModelRetryStrategy Tests + + +def test_model_retry_strategy_init_with_defaults(): + """Test ModelRetryStrategy initialization with default parameters.""" + strategy = ModelRetryStrategy() + assert strategy._max_attempts == 6 + assert strategy._initial_delay == 4 + assert strategy._max_delay == 240 + assert strategy._current_attempt == 0 + + +def test_model_retry_strategy_init_with_custom_parameters(): + """Test ModelRetryStrategy initialization with custom parameters.""" + strategy = ModelRetryStrategy(max_attempts=3, initial_delay=2, max_delay=60) + assert strategy._max_attempts == 3 + assert strategy._initial_delay == 2 + assert strategy._max_delay == 60 + assert strategy._current_attempt == 0 + + +def test_model_retry_strategy_calculate_delay_with_different_attempts(): + """Test _calculate_delay returns correct exponential backoff for different attempt numbers.""" + strategy = ModelRetryStrategy(initial_delay=2, max_delay=32) + + # Test exponential backoff: 2 * (2^attempt) + assert strategy._calculate_delay(0) == 2 # 2 * 2^0 = 2 + assert strategy._calculate_delay(1) == 4 # 2 * 2^1 = 4 + assert strategy._calculate_delay(2) == 8 # 2 * 2^2 = 8 + assert strategy._calculate_delay(3) == 16 # 2 * 2^3 = 16 + assert strategy._calculate_delay(4) == 32 # 2 * 2^4 = 32 (at max) + assert strategy._calculate_delay(5) == 32 # 2 * 2^5 = 64, capped at 32 + assert strategy._calculate_delay(10) == 32 # Large attempt, still capped + + +def test_model_retry_strategy_calculate_delay_respects_max_delay(): + """Test _calculate_delay respects max_delay cap.""" + strategy = ModelRetryStrategy(initial_delay=10, max_delay=50) + + assert strategy._calculate_delay(0) == 10 # 10 * 2^0 = 10 + assert strategy._calculate_delay(1) == 20 # 10 * 2^1 = 20 + assert strategy._calculate_delay(2) == 40 # 10 * 2^2 = 40 + assert strategy._calculate_delay(3) == 50 # 10 * 2^3 = 80, capped at 50 + assert strategy._calculate_delay(4) == 50 # 10 * 2^4 = 160, capped at 50 + + +def test_model_retry_strategy_register_hooks(): + """Test that ModelRetryStrategy registers AfterModelCallEvent and AfterInvocationEvent callbacks.""" + strategy = ModelRetryStrategy() + registry = HookRegistry() + + strategy.register_hooks(registry) + + # Verify AfterModelCallEvent callback was registered + assert AfterModelCallEvent in registry._registered_callbacks + assert len(registry._registered_callbacks[AfterModelCallEvent]) == 1 + + # Verify AfterInvocationEvent callback was registered + assert AfterInvocationEvent in registry._registered_callbacks + assert len(registry._registered_callbacks[AfterInvocationEvent]) == 1 + + +@pytest.mark.asyncio +async def test_model_retry_strategy_retry_on_throttle_exception_first_attempt(mock_sleep): + """Test retry behavior on first ModelThrottledException.""" + strategy = ModelRetryStrategy(max_attempts=3, initial_delay=2, max_delay=60) + mock_agent = Mock() + + event = AfterModelCallEvent( + agent=mock_agent, + exception=ModelThrottledException("Throttled"), + ) + + await strategy._handle_after_model_call(event) + + # Should set retry to True + assert event.retry is True + # Should sleep for initial_delay (attempt 0: 2 * 2^0 = 2) + assert mock_sleep.sleep_calls == [2] + assert mock_sleep.sleep_calls[0] == strategy._calculate_delay(0) + # Should increment attempt + assert strategy._current_attempt == 1 + + +@pytest.mark.asyncio +async def test_model_retry_strategy_exponential_backoff(mock_sleep): + """Test exponential backoff calculation.""" + strategy = ModelRetryStrategy(max_attempts=5, initial_delay=2, max_delay=16) + mock_agent = Mock() + + # Simulate multiple retries + for _ in range(4): + event = AfterModelCallEvent( + agent=mock_agent, + exception=ModelThrottledException("Throttled"), + ) + await strategy._handle_after_model_call(event) + assert event.retry is True + + # Verify exponential backoff with max_delay cap + # attempt 0: 2*2^0=2, attempt 1: 2*2^1=4, attempt 2: 2*2^2=8, attempt 3: 2*2^3=16 (capped) + assert mock_sleep.sleep_calls == [2, 4, 8, 16] + for i, sleep_delay in enumerate(mock_sleep.sleep_calls): + assert sleep_delay == strategy._calculate_delay(i) + + +@pytest.mark.asyncio +async def test_model_retry_strategy_no_retry_after_max_attempts(mock_sleep): + """Test that retry is not set after reaching max_attempts.""" + strategy = ModelRetryStrategy(max_attempts=2, initial_delay=2, max_delay=60) + mock_agent = Mock() + + # First attempt + event1 = AfterModelCallEvent( + agent=mock_agent, + exception=ModelThrottledException("Throttled"), + ) + await strategy._handle_after_model_call(event1) + assert event1.retry is True + assert strategy._current_attempt == 1 + + # Second attempt (at max_attempts) + event2 = AfterModelCallEvent( + agent=mock_agent, + exception=ModelThrottledException("Throttled"), + ) + await strategy._handle_after_model_call(event2) + # Should NOT retry after reaching max_attempts + assert event2.retry is False + assert strategy._current_attempt == 2 + + +@pytest.mark.asyncio +async def test_model_retry_strategy_no_retry_on_non_throttle_exception(): + """Test that retry is not set for non-throttling exceptions.""" + strategy = ModelRetryStrategy() + mock_agent = Mock() + + event = AfterModelCallEvent( + agent=mock_agent, + exception=ValueError("Some other error"), + ) + + await strategy._handle_after_model_call(event) + + # Should not retry on non-throttling exceptions + assert event.retry is False + assert strategy._current_attempt == 0 + + +@pytest.mark.asyncio +async def test_model_retry_strategy_no_retry_on_success(): + """Test that retry is not set when model call succeeds.""" + strategy = ModelRetryStrategy() + mock_agent = Mock() + + event = AfterModelCallEvent( + agent=mock_agent, + stop_response=AfterModelCallEvent.ModelStopResponse( + message={"role": "assistant", "content": [{"text": "Success"}]}, + stop_reason="end_turn", + ), + ) + + await strategy._handle_after_model_call(event) + + # Should not retry on success + assert event.retry is False + + +@pytest.mark.asyncio +async def test_model_retry_strategy_reset_on_success(mock_sleep): + """Test that strategy resets attempt counter on successful call.""" + strategy = ModelRetryStrategy(max_attempts=3, initial_delay=2, max_delay=60) + mock_agent = Mock() + + # First failure + event1 = AfterModelCallEvent( + agent=mock_agent, + exception=ModelThrottledException("Throttled"), + ) + await strategy._handle_after_model_call(event1) + assert event1.retry is True + assert strategy._current_attempt == 1 + # Should sleep for initial_delay (attempt 0: 2 * 2^0 = 2) + assert mock_sleep.sleep_calls == [2] + assert mock_sleep.sleep_calls[0] == strategy._calculate_delay(0) + + # Success - should reset + event2 = AfterModelCallEvent( + agent=mock_agent, + stop_response=AfterModelCallEvent.ModelStopResponse( + message={"role": "assistant", "content": [{"text": "Success"}]}, + stop_reason="end_turn", + ), + ) + await strategy._handle_after_model_call(event2) + assert event2.retry is False + # Should reset to initial state + assert strategy._current_attempt == 0 + assert strategy._calculate_delay(0) == 2 + + +@pytest.mark.asyncio +async def test_model_retry_strategy_skips_if_already_retrying(): + """Test that strategy skips processing if event.retry is already True.""" + strategy = ModelRetryStrategy(max_attempts=3, initial_delay=2, max_delay=60) + mock_agent = Mock() + + event = AfterModelCallEvent( + agent=mock_agent, + exception=ModelThrottledException("Throttled"), + ) + # Simulate another hook already set retry to True + event.retry = True + + await strategy._handle_after_model_call(event) + + # Should not modify state since another hook already triggered retry + assert strategy._current_attempt == 0 + assert event.retry is True + + +@pytest.mark.asyncio +async def test_model_retry_strategy_reset_on_after_invocation(): + """Test that strategy resets state on AfterInvocationEvent.""" + strategy = ModelRetryStrategy(max_attempts=3, initial_delay=2, max_delay=60) + mock_agent = Mock() + + # Simulate some retry attempts + strategy._current_attempt = 3 + + event = AfterInvocationEvent(agent=mock_agent, result=Mock()) + await strategy._handle_after_invocation(event) + + # Should reset to initial state + assert strategy._current_attempt == 0 + + +@pytest.mark.asyncio +async def test_model_retry_strategy_backwards_compatible_event_set_on_retry(mock_sleep): + """Test that _backwards_compatible_event_to_yield is set when retrying.""" + strategy = ModelRetryStrategy(max_attempts=3, initial_delay=2, max_delay=60) + mock_agent = Mock() + + event = AfterModelCallEvent( + agent=mock_agent, + exception=ModelThrottledException("Throttled"), + ) + + await strategy._handle_after_model_call(event) + + # Should have set the backwards compatible event + assert strategy._backwards_compatible_event_to_yield is not None + assert isinstance(strategy._backwards_compatible_event_to_yield, EventLoopThrottleEvent) + assert strategy._backwards_compatible_event_to_yield["event_loop_throttled_delay"] == 2 + + +@pytest.mark.asyncio +async def test_model_retry_strategy_backwards_compatible_event_cleared_on_success(): + """Test that _backwards_compatible_event_to_yield is cleared on success.""" + strategy = ModelRetryStrategy(max_attempts=3, initial_delay=2, max_delay=60) + mock_agent = Mock() + + # Set a previous backwards compatible event + strategy._backwards_compatible_event_to_yield = EventLoopThrottleEvent(delay=2) + + event = AfterModelCallEvent( + agent=mock_agent, + stop_response=AfterModelCallEvent.ModelStopResponse( + message={"role": "assistant", "content": [{"text": "Success"}]}, + stop_reason="end_turn", + ), + ) + + await strategy._handle_after_model_call(event) + + # Should have cleared the backwards compatible event + assert strategy._backwards_compatible_event_to_yield is None + + +@pytest.mark.asyncio +async def test_model_retry_strategy_backwards_compatible_event_not_set_on_max_attempts(mock_sleep): + """Test that _backwards_compatible_event_to_yield is not set when max attempts reached.""" + strategy = ModelRetryStrategy(max_attempts=1, initial_delay=2, max_delay=60) + mock_agent = Mock() + + event = AfterModelCallEvent( + agent=mock_agent, + exception=ModelThrottledException("Throttled"), + ) + + await strategy._handle_after_model_call(event) + + # Should not have set the backwards compatible event since max attempts reached + assert strategy._backwards_compatible_event_to_yield is None + assert event.retry is False + + +@pytest.mark.asyncio +async def test_model_retry_strategy_no_retry_when_no_exception_and_no_stop_response(): + """Test that retry is not set when there's no exception and no stop_response.""" + strategy = ModelRetryStrategy() + mock_agent = Mock() + + # Event with neither exception nor stop_response + event = AfterModelCallEvent( + agent=mock_agent, + exception=None, + stop_response=None, + ) + + await strategy._handle_after_model_call(event) + + # Should not retry and should reset state + assert event.retry is False + assert strategy._current_attempt == 0 diff --git a/tests/strands/agent/test_snapshot.py b/tests/strands/agent/test_snapshot.py new file mode 100644 index 000000000..50e83a484 --- /dev/null +++ b/tests/strands/agent/test_snapshot.py @@ -0,0 +1,453 @@ +"""Tests for _snapshot.py — Snapshot dataclass and resolve_snapshot_fields.""" + +import json +import re +from typing import Any +from unittest.mock import MagicMock + +import pytest + +from strands import Agent +from strands.types._snapshot import ( + ALL_SNAPSHOT_FIELDS, + SNAPSHOT_PRESETS, + SNAPSHOT_SCHEMA_VERSION, + VALID_SCOPES, + Snapshot, + resolve_snapshot_fields, +) +from strands.types.exceptions import SnapshotException + +# Helpers + +ISO_8601_UTC_RE = re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}(\.\d+)?Z$") + + +def _make_snapshot(**kwargs: object) -> Snapshot: + defaults: dict[str, Any] = { + "scope": "agent", + "schema_version": SNAPSHOT_SCHEMA_VERSION, + "created_at": "2025-01-15T12:00:00.000000Z", + "data": {}, + "app_data": {}, + } + defaults.update(kwargs) + return Snapshot(**defaults) + + +def _make_agent(**kwargs) -> Agent: + """Create a minimal Agent with a mock model for testing.""" + mock_model = MagicMock() + mock_model.get_config.return_value = {} + return Agent(model=mock_model, callback_handler=None, **kwargs) + + +def test_snapshot_from_dict_bad_version_raises(): + d = {"schema_version": "99.0", "created_at": "2025-01-15T12:00:00Z", "data": {}, "app_data": {}} + with pytest.raises(SnapshotException, match="Unsupported snapshot schema version"): + Snapshot.from_dict(d) + + +def test_snapshot_to_dict_round_trip(): + s = _make_snapshot(data={"messages": []}, app_data={"x": 1}) + assert Snapshot.from_dict(s.to_dict()) == s + + +def test_resolve_snapshot_fields_invalid_include_raises(): + with pytest.raises(SnapshotException, match="Invalid snapshot field"): + resolve_snapshot_fields(include=["not_a_field"]) # type: ignore[list-item] + + +def test_resolve_snapshot_fields_invalid_exclude_raises(): + with pytest.raises(SnapshotException, match="Invalid snapshot field"): + resolve_snapshot_fields(preset="session", exclude=["not_a_field"]) # type: ignore[list-item] + + +def test_resolve_snapshot_fields_no_preset_no_include_raises(): + with pytest.raises(SnapshotException, match="No snapshot fields resolved"): + resolve_snapshot_fields() + + +def test_resolve_snapshot_fields_session_preset(): + assert resolve_snapshot_fields(preset="session") == set(SNAPSHOT_PRESETS["session"]) + + +def test_resolve_snapshot_fields_include_adds_to_preset(): + fields = resolve_snapshot_fields(preset="session", include=["system_prompt"]) + assert fields == set(SNAPSHOT_PRESETS["session"]) | {"system_prompt"} + + +def test_resolve_snapshot_fields_exclude_removes_from_preset(): + fields = resolve_snapshot_fields(preset="session", exclude=["messages"]) + assert "messages" not in fields + + +def test_resolve_snapshot_fields_all_excluded_raises(): + with pytest.raises(SnapshotException): + resolve_snapshot_fields(exclude=list(ALL_SNAPSHOT_FIELDS)) # type: ignore[list-item] + + +_ORDERING_CASES = [ + # (preset, include, exclude) + ("session", [], []), + ("session", ["system_prompt"], []), + ("session", [], ["messages"]), + ("session", ["system_prompt"], ["messages", "state"]), + (None, ["messages", "state"], []), + (None, list(ALL_SNAPSHOT_FIELDS), []), + (None, list(ALL_SNAPSHOT_FIELDS), ["system_prompt"]), + ("session", ["system_prompt"], list(SNAPSHOT_PRESETS["session"])), # exclude all preset → only system_prompt +] + + +@pytest.mark.parametrize("preset,include,exclude", _ORDERING_CASES) +def test_resolve_snapshot_fields_ordering(preset, include, exclude): + expected = (set(SNAPSHOT_PRESETS[preset] if preset else []) | set(include)) - set(exclude) + + if not expected: + with pytest.raises(SnapshotException): + resolve_snapshot_fields(preset=preset, include=include or None, exclude=exclude or None) + else: + assert resolve_snapshot_fields(preset=preset, include=include or None, exclude=exclude or None) == expected + + +_STRUCTURAL_CASES = [ + ([], {}, None), + ([{"role": "user", "content": [{"text": "hi"}]}], {"k": "v"}, "system prompt"), + ([{"role": "user", "content": [{"text": "a"}]}, {"role": "user", "content": [{"text": "b"}]}], {}, None), + ([], {"num": 42, "flag": True}, "another prompt"), +] + + +@pytest.mark.parametrize("messages,state_dict,system_prompt", _STRUCTURAL_CASES) +def test_snapshot_structural_invariants(messages, state_dict, system_prompt): + agent = _make_agent(messages=messages, state=state_dict, system_prompt=system_prompt) + snapshot = agent.take_snapshot(preset="session") + + assert snapshot.schema_version == "1.0" + assert ISO_8601_UTC_RE.match(snapshot.created_at), f"created_at={snapshot.created_at!r} not ISO 8601 UTC" + assert isinstance(snapshot.data, dict) + assert isinstance(snapshot.app_data, dict) + for field in ("messages", "state", "conversation_manager_state", "interrupt_state"): + assert field in snapshot.data + assert "system_prompt" not in snapshot.data + + +_APP_DATA_CASES = [ + {"key": "value"}, + {"num": 42, "flag": True, "nothing": None}, + {"nested_str": "hello", "count": 0}, +] + + +@pytest.mark.parametrize("app_data", _APP_DATA_CASES) +def test_app_data_stored_verbatim(app_data): + agent = _make_agent() + snapshot = agent.take_snapshot(preset="session", app_data=app_data) + assert snapshot.app_data == app_data + + +_ROUND_TRIP_AGENT_CASES = [ + ([], {}), + ([{"role": "user", "content": [{"text": "hi"}]}], {"k": "v"}), + ( + [{"role": "user", "content": [{"text": "a"}]}, {"role": "user", "content": [{"text": "b"}]}], + {"num": 1, "flag": None}, + ), +] + + +@pytest.mark.parametrize("messages,state_dict", _ROUND_TRIP_AGENT_CASES) +def test_agent_state_round_trip(messages, state_dict): + agent = _make_agent(messages=messages, state=state_dict, system_prompt="original prompt") + snapshot = agent.take_snapshot(preset="session") + + fresh_agent = _make_agent(system_prompt="original prompt") + fresh_agent.load_snapshot(snapshot) + + assert fresh_agent.messages == messages + assert fresh_agent.state.get() == state_dict + assert fresh_agent.system_prompt == "original prompt" + assert fresh_agent.conversation_manager.get_state() == agent.conversation_manager.get_state() + assert fresh_agent._interrupt_state.to_dict() == agent._interrupt_state.to_dict() + + +@pytest.mark.parametrize("omitted_field", list(ALL_SNAPSHOT_FIELDS)) +def test_missing_fields_leave_agent_unchanged(omitted_field): + agent = _make_agent( + messages=[{"role": "user", "content": [{"text": "original"}]}], + state={"key": "original"}, + system_prompt="original prompt", + ) + + include_fields = [f for f in ALL_SNAPSHOT_FIELDS if f != omitted_field] + snapshot = agent.take_snapshot(include=include_fields) + # system_prompt field is stored under the key "system_prompt" in snapshot.data + data_key = "system_prompt" if omitted_field == "system_prompt" else omitted_field + assert data_key not in snapshot.data + + fresh_agent = _make_agent( + messages=list(agent.messages), + state=agent.state.get(), + system_prompt="original prompt", + ) + + if omitted_field == "messages": + before = list(fresh_agent.messages) + elif omitted_field == "state": + before = fresh_agent.state.get() + elif omitted_field == "system_prompt": + before = fresh_agent.system_prompt + elif omitted_field == "conversation_manager_state": + before = fresh_agent.conversation_manager.get_state() + elif omitted_field == "interrupt_state": + before = fresh_agent._interrupt_state.to_dict() + else: + pytest.fail(f"Unhandled field in test: {omitted_field!r}. Update this test when adding new snapshot fields.") + + fresh_agent.load_snapshot(snapshot) + + if omitted_field == "messages": + assert fresh_agent.messages == before + elif omitted_field == "state": + assert fresh_agent.state.get() == before + elif omitted_field == "system_prompt": + assert fresh_agent.system_prompt == before + elif omitted_field == "conversation_manager_state": + assert fresh_agent.conversation_manager.get_state() == before + elif omitted_field == "interrupt_state": + assert fresh_agent._interrupt_state.to_dict() == before + else: + pytest.fail(f"Unhandled field in test: {omitted_field!r}. Update this test when adding new snapshot fields.") + + +def test_snapshot_no_system_prompt_clears_target_agent_prompt(): + """Snapshot from agent with no system_prompt (field included) clears prompt on restore.""" + source_agent = _make_agent() # no system_prompt + snapshot = source_agent.take_snapshot(include=["system_prompt"]) + + assert "system_prompt" in snapshot.data + assert snapshot.data["system_prompt"] is None + + target_agent = _make_agent(system_prompt="existing prompt") + target_agent.load_snapshot(snapshot) + + assert target_agent.system_prompt is None + + +def test_snapshot_without_system_prompt_field_preserves_target_agent_prompt(): + """Snapshot taken without system_prompt field does not override target agent's prompt.""" + source_agent = _make_agent(system_prompt="source prompt") + snapshot = source_agent.take_snapshot(include=["messages"]) # system_prompt field excluded + + assert "system_prompt" not in snapshot.data + + target_agent = _make_agent(system_prompt="target prompt") + target_agent.load_snapshot(snapshot) + + assert target_agent.system_prompt == "target prompt" + + +def test_load_snapshot_messages_are_independent_copy(): + """Messages restored from a snapshot are a copy — mutating snapshot.data after load doesn't affect the agent.""" + agent = _make_agent(messages=[{"role": "user", "content": [{"text": "hello"}]}]) + snapshot = agent.take_snapshot(preset="session") + + fresh_agent = _make_agent() + fresh_agent.load_snapshot(snapshot) + + snapshot.data["messages"].append({"role": "user", "content": [{"text": "injected"}]}) + assert len(fresh_agent.messages) == 1 + + +def test_take_snapshot_messages_are_independent_copy(): + """Mutating agent messages after take_snapshot doesn't corrupt the snapshot.""" + msg = {"role": "user", "content": [{"text": "original"}]} + agent = _make_agent(messages=[msg]) + snapshot = agent.take_snapshot(preset="session") + + agent.messages[0]["content"][0]["text"] = "mutated" + assert snapshot.data["messages"][0]["content"][0]["text"] == "original" + + +def test_take_snapshot_app_data_is_independent_copy(): + """Mutating app_data after take_snapshot doesn't corrupt the snapshot.""" + app_data = {"key": "original"} + agent = _make_agent() + snapshot = agent.take_snapshot(preset="session", app_data=app_data) + + app_data["key"] = "mutated" + assert snapshot.app_data["key"] == "original" + + +# Scope validation + + +def test_valid_scopes_constant_matches_scope_type(): + """VALID_SCOPES contains exactly the values from the Scope Literal type.""" + assert set(VALID_SCOPES) == {"agent"} + + +def test_snapshot_validate_accepts_valid_scopes(): + """validate() should not raise for each valid scope value.""" + for scope in VALID_SCOPES: + snap = _make_snapshot(scope=scope) + snap.validate() # should not raise + + +def test_snapshot_validate_rejects_invalid_scope(): + """validate() should raise SnapshotException for an unrecognised scope.""" + snap = _make_snapshot(scope="invalid_scope") + with pytest.raises(SnapshotException, match="Invalid snapshot scope"): + snap.validate() + + +def test_snapshot_from_dict_rejects_invalid_scope(): + """from_dict() calls validate(), so an invalid scope should raise.""" + d = { + "scope": "bad_scope", + "schema_version": SNAPSHOT_SCHEMA_VERSION, + "created_at": "2025-01-15T12:00:00Z", + "data": {}, + "app_data": {}, + } + with pytest.raises(SnapshotException, match="Invalid snapshot scope"): + Snapshot.from_dict(d) + + +def test_snapshot_from_dict_defaults_scope_to_agent(): + """from_dict() defaults scope to 'agent' when the key is missing.""" + d = { + "schema_version": SNAPSHOT_SCHEMA_VERSION, + "created_at": "2025-01-15T12:00:00Z", + "data": {}, + "app_data": {}, + } + snap = Snapshot.from_dict(d) + assert snap.scope == "agent" + + +def test_load_snapshot_rejects_invalid_scope(): + """Agent.load_snapshot() should reject a snapshot with an invalid scope.""" + agent = _make_agent() + snap = _make_snapshot(scope="unknown") + with pytest.raises(SnapshotException, match="Invalid snapshot scope"): + agent.load_snapshot(snap) + + +def test_take_snapshot_always_produces_agent_scope(): + """take_snapshot() should always set scope to 'agent'.""" + agent = _make_agent() + snapshot = agent.take_snapshot(preset="session") + assert snapshot.scope == "agent" + + +# Individual field restore from a raw snapshot + + +def test_load_snapshot_restores_messages_only(): + """A snapshot containing only messages restores them on a fresh agent.""" + agent = _make_agent(messages=[{"role": "user", "content": [{"text": "existing"}]}]) + snap = _make_snapshot(data={"messages": [{"role": "user", "content": [{"text": "restored"}]}]}) + + agent.load_snapshot(snap) + + assert len(agent.messages) == 1 + assert agent.messages[0]["content"][0]["text"] == "restored" + + +def test_load_snapshot_restores_state_only(): + """A snapshot containing only state restores it on a fresh agent.""" + agent = _make_agent(state={"old": "value"}) + snap = _make_snapshot(data={"state": {"new_key": "new_value"}}) + + agent.load_snapshot(snap) + + assert agent.state.get() == {"new_key": "new_value"} + + +def test_load_snapshot_restores_system_prompt_only(): + """A snapshot containing only system_prompt restores it on a fresh agent.""" + agent = _make_agent(system_prompt="old prompt") + snap = _make_snapshot(data={"system_prompt": "restored prompt"}) + + agent.load_snapshot(snap) + + assert agent.system_prompt == "restored prompt" + + +def test_snapshot_json_string_round_trip(): + """Snapshot survives json.dumps / json.loads round-trip.""" + agent = _make_agent( + messages=[{"role": "user", "content": [{"text": "hello"}]}], + state={"k": "v"}, + system_prompt="test prompt", + ) + snapshot = agent.take_snapshot(preset="session", include=["system_prompt"]) + + json_str = json.dumps(snapshot.to_dict()) + restored = Snapshot.from_dict(json.loads(json_str)) + + assert restored == snapshot + + +def test_snapshot_json_store_and_restore_into_new_agent(): + """Simulate persisting a snapshot as JSON and restoring into a new agent.""" + agent = _make_agent( + messages=[{"role": "user", "content": [{"text": "test message"}]}], + state={"key": "value"}, + ) + snapshot = agent.take_snapshot(preset="session") + + stored = json.dumps(snapshot.to_dict()) + retrieved = Snapshot.from_dict(json.loads(stored)) + + new_agent = _make_agent() + new_agent.load_snapshot(retrieved) + + assert new_agent.messages == [{"role": "user", "content": [{"text": "test message"}]}] + assert new_agent.state.get() == {"key": "value"} + + +def test_snapshot_round_trip_with_tool_use_messages(): + """Snapshot preserves toolUse and toolResult content blocks through a round-trip.""" + tool_use_msg = { + "role": "assistant", + "content": [{"toolUse": {"toolUseId": "tool-123", "name": "calculator", "input": {"op": "add"}}}], + } + tool_result_msg = { + "role": "user", + "content": [{"toolResult": {"toolUseId": "tool-123", "status": "success", "content": [{"text": "6"}]}}], + } + agent = _make_agent(messages=[tool_use_msg, tool_result_msg]) + snapshot = agent.take_snapshot(include=["messages"]) + + fresh_agent = _make_agent() + fresh_agent.load_snapshot(snapshot) + + assert fresh_agent.messages == [tool_use_msg, tool_result_msg] + + +def test_take_snapshot_exclude_removes_field_from_data(): + """Excluding a field from take_snapshot omits it from snapshot.data while keeping others.""" + agent = _make_agent( + messages=[{"role": "user", "content": [{"text": "hi"}]}], + state={"k": "v"}, + ) + snapshot = agent.take_snapshot(preset="session", exclude=["messages"]) + + assert "messages" not in snapshot.data + assert "state" in snapshot.data + assert "conversation_manager_state" in snapshot.data + assert "interrupt_state" in snapshot.data + + +def test_take_snapshot_system_prompt_is_independent_copy(): + """Mutating agent system_prompt after take_snapshot doesn't corrupt the snapshot.""" + agent = _make_agent(system_prompt="original prompt") + snapshot = agent.take_snapshot(include=["system_prompt"]) + + original_content = snapshot.data["system_prompt"] + agent.system_prompt = "mutated prompt" + assert snapshot.data["system_prompt"] == original_content + assert snapshot.data["system_prompt"] != agent._system_prompt_content diff --git a/tests/strands/agent/test_summarizing_conversation_manager.py b/tests/strands/agent/test_summarizing_conversation_manager.py index 4b69e6653..dbd225e9b 100644 --- a/tests/strands/agent/test_summarizing_conversation_manager.py +++ b/tests/strands/agent/test_summarizing_conversation_manager.py @@ -1,26 +1,58 @@ from typing import cast -from unittest.mock import Mock, patch +from unittest.mock import MagicMock, Mock, patch import pytest from strands.agent.agent import Agent -from strands.agent.conversation_manager.summarizing_conversation_manager import SummarizingConversationManager +from strands.agent.conversation_manager.summarizing_conversation_manager import ( + DEFAULT_SUMMARIZATION_PROMPT, + SummarizingConversationManager, +) +from strands.hooks.events import BeforeModelCallEvent +from strands.hooks.registry import HookRegistry from strands.types.content import Messages from strands.types.exceptions import ContextWindowOverflowException from tests.fixtures.mocked_model_provider import MockedModelProvider +async def _mock_model_stream(response_text): + """Create an async generator that yields stream events for a text response. + + This simulates what a real Model.stream() returns so that process_stream() can + reconstruct the assistant message. + """ + yield {"messageStart": {"role": "assistant"}} + yield {"contentBlockStart": {"start": {}}} + yield {"contentBlockDelta": {"delta": {"text": response_text}}} + yield {"contentBlockStop": {}} + yield {"messageStop": {"stopReason": "end_turn"}} + + +async def _mock_model_stream_error(error): + """Async generator that raises an exception, simulating a model failure.""" + raise error + yield # pragma: no cover – makes this a generator + + class MockAgent: - """Mock agent for testing summarization.""" + """Mock agent for testing summarization. + + In the default path (no summarization_agent) the manager now calls + ``agent.model.stream()`` directly, so the model attribute must return a + proper async iterable. When used as a *summarization_agent* the manager + still calls ``agent("…")``, so the ``__call__`` interface is kept. + """ def __init__(self, summary_response="This is a summary of the conversation."): self.summary_response = summary_response self.system_prompt = None self.messages = [] self.model = Mock() + self.model.stream = Mock(side_effect=lambda *a, **kw: _mock_model_stream(self.summary_response)) self.call_tracker = Mock() self.tool_registry = Mock() self.tool_names = [] + self._default_structured_output_model = None def __call__(self, prompt): """Mock agent call that returns a summary.""" @@ -71,7 +103,7 @@ def test_init_clamps_summary_ratio(): def test_reduce_context_raises_when_no_agent(): - """Test that reduce_context raises exception when agent has no messages.""" + """Test that reduce_context raises exception when agent has no messages (reactive mode).""" manager = SummarizingConversationManager() # Create a mock agent with no messages @@ -79,8 +111,9 @@ def test_reduce_context_raises_when_no_agent(): empty_messages: Messages = [] mock_agent.messages = empty_messages + # Reactive mode (e is set) should raise with pytest.raises(ContextWindowOverflowException, match="insufficient messages for summarization"): - manager.reduce_context(mock_agent) + manager.reduce_context(mock_agent, e=RuntimeError("overflow")) def test_reduce_context_with_summarization(summarizing_manager, mock_agent): @@ -125,8 +158,9 @@ def test_reduce_context_too_few_messages_raises_exception(summarizing_manager, m ] mock_agent.messages = insufficient_test_messages # 5 messages, preserve_recent_messages=5, so nothing to summarize + # Reactive mode (e is set) should raise with pytest.raises(ContextWindowOverflowException, match="insufficient messages for summarization"): - manager.reduce_context(mock_agent) + manager.reduce_context(mock_agent, e=RuntimeError("overflow")) def test_reduce_context_insufficient_messages_for_summarization(mock_agent): @@ -143,17 +177,16 @@ def test_reduce_context_insufficient_messages_for_summarization(mock_agent): ] mock_agent.messages = insufficient_messages - # This should raise an exception since there aren't enough messages to summarize + # Reactive mode (e is set) should raise with pytest.raises(ContextWindowOverflowException, match="insufficient messages for summarization"): - manager.reduce_context(mock_agent) + manager.reduce_context(mock_agent, e=RuntimeError("overflow")) def test_reduce_context_raises_on_summarization_failure(): - """Test that reduce_context raises exception when summarization fails.""" - # Create an agent that will fail + """Test that reduce_context raises exception when model.stream() fails.""" failing_agent = Mock() - failing_agent.side_effect = Exception("Agent failed") - failing_agent.system_prompt = None + failing_agent.model = Mock() + failing_agent.model.stream = Mock(side_effect=lambda *a, **kw: _mock_model_stream_error(Exception("Agent failed"))) failing_agent_messages: Messages = [ {"role": "user", "content": [{"text": "Message 1"}]}, {"role": "assistant", "content": [{"text": "Response 1"}]}, @@ -168,8 +201,9 @@ def test_reduce_context_raises_on_summarization_failure(): ) with patch("strands.agent.conversation_manager.summarizing_conversation_manager.logger") as mock_logger: + # Reactive mode (e is set) should raise with pytest.raises(Exception, match="Agent failed"): - manager.reduce_context(failing_agent) + manager.reduce_context(failing_agent, e=RuntimeError("overflow")) # Should log the error mock_logger.error.assert_called_once() @@ -207,13 +241,11 @@ def test_generate_summary_with_tool_content(summarizing_manager, mock_agent): assert "text" in summary_content and summary_content["text"] == "This is a summary of the conversation." -def test_generate_summary_raises_on_agent_failure(): - """Test that _generate_summary raises exception when agent fails.""" +def test_generate_summary_raises_on_model_failure(): + """Test that _generate_summary raises exception when model.stream() fails.""" failing_agent = Mock() - failing_agent.side_effect = Exception("Agent failed") - failing_agent.system_prompt = None - empty_failing_messages: Messages = [] - failing_agent.messages = empty_failing_messages + failing_agent.model = Mock() + failing_agent.model.stream = Mock(side_effect=lambda *a, **kw: _mock_model_stream_error(Exception("Agent failed"))) manager = SummarizingConversationManager() @@ -222,7 +254,7 @@ def test_generate_summary_raises_on_agent_failure(): {"role": "assistant", "content": [{"text": "Hi there"}]}, ] - # Should raise the exception from the agent + # Should raise the exception from the model with pytest.raises(Exception, match="Agent failed"): manager._generate_summary(messages, failing_agent) @@ -325,8 +357,8 @@ def test_uses_summarization_agent_when_provided(): summary_agent.call_tracker.assert_called_once() -def test_uses_parent_agent_when_no_summarization_agent(): - """Test that parent agent is used when no summarization_agent is provided.""" +def test_default_path_calls_model_directly(): + """Test that the default path (no summarization_agent) calls model.stream() directly.""" manager = SummarizingConversationManager() messages: Messages = [ @@ -337,16 +369,36 @@ def test_uses_parent_agent_when_no_summarization_agent(): parent_agent = create_mock_agent("Parent agent summary") summary = manager._generate_summary(messages, parent_agent) - # Should use the parent agent + # Should use the model directly (via model.stream) summary_content = summary["content"][0] assert "text" in summary_content and summary_content["text"] == "Parent agent summary" - # Assert that the parent agent was called - parent_agent.call_tracker.assert_called_once() + # model.stream() should have been called + parent_agent.model.stream.assert_called_once() + + # The agent itself should NOT have been called (no re-entrant invocation) + parent_agent.call_tracker.assert_not_called() + + +def test_default_path_passes_correct_system_prompt(): + """Test that the default path passes the correct system prompt to model.stream().""" + manager = SummarizingConversationManager() + + messages: Messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there"}]}, + ] + + parent_agent = create_mock_agent() + manager._generate_summary(messages, parent_agent) + + # Verify model.stream() was called with the default summarization system prompt + call_kwargs = parent_agent.model.stream.call_args + assert call_kwargs.kwargs["system_prompt"] == DEFAULT_SUMMARIZATION_PROMPT -def test_uses_custom_system_prompt(): - """Test that custom system prompt is used when provided.""" +def test_default_path_uses_custom_system_prompt(): + """Test that custom system prompt is passed to model.stream() in default path.""" custom_prompt = "Custom system prompt for summarization" manager = SummarizingConversationManager(summarization_system_prompt=custom_prompt) mock_agent = create_mock_agent() @@ -356,16 +408,15 @@ def test_uses_custom_system_prompt(): {"role": "assistant", "content": [{"text": "Hi there"}]}, ] - # Capture the agent's system prompt changes - original_prompt = mock_agent.system_prompt manager._generate_summary(messages, mock_agent) - # The agent's system prompt should be restored after summarization - assert mock_agent.system_prompt == original_prompt + # Verify model.stream() was called with the custom system prompt + call_kwargs = mock_agent.model.stream.call_args + assert call_kwargs.kwargs["system_prompt"] == custom_prompt -def test_agent_state_restoration(): - """Test that agent state is properly restored after summarization.""" +def test_default_path_does_not_modify_agent_state(): + """Test that the default path does not modify any agent state.""" manager = SummarizingConversationManager() mock_agent = create_mock_agent() @@ -374,6 +425,7 @@ def test_agent_state_restoration(): original_messages: Messages = [{"role": "user", "content": [{"text": "Original message"}]}] mock_agent.system_prompt = original_system_prompt mock_agent.messages = original_messages.copy() + original_tool_registry = mock_agent.tool_registry messages: Messages = [ {"role": "user", "content": [{"text": "Hello"}]}, @@ -382,33 +434,99 @@ def test_agent_state_restoration(): manager._generate_summary(messages, mock_agent) - # State should be restored + # Agent state should be completely untouched assert mock_agent.system_prompt == original_system_prompt assert mock_agent.messages == original_messages + assert mock_agent.tool_registry is original_tool_registry -def test_agent_state_restoration_on_exception(): - """Test that agent state is restored even when summarization fails.""" +def test_default_path_does_not_modify_agent_state_on_exception(): + """Test that agent state is untouched when model.stream() fails in default path.""" manager = SummarizingConversationManager() - # Create an agent that fails during summarization mock_agent = Mock() mock_agent.system_prompt = "Original prompt" agent_messages: Messages = [{"role": "user", "content": [{"text": "Original"}]}] mock_agent.messages = agent_messages - mock_agent.side_effect = Exception("Summarization failed") + mock_agent.model = Mock() + mock_agent.model.stream = Mock( + side_effect=lambda *a, **kw: _mock_model_stream_error(Exception("Summarization failed")) + ) messages: Messages = [ {"role": "user", "content": [{"text": "Hello"}]}, {"role": "assistant", "content": [{"text": "Hi there"}]}, ] - # Should restore state even on exception with pytest.raises(Exception, match="Summarization failed"): manager._generate_summary(messages, mock_agent) - # State should still be restored + # Agent state should be untouched (default path never modifies it) assert mock_agent.system_prompt == "Original prompt" + assert mock_agent.messages == agent_messages + + +def test_default_path_passes_no_tool_specs(): + """Test that model.stream() is called with tool_specs=None in default path.""" + manager = SummarizingConversationManager() + + messages: Messages = [{"role": "user", "content": [{"text": "test"}]}] + agent = create_mock_agent() + + manager._generate_summary(messages, agent) + + # model.stream() should be called with tool_specs=None + call_kwargs = agent.model.stream.call_args + assert call_kwargs.kwargs.get("tool_specs") is None or call_kwargs[0][1] is None + + +def test_agent_path_state_restoration_with_summarization_agent(): + """Test that summarization_agent state is properly restored after summarization.""" + summary_agent = create_mock_agent("Summary from dedicated agent") + manager = SummarizingConversationManager(summarization_agent=summary_agent) + + # Set initial state on the summarization agent + original_system_prompt = "Agent original prompt" + original_messages: Messages = [{"role": "user", "content": [{"text": "Agent original message"}]}] + summary_agent.system_prompt = original_system_prompt + summary_agent.messages = original_messages.copy() + original_tool_registry = summary_agent.tool_registry + + messages: Messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there"}]}, + ] + + parent_agent = create_mock_agent("Should not be used") + manager._generate_summary(messages, parent_agent) + + # Summarization agent state should be restored + assert summary_agent.system_prompt == original_system_prompt + assert summary_agent.messages == original_messages + assert summary_agent.tool_registry is original_tool_registry + + +def test_agent_path_state_restoration_on_exception(): + """Test that summarization_agent state is restored even when it fails.""" + summary_agent = Mock() + summary_agent.system_prompt = "Original prompt" + agent_messages: Messages = [{"role": "user", "content": [{"text": "Original"}]}] + summary_agent.messages = agent_messages + summary_agent.side_effect = Exception("Summarization failed") + summary_agent.tool_names = [] + + manager = SummarizingConversationManager(summarization_agent=cast("Agent", summary_agent)) + + messages: Messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there"}]}, + ] + + with pytest.raises(Exception, match="Summarization failed"): + manager._generate_summary(messages, cast("Agent", Mock())) + + # State should still be restored + assert summary_agent.system_prompt == "Original prompt" def test_reduce_context_tool_pair_adjustment_works_with_forward_search(): @@ -562,9 +680,10 @@ def mock_adjust(messages, split_point): ] mock_agent.messages = simple_messages - # The adjustment method will return 0, which should trigger line 122-123 + # The adjustment method will return 0, which should trigger the <= 0 check + # Reactive mode (e is set) should raise with pytest.raises(ContextWindowOverflowException, match="insufficient messages for summarization"): - manager.reduce_context(mock_agent) + manager.reduce_context(mock_agent, e=RuntimeError("overflow")) def test_summarizing_conversation_manager_properly_records_removed_message_count(): @@ -613,27 +732,162 @@ def test_summarizing_conversation_manager_properly_records_removed_message_count @patch("strands.agent.conversation_manager.summarizing_conversation_manager.ToolRegistry") -def test_summarizing_conversation_manager_generate_summary_with_noop_tool(mock_registry_cls, summarizing_manager): +def test_summarizing_conversation_manager_generate_summary_with_noop_tool_agent_path( + mock_registry_cls, +): + """Test noop tool registration when using the agent path (summarization_agent provided).""" mock_registry = mock_registry_cls.return_value + summary_agent = create_mock_agent() + manager = SummarizingConversationManager( + summary_ratio=0.5, + preserve_recent_messages=2, + summarization_agent=summary_agent, + ) + messages = [{"role": "user", "content": [{"text": "test"}]}] - agent = create_mock_agent() + parent_agent = create_mock_agent() - original_tool_registry = agent.tool_registry - summarizing_manager._generate_summary(messages, agent) + original_tool_registry = summary_agent.tool_registry + manager._generate_summary(messages, parent_agent) - assert original_tool_registry == agent.tool_registry + assert original_tool_registry == summary_agent.tool_registry mock_registry.register_tool.assert_called_once() @patch("strands.agent.conversation_manager.summarizing_conversation_manager.ToolRegistry") -def test_summarizing_conversation_manager_generate_summary_with_tools(mock_registry_cls, summarizing_manager): +def test_summarizing_conversation_manager_generate_summary_with_tools_agent_path( + mock_registry_cls, +): + """Test no noop tool registration when summarization_agent has tools.""" mock_registry = mock_registry_cls.return_value + summary_agent = create_mock_agent() + summary_agent.tool_names = ["test_tool"] + manager = SummarizingConversationManager( + summary_ratio=0.5, + preserve_recent_messages=2, + summarization_agent=summary_agent, + ) + messages = [{"role": "user", "content": [{"text": "test"}]}] - agent = create_mock_agent() - agent.tool_names = ["test_tool"] + parent_agent = create_mock_agent() - summarizing_manager._generate_summary(messages, agent) + manager._generate_summary(messages, parent_agent) mock_registry.register_tool.assert_not_called() + + +def test_generate_summary_disables_structured_output_on_summarization_agent(): + """Test that structured output is disabled during summarization to avoid toolUse in user messages. + + When a summarization agent has structured_output_model configured, the response contains toolUse blocks. + Since the summary is converted to a user message, toolUse blocks would violate the model API constraint + that user messages cannot contain tool uses. The fix disables structured output during summarization. + """ + summary_agent = create_mock_agent() + structured_output_model = Mock() + summary_agent._default_structured_output_model = structured_output_model + + original_call = summary_agent.__class__.__call__ + observed_values = [] + + def tracking_call(self, prompt): + observed_values.append(self._default_structured_output_model) + return original_call(self, prompt) + + messages: Messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there"}]}, + ] + + manager = SummarizingConversationManager(summarization_agent=summary_agent) + + with patch.object(MockAgent, "__call__", tracking_call): + manager._generate_summary(messages, create_mock_agent()) + + assert observed_values == [None], "structured output should be disabled during summarization" + assert summary_agent._default_structured_output_model is structured_output_model, "should be restored after" + + +# ============================================================================== +# Compression Threshold Tests +# ============================================================================== + + +def _make_summarizing_threshold_agent(messages, summary_response="Summary of conversation", context_window_limit=1000): + agent = MagicMock() + agent.messages = messages + agent.model = MagicMock() + agent.model.context_window_limit = context_window_limit + agent.model.stream = Mock(side_effect=lambda *a, **kw: _mock_model_stream(summary_response)) + return agent + + +def test_proactive_compression_summarizes_when_exceeded(): + manager = SummarizingConversationManager( + summary_ratio=0.5, + preserve_recent_messages=2, + proactive_compression={"compression_threshold": 0.7}, + ) + messages = [ + {"role": "user", "content": [{"text": f"Message {i}"}]} + if i % 2 == 0 + else {"role": "assistant", "content": [{"text": f"Response {i}"}]} + for i in range(20) + ] + agent = _make_summarizing_threshold_agent(messages, context_window_limit=1000) + registry = HookRegistry() + manager.register_hooks(registry) + + event = BeforeModelCallEvent(agent=agent, invocation_state={}, projected_input_tokens=800) + registry.invoke_callbacks(event) + + # 20 * 0.5 = 10 summarized → 1 summary + 10 remaining = 11 + assert len(agent.messages) == 11 + assert agent.messages[0]["role"] == "user" + + +def test_proactive_compression_no_summarize_when_below(): + manager = SummarizingConversationManager(proactive_compression={"compression_threshold": 0.7}) + messages = [ + {"role": "user", "content": [{"text": f"Message {i}"}]} + if i % 2 == 0 + else {"role": "assistant", "content": [{"text": f"Response {i}"}]} + for i in range(20) + ] + agent = _make_summarizing_threshold_agent(messages, context_window_limit=1000) + registry = HookRegistry() + manager.register_hooks(registry) + + event = BeforeModelCallEvent(agent=agent, invocation_state={}, projected_input_tokens=500) + registry.invoke_callbacks(event) + + assert len(agent.messages) == 20 + + +def test_proactive_compression_swallows_errors(): + manager = SummarizingConversationManager( + summary_ratio=0.5, + preserve_recent_messages=2, + proactive_compression={"compression_threshold": 0.7}, + ) + messages = [ + {"role": "user", "content": [{"text": f"Message {i}"}]} + if i % 2 == 0 + else {"role": "assistant", "content": [{"text": f"Response {i}"}]} + for i in range(20) + ] + agent = MagicMock() + agent.messages = messages + agent.model = MagicMock() + agent.model.context_window_limit = 1000 + agent.model.stream = Mock(side_effect=lambda *a, **kw: _mock_model_stream_error(RuntimeError("model failed"))) + + registry = HookRegistry() + manager.register_hooks(registry) + + event = BeforeModelCallEvent(agent=agent, invocation_state={}, projected_input_tokens=800) + # Should not throw — proactive compression is best-effort + registry.invoke_callbacks(event) + assert len(agent.messages) == 20 diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 52980729c..f025a81ef 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -1,12 +1,18 @@ +import asyncio import concurrent +import threading import unittest.mock from unittest.mock import ANY, AsyncMock, MagicMock, call, patch import pytest +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter import strands import strands.telemetry from strands import Agent +from strands.event_loop._retry import ModelRetryStrategy from strands.hooks import ( AfterModelCallEvent, BeforeModelCallEvent, @@ -16,6 +22,7 @@ ) from strands.interrupt import Interrupt, _InterruptState from strands.telemetry.metrics import EventLoopMetrics +from strands.telemetry.tracer import Tracer from strands.tools.executors import SequentialToolExecutor from strands.tools.registry import ToolRegistry from strands.types._events import EventLoopStopEvent @@ -31,9 +38,7 @@ @pytest.fixture def mock_sleep(): - with unittest.mock.patch.object( - strands.event_loop.event_loop.asyncio, "sleep", new_callable=unittest.mock.AsyncMock - ) as mock: + with patch.object(strands.event_loop._retry.asyncio, "sleep", new_callable=AsyncMock) as mock: yield mock @@ -116,7 +121,11 @@ def tool_stream(tool): @pytest.fixture def hook_registry(): - return HookRegistry() + registry = HookRegistry() + # Register default retry strategy + retry_strategy = ModelRetryStrategy() + retry_strategy.register_hooks(registry) + return registry @pytest.fixture @@ -142,10 +151,14 @@ def agent(model, system_prompt, messages, tool_registry, thread_pool, hook_regis mock.tool_registry = tool_registry mock.thread_pool = thread_pool mock.event_loop_metrics = EventLoopMetrics() + mock.event_loop_metrics.reset_usage_metrics() mock.hooks = hook_registry mock.tool_executor = tool_executor mock._interrupt_state = _InterruptState() + mock._cancel_signal = threading.Event() + mock._model_state = {} mock.trace_attributes = {} + mock.retry_strategy = ModelRetryStrategy() return mock @@ -180,7 +193,7 @@ async def test_event_loop_cycle_text_response( tru_stop_reason, tru_message, _, tru_request_state, _, _ = events[-1]["stop"] exp_stop_reason = "end_turn" - exp_message = {"role": "assistant", "content": [{"text": "test text"}]} + exp_message = {"role": "assistant", "content": [{"text": "test text"}], "metadata": ANY} exp_request_state = {} assert tru_stop_reason == exp_stop_reason and tru_message == exp_message and tru_request_state == exp_request_state @@ -212,7 +225,7 @@ async def test_event_loop_cycle_text_response_throttling( tru_stop_reason, tru_message, _, tru_request_state, _, _ = events[-1]["stop"] exp_stop_reason = "end_turn" - exp_message = {"role": "assistant", "content": [{"text": "test text"}]} + exp_message = {"role": "assistant", "content": [{"text": "test text"}], "metadata": ANY} exp_request_state = {} assert tru_stop_reason == exp_stop_reason and tru_message == exp_message and tru_request_state == exp_request_state @@ -251,7 +264,7 @@ async def test_event_loop_cycle_exponential_backoff( # Verify the final response assert tru_stop_reason == "end_turn" - assert tru_message == {"role": "assistant", "content": [{"text": "test text"}]} + assert tru_message == {"role": "assistant", "content": [{"text": "test text"}], "metadata": ANY} assert tru_request_state == {} # Verify that sleep was called with increasing delays @@ -341,7 +354,7 @@ async def test_event_loop_cycle_tool_result( tru_stop_reason, tru_message, _, tru_request_state, _, _ = events[-1]["stop"] exp_stop_reason = "end_turn" - exp_message = {"role": "assistant", "content": [{"text": "test text"}]} + exp_message = {"role": "assistant", "content": [{"text": "test text"}], "metadata": ANY} exp_request_state = {} assert tru_stop_reason == exp_stop_reason and tru_message == exp_message and tru_request_state == exp_request_state @@ -376,12 +389,13 @@ async def test_event_loop_cycle_tool_result( }, ], }, - {"role": "assistant", "content": [{"text": "test text"}]}, ], tool_registry.get_all_tool_specs(), "p1", tool_choice=None, system_prompt_content=unittest.mock.ANY, + invocation_state=unittest.mock.ANY, + model_state=unittest.mock.ANY, ) @@ -469,6 +483,7 @@ async def test_event_loop_cycle_stop( } } ], + "metadata": ANY, } exp_request_state = {"stop_event_loop": True} @@ -538,6 +553,9 @@ async def test_event_loop_cycle_creates_spans( mock_get_tracer.assert_called_once() mock_tracer.start_event_loop_cycle_span.assert_called_once() mock_tracer.start_model_invoke_span.assert_called_once() + call_kwargs = mock_tracer.start_model_invoke_span.call_args[1] + assert call_kwargs["system_prompt"] == agent.system_prompt + assert call_kwargs["system_prompt_content"] == agent._system_prompt_content mock_tracer.end_model_invoke_span.assert_called_once() mock_tracer.end_event_loop_cycle_span.assert_called_once() @@ -569,8 +587,13 @@ async def test_event_loop_tracing_with_model_error( ) await alist(stream) - # Verify error handling span methods were called - mock_tracer.end_span_with_error.assert_called_once_with(model_span, "Input too long", model.stream.side_effect) + assert mock_tracer.end_span_with_error.call_count == 2 + mock_tracer.end_span_with_error.assert_has_calls( + [ + call(model_span, "Input too long", model.stream.side_effect), + call(cycle_span, "Input too long", model.stream.side_effect), + ] + ) @pytest.mark.asyncio @@ -662,6 +685,53 @@ async def test_event_loop_tracing_with_tool_execution( assert mock_tracer.end_model_invoke_span.call_count == 2 +@pytest.mark.asyncio +async def test_event_loop_cycle_closes_cycle_span_before_recursive_cycle( + agent, + model, + tool_stream, + agenerator, + alist, +): + exporter = InMemorySpanExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + + tracer = Tracer() + tracer.tracer_provider = provider + tracer.tracer = provider.get_tracer(tracer.service_name) + + async def delayed_text_stream(): + yield {"contentBlockDelta": {"delta": {"text": "test text"}}} + await asyncio.sleep(0.05) + yield {"contentBlockStop": {}} + + agent.trace_span = None + agent._system_prompt_content = None + model.config = {"model_id": "test-model"} + model.stream.side_effect = [ + agenerator(tool_stream), + delayed_text_stream(), + ] + + with patch("strands.event_loop.event_loop.get_tracer", return_value=tracer): + stream = strands.event_loop.event_loop.event_loop_cycle( + agent=agent, + invocation_state={}, + ) + await alist(stream) + + provider.force_flush() + cycle_spans = sorted( + [span for span in exporter.get_finished_spans() if span.name == "execute_event_loop_cycle"], + key=lambda span: span.start_time, + ) + + assert len(cycle_spans) == 2 + assert cycle_spans[0].end_time <= cycle_spans[1].start_time + assert cycle_spans[0].end_time < cycle_spans[1].end_time + + @patch("strands.event_loop.event_loop.get_tracer") @pytest.mark.asyncio async def test_event_loop_tracing_with_throttling_exception( @@ -691,14 +761,13 @@ async def test_event_loop_tracing_with_throttling_exception( ] # Mock the time.sleep function to speed up the test - with patch("strands.event_loop.event_loop.asyncio.sleep", new_callable=unittest.mock.AsyncMock): + with patch.object(asyncio, "sleep", new_callable=unittest.mock.AsyncMock): stream = strands.event_loop.event_loop.event_loop_cycle( agent=agent, invocation_state={}, ) await alist(stream) - # Verify error span was created for the throttling exception assert mock_tracer.end_span_with_error.call_count == 1 # Verify span was created for the successful retry assert mock_tracer.start_model_invoke_span.call_count == 2 @@ -754,6 +823,7 @@ async def test_request_state_initialization(alist): mock_agent = MagicMock() # not setting this to False results in endless recursion mock_agent._interrupt_state.activated = False + mock_agent._cancel_signal = threading.Event() mock_agent.event_loop_metrics.start_cycle.return_value = (0, MagicMock()) mock_agent.hooks.invoke_callbacks_async = AsyncMock() @@ -853,30 +923,37 @@ async def test_event_loop_cycle_exception_model_hooks(mock_sleep, agent, model, assert count == 9 # 1st call - throttled - assert next(events) == BeforeModelCallEvent(agent=agent) - assert next(events) == AfterModelCallEvent(agent=agent, stop_response=None, exception=exception) + assert next(events) == BeforeModelCallEvent(agent=agent, invocation_state=ANY) + expected_after = AfterModelCallEvent(agent=agent, invocation_state=ANY, stop_response=None, exception=exception) + expected_after.retry = True + assert next(events) == expected_after # 2nd call - throttled - assert next(events) == BeforeModelCallEvent(agent=agent) - assert next(events) == AfterModelCallEvent(agent=agent, stop_response=None, exception=exception) + assert next(events) == BeforeModelCallEvent(agent=agent, invocation_state=ANY) + expected_after = AfterModelCallEvent(agent=agent, invocation_state=ANY, stop_response=None, exception=exception) + expected_after.retry = True + assert next(events) == expected_after # 3rd call - throttled - assert next(events) == BeforeModelCallEvent(agent=agent) - assert next(events) == AfterModelCallEvent(agent=agent, stop_response=None, exception=exception) + assert next(events) == BeforeModelCallEvent(agent=agent, invocation_state=ANY) + expected_after = AfterModelCallEvent(agent=agent, invocation_state=ANY, stop_response=None, exception=exception) + expected_after.retry = True + assert next(events) == expected_after # 4th call - successful - assert next(events) == BeforeModelCallEvent(agent=agent) + assert next(events) == BeforeModelCallEvent(agent=agent, invocation_state=ANY) assert next(events) == AfterModelCallEvent( agent=agent, + invocation_state=ANY, stop_response=AfterModelCallEvent.ModelStopResponse( - message={"content": [{"text": "test text"}], "role": "assistant"}, stop_reason="end_turn" + message={"content": [{"text": "test text"}], "role": "assistant", "metadata": ANY}, stop_reason="end_turn" ), exception=None, ) # Final message assert next(events) == MessageAddedEvent( - agent=agent, message={"content": [{"text": "test text"}], "role": "assistant"} + agent=agent, message={"content": [{"text": "test text"}], "role": "assistant", "metadata": ANY} ) @@ -920,6 +997,7 @@ def interrupt_callback(event): }, ], "role": "assistant", + "metadata": ANY, }, }, "interrupts": { @@ -1054,7 +1132,7 @@ async def test_invalid_tool_names_adds_tool_uses(agent, model, alist): # ensure that we got end_turn and not tool_use assert events[-1] == EventLoopStopEvent( stop_reason="end_turn", - message={"content": [{"text": "I invoked a tool!"}], "role": "assistant"}, + message={"content": [{"text": "I invoked a tool!"}], "role": "assistant", "metadata": ANY}, metrics=ANY, request_state={}, ) @@ -1072,3 +1150,132 @@ async def test_invalid_tool_names_adds_tool_uses(agent, model, alist): ], "role": "user", } + + +@pytest.mark.asyncio +async def test_event_loop_metrics_recorded_before_recursion( + agent, + model, + tool, + agenerator, + alist, +): + model.stream.side_effect = [ + agenerator( + [ + { + "contentBlockStart": { + "start": { + "toolUse": { + "toolUseId": "t1", + "name": tool.tool_spec["name"], + }, + }, + }, + }, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "tool_use"}}, + ] + ), + agenerator( + [ + {"contentBlockDelta": {"delta": {"text": "test text"}}}, + {"contentBlockStop": {}}, + ] + ), + ] + + with unittest.mock.patch.object(agent.event_loop_metrics, "end_cycle") as mock_end_cycle: + stream = strands.event_loop.event_loop.event_loop_cycle( + agent=agent, + invocation_state={"request_state": {}}, + ) + events = await alist(stream) + + # Verify end_cycle was called once for tool cycle, once for text cycle + assert mock_end_cycle.call_count == 2 + + # Verify the event loop completed successfully + tru_stop_reason, _, _, _, _, _ = events[-1]["stop"] + assert tru_stop_reason == "end_turn" + + +class TestEstimateInputTokens: + """Tests for _estimate_input_tokens helper.""" + + @pytest.mark.asyncio + async def test_cold_start_estimates_all_messages(self): + """On cold start (no prior usage metadata), estimates all messages with lazily resolved tool specs.""" + agent = unittest.mock.AsyncMock() + agent.messages = [{"role": "user", "content": [{"text": "Hi"}]}] + agent.system_prompt = "You are helpful" + agent._system_prompt_content = None + agent.tool_registry = unittest.mock.MagicMock() + agent.tool_registry.get_all_tool_specs.return_value = [{"name": "tool1"}] + agent.model.count_tokens = AsyncMock(return_value=42) + + result = await strands.event_loop.event_loop._estimate_input_tokens(agent) + + assert result == 42 + agent.tool_registry.get_all_tool_specs.assert_called_once() + agent.model.count_tokens.assert_called_once_with( + agent.messages, + tool_specs=[{"name": "tool1"}], + system_prompt="You are helpful", + system_prompt_content=None, + ) + + @pytest.mark.asyncio + async def test_baseline_only_no_new_messages(self): + """When last message is assistant with usage and no new messages after, returns baseline.""" + agent = unittest.mock.AsyncMock() + agent.messages = [ + {"role": "user", "content": [{"text": "Hi"}]}, + { + "role": "assistant", + "content": [{"text": "Hello"}], + "metadata": {"usage": {"inputTokens": 100, "outputTokens": 20, "totalTokens": 120}}, + }, + ] + agent.system_prompt = "You are helpful" + + result = await strands.event_loop.event_loop._estimate_input_tokens(agent) + + assert result == 120 + agent.model.count_tokens.assert_not_called() + + @pytest.mark.asyncio + async def test_baseline_plus_delta(self): + """When new messages exist after last assistant, adds estimated delta to baseline.""" + agent = unittest.mock.AsyncMock() + agent.messages = [ + {"role": "user", "content": [{"text": "Hi"}]}, + { + "role": "assistant", + "content": [{"text": "Hello"}], + "metadata": {"usage": {"inputTokens": 100, "outputTokens": 30, "totalTokens": 130}}, + }, + {"role": "user", "content": [{"text": "tool result"}]}, + ] + agent.system_prompt = "You are helpful" + agent.model.count_tokens = AsyncMock(return_value=50) + + result = await strands.event_loop.event_loop._estimate_input_tokens(agent) + + # baseline (100+30) + delta (50) = 180 + assert result == 180 + agent.model.count_tokens.assert_called_once() + + @pytest.mark.asyncio + async def test_error_fallback_returns_none_at_call_site(self): + """When count_tokens raises, the caller catches and sets projected_input_tokens to None.""" + agent = unittest.mock.AsyncMock() + agent.messages = [{"role": "user", "content": [{"text": "Hi"}]}] + agent.system_prompt = "You are helpful" + agent._system_prompt_content = None + agent.tool_registry = unittest.mock.MagicMock() + agent.tool_registry.get_all_tool_specs.return_value = [] + agent.model.count_tokens = AsyncMock(side_effect=Exception("API unavailable")) + + with pytest.raises(Exception, match="API unavailable"): + await strands.event_loop.event_loop._estimate_input_tokens(agent) diff --git a/tests/strands/event_loop/test_event_loop_metadata.py b/tests/strands/event_loop/test_event_loop_metadata.py new file mode 100644 index 000000000..e6fe97f39 --- /dev/null +++ b/tests/strands/event_loop/test_event_loop_metadata.py @@ -0,0 +1,141 @@ +"""Tests for metadata population on assistant messages in the event loop.""" + +import threading +import unittest.mock + +import pytest + +import strands +import strands.event_loop.event_loop +from strands import Agent +from strands.event_loop._retry import ModelRetryStrategy +from strands.hooks import HookRegistry +from strands.interrupt import _InterruptState +from strands.telemetry.metrics import EventLoopMetrics +from strands.tools.executors import SequentialToolExecutor +from strands.tools.registry import ToolRegistry + + +@pytest.fixture +def model(): + return unittest.mock.Mock() + + +@pytest.fixture +def messages(): + return [{"role": "user", "content": [{"text": "Hello"}]}] + + +@pytest.fixture +def hook_registry(): + registry = HookRegistry() + retry_strategy = ModelRetryStrategy() + retry_strategy.register_hooks(registry) + return registry + + +@pytest.fixture +def tool_registry(): + return ToolRegistry() + + +@pytest.fixture +def agent(model, messages, tool_registry, hook_registry): + mock = unittest.mock.Mock(name="agent") + mock.__class__ = Agent + mock.config.cache_points = [] + mock.model = model + mock.system_prompt = "test" + mock.messages = messages + mock.tool_registry = tool_registry + mock.thread_pool = None + mock.event_loop_metrics = EventLoopMetrics() + mock.event_loop_metrics.reset_usage_metrics() + mock.hooks = hook_registry + mock.tool_executor = SequentialToolExecutor() + mock._interrupt_state = _InterruptState() + mock._cancel_signal = threading.Event() + mock.trace_attributes = {} + mock.retry_strategy = ModelRetryStrategy() + return mock + + +@pytest.mark.asyncio +async def test_metadata_populated_on_assistant_message(agent, model, agenerator, alist): + """After a model response, the assistant message should have metadata with usage and metrics.""" + model.stream.return_value = agenerator( + [ + {"contentBlockDelta": {"delta": {"text": "response"}}}, + {"contentBlockStop": {}}, + { + "metadata": { + "usage": {"inputTokens": 42, "outputTokens": 10, "totalTokens": 52}, + "metrics": {"latencyMs": 200}, + } + }, + ] + ) + + stream = strands.event_loop.event_loop.event_loop_cycle(agent=agent, invocation_state={}) + await alist(stream) + + # The assistant message should be in agent.messages + assistant_msg = agent.messages[-1] + assert assistant_msg["role"] == "assistant" + assert "metadata" in assistant_msg + + meta = assistant_msg["metadata"] + assert meta["usage"]["inputTokens"] == 42 + assert meta["usage"]["outputTokens"] == 10 + assert meta["usage"]["totalTokens"] == 52 + assert meta["metrics"]["latencyMs"] == 200 + + +@pytest.mark.asyncio +async def test_metadata_has_default_usage_when_no_metadata_event(agent, model, agenerator, alist): + """When no metadata event is in the stream, metadata should still be set with defaults.""" + model.stream.return_value = agenerator( + [ + {"contentBlockDelta": {"delta": {"text": "response"}}}, + {"contentBlockStop": {}}, + ] + ) + + stream = strands.event_loop.event_loop.event_loop_cycle(agent=agent, invocation_state={}) + await alist(stream) + + assistant_msg = agent.messages[-1] + assert "metadata" in assistant_msg + assert assistant_msg["metadata"]["usage"]["inputTokens"] == 0 + assert assistant_msg["metadata"]["usage"]["outputTokens"] == 0 + assert assistant_msg["metadata"]["metrics"]["latencyMs"] == 0 + + +@pytest.mark.asyncio +async def test_metadata_stripped_before_model_call(agent, model, agenerator, alist): + """Metadata from previous messages should be stripped before sending to the model.""" + # Pre-populate a message with metadata (simulating a previous turn) + agent.messages.append( + { + "role": "assistant", + "content": [{"text": "previous response"}], + "metadata": {"usage": {"inputTokens": 10, "outputTokens": 5, "totalTokens": 15}}, + } + ) + agent.messages.append({"role": "user", "content": [{"text": "follow up"}]}) + + model.stream.return_value = agenerator( + [ + {"contentBlockDelta": {"delta": {"text": "response"}}}, + {"contentBlockStop": {}}, + ] + ) + + stream = strands.event_loop.event_loop.event_loop_cycle(agent=agent, invocation_state={}) + await alist(stream) + + # Verify that messages passed to model.stream() have no metadata key + call_args = model.stream.call_args + messages_sent = call_args[0][0] + for msg in messages_sent: + assert "metadata" not in msg, f"metadata leaked to model: {msg}" diff --git a/tests/strands/event_loop/test_event_loop_structured_output.py b/tests/strands/event_loop/test_event_loop_structured_output.py index 508042af0..2d1150712 100644 --- a/tests/strands/event_loop/test_event_loop_structured_output.py +++ b/tests/strands/event_loop/test_event_loop_structured_output.py @@ -1,15 +1,25 @@ """Tests for structured output integration in the event loop.""" +import threading from unittest.mock import AsyncMock, Mock, patch import pytest +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter +from opentelemetry.trace import StatusCode from pydantic import BaseModel from strands.event_loop.event_loop import event_loop_cycle, recurse_event_loop from strands.telemetry.metrics import EventLoopMetrics +from strands.telemetry.tracer import Tracer from strands.tools.registry import ToolRegistry -from strands.tools.structured_output._structured_output_context import StructuredOutputContext +from strands.tools.structured_output._structured_output_context import ( + DEFAULT_STRUCTURED_OUTPUT_PROMPT, + StructuredOutputContext, +) from strands.types._events import EventLoopStopEvent, StructuredOutputEvent +from strands.types.exceptions import EventLoopException, StructuredOutputException class UserModel(BaseModel): @@ -37,6 +47,7 @@ def mock_agent(): agent.messages = [] agent.tool_registry = ToolRegistry() agent.event_loop_metrics = EventLoopMetrics() + agent.event_loop_metrics.reset_usage_metrics() agent.hooks = Mock() agent.hooks.invoke_callbacks_async = AsyncMock() agent.trace_span = None @@ -48,6 +59,7 @@ def mock_agent(): agent._interrupt_state = Mock() agent._interrupt_state.activated = False agent._interrupt_state.context = {} + agent._cancel_signal = threading.Event() return agent @@ -189,6 +201,8 @@ async def test_event_loop_forces_structured_output_on_end_turn( mock_agent._append_messages.assert_called_once() args = mock_agent._append_messages.call_args[0][0] assert args["role"] == "user" + # Should use the default prompt + assert args["content"][0]["text"] == DEFAULT_STRUCTURED_OUTPUT_PROMPT # Should have called recurse_event_loop with the context mock_recurse.assert_called_once() @@ -196,6 +210,134 @@ async def test_event_loop_forces_structured_output_on_end_turn( assert call_kwargs["structured_output_context"] == structured_output_context +@pytest.mark.asyncio +async def test_event_loop_forces_structured_output_with_custom_prompt(mock_agent, agenerator, alist): + """Test that event loop uses custom prompt when forcing structured output.""" + custom_prompt = "Please format your response as structured data using the output schema." + structured_output_context = StructuredOutputContext( + structured_output_model=UserModel, + structured_output_prompt=custom_prompt, + ) + + # First call returns end_turn without using structured output tool + mock_agent.model.stream.side_effect = [ + agenerator( + [ + {"contentBlockDelta": {"delta": {"text": "Here is the user info"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn"}}, + ] + ), + ] + + # Mock recurse_event_loop to return final result + with patch("strands.event_loop.event_loop.recurse_event_loop") as mock_recurse: + mock_stop_event = Mock() + mock_stop_event.stop = ( + "end_turn", + {"role": "assistant", "content": [{"text": "Done"}]}, + mock_agent.event_loop_metrics, + {}, + None, + UserModel(name="John", age=30, email="john@example.com"), + ) + mock_stop_event.__getitem__ = lambda self, key: {"stop": self.stop}[key] + + mock_recurse.return_value = agenerator([mock_stop_event]) + + stream = event_loop_cycle( + agent=mock_agent, + invocation_state={}, + structured_output_context=structured_output_context, + ) + await alist(stream) + + # Should have appended a message with the custom prompt + mock_agent._append_messages.assert_called_once() + args = mock_agent._append_messages.call_args[0][0] + assert args["role"] == "user" + assert args["content"][0]["text"] == custom_prompt + + +@patch("strands.event_loop.event_loop.get_tracer") +@pytest.mark.asyncio +async def test_event_loop_structured_output_failure_closes_cycle_span_with_error( + mock_get_tracer, + mock_agent, + structured_output_context, + agenerator, + alist, +): + mock_tracer = Mock() + cycle_span = Mock() + model_span = Mock() + mock_tracer.start_event_loop_cycle_span.return_value = cycle_span + mock_tracer.start_model_invoke_span.return_value = model_span + mock_get_tracer.return_value = mock_tracer + + structured_output_context.set_forced_mode() + mock_agent.model.stream.return_value = agenerator( + [ + {"contentBlockDelta": {"delta": {"text": "Still not structured"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn"}}, + ] + ) + + expected_message = "The model failed to invoke the structured output tool even after it was forced." + with pytest.raises(StructuredOutputException, match=expected_message): + stream = event_loop_cycle( + agent=mock_agent, + invocation_state={}, + structured_output_context=structured_output_context, + ) + await alist(stream) + + mock_tracer.end_model_invoke_span.assert_called_once() + mock_tracer.end_event_loop_cycle_span.assert_not_called() + mock_tracer.end_span_with_error.assert_called_once() + assert mock_tracer.end_span_with_error.call_args.args[0] == cycle_span + assert mock_tracer.end_span_with_error.call_args.args[1] == expected_message + assert isinstance(mock_tracer.end_span_with_error.call_args.args[2], StructuredOutputException) + + +@pytest.mark.asyncio +async def test_event_loop_forced_structured_output_append_failure_records_error_span( + mock_agent, structured_output_context, agenerator, alist +): + exporter = InMemorySpanExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + + tracer = Tracer() + tracer.tracer_provider = provider + tracer.tracer = provider.get_tracer(tracer.service_name) + + mock_agent.model.stream.return_value = agenerator( + [ + {"contentBlockDelta": {"delta": {"text": "Here is the user info"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn"}}, + ] + ) + + mock_agent._append_messages = AsyncMock(side_effect=RuntimeError("append failed")) + + with patch("strands.event_loop.event_loop.get_tracer", return_value=tracer): + with pytest.raises(EventLoopException, match="append failed"): + stream = event_loop_cycle( + agent=mock_agent, + invocation_state={}, + structured_output_context=structured_output_context, + ) + await alist(stream) + + finished_cycle_spans = [span for span in exporter.get_finished_spans() if span.name == "execute_event_loop_cycle"] + + assert len(finished_cycle_spans) == 1 + assert finished_cycle_spans[0].status.status_code == StatusCode.ERROR + + @pytest.mark.asyncio async def test_structured_output_tool_execution_extracts_result( mock_agent, structured_output_context, agenerator, alist diff --git a/tests/strands/event_loop/test_recover_message_on_max_tokens_reached.py b/tests/strands/event_loop/test_recover_message_on_max_tokens_reached.py index 402e90966..6dff0fc29 100644 --- a/tests/strands/event_loop/test_recover_message_on_max_tokens_reached.py +++ b/tests/strands/event_loop/test_recover_message_on_max_tokens_reached.py @@ -224,6 +224,34 @@ def test_recover_message_on_max_tokens_reached_multiple_incomplete_tools(): assert "incomplete due to maximum token limits" in result["content"][2]["text"] +def test_recover_message_on_max_tokens_reached_preserves_metadata(): + """Test that metadata is preserved through recovery.""" + message: Message = { + "role": "assistant", + "content": [ + {"toolUse": {"name": "calculator", "input": {}, "toolUseId": "123"}}, + ], + "metadata": {"usage": {"inputTokens": 42, "outputTokens": 10, "totalTokens": 52}}, + } + + result = recover_message_on_max_tokens_reached(message) + + assert "metadata" in result + assert result["metadata"]["usage"]["inputTokens"] == 42 + + +def test_recover_message_on_max_tokens_reached_without_metadata(): + """Test that recovery works fine when no metadata is present.""" + message: Message = { + "role": "assistant", + "content": [{"text": "some text"}], + } + + result = recover_message_on_max_tokens_reached(message) + + assert "metadata" not in result + + def test_recover_message_on_max_tokens_reached_preserves_user_role(): """Test that the function preserves the original message role.""" incomplete_message: Message = { diff --git a/tests/strands/event_loop/test_streaming.py b/tests/strands/event_loop/test_streaming.py index 02be400b1..93f8d95f8 100644 --- a/tests/strands/event_loop/test_streaming.py +++ b/tests/strands/event_loop/test_streaming.py @@ -48,6 +48,7 @@ def moto_autouse(moto_env, moto_mock_aws): ), ], ) +@pytest.mark.filterwarnings("ignore:remove_blank_messages_content_text is deprecated:DeprecationWarning") def test_remove_blank_messages_content_text(messages, exp_result): tru_result = strands.event_loop.streaming.remove_blank_messages_content_text(messages) @@ -124,6 +125,10 @@ def test_handle_message_start(): {"start": {"toolUse": {"toolUseId": "test", "name": "test"}}}, {"toolUseId": "test", "name": "test", "input": ""}, ), + ( + {"start": {"toolUse": {"toolUseId": "test", "name": "test", "reasoningSignature": "YWJj"}}}, + {"toolUseId": "test", "name": "test", "input": "", "reasoningSignature": "YWJj"}, + ), ], ) def test_handle_content_block_start(chunk: ContentBlockStartEvent, exp_tool_use): @@ -215,6 +220,59 @@ def test_handle_content_block_start(chunk: ContentBlockStartEvent, exp_tool_use) {}, {}, ), + # Citation - New + ( + { + "delta": { + "citation": { + "location": {"documentChar": {"documentIndex": 0, "start": 10, "end": 20}}, + "title": "Test Doc", + } + } + }, + {}, + {}, + { + "citationsContent": [ + {"location": {"documentChar": {"documentIndex": 0, "start": 10, "end": 20}}, "title": "Test Doc"} + ] + }, + { + "citation": { + "location": {"documentChar": {"documentIndex": 0, "start": 10, "end": 20}}, + "title": "Test Doc", + } + }, + ), + # Citation - Existing + ( + { + "delta": { + "citation": { + "location": {"documentPage": {"documentIndex": 1, "start": 5, "end": 6}}, + "title": "Another Doc", + } + } + }, + {}, + { + "citationsContent": [ + {"location": {"documentChar": {"documentIndex": 0, "start": 10, "end": 20}}, "title": "Test Doc"} + ] + }, + { + "citationsContent": [ + {"location": {"documentChar": {"documentIndex": 0, "start": 10, "end": 20}}, "title": "Test Doc"}, + {"location": {"documentPage": {"documentIndex": 1, "start": 5, "end": 6}}, "title": "Another Doc"}, + ] + }, + { + "citation": { + "location": {"documentPage": {"documentIndex": 1, "start": 5, "end": 6}}, + "title": "Another Doc", + } + }, + ), # Empty ( {"delta": {}}, @@ -256,6 +314,39 @@ def test_handle_content_block_delta(event: ContentBlockDeltaEvent, event_type, s "redactedContent": b"", }, ), + # Tool Use - With reasoningSignature + ( + { + "content": [], + "current_tool_use": { + "toolUseId": "123", + "name": "test", + "input": '{"key": "value"}', + "reasoningSignature": "YWJj", + }, + "text": "", + "reasoningText": "", + "citationsContent": [], + "redactedContent": b"", + }, + { + "content": [ + { + "toolUse": { + "toolUseId": "123", + "name": "test", + "input": {"key": "value"}, + "reasoningSignature": "YWJj", + } + } + ], + "current_tool_use": {}, + "text": "", + "reasoningText": "", + "citationsContent": [], + "redactedContent": b"", + }, + ), # Tool Use - Missing input ( { @@ -294,22 +385,59 @@ def test_handle_content_block_delta(event: ContentBlockDeltaEvent, event_type, s "redactedContent": b"", }, ), - # Citations + # Text with Citations ( { "content": [], "current_tool_use": {}, + "text": "This is cited text", + "reasoningText": "", + "citationsContent": [ + {"location": {"documentChar": {"documentIndex": 0, "start": 10, "end": 20}}, "title": "Test Doc"} + ], + "redactedContent": b"", + }, + { + "content": [ + { + "citationsContent": { + "citations": [ + { + "location": {"documentChar": {"documentIndex": 0, "start": 10, "end": 20}}, + "title": "Test Doc", + } + ], + "content": [{"text": "This is cited text"}], + } + } + ], + "current_tool_use": {}, "text": "", "reasoningText": "", - "citationsContent": [{"citations": [{"text": "test", "source": "test"}]}], + "citationsContent": [], "redactedContent": b"", }, + ), + # Citations without text (should not create content block) + ( { "content": [], "current_tool_use": {}, "text": "", "reasoningText": "", - "citationsContent": [{"citations": [{"text": "test", "source": "test"}]}], + "citationsContent": [ + {"location": {"documentChar": {"documentIndex": 0, "start": 10, "end": 20}}, "title": "Test Doc"} + ], + "redactedContent": b"", + }, + { + "content": [], + "current_tool_use": {}, + "text": "", + "reasoningText": "", + "citationsContent": [ + {"location": {"documentChar": {"documentIndex": 0, "start": 10, "end": 20}}, "title": "Test Doc"} + ], "redactedContent": b"", }, ), @@ -402,12 +530,30 @@ def test_handle_content_block_stop(state, exp_updated_state): def test_handle_message_stop(): event: MessageStopEvent = {"stopReason": "end_turn"} - tru_reason = strands.event_loop.streaming.handle_message_stop(event) + tru_reason = strands.event_loop.streaming.handle_message_stop(event, []) exp_reason = "end_turn" assert tru_reason == exp_reason +def test_handle_message_stop_overrides_end_turn_when_tool_use_present(): + event: MessageStopEvent = {"stopReason": "end_turn"} + content = [{"toolUse": {"toolUseId": "t1", "name": "myTool", "input": {}}}] + + tru_reason = strands.event_loop.streaming.handle_message_stop(event, content) + + assert tru_reason == "tool_use" + + +def test_handle_message_stop_keeps_tool_use_unchanged(): + event: MessageStopEvent = {"stopReason": "tool_use"} + content = [{"toolUse": {"toolUseId": "t1", "name": "myTool", "input": {}}}] + + tru_reason = strands.event_loop.streaming.handle_message_stop(event, content) + + assert tru_reason == "tool_use" + + def test_extract_usage_metrics(): event = { "usage": {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}, @@ -578,6 +724,137 @@ def test_extract_usage_metrics_empty_metadata(): }, ], ), + # Message with Citations + ( + [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockStart": {"start": {}}}, + {"contentBlockDelta": {"delta": {"text": "This is cited text"}}}, + { + "contentBlockDelta": { + "delta": { + "citation": { + "location": {"documentChar": {"documentIndex": 0, "start": 10, "end": 20}}, + "title": "Test Doc", + } + } + } + }, + { + "contentBlockDelta": { + "delta": { + "citation": { + "location": {"documentPage": {"documentIndex": 1, "start": 5, "end": 6}}, + "title": "Another Doc", + } + } + } + }, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn"}}, + { + "metadata": { + "usage": {"inputTokens": 5, "outputTokens": 10, "totalTokens": 15}, + "metrics": {"latencyMs": 100}, + } + }, + ], + [ + {"event": {"messageStart": {"role": "assistant"}}}, + {"event": {"contentBlockStart": {"start": {}}}}, + {"event": {"contentBlockDelta": {"delta": {"text": "This is cited text"}}}}, + {"data": "This is cited text", "delta": {"text": "This is cited text"}}, + { + "event": { + "contentBlockDelta": { + "delta": { + "citation": { + "location": {"documentChar": {"documentIndex": 0, "start": 10, "end": 20}}, + "title": "Test Doc", + } + } + } + } + }, + { + "citation": { + "location": {"documentChar": {"documentIndex": 0, "start": 10, "end": 20}}, + "title": "Test Doc", + }, + "delta": { + "citation": { + "location": {"documentChar": {"documentIndex": 0, "start": 10, "end": 20}}, + "title": "Test Doc", + } + }, + }, + { + "event": { + "contentBlockDelta": { + "delta": { + "citation": { + "location": {"documentPage": {"documentIndex": 1, "start": 5, "end": 6}}, + "title": "Another Doc", + } + } + } + } + }, + { + "citation": { + "location": {"documentPage": {"documentIndex": 1, "start": 5, "end": 6}}, + "title": "Another Doc", + }, + "delta": { + "citation": { + "location": {"documentPage": {"documentIndex": 1, "start": 5, "end": 6}}, + "title": "Another Doc", + } + }, + }, + {"event": {"contentBlockStop": {}}}, + {"event": {"messageStop": {"stopReason": "end_turn"}}}, + { + "event": { + "metadata": { + "usage": {"inputTokens": 5, "outputTokens": 10, "totalTokens": 15}, + "metrics": {"latencyMs": 100}, + } + } + }, + { + "stop": ( + "end_turn", + { + "role": "assistant", + "content": [ + { + "citationsContent": { + "citations": [ + { + "location": { + "documentChar": {"documentIndex": 0, "start": 10, "end": 20} + }, + "title": "Test Doc", + }, + { + "location": { + "documentPage": {"documentIndex": 1, "start": 5, "end": 6} + }, + "title": "Another Doc", + }, + ], + "content": [{"text": "This is cited text"}], + } + } + ], + }, + {"inputTokens": 5, "outputTokens": 10, "totalTokens": 15}, + {"latencyMs": 100}, + ) + }, + ], + ), # Empty Message ( [{}], @@ -896,6 +1173,8 @@ async def test_stream_messages(agenerator, alist): "test prompt", tool_choice=None, system_prompt_content=[{"text": "test prompt"}], + invocation_state=None, + model_state=None, ) @@ -929,6 +1208,8 @@ async def test_stream_messages_with_system_prompt_content(agenerator, alist): None, tool_choice=None, system_prompt_content=system_prompt_content, + invocation_state=None, + model_state=None, ) @@ -962,6 +1243,8 @@ async def test_stream_messages_single_text_block_backwards_compatibility(agenera "You are a helpful assistant.", tool_choice=None, system_prompt_content=system_prompt_content, + invocation_state=None, + model_state=None, ) @@ -993,6 +1276,8 @@ async def test_stream_messages_empty_system_prompt_content(agenerator, alist): None, tool_choice=None, system_prompt_content=[], + invocation_state=None, + model_state=None, ) @@ -1024,6 +1309,8 @@ async def test_stream_messages_none_system_prompt_content(agenerator, alist): None, tool_choice=None, system_prompt_content=None, + invocation_state=None, + model_state=None, ) # Ensure that we're getting typed events coming out of process_stream @@ -1070,3 +1357,68 @@ async def test_stream_messages_normalizes_messages(agenerator, alist): {"content": [{"toolUse": {"name": "INVALID_TOOL_NAME"}}], "role": "assistant"}, {"content": [{"toolUse": {"name": "INVALID_TOOL_NAME"}}], "role": "assistant"}, ] + + +@pytest.mark.asyncio +async def test_process_stream_overrides_end_turn_when_tool_use_present(agenerator, alist): + response = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockStart": {"contentBlockIndex": 0, "start": {"toolUse": {"toolUseId": "t1", "name": "myTool"}}}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"key": "val"}'}}, "contentBlockIndex": 0}}, + {"contentBlockStop": {"contentBlockIndex": 0}}, + {"messageStop": {"stopReason": "end_turn"}}, + { + "metadata": { + "usage": {"inputTokens": 10, "outputTokens": 20, "totalTokens": 30}, + "metrics": {"latencyMs": 100}, + } + }, + ] + + stream = strands.event_loop.streaming.process_stream(agenerator(response)) + last_event = cast(ModelStopReason, (await alist(stream))[-1]) + + assert last_event["stop"][0] == "tool_use" + + +@pytest.mark.asyncio +async def test_process_stream_keeps_end_turn_when_no_tool_use(agenerator, alist): + response = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockDelta": {"delta": {"text": "Hello!"}, "contentBlockIndex": 0}}, + {"contentBlockStop": {"contentBlockIndex": 0}}, + {"messageStop": {"stopReason": "end_turn"}}, + { + "metadata": { + "usage": {"inputTokens": 10, "outputTokens": 20, "totalTokens": 30}, + "metrics": {"latencyMs": 100}, + } + }, + ] + + stream = strands.event_loop.streaming.process_stream(agenerator(response)) + last_event = cast(ModelStopReason, (await alist(stream))[-1]) + + assert last_event["stop"][0] == "end_turn" + + +@pytest.mark.asyncio +async def test_process_stream_keeps_tool_use_stop_reason_unchanged(agenerator, alist): + response = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockStart": {"contentBlockIndex": 0, "start": {"toolUse": {"toolUseId": "t1", "name": "myTool"}}}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": "{}"}}, "contentBlockIndex": 0}}, + {"contentBlockStop": {"contentBlockIndex": 0}}, + {"messageStop": {"stopReason": "tool_use"}}, + { + "metadata": { + "usage": {"inputTokens": 10, "outputTokens": 20, "totalTokens": 30}, + "metrics": {"latencyMs": 100}, + } + }, + ] + + stream = strands.event_loop.streaming.process_stream(agenerator(response)) + last_event = cast(ModelStopReason, (await alist(stream))[-1]) + + assert last_event["stop"][0] == "tool_use" diff --git a/tests/strands/event_loop/test_streaming_structured_output.py b/tests/strands/event_loop/test_streaming_structured_output.py index 4645e1724..3c7358237 100644 --- a/tests/strands/event_loop/test_streaming_structured_output.py +++ b/tests/strands/event_loop/test_streaming_structured_output.py @@ -66,6 +66,8 @@ async def test_stream_messages_with_tool_choice(agenerator, alist): "test prompt", tool_choice=tool_choice, system_prompt_content=[{"text": "test prompt"}], + invocation_state=None, + model_state=None, ) # Verify we get the expected events @@ -131,6 +133,8 @@ async def test_stream_messages_with_forced_structured_output(agenerator, alist): "Extract user information", tool_choice=tool_choice, system_prompt_content=[{"text": "Extract user information"}], + invocation_state=None, + model_state=None, ) assert len(tru_events) > 0 diff --git a/tests/strands/experimental/bidi/_async/test__init__.py b/tests/strands/experimental/bidi/_async/test__init__.py index f8df25e14..a121ddecc 100644 --- a/tests/strands/experimental/bidi/_async/test__init__.py +++ b/tests/strands/experimental/bidi/_async/test__init__.py @@ -10,17 +10,19 @@ async def test_stop_exception(): func1 = AsyncMock() func2 = AsyncMock(side_effect=ValueError("stop 2 failed")) func3 = AsyncMock() + func4 = AsyncMock(side_effect=ValueError("stop 4 failed")) - with pytest.raises(ExceptionGroup) as exc_info: - await stop_all(func1, func2, func3) + with pytest.raises(Exception, match=r"failed stop sequence") as exc_info: + await stop_all(func1, func2, func3, func4) func1.assert_called_once() func2.assert_called_once() func3.assert_called_once() + func4.assert_called_once() - assert len(exc_info.value.exceptions) == 1 - with pytest.raises(ValueError, match=r"stop 2 failed"): - raise exc_info.value.exceptions[0] + tru_message = str(exc_info.value) + assert "ValueError('stop 2 failed')" in tru_message + assert "ValueError('stop 4 failed')" in tru_message @pytest.mark.asyncio diff --git a/tests/strands/experimental/bidi/_async/test_task_group.py b/tests/strands/experimental/bidi/_async/test_task_group.py new file mode 100644 index 000000000..255ead15e --- /dev/null +++ b/tests/strands/experimental/bidi/_async/test_task_group.py @@ -0,0 +1,74 @@ +import asyncio +import unittest.mock + +import pytest + +from strands.experimental.bidi._async._task_group import _TaskGroup + + +@pytest.mark.asyncio +async def test_task_group__aexit__(): + coro = unittest.mock.AsyncMock() + + async with _TaskGroup() as task_group: + task_group.create_task(coro()) + + coro.assert_called_once() + + +@pytest.mark.asyncio +async def test_task_group__aexit__task_exception(): + wait_event = asyncio.Event() + + async def wait(): + await wait_event.wait() + + async def fail(): + raise ValueError("test error") + + with pytest.raises(ValueError, match=r"test error"): + async with _TaskGroup() as task_group: + wait_task = task_group.create_task(wait()) + fail_task = task_group.create_task(fail()) + + assert wait_task.cancelled() + assert not fail_task.cancelled() + + +@pytest.mark.asyncio +async def test_task_group__aexit__task_cancelled(): + async def wait(): + asyncio.current_task().cancel() + await asyncio.sleep(0) + + async with _TaskGroup() as task_group: + wait_task = task_group.create_task(wait()) + + assert wait_task.cancelled() + + +@pytest.mark.asyncio +async def test_task_group__aexit__context_cancelled(): + wait_event = asyncio.Event() + + async def wait(): + await wait_event.wait() + + tasks = [] + + run_event = asyncio.Event() + + async def run(): + async with _TaskGroup() as task_group: + tasks.append(task_group.create_task(wait())) + run_event.set() + + run_task = asyncio.create_task(run()) + await run_event.wait() + run_task.cancel() + + with pytest.raises(asyncio.CancelledError): + await run_task + + wait_task = tasks[0] + assert wait_task.cancelled() diff --git a/tests/strands/experimental/bidi/agent/__init__.py b/tests/strands/experimental/bidi/agent/__init__.py index 3359c6565..dd401a83d 100644 --- a/tests/strands/experimental/bidi/agent/__init__.py +++ b/tests/strands/experimental/bidi/agent/__init__.py @@ -1 +1 @@ -"""Bidirectional streaming agent tests.""" \ No newline at end of file +"""Bidirectional streaming agent tests.""" diff --git a/tests/strands/experimental/bidi/agent/test_agent.py b/tests/strands/experimental/bidi/agent/test_agent.py index 19d3525d7..50c9afef9 100644 --- a/tests/strands/experimental/bidi/agent/test_agent.py +++ b/tests/strands/experimental/bidi/agent/test_agent.py @@ -1,21 +1,23 @@ """Unit tests for BidiAgent.""" -import unittest.mock import asyncio -import pytest +import sys +import unittest.mock from uuid import uuid4 +import pytest + from strands.experimental.bidi.agent.agent import BidiAgent -from strands.experimental.bidi.models.nova_sonic import BidiNovaSonicModel from strands.experimental.bidi.types.events import ( - BidiTextInputEvent, BidiAudioInputEvent, BidiAudioStreamEvent, - BidiTranscriptStreamEvent, - BidiConnectionStartEvent, BidiConnectionCloseEvent, + BidiConnectionStartEvent, + BidiTextInputEvent, + BidiTranscriptStreamEvent, ) + class MockBidiModel: """Mock bidirectional model for testing.""" @@ -46,14 +48,14 @@ async def receive(self): """Async generator yielding mock events.""" if not self._started: raise RuntimeError("model not started | call start before sending/receiving") - + # Yield connection start event yield BidiConnectionStartEvent(connection_id=self._connection_id, model=self.model_id) - + # Yield any configured events for event in self._events_to_yield: yield event - + # Yield connection end event yield BidiConnectionCloseEvent(connection_id=self._connection_id, reason="complete") @@ -61,11 +63,13 @@ def set_events(self, events): """Helper to set events this mock model will yield.""" self._events_to_yield = events + @pytest.fixture def mock_model(): """Create a mock BidiModel instance.""" return MockBidiModel() + @pytest.fixture def mock_tool_registry(): """Mock tool registry with some basic tools.""" @@ -73,15 +77,15 @@ def mock_tool_registry(): registry.get_all_tool_specs.return_value = [ { "name": "calculator", - "description": "Perform calculations", - "inputSchema": {"json": {"type": "object", "properties": {}}} + "description": "Perform calculations", + "inputSchema": {"json": {"type": "object", "properties": {}}}, } ] registry.get_all_tools_config.return_value = {"calculator": {}} return registry -@pytest.fixture +@pytest.fixture def mock_tool_caller(): """Mock tool caller for testing tool execution.""" caller = unittest.mock.AsyncMock() @@ -94,203 +98,198 @@ def agent(mock_model, mock_tool_registry, mock_tool_caller): """Create a BidiAgent instance for testing.""" with unittest.mock.patch("strands.experimental.bidi.agent.agent.ToolRegistry") as mock_registry_class: mock_registry_class.return_value = mock_tool_registry - + with unittest.mock.patch("strands.experimental.bidi.agent.agent._ToolCaller") as mock_caller_class: mock_caller_class.return_value = mock_tool_caller - + # Don't pass tools to avoid real tool loading agent = BidiAgent(model=mock_model) return agent + def test_bidi_agent_init_with_various_configurations(): """Test agent initialization with various configurations.""" # Test default initialization mock_model = MockBidiModel() agent = BidiAgent(model=mock_model) - + assert agent.model == mock_model assert agent.system_prompt is None assert not agent._started assert agent.model._connection_id is None - + # Test with configuration system_prompt = "You are a helpful assistant." - agent_with_config = BidiAgent( - model=mock_model, - system_prompt=system_prompt, - agent_id="test_agent" - ) - + agent_with_config = BidiAgent(model=mock_model, system_prompt=system_prompt, agent_id="test_agent") + assert agent_with_config.system_prompt == system_prompt assert agent_with_config.agent_id == "test_agent" - - # Test with string model ID - model_id = "amazon.nova-sonic-v1:0" - agent_with_string = BidiAgent(model=model_id) - - assert isinstance(agent_with_string.model, BidiNovaSonicModel) - assert agent_with_string.model.model_id == model_id - + # Test model config access config = agent.model.config assert config["audio"]["input_rate"] == 16000 assert config["audio"]["output_rate"] == 24000 assert config["audio"]["channels"] == 1 + +@pytest.mark.skipif(sys.version_info < (3, 12), reason="BidiNovaSonicModel is only supported for Python 3.12+") +def test_bidi_agent_init_with_model_id(): + from strands.experimental.bidi.models.nova_sonic import BidiNovaSonicModel + + model_id = "amazon.nova-sonic-v1:0" + agent = BidiAgent(model=model_id) + + assert isinstance(agent.model, BidiNovaSonicModel) + assert agent.model.model_id == model_id + + @pytest.mark.asyncio async def test_bidi_agent_start_stop_lifecycle(agent): """Test agent start/stop lifecycle and state management.""" # Initial state assert not agent._started assert agent.model._connection_id is None - + # Start agent await agent.start() assert agent._started assert agent.model._connection_id is not None connection_id = agent.model._connection_id - + # Double start should error with pytest.raises(RuntimeError, match="agent already started"): await agent.start() - + # Stop agent await agent.stop() assert not agent._started assert agent.model._connection_id is None - + # Multiple stops should be safe await agent.stop() await agent.stop() - + # Restart should work with new connection ID await agent.start() assert agent._started assert agent.model._connection_id != connection_id + @pytest.mark.asyncio async def test_bidi_agent_send_with_input_types(agent): """Test sending various input types through agent.send().""" await agent.start() - + # Test text input with TypedEvent text_input = BidiTextInputEvent(text="Hello", role="user") await agent.send(text_input) assert len(agent.messages) == 1 assert agent.messages[0]["content"][0]["text"] == "Hello" - + # Test string input (shorthand) await agent.send("World") assert len(agent.messages) == 2 assert agent.messages[1]["content"][0]["text"] == "World" - + # Test audio input (doesn't add to messages) audio_input = BidiAudioInputEvent( audio="dGVzdA==", # base64 "test" format="pcm", sample_rate=16000, - channels=1 + channels=1, ) await agent.send(audio_input) assert len(agent.messages) == 2 # Still 2, audio doesn't add - + # Test concurrent sends - sends = [ - agent.send(BidiTextInputEvent(text=f"Message {i}", role="user")) - for i in range(3) - ] + sends = [agent.send(BidiTextInputEvent(text=f"Message {i}", role="user")) for i in range(3)] await asyncio.gather(*sends) assert len(agent.messages) == 5 # 2 + 3 new messages + @pytest.mark.asyncio async def test_bidi_agent_receive_events_from_model(agent): """Test receiving events from model.""" # Configure mock model to yield events events = [ - BidiAudioStreamEvent( - audio="dGVzdA==", - format="pcm", - sample_rate=24000, - channels=1 - ), + BidiAudioStreamEvent(audio="dGVzdA==", format="pcm", sample_rate=24000, channels=1), BidiTranscriptStreamEvent( text="Hello world", role="assistant", is_final=True, delta={"text": "Hello world"}, - current_transcript="Hello world" - ) + current_transcript="Hello world", + ), ] agent.model.set_events(events) - + await agent.start() - + received_events = [] async for event in agent.receive(): received_events.append(event) if len(received_events) >= 4: # Stop after getting expected events break - + # Verify event types and order assert len(received_events) >= 3 assert isinstance(received_events[0], BidiConnectionStartEvent) assert isinstance(received_events[1], BidiAudioStreamEvent) assert isinstance(received_events[2], BidiTranscriptStreamEvent) - + # Test empty events agent.model.set_events([]) await agent.stop() await agent.start() - + empty_events = [] async for event in agent.receive(): empty_events.append(event) if len(empty_events) >= 2: break - + assert len(empty_events) >= 1 assert isinstance(empty_events[0], BidiConnectionStartEvent) + def test_bidi_agent_tool_integration(agent, mock_tool_registry): """Test agent tool integration and properties.""" # Test tool property access - assert hasattr(agent, 'tool') + assert hasattr(agent, "tool") assert agent.tool is not None assert agent.tool == agent._tool_caller - + # Test tool names property - mock_tool_registry.get_all_tools_config.return_value = { - "calculator": {}, - "weather": {} - } - + mock_tool_registry.get_all_tools_config.return_value = {"calculator": {}, "weather": {}} + tool_names = agent.tool_names assert isinstance(tool_names, list) assert len(tool_names) == 2 assert "calculator" in tool_names assert "weather" in tool_names + @pytest.mark.asyncio async def test_bidi_agent_send_receive_error_before_start(agent): """Test error handling in various scenarios.""" # Test send before start with pytest.raises(RuntimeError, match="call start before"): await agent.send(BidiTextInputEvent(text="Hello", role="user")) - + # Test receive before start with pytest.raises(RuntimeError, match="call start before"): - async for event in agent.receive(): + async for _ in agent.receive(): pass - + # Test send after stop await agent.start() await agent.stop() with pytest.raises(RuntimeError, match="call start before"): await agent.send(BidiTextInputEvent(text="Hello", role="user")) - + # Test receive after stop with pytest.raises(RuntimeError, match="call start before"): - async for event in agent.receive(): + async for _ in agent.receive(): pass @@ -301,43 +300,44 @@ async def test_bidi_agent_start_receive_propagates_model_errors(): mock_model = MockBidiModel() mock_model.start = unittest.mock.AsyncMock(side_effect=Exception("Connection failed")) error_agent = BidiAgent(model=mock_model) - + with pytest.raises(Exception, match="Connection failed"): await error_agent.start() - + # Test model receive error mock_model2 = MockBidiModel() agent2 = BidiAgent(model=mock_model2) await agent2.start() - + async def failing_receive(): yield BidiConnectionStartEvent(connection_id="test", model="test-model") raise Exception("Receive failed") - + agent2.model.receive = failing_receive with pytest.raises(Exception, match="Receive failed"): - async for event in agent2.receive(): + async for _ in agent2.receive(): pass + @pytest.mark.asyncio async def test_bidi_agent_state_consistency(agent): """Test that agent state remains consistent across operations.""" # Initial state assert not agent._started assert agent.model._connection_id is None - + # Start await agent.start() assert agent._started assert agent.model._connection_id is not None connection_id = agent.model._connection_id - + # Send operations shouldn't change connection state await agent.send(BidiTextInputEvent(text="Hello", role="user")) assert agent._started assert agent.model._connection_id == connection_id - + # Stop await agent.stop() assert not agent._started - assert agent.model._connection_id is None \ No newline at end of file + assert agent.model._connection_id is None diff --git a/tests/strands/experimental/bidi/agent/test_loop.py b/tests/strands/experimental/bidi/agent/test_loop.py index 0ce8d6658..a8efd9a93 100644 --- a/tests/strands/experimental/bidi/agent/test_loop.py +++ b/tests/strands/experimental/bidi/agent/test_loop.py @@ -1,13 +1,13 @@ import unittest.mock +import warnings import pytest import pytest_asyncio from strands import tool from strands.experimental.bidi import BidiAgent -from strands.experimental.bidi.agent.loop import _BidiAgentLoop -from strands.experimental.bidi.models import BidiModelTimeoutError -from strands.experimental.bidi.types.events import BidiConnectionRestartEvent, BidiTextInputEvent +from strands.experimental.bidi.models import BidiModel, BidiModelTimeoutError +from strands.experimental.bidi.types.events import BidiConnectionCloseEvent, BidiConnectionRestartEvent, BidiTextInputEvent from strands.types._events import ToolResultEvent, ToolResultMessageEvent, ToolUseStreamEvent @@ -22,7 +22,7 @@ async def func(): @pytest.fixture def agent(time_tool): - return BidiAgent(model=unittest.mock.AsyncMock(), tools=[time_tool]) + return BidiAgent(model=unittest.mock.AsyncMock(spec=BidiModel), tools=[time_tool]) @pytest_asyncio.fixture @@ -38,19 +38,19 @@ async def test_bidi_agent_loop_receive_restart_connection(loop, agent, agenerato agent.model.receive = unittest.mock.Mock(side_effect=[timeout_error, agenerator([text_event])]) await loop.start() - + tru_events = [] async for event in loop.receive(): tru_events.append(event) if len(tru_events) >= 2: break - + exp_events = [ BidiConnectionRestartEvent(timeout_error), text_event, ] assert tru_events == exp_events - + agent.model.stop.assert_called_once() assert agent.model.start.call_count == 2 agent.model.start.assert_called_with( @@ -63,7 +63,6 @@ async def test_bidi_agent_loop_receive_restart_connection(loop, agent, agenerato @pytest.mark.asyncio async def test_bidi_agent_loop_receive_tool_use(loop, agent, agenerator): - tool_use = {"toolUseId": "t1", "name": "time_tool", "input": {}} tool_result = {"toolUseId": "t1", "status": "success", "content": [{"text": "12:00"}]} @@ -71,9 +70,9 @@ async def test_bidi_agent_loop_receive_tool_use(loop, agent, agenerator): tool_result_event = ToolResultEvent(tool_result) agent.model.receive = unittest.mock.Mock(return_value=agenerator([tool_use_event])) - + await loop.start() - + tru_events = [] async for event in loop.receive(): tru_events.append(event) @@ -86,7 +85,7 @@ async def test_bidi_agent_loop_receive_tool_use(loop, agent, agenerator): ToolResultMessageEvent({"role": "user", "content": [{"toolResult": tool_result}]}), ] assert tru_events == exp_events - + tru_messages = agent.messages exp_messages = [ {"role": "assistant", "content": [{"toolUse": tool_use}]}, @@ -95,3 +94,157 @@ async def test_bidi_agent_loop_receive_tool_use(loop, agent, agenerator): assert tru_messages == exp_messages agent.model.send.assert_called_with(tool_result_event) + + +@pytest.mark.asyncio +async def test_bidi_agent_loop_request_state_initialized_for_tools(loop, agent, agenerator): + """Test that request_state is initialized in invocation_state before tool execution. + + This ensures request_state exists for tools that may need it via invocation_state, + even when invocation_state is not provided by the user. + """ + tool_use = {"toolUseId": "t2", "name": "time_tool", "input": {}} + tool_use_event = ToolUseStreamEvent(current_tool_use=tool_use, delta="") + + agent.model.receive = unittest.mock.Mock(return_value=agenerator([tool_use_event])) + + # Start without providing invocation_state + await loop.start() + + tru_events = [] + async for event in loop.receive(): + tru_events.append(event) + if len(tru_events) >= 3: + break + + # Verify tool executed successfully + tool_result_event = tru_events[1] + assert isinstance(tool_result_event, ToolResultEvent) + assert tool_result_event.tool_result["status"] == "success" + + # Verify request_state was initialized in invocation_state + assert "request_state" in loop._invocation_state + assert isinstance(loop._invocation_state["request_state"], dict) + + +@pytest.mark.asyncio +async def test_bidi_agent_loop_stop_event_loop_flag(agent, agenerator): + """Test that the stop_event_loop flag in request_state gracefully closes the connection. + + This simulates a tool (like strands_tools.stop) setting the flag via invocation_state. + """ + # Use a tool that modifies invocation_state to set the stop flag + # We'll mock the tool executor to simulate this behavior + loop = agent._loop + + tool_use = {"toolUseId": "t3", "name": "time_tool", "input": {}} + tool_use_event = ToolUseStreamEvent(current_tool_use=tool_use, delta="") + tool_result = {"toolUseId": "t3", "status": "success", "content": [{"text": "12:00"}]} + + agent.model.receive = unittest.mock.Mock(return_value=agenerator([tool_use_event])) + + # Start with request_state that already has stop_event_loop=True + # This simulates a tool having set it during execution + await loop.start(invocation_state={"request_state": {"stop_event_loop": True}}) + + tru_events = [] + async for event in loop.receive(): + tru_events.append(event) + + # Should receive: tool_use_event, tool_result_event, tool_result_message, connection_close + assert len(tru_events) == 4 + + # Verify tool executed successfully + tool_result_event = tru_events[1] + assert isinstance(tool_result_event, ToolResultEvent) + assert tool_result_event.tool_result["status"] == "success" + + # Verify connection close event was emitted + connection_close_event = tru_events[3] + assert isinstance(connection_close_event, BidiConnectionCloseEvent) + assert connection_close_event["reason"] == "user_request" + + # Verify model.send was NOT called (tool result not sent to model) + agent.model.send.assert_not_called() + + +@pytest.mark.asyncio +async def test_bidi_agent_loop_stop_conversation_deprecated_but_works(loop, agent, agenerator): + """Test that stop_conversation tool still works but emits a deprecation warning. + + The stop_conversation tool is deprecated in favor of request_state["stop_event_loop"], + but should continue to work for backward compatibility via the name-based check. + """ + from strands.experimental.bidi.tools import stop_conversation + + agent.tool_registry.register_tool(stop_conversation) + + tool_use = {"toolUseId": "t5", "name": "stop_conversation", "input": {}} + tool_use_event = ToolUseStreamEvent(current_tool_use=tool_use, delta="") + + agent.model.receive = unittest.mock.Mock(return_value=agenerator([tool_use_event])) + + await loop.start() + + tru_events = [] + with warnings.catch_warnings(record=True) as caught_warnings: + warnings.simplefilter("always") + async for event in loop.receive(): + tru_events.append(event) + + # Should receive: tool_use_event, tool_result_event, tool_result_message, connection_close + assert len(tru_events) == 4 + + # Verify tool executed successfully + tool_result_event = tru_events[1] + assert isinstance(tool_result_event, ToolResultEvent) + assert tool_result_event.tool_result["status"] == "success" + assert "Ending conversation" in tool_result_event.tool_result["content"][0]["text"] + + # Verify connection close event was emitted + connection_close_event = tru_events[3] + assert isinstance(connection_close_event, BidiConnectionCloseEvent) + assert connection_close_event["reason"] == "user_request" + + # Verify model.send was NOT called (tool result not sent to model) + agent.model.send.assert_not_called() + + # Verify deprecation warnings were emitted (from both the tool itself and the loop name check) + deprecation_warnings = [w for w in caught_warnings if issubclass(w.category, DeprecationWarning)] + assert len(deprecation_warnings) >= 1 + assert any("stop_conversation" in str(w.message).lower() for w in deprecation_warnings) + + +@pytest.mark.asyncio +async def test_bidi_agent_loop_request_state_preserved_with_invocation_state(agent, agenerator): + """Test that existing invocation_state is preserved when request_state is initialized.""" + + @tool(name="check_invocation_state") + async def check_invocation_state(custom_key: str) -> str: + return f"custom_key: {custom_key}" + + agent.tool_registry.register_tool(check_invocation_state) + + tool_use = {"toolUseId": "t4", "name": "check_invocation_state", "input": {"custom_key": "from_state"}} + tool_use_event = ToolUseStreamEvent(current_tool_use=tool_use, delta="") + + agent.model.receive = unittest.mock.Mock(return_value=agenerator([tool_use_event])) + + loop = agent._loop + # Start with custom invocation_state but no request_state + await loop.start(invocation_state={"custom_data": "preserved"}) + + tru_events = [] + async for event in loop.receive(): + tru_events.append(event) + if len(tru_events) >= 3: + break + + # Verify tool executed successfully + tool_result_event = tru_events[1] + assert isinstance(tool_result_event, ToolResultEvent) + assert tool_result_event.tool_result["status"] == "success" + + # Verify request_state was added without removing custom_data + assert "request_state" in loop._invocation_state + assert loop._invocation_state.get("custom_data") == "preserved" diff --git a/tests/strands/experimental/bidi/io/test_audio.py b/tests/strands/experimental/bidi/io/test_audio.py index 459faa78a..9b502700b 100644 --- a/tests/strands/experimental/bidi/io/test_audio.py +++ b/tests/strands/experimental/bidi/io/test_audio.py @@ -29,7 +29,7 @@ def agent(): "voice": "test-voice", }, } - return mock + return mock @pytest.fixture @@ -49,6 +49,7 @@ def config(): "output_frames_per_buffer": 2048, } + @pytest.fixture def audio_io(py_audio, config): _ = py_audio diff --git a/tests/strands/experimental/bidi/io/test_text.py b/tests/strands/experimental/bidi/io/test_text.py index 5507a8c0f..e21e149bd 100644 --- a/tests/strands/experimental/bidi/io/test_text.py +++ b/tests/strands/experimental/bidi/io/test_text.py @@ -42,7 +42,7 @@ async def test_bidi_text_io_input(prompt_session, text_input): (BidiInterruptionEvent(reason="user_speech"), "interrupted"), (BidiTranscriptStreamEvent(text="test text", delta="", is_final=False, role="user"), "Preview: test text"), (BidiTranscriptStreamEvent(text="test text", delta="", is_final=True, role="user"), "test text"), - ] + ], ) @pytest.mark.asyncio async def test_bidi_text_io_output(event, exp_print, text_output, capsys): diff --git a/tests/strands/experimental/bidi/models/test_gemini_live.py b/tests/strands/experimental/bidi/models/test_gemini_live.py index da516d4a0..3a9d7e3dc 100644 --- a/tests/strands/experimental/bidi/models/test_gemini_live.py +++ b/tests/strands/experimental/bidi/models/test_gemini_live.py @@ -13,8 +13,8 @@ import pytest from google.genai import types as genai_types -from strands.experimental.bidi.models.model import BidiModelTimeoutError from strands.experimental.bidi.models.gemini_live import BidiGeminiLiveModel +from strands.experimental.bidi.models.model import BidiModelTimeoutError from strands.experimental.bidi.types.events import ( BidiAudioInputEvent, BidiAudioStreamEvent, @@ -185,7 +185,7 @@ async def test_connection_edge_cases(mock_genai_client, api_key, model_id): model4 = BidiGeminiLiveModel(model_id=model_id, client_config={"api_key": api_key}) await model4.start() mock_live_session_cm.__aexit__.side_effect = Exception("Close failed") - with pytest.raises(ExceptionGroup): + with pytest.raises(Exception, match=r"failed stop sequence"): await model4.stop() @@ -572,7 +572,6 @@ def test_tool_formatting(model, tool_spec): assert formatted_empty == [] - # Tool Result Content Tests @@ -601,7 +600,7 @@ async def test_custom_audio_rates_in_events(mock_genai_client, model_id, api_key assert isinstance(audio_event, BidiAudioStreamEvent) # Should use configured rates, not constants assert audio_event.sample_rate == 48000 # Custom config - assert audio_event.channels == 2 # Custom config + assert audio_event.channels == 2 # Custom config assert audio_event.format == "pcm" await model.stop() @@ -631,7 +630,7 @@ async def test_default_audio_rates_in_events(mock_genai_client, model_id, api_ke assert isinstance(audio_event, BidiAudioStreamEvent) # Should use default rates assert audio_event.sample_rate == 24000 # Default output rate - assert audio_event.channels == 1 # Default channels + assert audio_event.channels == 1 # Default channels assert audio_event.format == "pcm" await model.stop() diff --git a/tests/strands/experimental/bidi/models/test_nova_sonic.py b/tests/strands/experimental/bidi/models/test_nova_sonic.py index 04f8043be..14630875b 100644 --- a/tests/strands/experimental/bidi/models/test_nova_sonic.py +++ b/tests/strands/experimental/bidi/models/test_nova_sonic.py @@ -4,19 +4,28 @@ covering connection lifecycle, event conversion, audio streaming, and tool execution. """ +import sys + +if sys.version_info < (3, 12): + import pytest + + pytest.skip(reason="BidiNovaSonicModel is only supported for Python 3.12+", allow_module_level=True) + import asyncio import base64 import json -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, Mock, patch import pytest import pytest_asyncio from aws_sdk_bedrock_runtime.models import ModelTimeoutException, ValidationException +from strands.experimental.bidi.models.model import BidiModelTimeoutError from strands.experimental.bidi.models.nova_sonic import ( BidiNovaSonicModel, + NOVA_SONIC_V1_MODEL_ID, + NOVA_SONIC_V2_MODEL_ID, ) -from strands.experimental.bidi.models.model import BidiModelTimeoutError from strands.experimental.bidi.types.events import ( BidiAudioInputEvent, BidiAudioStreamEvent, @@ -39,9 +48,8 @@ def model_id(): @pytest.fixture -def region(): - """AWS region.""" - return "us-east-1" +def boto_session(): + return Mock(region_name="us-east-1") @pytest.fixture @@ -67,11 +75,11 @@ def mock_client(mock_stream): @pytest_asyncio.fixture -def nova_model(model_id, region, mock_client): +def nova_model(model_id, boto_session, mock_client): """Create Nova Sonic model instance.""" _ = mock_client - model = BidiNovaSonicModel(model_id=model_id, client_config={"region": region}) + model = BidiNovaSonicModel(model_id=model_id, client_config={"boto_session": boto_session}) yield model @@ -79,12 +87,12 @@ def nova_model(model_id, region, mock_client): @pytest.mark.asyncio -async def test_model_initialization(model_id, region): +async def test_model_initialization(model_id, boto_session): """Test model initialization with configuration.""" - model = BidiNovaSonicModel(model_id=model_id, client_config={"region": region}) + model = BidiNovaSonicModel(model_id=model_id, client_config={"boto_session": boto_session}) assert model.model_id == model_id - assert model.region == region + assert model.region == "us-east-1" assert model._connection_id is None @@ -92,9 +100,9 @@ async def test_model_initialization(model_id, region): @pytest.mark.asyncio -async def test_audio_config_defaults(model_id, region): +async def test_audio_config_defaults(model_id, boto_session): """Test default audio configuration.""" - model = BidiNovaSonicModel(model_id=model_id, client_config={"region": region}) + model = BidiNovaSonicModel(model_id=model_id, client_config={"boto_session": boto_session}) assert model.config["audio"]["input_rate"] == 16000 assert model.config["audio"]["output_rate"] == 16000 @@ -104,10 +112,12 @@ async def test_audio_config_defaults(model_id, region): @pytest.mark.asyncio -async def test_audio_config_partial_override(model_id, region): +async def test_audio_config_partial_override(model_id, boto_session): """Test partial audio configuration override.""" provider_config = {"audio": {"output_rate": 24000, "voice": "ruth"}} - model = BidiNovaSonicModel(model_id=model_id, client_config={"region": region}, provider_config=provider_config) + model = BidiNovaSonicModel( + model_id=model_id, client_config={"boto_session": boto_session}, provider_config=provider_config + ) # Overridden values assert model.config["audio"]["output_rate"] == 24000 @@ -120,7 +130,7 @@ async def test_audio_config_partial_override(model_id, region): @pytest.mark.asyncio -async def test_audio_config_full_override(model_id, region): +async def test_audio_config_full_override(model_id, boto_session): """Test full audio configuration override.""" provider_config = { "audio": { @@ -131,7 +141,9 @@ async def test_audio_config_full_override(model_id, region): "voice": "stephen", } } - model = BidiNovaSonicModel(model_id=model_id, client_config={"region": region}, provider_config=provider_config) + model = BidiNovaSonicModel( + model_id=model_id, client_config={"boto_session": boto_session}, provider_config=provider_config + ) assert model.config["audio"]["input_rate"] == 48000 assert model.config["audio"]["output_rate"] == 48000 @@ -527,55 +539,220 @@ async def test_message_history_empty_and_edge_cases(nova_model): @pytest.mark.asyncio -async def test_custom_audio_rates_in_events(model_id, region): +async def test_custom_audio_rates_in_events(model_id, boto_session): """Test that audio events use configured sample rates.""" # Create model with custom audio configuration provider_config = {"audio": {"output_rate": 48000, "channels": 2}} - model = BidiNovaSonicModel(model_id=model_id, client_config={"region": region}, provider_config=provider_config) + model = BidiNovaSonicModel( + model_id=model_id, client_config={"boto_session": boto_session}, provider_config=provider_config + ) # Test audio output event uses custom configuration audio_bytes = b"test audio data" audio_base64 = base64.b64encode(audio_bytes).decode("utf-8") nova_event = {"audioOutput": {"content": audio_base64}} result = model._convert_nova_event(nova_event) - + assert result is not None assert isinstance(result, BidiAudioStreamEvent) # Should use configured rates, not constants assert result.sample_rate == 48000 # Custom config - assert result.channels == 2 # Custom config + assert result.channels == 2 # Custom config assert result.format == "pcm" @pytest.mark.asyncio -async def test_default_audio_rates_in_events(model_id, region): +async def test_default_audio_rates_in_events(model_id, boto_session): """Test that audio events use default sample rates when no custom config.""" # Create model without custom audio configuration - model = BidiNovaSonicModel(model_id=model_id, client_config={"region": region}) + model = BidiNovaSonicModel(model_id=model_id, client_config={"boto_session": boto_session}) # Test audio output event uses defaults audio_bytes = b"test audio data" audio_base64 = base64.b64encode(audio_bytes).decode("utf-8") nova_event = {"audioOutput": {"content": audio_base64}} result = model._convert_nova_event(nova_event) - + assert result is not None assert isinstance(result, BidiAudioStreamEvent) # Should use default rates assert result.sample_rate == 16000 # Default output rate - assert result.channels == 1 # Default channels + assert result.channels == 1 # Default channels assert result.format == "pcm" +# Nova Sonic v2 Support Tests + + +def test_nova_sonic_model_constants(): + """Test that Nova Sonic model ID constants are correctly defined.""" + assert NOVA_SONIC_V1_MODEL_ID == "amazon.nova-sonic-v1:0" + assert NOVA_SONIC_V2_MODEL_ID == "amazon.nova-2-sonic-v1:0" + + +@pytest.mark.asyncio +async def test_nova_sonic_v1_instantiation(boto_session, mock_client): + """Test direct instantiation with Nova Sonic v1 model ID.""" + _ = mock_client # Ensure mock is active + + # Test default creation + model = BidiNovaSonicModel(model_id=NOVA_SONIC_V1_MODEL_ID, client_config={"boto_session": boto_session}) + assert model.model_id == NOVA_SONIC_V1_MODEL_ID + assert model.region == "us-east-1" + + # Test with custom config + provider_config = {"audio": {"voice": "joanna", "output_rate": 24000}} + client_config = {"boto_session": boto_session} + model_custom = BidiNovaSonicModel( + model_id=NOVA_SONIC_V1_MODEL_ID, provider_config=provider_config, client_config=client_config + ) + + assert model_custom.model_id == NOVA_SONIC_V1_MODEL_ID + assert model_custom.config["audio"]["voice"] == "joanna" + assert model_custom.config["audio"]["output_rate"] == 24000 + + +@pytest.mark.asyncio +async def test_nova_sonic_v2_instantiation(boto_session, mock_client): + """Test direct instantiation with Nova Sonic v2 model ID.""" + _ = mock_client # Ensure mock is active + + # Test default creation + model = BidiNovaSonicModel(model_id=NOVA_SONIC_V2_MODEL_ID, client_config={"boto_session": boto_session}) + assert model.model_id == NOVA_SONIC_V2_MODEL_ID + assert model.region == "us-east-1" + + # Test with custom config + provider_config = {"audio": {"voice": "ruth", "input_rate": 48000}, "inference": {"temperature": 0.8}} + client_config = {"boto_session": boto_session} + model_custom = BidiNovaSonicModel( + model_id=NOVA_SONIC_V2_MODEL_ID, provider_config=provider_config, client_config=client_config + ) + + assert model_custom.model_id == NOVA_SONIC_V2_MODEL_ID + assert model_custom.config["audio"]["voice"] == "ruth" + assert model_custom.config["audio"]["input_rate"] == 48000 + assert model_custom.config["inference"]["temperature"] == 0.8 + + +@pytest.mark.asyncio +async def test_nova_sonic_v1_v2_compatibility(boto_session, mock_client): + """Test that v1 and v2 models have the same config structure and behavior.""" + _ = mock_client # Ensure mock is active + + # Create both models with same config + provider_config = {"audio": {"voice": "matthew"}} + client_config = {"boto_session": boto_session} + + model_v1 = BidiNovaSonicModel( + model_id=NOVA_SONIC_V1_MODEL_ID, provider_config=provider_config, client_config=client_config + ) + model_v2 = BidiNovaSonicModel( + model_id=NOVA_SONIC_V2_MODEL_ID, provider_config=provider_config, client_config=client_config + ) + + # Both should have the same config structure + assert model_v1.config["audio"]["voice"] == model_v2.config["audio"]["voice"] + assert model_v1.region == model_v2.region + + # Only model_id should differ + assert model_v1.model_id != model_v2.model_id + assert model_v1.model_id == NOVA_SONIC_V1_MODEL_ID + assert model_v2.model_id == NOVA_SONIC_V2_MODEL_ID + + +@pytest.mark.asyncio +async def test_backward_compatibility(boto_session, mock_client): + """Test that existing code continues to work (backward compatibility).""" + _ = mock_client # Ensure mock is active + + # Test that default behavior now uses v2 (updated default) + model_default = BidiNovaSonicModel(client_config={"boto_session": boto_session}) + assert model_default.model_id == NOVA_SONIC_V2_MODEL_ID + + # Test that existing explicit v1 usage still works + model_explicit_v1 = BidiNovaSonicModel( + model_id=NOVA_SONIC_V1_MODEL_ID, client_config={"boto_session": boto_session} + ) + assert model_explicit_v1.model_id == NOVA_SONIC_V1_MODEL_ID + + # Test that explicit v2 usage works + model_explicit_v2 = BidiNovaSonicModel( + model_id=NOVA_SONIC_V2_MODEL_ID, client_config={"boto_session": boto_session} + ) + assert model_explicit_v2.model_id == NOVA_SONIC_V2_MODEL_ID + + +@pytest.mark.asyncio +async def test_turn_detection_v1_validation(boto_session, mock_client): + """Test that turn_detection raises error when used with v1 model.""" + _ = mock_client # Ensure mock is active + + # Test that turn_detection with v1 raises ValueError + with pytest.raises(ValueError, match="turn_detection is only supported in Nova Sonic v2"): + BidiNovaSonicModel( + model_id=NOVA_SONIC_V1_MODEL_ID, + provider_config={"turn_detection": {"endpointingSensitivity": "MEDIUM"}}, + client_config={"boto_session": boto_session}, + ) + + # Test that turn_detection with v2 works fine + model_v2 = BidiNovaSonicModel( + model_id=NOVA_SONIC_V2_MODEL_ID, + provider_config={"turn_detection": {"endpointingSensitivity": "MEDIUM"}}, + client_config={"boto_session": boto_session}, + ) + assert model_v2.config["turn_detection"]["endpointingSensitivity"] == "MEDIUM" + + # Test that empty turn_detection dict doesn't raise error for v1 + model_v1_empty = BidiNovaSonicModel( + model_id=NOVA_SONIC_V1_MODEL_ID, + provider_config={"turn_detection": {}}, + client_config={"boto_session": boto_session}, + ) + assert model_v1_empty.model_id == NOVA_SONIC_V1_MODEL_ID + + +@pytest.mark.asyncio +async def test_turn_detection_sensitivity_validation(boto_session, mock_client): + """Test that endpointingSensitivity is validated at initialization.""" + _ = mock_client # Ensure mock is active + + # Test invalid sensitivity value raises ValueError at init + with pytest.raises(ValueError, match="Invalid endpointingSensitivity.*Must be HIGH, MEDIUM, or LOW"): + BidiNovaSonicModel( + model_id=NOVA_SONIC_V2_MODEL_ID, + provider_config={"turn_detection": {"endpointingSensitivity": "INVALID"}}, + client_config={"boto_session": boto_session}, + ) + + # Test valid sensitivity values work + for sensitivity in ["HIGH", "MEDIUM", "LOW"]: + model = BidiNovaSonicModel( + model_id=NOVA_SONIC_V2_MODEL_ID, + provider_config={"turn_detection": {"endpointingSensitivity": sensitivity}}, + client_config={"boto_session": boto_session}, + ) + assert model.config["turn_detection"]["endpointingSensitivity"] == sensitivity + + # Test that turn_detection without sensitivity works (sensitivity is optional) + model_no_sensitivity = BidiNovaSonicModel( + model_id=NOVA_SONIC_V2_MODEL_ID, + provider_config={"turn_detection": {}}, + client_config={"boto_session": boto_session}, + ) + assert "endpointingSensitivity" not in model_no_sensitivity.config["turn_detection"] + + # Error Handling Tests @pytest.mark.asyncio async def test_bidi_nova_sonic_model_receive_timeout(nova_model, mock_stream): mock_output = AsyncMock() mock_output.receive.side_effect = ModelTimeoutException("Connection timeout") mock_stream.await_output.return_value = (None, mock_output) - + await nova_model.start() - + with pytest.raises(BidiModelTimeoutError, match=r"Connection timeout"): async for _ in nova_model.receive(): pass @@ -586,9 +763,9 @@ async def test_bidi_nova_sonic_model_receive_timeout_validation(nova_model, mock mock_output = AsyncMock() mock_output.receive.side_effect = ValidationException("InternalErrorCode=531: Request timeout") mock_stream.await_output.return_value = (None, mock_output) - + await nova_model.start() - + with pytest.raises(BidiModelTimeoutError, match=r"InternalErrorCode=531"): async for _ in nova_model.receive(): pass diff --git a/tests/strands/experimental/bidi/models/test_openai_realtime.py b/tests/strands/experimental/bidi/models/test_openai_realtime.py index 5c9c0900d..09f4c8bc8 100644 --- a/tests/strands/experimental/bidi/models/test_openai_realtime.py +++ b/tests/strands/experimental/bidi/models/test_openai_realtime.py @@ -9,6 +9,7 @@ """ import base64 +import itertools import json import unittest.mock @@ -131,7 +132,9 @@ def test_audio_config_defaults(api_key, model_name): def test_audio_config_partial_override(api_key, model_name): """Test partial audio configuration override.""" provider_config = {"audio": {"output_rate": 48000, "voice": "echo"}} - model = BidiOpenAIRealtimeModel(model_id=model_name, client_config={"api_key": api_key}, provider_config=provider_config) + model = BidiOpenAIRealtimeModel( + model_id=model_name, client_config={"api_key": api_key}, provider_config=provider_config + ) # Overridden values assert model.config["audio"]["output_rate"] == 48000 @@ -154,7 +157,9 @@ def test_audio_config_full_override(api_key, model_name): "voice": "shimmer", } } - model = BidiOpenAIRealtimeModel(model_id=model_name, client_config={"api_key": api_key}, provider_config=provider_config) + model = BidiOpenAIRealtimeModel( + model_id=model_name, client_config={"api_key": api_key}, provider_config=provider_config + ) assert model.config["audio"]["input_rate"] == 48000 assert model.config["audio"]["output_rate"] == 48000 @@ -349,7 +354,7 @@ async def async_connect(*args, **kwargs): model4 = BidiOpenAIRealtimeModel(model_id=model_name, client_config={"api_key": api_key}) await model4.start() mock_ws.close.side_effect = Exception("Close failed") - with pytest.raises(ExceptionGroup): + with pytest.raises(Exception, match=r"failed stop sequence"): await model4.stop() @@ -510,7 +515,7 @@ async def test_receive_lifecycle_events(mock_websocket, model): format="pcm", sample_rate=24000, channels=1, - ) + ), ] assert tru_events == exp_events @@ -518,7 +523,7 @@ async def test_receive_lifecycle_events(mock_websocket, model): @unittest.mock.patch("strands.experimental.bidi.models.openai_realtime.time.time") @pytest.mark.asyncio async def test_receive_timeout(mock_time, model): - mock_time.side_effect = [1, 2] + mock_time.side_effect = itertools.count() model.timeout_s = 1 await model.start() diff --git a/tests/strands/experimental/hooks/multiagent/__init__.py b/tests/strands/experimental/checkpoint/__init__.py similarity index 100% rename from tests/strands/experimental/hooks/multiagent/__init__.py rename to tests/strands/experimental/checkpoint/__init__.py diff --git a/tests/strands/experimental/checkpoint/test_checkpoint.py b/tests/strands/experimental/checkpoint/test_checkpoint.py new file mode 100644 index 000000000..4435fb3db --- /dev/null +++ b/tests/strands/experimental/checkpoint/test_checkpoint.py @@ -0,0 +1,53 @@ +"""Tests for strands.experimental.checkpoint — Checkpoint serialization.""" + +import pytest + +from strands.experimental.checkpoint import CHECKPOINT_SCHEMA_VERSION, Checkpoint + + +class TestCheckpoint: + """Checkpoint dataclass serialization tests.""" + + def test_round_trip(self): + checkpoint = Checkpoint( + position="after_model", + cycle_index=1, + snapshot={"messages": []}, + app_data={"workflow_id": "wf-123"}, + ) + data = checkpoint.to_dict() + restored = Checkpoint.from_dict(data) + + assert restored.position == checkpoint.position + assert restored.cycle_index == checkpoint.cycle_index + assert restored.snapshot == checkpoint.snapshot + assert restored.app_data == checkpoint.app_data + assert restored.schema_version == CHECKPOINT_SCHEMA_VERSION + + def test_schema_version_immutable(self): + checkpoint = Checkpoint(position="after_tools") + assert checkpoint.schema_version == CHECKPOINT_SCHEMA_VERSION + + def test_schema_version_mismatch_raises(self): + data = Checkpoint(position="after_model").to_dict() + data["schema_version"] = "0.0" + with pytest.raises(ValueError, match="not compatible with current version"): + Checkpoint.from_dict(data) + + def test_defaults(self): + checkpoint = Checkpoint(position="after_model") + assert checkpoint.cycle_index == 0 + assert checkpoint.snapshot == {} + assert checkpoint.app_data == {} + + def test_from_dict_warns_on_unknown_fields(self, caplog): + data = Checkpoint(position="after_tools").to_dict() + data["unknown_future_field"] = "something" + restored = Checkpoint.from_dict(data) + assert restored.position == "after_tools" + assert "unknown_future_field" in caplog.text + + def test_from_dict_missing_schema_version_raises(self): + data = {"position": "after_model", "cycle_index": 0, "snapshot": {}, "app_data": {}} + with pytest.raises(ValueError, match="not compatible with current version"): + Checkpoint.from_dict(data) diff --git a/tests/strands/experimental/hooks/test_hook_aliases.py b/tests/strands/experimental/hooks/test_hook_aliases.py index f4899f2ab..ed7adba8a 100644 --- a/tests/strands/experimental/hooks/test_hook_aliases.py +++ b/tests/strands/experimental/hooks/test_hook_aliases.py @@ -7,16 +7,20 @@ import importlib import sys +import warnings from unittest.mock import Mock import pytest -from strands.experimental.hooks import ( - AfterModelInvocationEvent, - AfterToolInvocationEvent, - BeforeModelInvocationEvent, - BeforeToolInvocationEvent, -) +# Suppress deprecation warnings from imports since we're testing the aliases themselves +with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + from strands.experimental.hooks import ( + AfterModelInvocationEvent, + AfterToolInvocationEvent, + BeforeModelInvocationEvent, + BeforeToolInvocationEvent, + ) from strands.hooks import ( AfterModelCallEvent, AfterToolCallEvent, @@ -68,7 +72,7 @@ def test_after_tool_call_event_type_equality(): def test_before_model_call_event_type_equality(): """Verify that BeforeModelInvocationEvent alias has the same type identity.""" - before_model_event = BeforeModelCallEvent(agent=Mock()) + before_model_event = BeforeModelCallEvent(agent=Mock(), invocation_state={}) assert isinstance(before_model_event, BeforeModelInvocationEvent) assert isinstance(before_model_event, BeforeModelCallEvent) @@ -76,7 +80,7 @@ def test_before_model_call_event_type_equality(): def test_after_model_call_event_type_equality(): """Verify that AfterModelInvocationEvent alias has the same type identity.""" - after_model_event = AfterModelCallEvent(agent=Mock()) + after_model_event = AfterModelCallEvent(agent=Mock(), invocation_state={}) assert isinstance(after_model_event, AfterModelInvocationEvent) assert isinstance(after_model_event, AfterModelCallEvent) @@ -112,18 +116,20 @@ def experimental_callback(event: BeforeToolInvocationEvent): assert received_event is test_event -def test_deprecation_warning_on_import(captured_warnings): - """Verify that importing from experimental module emits deprecation warning.""" +def test_deprecation_warning_on_access(captured_warnings): + """Verify that accessing deprecated aliases emits deprecation warning.""" + import strands.experimental.hooks.events as events_module - module = sys.modules.get("strands.experimental.hooks.events") - if module: - importlib.reload(module) - else: - importlib.import_module("strands.experimental.hooks.events") + # Clear any existing warnings + captured_warnings.clear() + + # Access a deprecated alias - this should trigger the warning + _ = events_module.BeforeToolInvocationEvent assert len(captured_warnings) == 1 assert issubclass(captured_warnings[0].category, DeprecationWarning) - assert "are no longer experimental" in str(captured_warnings[0].message) + assert "BeforeToolInvocationEvent" in str(captured_warnings[0].message) + assert "BeforeToolCallEvent" in str(captured_warnings[0].message) def test_deprecation_warning_on_import_only_for_experimental(captured_warnings): diff --git a/tests/strands/experimental/steering/context_providers/test_ledger_provider.py b/tests/strands/experimental/steering/context_providers/test_ledger_provider.py deleted file mode 100644 index 4356b3ea8..000000000 --- a/tests/strands/experimental/steering/context_providers/test_ledger_provider.py +++ /dev/null @@ -1,135 +0,0 @@ -"""Unit tests for ledger context providers.""" - -from unittest.mock import Mock, patch - -from strands.experimental.steering.context_providers.ledger_provider import ( - LedgerAfterToolCall, - LedgerBeforeToolCall, - LedgerProvider, -) -from strands.experimental.steering.core.context import SteeringContext -from strands.hooks.events import AfterToolCallEvent, BeforeToolCallEvent - - -def test_context_providers_method(): - """Test context_providers method returns correct callbacks.""" - provider = LedgerProvider() - - callbacks = provider.context_providers() - - assert len(callbacks) == 2 - assert isinstance(callbacks[0], LedgerBeforeToolCall) - assert isinstance(callbacks[1], LedgerAfterToolCall) - - -@patch("strands.experimental.steering.context_providers.ledger_provider.datetime") -def test_ledger_before_tool_call_new_ledger(mock_datetime): - """Test LedgerBeforeToolCall with new ledger.""" - mock_datetime.now.return_value.isoformat.return_value = "2024-01-01T12:00:00" - - callback = LedgerBeforeToolCall() - steering_context = SteeringContext() - - tool_use = {"name": "test_tool", "arguments": {"param": "value"}} - event = Mock(spec=BeforeToolCallEvent) - event.tool_use = tool_use - - callback(event, steering_context) - - ledger = steering_context.data.get("ledger") - assert ledger is not None - assert "session_start" in ledger - assert "tool_calls" in ledger - assert len(ledger["tool_calls"]) == 1 - - tool_call = ledger["tool_calls"][0] - assert tool_call["tool_name"] == "test_tool" - assert tool_call["tool_args"] == {"param": "value"} - assert tool_call["status"] == "pending" - - -@patch("strands.experimental.steering.context_providers.ledger_provider.datetime") -def test_ledger_before_tool_call_existing_ledger(mock_datetime): - """Test LedgerBeforeToolCall with existing ledger.""" - mock_datetime.now.return_value.isoformat.return_value = "2024-01-01T12:00:00" - - callback = LedgerBeforeToolCall() - steering_context = SteeringContext() - - # Set up existing ledger - existing_ledger = { - "session_start": "2024-01-01T10:00:00", - "tool_calls": [{"name": "previous_tool"}], - "conversation_history": [], - "session_metadata": {}, - } - steering_context.data.set("ledger", existing_ledger) - - tool_use = {"name": "new_tool", "arguments": {"param": "value"}} - event = Mock(spec=BeforeToolCallEvent) - event.tool_use = tool_use - - callback(event, steering_context) - - ledger = steering_context.data.get("ledger") - assert len(ledger["tool_calls"]) == 2 - assert ledger["tool_calls"][0]["name"] == "previous_tool" - assert ledger["tool_calls"][1]["tool_name"] == "new_tool" - - -@patch("strands.experimental.steering.context_providers.ledger_provider.datetime") -def test_ledger_after_tool_call_success(mock_datetime): - """Test LedgerAfterToolCall with successful completion.""" - mock_datetime.now.return_value.isoformat.return_value = "2024-01-01T12:05:00" - - callback = LedgerAfterToolCall() - steering_context = SteeringContext() - - # Set up existing ledger with pending call - existing_ledger = { - "tool_calls": [{"tool_name": "test_tool", "status": "pending", "timestamp": "2024-01-01T12:00:00"}] - } - steering_context.data.set("ledger", existing_ledger) - - event = Mock(spec=AfterToolCallEvent) - event.result = {"status": "success", "content": ["success_result"]} - event.exception = None - - callback(event, steering_context) - - ledger = steering_context.data.get("ledger") - tool_call = ledger["tool_calls"][0] - assert tool_call["status"] == "success" - assert tool_call["result"] == ["success_result"] - assert tool_call["error"] is None - assert tool_call["completion_timestamp"] == "2024-01-01T12:05:00" - - -def test_ledger_after_tool_call_no_calls(): - """Test LedgerAfterToolCall when no tool calls exist.""" - callback = LedgerAfterToolCall() - steering_context = SteeringContext() - - # Set up ledger with no tool calls - existing_ledger = {"tool_calls": []} - steering_context.data.set("ledger", existing_ledger) - - event = Mock(spec=AfterToolCallEvent) - event.result = {"status": "success", "content": ["test"]} - event.exception = None - - # Should not crash when no tool calls exist - callback(event, steering_context) - - ledger = steering_context.data.get("ledger") - assert ledger["tool_calls"] == [] - - -def test_session_start_persistence(): - """Test that session_start is set during initialization and persists.""" - with patch("strands.experimental.steering.context_providers.ledger_provider.datetime") as mock_datetime: - mock_datetime.now.return_value.isoformat.return_value = "2024-01-01T10:00:00" - - callback = LedgerBeforeToolCall() - - assert callback.session_start == "2024-01-01T10:00:00" diff --git a/tests/strands/experimental/steering/core/test_handler.py b/tests/strands/experimental/steering/core/test_handler.py deleted file mode 100644 index 8d5ef6884..000000000 --- a/tests/strands/experimental/steering/core/test_handler.py +++ /dev/null @@ -1,278 +0,0 @@ -"""Unit tests for steering handler base class.""" - -from unittest.mock import Mock - -import pytest - -from strands.experimental.steering.core.action import Guide, Interrupt, Proceed -from strands.experimental.steering.core.context import SteeringContext, SteeringContextCallback, SteeringContextProvider -from strands.experimental.steering.core.handler import SteeringHandler -from strands.hooks.events import BeforeToolCallEvent -from strands.hooks.registry import HookRegistry - - -class TestSteeringHandler(SteeringHandler): - """Test implementation of SteeringHandler.""" - - async def steer(self, agent, tool_use, **kwargs): - return Proceed(reason="Test proceed") - - -def test_steering_handler_initialization(): - """Test SteeringHandler initialization.""" - handler = TestSteeringHandler() - assert handler is not None - - -def test_register_hooks(): - """Test hook registration.""" - handler = TestSteeringHandler() - registry = Mock(spec=HookRegistry) - - handler.register_hooks(registry) - - # Verify hooks were registered - assert registry.add_callback.call_count >= 1 - registry.add_callback.assert_any_call(BeforeToolCallEvent, handler._provide_steering_guidance) - - -def test_steering_context_initialization(): - """Test steering context is initialized.""" - handler = TestSteeringHandler() - - assert handler.steering_context is not None - assert isinstance(handler.steering_context, SteeringContext) - - -def test_steering_context_persistence(): - """Test steering context persists across calls.""" - handler = TestSteeringHandler() - - handler.steering_context.data.set("test", "value") - assert handler.steering_context.data.get("test") == "value" - - -def test_steering_context_access(): - """Test steering context can be accessed and modified.""" - handler = TestSteeringHandler() - - handler.steering_context.data.set("key", "value") - assert handler.steering_context.data.get("key") == "value" - - -@pytest.mark.asyncio -async def test_proceed_action_flow(): - """Test complete flow with Proceed action.""" - - class ProceedHandler(SteeringHandler): - async def steer(self, agent, tool_use, **kwargs): - return Proceed(reason="Test proceed") - - handler = ProceedHandler() - agent = Mock() - tool_use = {"name": "test_tool"} - event = BeforeToolCallEvent(agent=agent, selected_tool=None, tool_use=tool_use, invocation_state={}) - - await handler._provide_steering_guidance(event) - - # Should not modify event for Proceed - assert not event.cancel_tool - - -@pytest.mark.asyncio -async def test_guide_action_flow(): - """Test complete flow with Guide action.""" - - class GuideHandler(SteeringHandler): - async def steer(self, agent, tool_use, **kwargs): - return Guide(reason="Test guidance") - - handler = GuideHandler() - agent = Mock() - tool_use = {"name": "test_tool"} - event = BeforeToolCallEvent(agent=agent, selected_tool=None, tool_use=tool_use, invocation_state={}) - - await handler._provide_steering_guidance(event) - - # Should set cancel_tool with guidance message - expected_message = "Tool call cancelled given new guidance. Test guidance. Consider this approach and continue" - assert event.cancel_tool == expected_message - - -@pytest.mark.asyncio -async def test_interrupt_action_approved_flow(): - """Test complete flow with Interrupt action when approved.""" - - class InterruptHandler(SteeringHandler): - async def steer(self, agent, tool_use, **kwargs): - return Interrupt(reason="Need approval") - - handler = InterruptHandler() - tool_use = {"name": "test_tool"} - event = Mock() - event.tool_use = tool_use - event.interrupt = Mock(return_value=True) # Approved - - await handler._provide_steering_guidance(event) - - event.interrupt.assert_called_once() - - -@pytest.mark.asyncio -async def test_interrupt_action_denied_flow(): - """Test complete flow with Interrupt action when denied.""" - - class InterruptHandler(SteeringHandler): - async def steer(self, agent, tool_use, **kwargs): - return Interrupt(reason="Need approval") - - handler = InterruptHandler() - tool_use = {"name": "test_tool"} - event = Mock() - event.tool_use = tool_use - event.interrupt = Mock(return_value=False) # Denied - - await handler._provide_steering_guidance(event) - - event.interrupt.assert_called_once() - assert event.cancel_tool.startswith("Manual approval denied:") - - -@pytest.mark.asyncio -async def test_unknown_action_flow(): - """Test complete flow with unknown action type raises error.""" - - class UnknownActionHandler(SteeringHandler): - async def steer(self, agent, tool_use, **kwargs): - return Mock() # Not a valid SteeringAction - - handler = UnknownActionHandler() - agent = Mock() - tool_use = {"name": "test_tool"} - event = BeforeToolCallEvent(agent=agent, selected_tool=None, tool_use=tool_use, invocation_state={}) - - with pytest.raises(ValueError, match="Unknown steering action type"): - await handler._provide_steering_guidance(event) - - -def test_register_steering_hooks_override(): - """Test that _register_steering_hooks can be overridden.""" - - class CustomHandler(SteeringHandler): - async def steer(self, agent, tool_use, **kwargs): - return Proceed(reason="Custom") - - def register_hooks(self, registry, **kwargs): - # Custom hook registration - don't call parent - pass - - handler = CustomHandler() - registry = Mock(spec=HookRegistry) - - handler.register_hooks(registry) - - # Should not register any hooks - assert registry.add_callback.call_count == 0 - - -# Integration tests with context providers -class MockContextCallback(SteeringContextCallback[BeforeToolCallEvent]): - """Mock context callback for testing.""" - - def __call__(self, event: BeforeToolCallEvent, steering_context, **kwargs) -> None: - steering_context.data.set("test_key", "test_value") - - -class MockContextProvider(SteeringContextProvider): - """Mock context provider for testing.""" - - def __init__(self, callbacks): - self.callbacks = callbacks - - def context_providers(self): - return self.callbacks - - -class TestSteeringHandlerWithProvider(SteeringHandler): - """Test implementation with context callbacks.""" - - def __init__(self, context_callbacks=None): - providers = [MockContextProvider(context_callbacks)] if context_callbacks else None - super().__init__(context_providers=providers) - - async def steer(self, agent, tool_use, **kwargs): - return Proceed(reason="Test proceed") - - -def test_handler_registers_context_provider_hooks(): - """Test that handler registers hooks from context callbacks.""" - mock_callback = MockContextCallback() - handler = TestSteeringHandlerWithProvider(context_callbacks=[mock_callback]) - registry = Mock(spec=HookRegistry) - - handler.register_hooks(registry) - - # Should register hooks for context callback and steering guidance - assert registry.add_callback.call_count >= 2 - - # Check that BeforeToolCallEvent was registered - call_args = [call[0] for call in registry.add_callback.call_args_list] - event_types = [args[0] for args in call_args] - - assert BeforeToolCallEvent in event_types - - -def test_context_callbacks_receive_steering_context(): - """Test that context callbacks receive the handler's steering context.""" - mock_callback = MockContextCallback() - handler = TestSteeringHandlerWithProvider(context_callbacks=[mock_callback]) - registry = Mock(spec=HookRegistry) - - handler.register_hooks(registry) - - # Get the registered callback for BeforeToolCallEvent - before_callback = None - for call in registry.add_callback.call_args_list: - if call[0][0] == BeforeToolCallEvent: - before_callback = call[0][1] - break - - assert before_callback is not None - - # Create a mock event and call the callback - event = Mock(spec=BeforeToolCallEvent) - event.tool_use = {"name": "test_tool", "arguments": {}} - - # The callback should execute without error and update the steering context - before_callback(event) - - # Verify the steering context was updated - assert handler.steering_context.data.get("test_key") == "test_value" - - -def test_multiple_context_callbacks_registered(): - """Test that multiple context callbacks are registered.""" - callback1 = MockContextCallback() - callback2 = MockContextCallback() - - handler = TestSteeringHandlerWithProvider(context_callbacks=[callback1, callback2]) - registry = Mock(spec=HookRegistry) - - handler.register_hooks(registry) - - # Should register one callback for each context provider plus steering guidance - expected_calls = 2 + 1 # 2 callbacks + 1 for steering guidance - assert registry.add_callback.call_count >= expected_calls - - -def test_handler_initialization_with_callbacks(): - """Test handler initialization stores context callbacks.""" - callback1 = MockContextCallback() - callback2 = MockContextCallback() - - handler = TestSteeringHandlerWithProvider(context_callbacks=[callback1, callback2]) - - # Should have stored the callbacks - assert len(handler._context_callbacks) == 2 - assert callback1 in handler._context_callbacks - assert callback2 in handler._context_callbacks diff --git a/tests/strands/experimental/steering/test_steering_aliases.py b/tests/strands/experimental/steering/test_steering_aliases.py new file mode 100644 index 000000000..25fd86eb4 --- /dev/null +++ b/tests/strands/experimental/steering/test_steering_aliases.py @@ -0,0 +1,176 @@ +"""Tests to verify that experimental steering aliases work with deprecation warning. + +This test module ensures that the experimental steering aliases maintain +backwards compatibility and can be used interchangeably with the actual +types from strands.vended_plugins.steering. +""" + +import importlib +import sys +import warnings + +import pytest + +from strands.vended_plugins.steering import ( + Guide, + Interrupt, + LedgerAfterToolCall, + LedgerBeforeToolCall, + LedgerProvider, + LLMPromptMapper, + LLMSteeringHandler, + ModelSteeringAction, + Proceed, + SteeringContextCallback, + SteeringContextProvider, + SteeringHandler, + ToolSteeringAction, +) + +_ALL_NAMES = [ + "ToolSteeringAction", + "ModelSteeringAction", + "Proceed", + "Guide", + "Interrupt", + "SteeringHandler", + "SteeringContextCallback", + "SteeringContextProvider", + "LedgerBeforeToolCall", + "LedgerAfterToolCall", + "LedgerProvider", + "LLMSteeringHandler", + "LLMPromptMapper", +] + +_PRODUCTION_TYPES = { + "ToolSteeringAction": ToolSteeringAction, + "ModelSteeringAction": ModelSteeringAction, + "Proceed": Proceed, + "Guide": Guide, + "Interrupt": Interrupt, + "SteeringHandler": SteeringHandler, + "SteeringContextCallback": SteeringContextCallback, + "SteeringContextProvider": SteeringContextProvider, + "LedgerBeforeToolCall": LedgerBeforeToolCall, + "LedgerAfterToolCall": LedgerAfterToolCall, + "LedgerProvider": LedgerProvider, + "LLMSteeringHandler": LLMSteeringHandler, + "LLMPromptMapper": LLMPromptMapper, +} + + +@pytest.mark.parametrize("name", _ALL_NAMES) +def test_experimental_alias_is_same_type(name): + """Verify that experimental steering alias is identical to the actual type.""" + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + from strands.experimental import steering + + experimental_type = getattr(steering, name) + + assert experimental_type is _PRODUCTION_TYPES[name] + + +@pytest.mark.parametrize("name", _ALL_NAMES) +def test_deprecation_warning_on_access(name, captured_warnings): + """Verify that accessing deprecated aliases emits deprecation warning.""" + # Clear the module from cache to trigger fresh import + if "strands.experimental.steering" in sys.modules: + del sys.modules["strands.experimental.steering"] + + # Clear any existing warnings + captured_warnings.clear() + + # Access from experimental - this should trigger the warning + from strands.experimental import steering + + _ = getattr(steering, name) + + assert len(captured_warnings) >= 1 + warning = captured_warnings[0] + assert issubclass(warning.category, DeprecationWarning) + assert name in str(warning.message) + assert "strands.vended_plugins.steering" in str(warning.message) + + +def test_attribute_error_on_unknown_attribute(): + """Verify that accessing unknown attributes raises AttributeError.""" + import strands.experimental.steering as steering_module + + with pytest.raises(AttributeError, match="has no attribute"): + _ = steering_module.NonExistentClass + + +def test_no_warning_on_production_import(captured_warnings): + """Verify that importing from strands.vended_plugins.steering does not emit deprecation warning.""" + # Clear any existing warnings + captured_warnings.clear() + + # Import from production - should NOT trigger warning + from strands.vended_plugins.steering import Proceed as _ # noqa: F401 + + # Filter for steering-related deprecation warnings + steering_warnings = [ + w + for w in captured_warnings + if "has been moved" in str(w.message) and issubclass(w.category, DeprecationWarning) + ] + + assert len(steering_warnings) == 0 + + +# Submodule import tests - verify deep import paths still work with deprecation warnings + +_SUBMODULE_IMPORTS = [ + ("strands.experimental.steering.core.action", "Guide", Guide), + ("strands.experimental.steering.core.action", "Interrupt", Interrupt), + ("strands.experimental.steering.core.action", "Proceed", Proceed), + ("strands.experimental.steering.core.context", "SteeringContextCallback", SteeringContextCallback), + ("strands.experimental.steering.core.context", "SteeringContextProvider", SteeringContextProvider), + ("strands.experimental.steering.core.handler", "SteeringHandler", SteeringHandler), + ("strands.experimental.steering.context_providers.ledger_provider", "LedgerProvider", LedgerProvider), + ("strands.experimental.steering.context_providers.ledger_provider", "LedgerBeforeToolCall", LedgerBeforeToolCall), + ("strands.experimental.steering.context_providers.ledger_provider", "LedgerAfterToolCall", LedgerAfterToolCall), + ("strands.experimental.steering.handlers.llm.llm_handler", "LLMSteeringHandler", LLMSteeringHandler), + ("strands.experimental.steering.handlers.llm.mappers", "DefaultPromptMapper", None), +] + + +@pytest.mark.parametrize( + "module_path,attr_name,expected_type", + _SUBMODULE_IMPORTS, + ids=[f"{m}.{a}" for m, a, _ in _SUBMODULE_IMPORTS], +) +def test_submodule_import_resolves_correctly(module_path, attr_name, expected_type): + """Verify that submodule imports resolve to the correct production types.""" + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + mod = importlib.import_module(module_path) + obj = getattr(mod, attr_name) + + if expected_type is not None: + assert obj is expected_type + + +@pytest.mark.parametrize( + "module_path,attr_name,expected_type", + _SUBMODULE_IMPORTS, + ids=[f"{m}.{a}" for m, a, _ in _SUBMODULE_IMPORTS], +) +def test_submodule_import_emits_deprecation_warning(module_path, attr_name, expected_type, captured_warnings): + """Verify that submodule imports emit deprecation warnings.""" + # Clear module from cache to trigger fresh import + if module_path in sys.modules: + del sys.modules[module_path] + + captured_warnings.clear() + + mod = importlib.import_module(module_path) + _ = getattr(mod, attr_name) + + assert len(captured_warnings) >= 1 + warning = captured_warnings[0] + assert issubclass(warning.category, DeprecationWarning) + assert attr_name in str(warning.message) + assert "has been moved to production" in str(warning.message) diff --git a/tests/strands/experimental/tools/test_tool_provider_alias.py b/tests/strands/experimental/tools/test_tool_provider_alias.py new file mode 100644 index 000000000..3b3055bc6 --- /dev/null +++ b/tests/strands/experimental/tools/test_tool_provider_alias.py @@ -0,0 +1,87 @@ +"""Tests to verify that experimental ToolProvider alias works with deprecation warning. + +This test module ensures that the experimental ToolProvider alias maintains +backwards compatibility and can be used interchangeably with the actual +ToolProvider type from strands.tools. +""" + +import sys +import warnings + +import pytest + +from strands.tools import ToolProvider + + +def test_experimental_alias_is_same_type(): + """Verify that experimental ToolProvider alias is identical to the actual type.""" + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + from strands.experimental.tools import ToolProvider as ExperimentalToolProvider + + assert ExperimentalToolProvider is ToolProvider + + +def test_deprecation_warning_on_import(captured_warnings): + """Verify that importing ToolProvider from experimental emits deprecation warning.""" + # Clear the module from cache to trigger fresh import + if "strands.experimental.tools" in sys.modules: + del sys.modules["strands.experimental.tools"] + + # Clear any existing warnings + captured_warnings.clear() + + # Import from experimental - this should trigger the warning + from strands.experimental import tools + + _ = tools.ToolProvider + + assert len(captured_warnings) >= 1 + warning = captured_warnings[0] + assert issubclass(warning.category, DeprecationWarning) + assert "ToolProvider" in str(warning.message) + assert "strands.tools" in str(warning.message) + + +def test_deprecation_warning_on_direct_import(captured_warnings): + """Verify that direct import from experimental.tools emits deprecation warning.""" + # Clear the module from cache to trigger fresh import + if "strands.experimental.tools" in sys.modules: + del sys.modules["strands.experimental.tools"] + + # Clear any existing warnings + captured_warnings.clear() + + # Direct import - this should trigger the warning + from strands.experimental.tools import ToolProvider as _ # noqa: F401 + + assert len(captured_warnings) >= 1 + warning = captured_warnings[0] + assert issubclass(warning.category, DeprecationWarning) + assert "ToolProvider" in str(warning.message) + assert "strands.tools" in str(warning.message) + + +def test_attribute_error_on_unknown_attribute(): + """Verify that accessing unknown attributes raises AttributeError.""" + import strands.experimental.tools as tools_module + + with pytest.raises(AttributeError, match="has no attribute"): + _ = tools_module.NonExistentClass + + +def test_no_warning_on_production_import(captured_warnings): + """Verify that importing from strands.tools does not emit deprecation warning.""" + # Clear any existing warnings + captured_warnings.clear() + + # Import from production - should NOT trigger warning + from strands.tools import ToolProvider as _ # noqa: F401 + + # Filter for ToolProvider-related deprecation warnings + tool_provider_warnings = [ + w for w in captured_warnings if "ToolProvider" in str(w.message) and issubclass(w.category, DeprecationWarning) + ] + + assert len(tool_provider_warnings) == 0 diff --git a/tests/strands/handlers/test_callback_handler.py b/tests/strands/handlers/test_callback_handler.py index 224823ef7..0d72c8563 100644 --- a/tests/strands/handlers/test_callback_handler.py +++ b/tests/strands/handlers/test_callback_handler.py @@ -72,56 +72,21 @@ def test_call_with_data_complete(handler, mock_print): mock_print.assert_any_call("\n") -def test_call_with_current_tool_use_new(handler, mock_print): - """Test calling the handler with a new tool use.""" - current_tool_use = {"name": "test_tool", "input": {"param": "value"}} - - handler(current_tool_use=current_tool_use) - - # Should print tool information - mock_print.assert_called_once_with("\nTool #1: test_tool") - - # Should update the handler state - assert handler.tool_count == 1 - assert handler.previous_tool_use == current_tool_use - - -def test_call_with_current_tool_use_same(handler, mock_print): - """Test calling the handler with the same tool use twice.""" - current_tool_use = {"name": "test_tool", "input": {"param": "value"}} - - # First call - handler(current_tool_use=current_tool_use) - mock_print.reset_mock() - - # Second call with same tool use - handler(current_tool_use=current_tool_use) - - # Should not print tool information again - mock_print.assert_not_called() - - # Tool count should not increase - assert handler.tool_count == 1 - - -def test_call_with_current_tool_use_different(handler, mock_print): +def test_call_with_tool_uses(handler, mock_print): """Test calling the handler with different tool uses.""" - first_tool_use = {"name": "first_tool", "input": {"param": "value1"}} - second_tool_use = {"name": "second_tool", "input": {"param": "value2"}} - - # First call - handler(current_tool_use=first_tool_use) - mock_print.reset_mock() + first_event = {"contentBlockStart": {"start": {"toolUse": {"name": "first_tool"}}}} + second_event = {"contentBlockStart": {"start": {"toolUse": {"name": "second_tool"}}}} - # Second call with different tool use - handler(current_tool_use=second_tool_use) + handler(event=first_event) + handler(event=second_event) - # Should print info for the new tool - mock_print.assert_called_once_with("\nTool #2: second_tool") + assert mock_print.call_args_list == [ + unittest.mock.call("\nTool #1: first_tool"), + unittest.mock.call("\nTool #2: second_tool"), + ] # Tool count should increase assert handler.tool_count == 2 - assert handler.previous_tool_use == second_tool_use def test_call_with_data_and_complete_extra_newline(handler, mock_print): @@ -146,42 +111,30 @@ def test_call_with_message_no_effect(handler, mock_print): def test_call_with_multiple_parameters(handler, mock_print): """Test calling handler with multiple parameters.""" - current_tool_use = {"name": "test_tool", "input": {"param": "value"}} + event = {"contentBlockStart": {"start": {"toolUse": {"name": "test_tool"}}}} - handler(data="Test output", complete=True, current_tool_use=current_tool_use) + handler(data="Test output", complete=True, event=event) - # Should print data with newline, an extra newline for completion, and tool information - assert mock_print.call_count == 3 - mock_print.assert_any_call("Test output", end="\n") - mock_print.assert_any_call("\n") - mock_print.assert_any_call("\nTool #1: test_tool") - - -def test_unknown_tool_name_handling(handler, mock_print): - """Test handling of a tool use without a name.""" - # The SDK implementation doesn't have a fallback for tool uses without a name field - # It checks for both presence of current_tool_use and current_tool_use.get("name") - current_tool_use = {"input": {"param": "value"}, "name": "Unknown tool"} - - handler(current_tool_use=current_tool_use) - - # Should print the tool information - mock_print.assert_called_once_with("\nTool #1: Unknown tool") + # Should print data with newline, tool information, and an extra newline for completion + assert mock_print.call_args_list == [ + unittest.mock.call("Test output", end="\n"), + unittest.mock.call("\nTool #1: test_tool"), + unittest.mock.call("\n"), + ] def test_tool_use_empty_object(handler, mock_print): - """Test handling of an empty tool use object.""" + """Test handling of an empty tool use object in event.""" # Tool use is an empty dict - current_tool_use = {} + event = {"contentBlockStart": {"start": {"toolUse": {}}}} - handler(current_tool_use=current_tool_use) + handler(event=event) # Should not print anything mock_print.assert_not_called() # Should not update state assert handler.tool_count == 0 - assert handler.previous_tool_use is None def test_composite_handler_forwards_to_all_handlers(): @@ -193,7 +146,7 @@ def test_composite_handler_forwards_to_all_handlers(): kwargs = { "data": "Test output", "complete": True, - "current_tool_use": {"name": "test_tool", "input": {"param": "value"}}, + "event": {"contentBlockStart": {"start": {"toolUse": {"name": "test_tool"}}}}, } # Call the composite handler @@ -215,12 +168,11 @@ def test_verbose_tool_use_disabled(mock_print): handler = PrintingCallbackHandler(verbose_tool_use=False) assert handler._verbose_tool_use is False - current_tool_use = {"name": "test_tool", "input": {"param": "value"}} - handler(current_tool_use=current_tool_use) + event = {"contentBlockStart": {"start": {"toolUse": {"name": "test_tool"}}}} + handler(event=event) # Should not print tool information when verbose_tool_use is False mock_print.assert_not_called() - # Should still update tool count and previous_tool_use + # Should still update tool count assert handler.tool_count == 1 - assert handler.previous_tool_use == current_tool_use diff --git a/tests/strands/experimental/hooks/multiagent/test_events.py b/tests/strands/hooks/test_events.py similarity index 97% rename from tests/strands/experimental/hooks/multiagent/test_events.py rename to tests/strands/hooks/test_events.py index 6c4d7c4e7..90ab205a9 100644 --- a/tests/strands/experimental/hooks/multiagent/test_events.py +++ b/tests/strands/hooks/test_events.py @@ -4,14 +4,14 @@ import pytest -from strands.experimental.hooks.multiagent.events import ( +from strands.hooks import ( AfterMultiAgentInvocationEvent, AfterNodeCallEvent, + BaseHookEvent, BeforeMultiAgentInvocationEvent, BeforeNodeCallEvent, MultiAgentInitializedEvent, ) -from strands.hooks import BaseHookEvent @pytest.fixture diff --git a/tests/strands/experimental/hooks/multiagent/test_multi_agent_hooks.py b/tests/strands/hooks/test_multi_agent_hooks.py similarity index 98% rename from tests/strands/experimental/hooks/multiagent/test_multi_agent_hooks.py rename to tests/strands/hooks/test_multi_agent_hooks.py index 4e97a9217..3f6e0c940 100644 --- a/tests/strands/experimental/hooks/multiagent/test_multi_agent_hooks.py +++ b/tests/strands/hooks/test_multi_agent_hooks.py @@ -1,7 +1,7 @@ import pytest from strands import Agent -from strands.experimental.hooks.multiagent.events import ( +from strands.hooks import ( AfterMultiAgentInvocationEvent, AfterNodeCallEvent, BeforeMultiAgentInvocationEvent, diff --git a/tests/strands/hooks/test_registry.py b/tests/strands/hooks/test_registry.py index 3daf41734..5b0f3c574 100644 --- a/tests/strands/hooks/test_registry.py +++ b/tests/strands/hooks/test_registry.py @@ -1,8 +1,16 @@ import unittest.mock +from typing import Union import pytest -from strands.hooks import AgentInitializedEvent, BeforeInvocationEvent, BeforeToolCallEvent, HookRegistry +from strands.hooks import ( + AfterModelCallEvent, + AgentInitializedEvent, + BeforeInvocationEvent, + BeforeModelCallEvent, + BeforeToolCallEvent, + HookRegistry, +) from strands.interrupt import Interrupt, _InterruptState @@ -87,3 +95,217 @@ def test_hook_registry_invoke_callbacks_coroutine(registry, agent): with pytest.raises(RuntimeError, match=r"use invoke_callbacks_async to invoke async callback"): registry.invoke_callbacks(BeforeInvocationEvent(agent=agent)) + + +def test_hook_registry_add_callback_infers_event_type(registry): + """Test that add_callback infers event type from callback type hint.""" + + def typed_callback(event: BeforeInvocationEvent) -> None: + pass + + # Register without explicit event_type - should infer from type hint + registry.add_callback(None, typed_callback) + + # Verify callback was registered + assert BeforeInvocationEvent in registry._registered_callbacks + assert typed_callback in registry._registered_callbacks[BeforeInvocationEvent] + + +def test_hook_registry_add_callback_raises_error_no_type_hint(registry): + """Test that add_callback raises error when type hint is missing.""" + + def untyped_callback(event): + pass + + with pytest.raises(ValueError, match="cannot infer event type"): + registry.add_callback(None, untyped_callback) + + +def test_hook_registry_add_callback_raises_error_invalid_type_hint(registry): + """Test that add_callback raises error when type hint is not a BaseHookEvent subclass.""" + + def invalid_callback(event: str) -> None: + pass + + with pytest.raises(ValueError, match="must be a subclass of BaseHookEvent"): + registry.add_callback(None, invalid_callback) + + +def test_hook_registry_add_callback_raises_error_no_parameters(registry): + """Test that add_callback raises error when callback has no parameters.""" + + def no_param_callback() -> None: + pass + + with pytest.raises(ValueError, match="callback has no parameters"): + registry.add_callback(None, no_param_callback) + + +def test_hook_registry_add_callback_infers_event_type_when_callback_provided_without_event_type(registry): + """Test that add_callback infers event type when callback is provided but event_type is None.""" + + def typed_callback(event: BeforeInvocationEvent) -> None: + pass + + registry.add_callback(None, typed_callback) + + assert BeforeInvocationEvent in registry._registered_callbacks + assert typed_callback in registry._registered_callbacks[BeforeInvocationEvent] + + +def test_hook_registry_add_callback_with_explicit_event_type_and_callback(registry): + """Test that add_callback works with explicit event_type and callback.""" + + def callback(event: BeforeInvocationEvent) -> None: + pass + + registry.add_callback(BeforeInvocationEvent, callback) + + assert BeforeInvocationEvent in registry._registered_callbacks + assert callback in registry._registered_callbacks[BeforeInvocationEvent] + + +# ========== Tests for union type support ========== + + +def test_hook_registry_add_callback_infers_union_types_pipe_syntax(registry): + """Test that add_callback registers callback for each type in A | B union.""" + + def union_callback(event: BeforeModelCallEvent | AfterModelCallEvent) -> None: + pass + + registry.add_callback(None, union_callback) + + # Callback should be registered for both event types + assert BeforeModelCallEvent in registry._registered_callbacks + assert AfterModelCallEvent in registry._registered_callbacks + assert union_callback in registry._registered_callbacks[BeforeModelCallEvent] + assert union_callback in registry._registered_callbacks[AfterModelCallEvent] + + +def test_hook_registry_add_callback_infers_union_types_union_syntax(registry): + """Test that add_callback registers callback for each type in Union[A, B].""" + + def union_callback(event: Union[BeforeModelCallEvent, AfterModelCallEvent]) -> None: # noqa: UP007 + pass + + registry.add_callback(None, union_callback) + + # Callback should be registered for both event types + assert BeforeModelCallEvent in registry._registered_callbacks + assert AfterModelCallEvent in registry._registered_callbacks + assert union_callback in registry._registered_callbacks[BeforeModelCallEvent] + assert union_callback in registry._registered_callbacks[AfterModelCallEvent] + + +def test_hook_registry_add_callback_union_with_none_raises_error(registry): + """Test that add_callback raises error when union contains None.""" + + def callback_with_none(event: BeforeModelCallEvent | None) -> None: + pass + + with pytest.raises(ValueError, match="None is not a valid event type"): + registry.add_callback(None, callback_with_none) + + +def test_hook_registry_add_callback_union_with_invalid_type_raises_error(registry): + """Test that add_callback raises error when union contains non-BaseHookEvent type.""" + + def callback_with_invalid_type(event: BeforeModelCallEvent | str) -> None: + pass + + with pytest.raises(ValueError, match="Invalid type in union"): + registry.add_callback(None, callback_with_invalid_type) + + +def test_hook_registry_add_callback_union_multiple_types(registry): + """Test that add_callback handles union with more than two types.""" + + def multi_union_callback(event: BeforeModelCallEvent | AfterModelCallEvent | BeforeInvocationEvent) -> None: + pass + + registry.add_callback(None, multi_union_callback) + + # Callback should be registered for all three event types + assert BeforeModelCallEvent in registry._registered_callbacks + assert AfterModelCallEvent in registry._registered_callbacks + assert BeforeInvocationEvent in registry._registered_callbacks + assert multi_union_callback in registry._registered_callbacks[BeforeModelCallEvent] + assert multi_union_callback in registry._registered_callbacks[AfterModelCallEvent] + assert multi_union_callback in registry._registered_callbacks[BeforeInvocationEvent] + + +# ========== Tests for list of types support ========== + + +def test_hook_registry_add_callback_with_list_of_types(registry): + """Test that add_callback registers callback for each type in a list.""" + + def my_callback(event) -> None: + pass + + registry.add_callback([BeforeModelCallEvent, AfterModelCallEvent], my_callback) + + # Callback should be registered for both event types + assert BeforeModelCallEvent in registry._registered_callbacks + assert AfterModelCallEvent in registry._registered_callbacks + assert my_callback in registry._registered_callbacks[BeforeModelCallEvent] + assert my_callback in registry._registered_callbacks[AfterModelCallEvent] + + +def test_hook_registry_add_callback_with_list_deduplicates(registry): + """Test that add_callback deduplicates event types in a list.""" + + def my_callback(event) -> None: + pass + + # Same type appears multiple times + registry.add_callback([BeforeModelCallEvent, BeforeModelCallEvent, AfterModelCallEvent], my_callback) + + # Callback should be registered only once per event type + assert len(registry._registered_callbacks[BeforeModelCallEvent]) == 1 + assert len(registry._registered_callbacks[AfterModelCallEvent]) == 1 + + +def test_hook_registry_add_callback_with_list_validates_types(registry): + """Test that add_callback validates all types in a list are BaseHookEvent subclasses.""" + + def my_callback(event) -> None: + pass + + with pytest.raises(ValueError, match="Invalid event type"): + registry.add_callback([BeforeModelCallEvent, str], my_callback) + + +def test_hook_registry_add_callback_with_empty_list_raises_error(registry): + """Test that add_callback raises error when given an empty list.""" + + def my_callback(event) -> None: + pass + + with pytest.raises(ValueError, match="event_type list cannot be empty"): + registry.add_callback([], my_callback) + + +@pytest.mark.asyncio +async def test_hook_registry_union_callback_invoked_for_each_type(registry, agent): + """Test that a union-registered callback is invoked correctly for each event type.""" + call_count = {"before": 0, "after": 0} + + def union_callback(event: BeforeModelCallEvent | AfterModelCallEvent) -> None: + if isinstance(event, BeforeModelCallEvent): + call_count["before"] += 1 + elif isinstance(event, AfterModelCallEvent): + call_count["after"] += 1 + + registry.add_callback(None, union_callback) + + # Invoke BeforeModelCallEvent + before_event = BeforeModelCallEvent(agent=agent) + await registry.invoke_callbacks_async(before_event) + assert call_count["before"] == 1 + + # Invoke AfterModelCallEvent + after_event = AfterModelCallEvent(agent=agent) + await registry.invoke_callbacks_async(after_event) + assert call_count["after"] == 1 diff --git a/tests/strands/models/conftest.py b/tests/strands/models/conftest.py new file mode 100644 index 000000000..aaf01a047 --- /dev/null +++ b/tests/strands/models/conftest.py @@ -0,0 +1,25 @@ +"""Pytest configuration for model tests.""" + +import sys +import unittest.mock + +# Mock OpenAI version check before the openai_responses module is imported. +# This is necessary because the version check happens at module import time. +# We patch importlib.metadata.version directly since that's where get_package_version comes from. +if "strands.models.openai_responses" not in sys.modules: + _original_version = None + try: + from importlib.metadata import version as _original_version_func + + _original_version = _original_version_func + except ImportError: + pass + + def _mock_version(package_name: str) -> str: + if package_name == "openai": + return "2.0.0" + if _original_version: + return _original_version(package_name) + raise Exception(f"Package {package_name} not found") + + unittest.mock.patch("importlib.metadata.version", _mock_version).start() diff --git a/tests/strands/models/test__validation.py b/tests/strands/models/test__validation.py new file mode 100644 index 000000000..e8a451494 --- /dev/null +++ b/tests/strands/models/test__validation.py @@ -0,0 +1,67 @@ +"""Tests for model validation helper functions.""" + +from strands.models._validation import _has_location_source + + +class TestHasLocationSource: + """Tests for _has_location_source helper function.""" + + def test_image_with_location_source(self): + """Test detection of location source in image content.""" + content = {"image": {"source": {"location": {"type": "s3", "uri": "s3://bucket/key"}}}} + assert _has_location_source(content) + + def test_image_with_bytes_source(self): + """Test that bytes source is not detected as location.""" + content = {"image": {"source": {"bytes": b"data"}}} + assert not _has_location_source(content) + + def test_document_with_location_source(self): + """Test detection of location source in document content.""" + content = {"document": {"source": {"location": {"type": "s3", "uri": "s3://bucket/key"}}}} + assert _has_location_source(content) + + def test_document_with_bytes_source(self): + """Test that bytes source is not detected as location.""" + content = {"document": {"source": {"bytes": b"data"}}} + assert not _has_location_source(content) + + def test_video_with_location_source(self): + """Test detection of location source in video content.""" + content = {"video": {"source": {"location": {"type": "s3", "uri": "s3://bucket/key"}}}} + assert _has_location_source(content) + + def test_video_with_bytes_source(self): + """Test that bytes source is not detected as location.""" + content = {"video": {"source": {"bytes": b"data"}}} + assert not _has_location_source(content) + + def test_text_content(self): + """Test that text content is not detected as location source.""" + content = {"text": "hello"} + assert not _has_location_source(content) + + def test_tool_use_content(self): + """Test that toolUse content is not detected as location source.""" + content = {"toolUse": {"name": "test", "input": {}, "toolUseId": "123"}} + assert not _has_location_source(content) + + def test_tool_result_content(self): + """Test that toolResult content is not detected as location source.""" + content = {"toolResult": {"toolUseId": "123", "content": [{"text": "result"}]}} + assert not _has_location_source(content) + + def test_image_without_source(self): + """Test that image without source is not detected as location.""" + content = {"image": {"format": "png"}} + assert not _has_location_source(content) + + def test_document_without_source(self): + """Test that document without source is not detected as location.""" + content = {"document": {"format": "pdf", "name": "test.pdf"}} + assert not _has_location_source(content) + + def test_video_without_source(self): + """Test that video without source is not detected as location.""" + content = {"video": {"format": "mp4"}} + assert not _has_location_source(content) diff --git a/tests/strands/models/test_anthropic.py b/tests/strands/models/test_anthropic.py index 74bbb8d45..0ebdb161c 100644 --- a/tests/strands/models/test_anthropic.py +++ b/tests/strands/models/test_anthropic.py @@ -1,4 +1,6 @@ +import logging import unittest.mock +import warnings import anthropic import pydantic @@ -51,6 +53,24 @@ class TestOutputModel(pydantic.BaseModel): return TestOutputModel +def generate_mock_stream_context(events, final_message=None): + mock_stream = unittest.mock.AsyncMock() + + async def mock_aiter(self): + for event in events: + yield event + + mock_stream.__aiter__ = mock_aiter + if isinstance(final_message, Exception): + mock_stream.get_final_message.side_effect = final_message + elif final_message: + mock_stream.get_final_message.return_value = final_message + + mock_context = unittest.mock.AsyncMock() + mock_context.__aenter__.return_value = mock_stream + return mock_context + + def test__init__model_configs(anthropic_client, model_id, max_tokens): _ = anthropic_client @@ -62,6 +82,30 @@ def test__init__model_configs(anthropic_client, model_id, max_tokens): assert tru_temperature == exp_temperature +def test__init__auto_populates_context_window_limit(anthropic_client): + _ = anthropic_client + + model = AnthropicModel(model_id="claude-sonnet-4-20250514", max_tokens=1) + + assert model.get_config().get("context_window_limit") == 1_000_000 + + +def test__init__explicit_context_window_limit_not_overridden(anthropic_client): + _ = anthropic_client + + model = AnthropicModel(model_id="claude-sonnet-4-20250514", max_tokens=1, context_window_limit=100_000) + + assert model.get_config().get("context_window_limit") == 100_000 + + +def test__init__unknown_model_no_context_window_limit(anthropic_client): + _ = anthropic_client + + model = AnthropicModel(model_id="unknown-model", max_tokens=1) + + assert model.get_config().get("context_window_limit") is None + + def test_update_config(model, model_id): model.update_config(model_id=model_id) @@ -691,7 +735,7 @@ def test_format_chunk_unknown(model): @pytest.mark.asyncio -async def test_stream(anthropic_client, model, agenerator, alist): +async def test_stream(anthropic_client, model, alist): mock_event_1 = unittest.mock.Mock( type="message_start", dict=lambda: {"type": "message_start"}, @@ -712,9 +756,14 @@ async def test_stream(anthropic_client, model, agenerator, alist): ), ) - mock_context = unittest.mock.AsyncMock() - mock_context.__aenter__.return_value = agenerator([mock_event_1, mock_event_2, mock_event_3]) - anthropic_client.messages.stream.return_value = mock_context + anthropic_client.messages.stream.return_value = generate_mock_stream_context( + [mock_event_1, mock_event_2, mock_event_3], + final_message=unittest.mock.Mock( + usage=unittest.mock.Mock( + model_dump=lambda: {"input_tokens": 1, "output_tokens": 2}, + ) + ), + ) messages = [{"role": "user", "content": [{"text": "hello"}]}] response = model.stream(messages, None, None) @@ -737,6 +786,42 @@ async def test_stream(anthropic_client, model, agenerator, alist): anthropic_client.messages.stream.assert_called_once_with(**expected_request) +@pytest.mark.asyncio +async def test_stream_early_termination(anthropic_client, model, alist, caplog): + caplog.set_level(logging.WARNING, logger="strands.models.anthropic") + mock_event = unittest.mock.Mock( + type="message_start", + model_dump=lambda: {"type": "message_start"}, + ) + + anthropic_client.messages.stream.return_value = generate_mock_stream_context( + [mock_event], + final_message=AssertionError("message snapshot is not available"), + ) + + messages = [{"role": "user", "content": [{"text": "hello"}]}] + tru_events = await alist(model.stream(messages, None, None)) + + assert len(tru_events) == 1 + assert "messageStart" in tru_events[0] + assert "failed to retrieve message snapshot, usage metadata unavailable" in caplog.text + + +@pytest.mark.asyncio +async def test_stream_empty(anthropic_client, model, alist, caplog): + caplog.set_level(logging.WARNING, logger="strands.models.anthropic") + anthropic_client.messages.stream.return_value = generate_mock_stream_context( + [], + final_message=AssertionError("message snapshot is not available"), + ) + + messages = [{"role": "user", "content": [{"text": "hello"}]}] + tru_events = await alist(model.stream(messages, None, None)) + + assert tru_events == [] + assert "failed to retrieve message snapshot, usage metadata unavailable" in caplog.text + + @pytest.mark.asyncio async def test_stream_rate_limit_error(anthropic_client, model, alist): anthropic_client.messages.stream.side_effect = anthropic.RateLimitError( @@ -779,7 +864,7 @@ async def test_stream_bad_request_error(anthropic_client, model): @pytest.mark.asyncio -async def test_structured_output(anthropic_client, model, test_output_model_cls, agenerator, alist): +async def test_structured_output(anthropic_client, model, test_output_model_cls, alist): messages = [{"role": "user", "content": [{"text": "Generate a person"}]}] events = [ @@ -810,22 +895,21 @@ async def test_structured_output(anthropic_client, model, test_output_model_cls, ), unittest.mock.Mock( type="message_stop", + message=unittest.mock.Mock(stop_reason="tool_use"), model_dump=unittest.mock.Mock( return_value={"type": "message_stop", "message": {"stop_reason": "tool_use"}} ), ), - unittest.mock.Mock( - message=unittest.mock.Mock( - usage=unittest.mock.Mock( - model_dump=unittest.mock.Mock(return_value={"input_tokens": 0, "output_tokens": 0}) - ), - ), - ), ] - mock_context = unittest.mock.AsyncMock() - mock_context.__aenter__.return_value = agenerator(events) - anthropic_client.messages.stream.return_value = mock_context + anthropic_client.messages.stream.return_value = generate_mock_stream_context( + events, + final_message=unittest.mock.Mock( + usage=unittest.mock.Mock( + model_dump=unittest.mock.Mock(return_value={"input_tokens": 0, "output_tokens": 0}) + ), + ), + ) stream = model.structured_output(test_output_model_cls, messages) events = await alist(stream) @@ -866,3 +950,239 @@ def test_tool_choice_none_no_warning(model, messages, captured_warnings): model.format_request(messages, tool_choice=None) assert len(captured_warnings) == 0 + + +def test_format_request_filters_s3_source_image(model, model_id, max_tokens, caplog): + """Test that images with Location sources are filtered out with warning.""" + caplog.set_level(logging.WARNING, logger="strands.models.anthropic") + + messages = [ + { + "role": "user", + "content": [ + {"text": "look at this image"}, + { + "image": { + "format": "png", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/image.png"}}, + }, + }, + ], + }, + ] + + tru_request = model.format_request(messages) + + # Image with S3 source should be filtered, text should remain + exp_messages = [ + {"role": "user", "content": [{"type": "text", "text": "look at this image"}]}, + ] + assert tru_request["messages"] == exp_messages + assert "Location sources are not supported by Anthropic" in caplog.text + + +def test_format_request_filters_location_source_document(model, model_id, max_tokens, caplog): + """Test that documents with Location sources are filtered out with warning.""" + caplog.set_level(logging.WARNING, logger="strands.models.anthropic") + + messages = [ + { + "role": "user", + "content": [ + {"text": "analyze this document"}, + { + "document": { + "format": "pdf", + "name": "report.pdf", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/report.pdf"}}, + }, + }, + { + "document": { + "format": "pdf", + "name": "report.pdf", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/report.pdf"}}, + }, + }, + ], + }, + ] + + tru_request = model.format_request(messages) + + # Document with S3 source should be filtered, text should remain + exp_messages = [ + {"role": "user", "content": [{"type": "text", "text": "analyze this document"}]}, + ] + assert tru_request["messages"] == exp_messages + assert "Location sources are not supported by Anthropic" in caplog.text + + +@pytest.mark.asyncio +async def test_stream_message_stop_no_pydantic_warnings(anthropic_client, model, alist): + """Verify no Pydantic serialization warnings are emitted for message_stop events. + + Regression test for https://github.com/strands-agents/sdk-python/issues/1746. + """ + # Create a mock message_stop event where model_dump() would emit warnings + # The key is that the event has a .message attribute with .stop_reason + mock_message_stop = unittest.mock.Mock() + mock_message_stop.type = "message_stop" + mock_message_stop.message = unittest.mock.Mock() + mock_message_stop.message.stop_reason = "end_turn" + + # Make model_dump() emit a warning to simulate the problematic behavior + def model_dump_with_warning(): + warnings.warn( + "PydanticSerializationUnexpectedValue(Expected `ParsedTextBlock[TypeVar]`)", + UserWarning, + stacklevel=2, + ) + return {"type": mock_message_stop.type, "message": {"stop_reason": mock_message_stop.message.stop_reason}} + + mock_message_stop.model_dump = model_dump_with_warning + + final_message = unittest.mock.Mock() + final_message.usage = unittest.mock.Mock( + model_dump=lambda: {"input_tokens": 1, "output_tokens": 2}, + ) + + mock_context = generate_mock_stream_context([mock_message_stop], final_message=final_message) + anthropic_client.messages.stream.return_value = mock_context + + messages = [{"role": "user", "content": [{"text": "hello"}]}] + + # Capture warnings during streaming + with warnings.catch_warnings(record=True) as caught_warnings: + warnings.simplefilter("always") + response = model.stream(messages, None, None) + events = await alist(response) + + # Verify no Pydantic serialization warnings were emitted + pydantic_warnings = [w for w in caught_warnings if "PydanticSerializationUnexpectedValue" in str(w.message)] + assert len(pydantic_warnings) == 0, f"Unexpected Pydantic warnings: {pydantic_warnings}" + + # Verify the message_stop event was still processed correctly + assert {"messageStop": {"stopReason": mock_message_stop.message.stop_reason}} in events + + +class TestCountTokens: + """Tests for AnthropicModel.count_tokens native token counting.""" + + @pytest.fixture + def model_with_client(self, anthropic_client, model_id, max_tokens): + _ = anthropic_client + return AnthropicModel(model_id=model_id, max_tokens=max_tokens, use_native_token_count=True) + + @pytest.fixture + def messages(self): + return [{"role": "user", "content": [{"text": "hello"}]}] + + @pytest.fixture + def tool_specs(self): + return [ + { + "name": "test_tool", + "description": "A test tool", + "inputSchema": {"json": {"type": "object", "properties": {}}}, + } + ] + + @pytest.mark.asyncio + async def test_native_count_tokens_success(self, model_with_client, anthropic_client, messages): + mock_response = unittest.mock.MagicMock() + mock_response.input_tokens = 42 + anthropic_client.messages.count_tokens = unittest.mock.AsyncMock(return_value=mock_response) + + result = await model_with_client.count_tokens(messages=messages) + + assert result == 42 + anthropic_client.messages.count_tokens.assert_called_once() + + @pytest.mark.asyncio + async def test_native_count_tokens_with_system_prompt(self, model_with_client, anthropic_client, messages): + mock_response = unittest.mock.MagicMock() + mock_response.input_tokens = 55 + anthropic_client.messages.count_tokens = unittest.mock.AsyncMock(return_value=mock_response) + + result = await model_with_client.count_tokens(messages=messages, system_prompt="Be helpful.") + + assert result == 55 + call_kwargs = anthropic_client.messages.count_tokens.call_args[1] + assert call_kwargs["system"] == "Be helpful." + + @pytest.mark.asyncio + async def test_native_count_tokens_with_tool_specs(self, model_with_client, anthropic_client, messages, tool_specs): + mock_response = unittest.mock.MagicMock() + mock_response.input_tokens = 100 + anthropic_client.messages.count_tokens = unittest.mock.AsyncMock(return_value=mock_response) + + result = await model_with_client.count_tokens(messages=messages, tool_specs=tool_specs) + + assert result == 100 + call_kwargs = anthropic_client.messages.count_tokens.call_args[1] + assert "tools" in call_kwargs + + @pytest.mark.asyncio + async def test_max_tokens_stripped_from_request(self, model_with_client, anthropic_client, messages): + mock_response = unittest.mock.MagicMock() + mock_response.input_tokens = 10 + anthropic_client.messages.count_tokens = unittest.mock.AsyncMock(return_value=mock_response) + + await model_with_client.count_tokens(messages=messages) + + call_kwargs = anthropic_client.messages.count_tokens.call_args[1] + assert "max_tokens" not in call_kwargs + + @pytest.mark.asyncio + async def test_fallback_on_api_error(self, model_with_client, anthropic_client, messages): + anthropic_client.messages.count_tokens = unittest.mock.AsyncMock( + side_effect=anthropic.APIError(message="Unsupported", request=unittest.mock.MagicMock(), body=None) + ) + + result = await model_with_client.count_tokens(messages=messages) + + assert isinstance(result, int) + assert result >= 0 + + @pytest.mark.asyncio + async def test_fallback_on_generic_exception(self, model_with_client, anthropic_client, messages): + anthropic_client.messages.count_tokens = unittest.mock.AsyncMock(side_effect=RuntimeError("Connection failed")) + + result = await model_with_client.count_tokens(messages=messages) + + assert isinstance(result, int) + assert result >= 0 + + @pytest.mark.asyncio + async def test_fallback_logs_debug(self, model_with_client, anthropic_client, messages, caplog): + anthropic_client.messages.count_tokens = unittest.mock.AsyncMock(side_effect=RuntimeError("API down")) + + with caplog.at_level(logging.DEBUG, logger="strands.models.anthropic"): + await model_with_client.count_tokens(messages=messages) + + assert any("native token counting failed" in record.message for record in caplog.records) + + @pytest.mark.asyncio + async def test_skip_native_api_when_use_native_token_count_false( + self, anthropic_client, model_id, max_tokens, messages + ): + _ = anthropic_client + model = AnthropicModel(model_id=model_id, max_tokens=max_tokens, use_native_token_count=False) + + result = await model.count_tokens(messages=messages) + + anthropic_client.messages.count_tokens.assert_not_called() + assert isinstance(result, int) + assert result >= 0 + + @pytest.mark.asyncio + async def test_skip_native_api_by_default(self, anthropic_client, model_id, max_tokens, messages): + _ = anthropic_client + model = AnthropicModel(model_id=model_id, max_tokens=max_tokens) + + result = await model.count_tokens(messages=messages) + + anthropic_client.messages.count_tokens.assert_not_called() + assert isinstance(result, int) + assert result >= 0 diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 2809e8a72..319b5574f 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -1,3 +1,5 @@ +import copy +import logging import os import sys import traceback @@ -12,24 +14,26 @@ import strands from strands import _exception_notes -from strands.models import BedrockModel +from strands.models import BedrockModel, CacheConfig, CacheToolsConfig from strands.models.bedrock import ( - _DEFAULT_BEDROCK_MODEL_ID, DEFAULT_BEDROCK_MODEL_ID, DEFAULT_BEDROCK_REGION, DEFAULT_READ_TIMEOUT, + _clear_skip_count_tokens_cache, ) -from strands.types.exceptions import ModelThrottledException +from strands.types.exceptions import ContextWindowOverflowException, ModelThrottledException from strands.types.tools import ToolSpec -FORMATTED_DEFAULT_MODEL_ID = DEFAULT_BEDROCK_MODEL_ID.format("us") +FORMATTED_DEFAULT_MODEL_ID = DEFAULT_BEDROCK_MODEL_ID @pytest.fixture def session_cls(): # Mock the creation of a Session so that we don't depend on environment variables or profiles with unittest.mock.patch.object(strands.models.bedrock.boto3, "Session") as mock_session_cls: - mock_session_cls.return_value.region_name = None + mock_session = unittest.mock.Mock() + mock_session.region_name = None + mock_session_cls.return_value = mock_session yield mock_session_cls @@ -201,10 +205,11 @@ def test__init__region_precedence(mock_client_method, session_cls): def test__init__with_endpoint_url(mock_client_method): """Test that BedrockModel uses the provided endpoint_url for VPC endpoints.""" custom_endpoint = "https://vpce-12345-abcde.bedrock-runtime.us-west-2.vpce.amazonaws.com" - BedrockModel(endpoint_url=custom_endpoint) - mock_client_method.assert_called_with( - region_name=DEFAULT_BEDROCK_REGION, config=ANY, service_name=ANY, endpoint_url=custom_endpoint - ) + with unittest.mock.patch.object(os, "environ", {}): + BedrockModel(endpoint_url=custom_endpoint) + mock_client_method.assert_called_with( + region_name=DEFAULT_BEDROCK_REGION, config=ANY, service_name=ANY, endpoint_url=custom_endpoint + ) def test__init__with_region_and_session_raises_value_error(): @@ -213,66 +218,63 @@ def test__init__with_region_and_session_raises_value_error(): _ = BedrockModel(region_name="us-east-1", boto_session=boto3.Session(region_name="us-east-1")) -def test__init__default_user_agent(bedrock_client): +def test__init__default_user_agent(session_cls, bedrock_client): """Set user agent when no boto_client_config is provided.""" - with unittest.mock.patch("strands.models.bedrock.boto3.Session") as mock_session_cls: - mock_session = mock_session_cls.return_value - _ = BedrockModel() + _ = BedrockModel() - # Verify the client was created with the correct config - mock_session.client.assert_called_once() - args, kwargs = mock_session.client.call_args - assert kwargs["service_name"] == "bedrock-runtime" - assert isinstance(kwargs["config"], BotocoreConfig) - assert kwargs["config"].user_agent_extra == "strands-agents" - assert kwargs["config"].read_timeout == DEFAULT_READ_TIMEOUT + # Verify the client was created with the correct config + client = session_cls.return_value.client + client.assert_called_once() + args, kwargs = client.call_args + assert kwargs["service_name"] == "bedrock-runtime" + assert isinstance(kwargs["config"], BotocoreConfig) + assert kwargs["config"].user_agent_extra == "strands-agents" + assert kwargs["config"].read_timeout == DEFAULT_READ_TIMEOUT -def test__init__default_read_timeout(bedrock_client): +def test__init__default_read_timeout(session_cls, bedrock_client): """Set default read timeout when no boto_client_config is provided.""" - with unittest.mock.patch("strands.models.bedrock.boto3.Session") as mock_session_cls: - mock_session = mock_session_cls.return_value - _ = BedrockModel() - # Verify the client was created with the correct read timeout - mock_session.client.assert_called_once() - args, kwargs = mock_session.client.call_args - assert isinstance(kwargs["config"], BotocoreConfig) - assert kwargs["config"].read_timeout == DEFAULT_READ_TIMEOUT + _ = BedrockModel() + # Verify the client was created with the correct read timeout + client = session_cls.return_value.client + client.assert_called_once() + args, kwargs = client.call_args + assert isinstance(kwargs["config"], BotocoreConfig) + assert kwargs["config"].read_timeout == DEFAULT_READ_TIMEOUT -def test__init__with_custom_boto_client_config_no_user_agent(bedrock_client): + +def test__init__with_custom_boto_client_config_no_user_agent(session_cls, bedrock_client): """Set user agent when boto_client_config is provided without user_agent_extra.""" custom_config = BotocoreConfig(read_timeout=900) - with unittest.mock.patch("strands.models.bedrock.boto3.Session") as mock_session_cls: - mock_session = mock_session_cls.return_value - _ = BedrockModel(boto_client_config=custom_config) + _ = BedrockModel(boto_client_config=custom_config) - # Verify the client was created with the correct config - mock_session.client.assert_called_once() - args, kwargs = mock_session.client.call_args - assert kwargs["service_name"] == "bedrock-runtime" - assert isinstance(kwargs["config"], BotocoreConfig) - assert kwargs["config"].user_agent_extra == "strands-agents" - assert kwargs["config"].read_timeout == 900 + # Verify the client was created with the correct config + client = session_cls.return_value.client + client.assert_called_once() + args, kwargs = client.call_args + assert kwargs["service_name"] == "bedrock-runtime" + assert isinstance(kwargs["config"], BotocoreConfig) + assert kwargs["config"].user_agent_extra == "strands-agents" + assert kwargs["config"].read_timeout == 900 -def test__init__with_custom_boto_client_config_with_user_agent(bedrock_client): +def test__init__with_custom_boto_client_config_with_user_agent(session_cls, bedrock_client): """Append to existing user agent when boto_client_config is provided with user_agent_extra.""" custom_config = BotocoreConfig(user_agent_extra="existing-agent", read_timeout=900) - with unittest.mock.patch("strands.models.bedrock.boto3.Session") as mock_session_cls: - mock_session = mock_session_cls.return_value - _ = BedrockModel(boto_client_config=custom_config) + _ = BedrockModel(boto_client_config=custom_config) - # Verify the client was created with the correct config - mock_session.client.assert_called_once() - args, kwargs = mock_session.client.call_args - assert kwargs["service_name"] == "bedrock-runtime" - assert isinstance(kwargs["config"], BotocoreConfig) - assert kwargs["config"].user_agent_extra == "existing-agent strands-agents" - assert kwargs["config"].read_timeout == 900 + # Verify the client was created with the correct config + client = session_cls.return_value.client + client.assert_called_once() + args, kwargs = client.call_args + assert kwargs["service_name"] == "bedrock-runtime" + assert isinstance(kwargs["config"], BotocoreConfig) + assert kwargs["config"].user_agent_extra == "existing-agent strands-agents" + assert kwargs["config"].read_timeout == 900 def test__init__model_config(bedrock_client): @@ -286,6 +288,55 @@ def test__init__model_config(bedrock_client): assert tru_max_tokens == exp_max_tokens +def test__init__context_window_limit(bedrock_client): + _ = bedrock_client + + model = BedrockModel(context_window_limit=200_000) + + assert model.get_config().get("context_window_limit") == 200_000 + assert model.context_window_limit == 200_000 + + +def test__init__auto_populates_context_window_limit(bedrock_client): + _ = bedrock_client + + model = BedrockModel(model_id="anthropic.claude-sonnet-4-20250514-v1:0") + + assert model.get_config().get("context_window_limit") == 1_000_000 + + +def test__init__auto_populates_context_window_limit_cross_region(bedrock_client): + _ = bedrock_client + + model = BedrockModel(model_id="us.anthropic.claude-sonnet-4-6") + + assert model.get_config().get("context_window_limit") == 1_000_000 + + +def test__init__auto_populates_context_window_limit_default_model(bedrock_client): + _ = bedrock_client + + model = BedrockModel() + + assert model.get_config().get("context_window_limit") == 1_000_000 + + +def test__init__explicit_context_window_limit_not_overridden(bedrock_client): + _ = bedrock_client + + model = BedrockModel(model_id="anthropic.claude-sonnet-4-20250514-v1:0", context_window_limit=100_000) + + assert model.get_config().get("context_window_limit") == 100_000 + + +def test__init__unknown_model_no_context_window_limit(bedrock_client): + _ = bedrock_client + + model = BedrockModel(model_id="unknown.model-v1:0") + + assert model.get_config().get("context_window_limit") is None + + def test_update_config(model, model_id): model.update_config(model_id=model_id) @@ -377,6 +428,20 @@ def test_format_request_guardrail_config_without_trace_or_stream_processing_mode assert tru_request == exp_request +def test_format_request_with_service_tier(model, messages, model_id): + model.update_config(service_tier="flex") + tru_request = model._format_request(messages) + exp_request = { + "inferenceConfig": {}, + "modelId": model_id, + "messages": messages, + "serviceTier": {"type": "flex"}, + "system": [], + } + + assert tru_request == exp_request + + def test_format_request_inference_config(model, messages, model_id, inference_config): model.update_config(**inference_config) tru_request = model._format_request(messages) @@ -469,6 +534,188 @@ def test_format_request_tool_specs(model, messages, model_id, tool_spec): assert tru_request == exp_request +def test_format_request_strict_tools_injects_strict_and_closes_schema(bedrock_client, model_id, messages): + tool_specs = [ + { + "name": "my_tool", + "description": "A tool", + "inputSchema": { + "json": { + "type": "object", + "properties": {"param": {"type": "string"}}, + "required": ["param"], + } + }, + } + ] + model = BedrockModel(model_id=model_id, strict_tools=True) + request = model._format_request(messages, tool_specs=tool_specs) + tool_spec_result = request["toolConfig"]["tools"][0]["toolSpec"] + + assert tool_spec_result == { + "name": "my_tool", + "description": "A tool", + "inputSchema": { + "json": { + "type": "object", + "properties": {"param": {"type": "string"}}, + "required": ["param"], + "additionalProperties": False, + } + }, + "strict": True, + } + + +def test_format_request_strict_tools_does_not_mutate_original(bedrock_client, model_id, messages): + tool_specs = [ + { + "name": "my_tool", + "description": "A tool", + "inputSchema": { + "json": { + "type": "object", + "properties": {"param": {"type": "string"}}, + "required": ["param"], + } + }, + } + ] + model = BedrockModel(model_id=model_id, strict_tools=True) + model._format_request(messages, tool_specs=tool_specs) + + assert "additionalProperties" not in tool_specs[0]["inputSchema"]["json"] + + +def test_format_request_strict_tools_preserves_additional_properties_true(bedrock_client, model_id, messages): + tool_specs = [ + { + "name": "my_tool", + "description": "A tool", + "inputSchema": { + "json": { + "type": "object", + "properties": {"param": {"type": "string"}}, + "required": ["param"], + "additionalProperties": True, + } + }, + } + ] + model = BedrockModel(model_id=model_id, strict_tools=True) + request = model._format_request(messages, tool_specs=tool_specs) + schema = request["toolConfig"]["tools"][0]["toolSpec"]["inputSchema"]["json"] + + assert schema["additionalProperties"] is True + + +def test_format_request_strict_tools_nested_objects(bedrock_client, model_id, messages): + tool_specs = [ + { + "name": "my_tool", + "description": "A tool", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "config": { + "type": "object", + "properties": {"value": {"type": "integer"}}, + } + }, + "required": ["config"], + } + }, + } + ] + model = BedrockModel(model_id=model_id, strict_tools=True) + request = model._format_request(messages, tool_specs=tool_specs) + schema = request["toolConfig"]["tools"][0]["toolSpec"]["inputSchema"]["json"] + + assert schema == { + "type": "object", + "properties": { + "config": { + "type": "object", + "properties": {"value": {"type": "integer"}}, + "additionalProperties": False, + } + }, + "required": ["config"], + "additionalProperties": False, + } + + +def test_format_request_strict_tools_default_no_strict(bedrock_client, model_id, messages): + tool_specs = [ + { + "name": "my_tool", + "description": "A tool", + "inputSchema": { + "json": { + "type": "object", + "properties": {"param": {"type": "string"}}, + "required": ["param"], + } + }, + } + ] + model = BedrockModel(model_id=model_id) + request = model._format_request(messages, tool_specs=tool_specs) + tool_spec_result = request["toolConfig"]["tools"][0]["toolSpec"] + + assert "strict" not in tool_spec_result + assert tool_spec_result["inputSchema"]["json"] == { + "type": "object", + "properties": {"param": {"type": "string"}}, + "required": ["param"], + } + + +def test_format_request_strict_tools_false_no_strict(bedrock_client, model_id, messages): + tool_specs = [ + { + "name": "my_tool", + "description": "A tool", + "inputSchema": {"json": {"type": "object", "properties": {"x": {"type": "string"}}}}, + } + ] + model = BedrockModel(model_id=model_id, strict_tools=False) + request = model._format_request(messages, tool_specs=tool_specs) + tool_spec_result = request["toolConfig"]["tools"][0]["toolSpec"] + + assert "strict" not in tool_spec_result + + +def test_format_request_strict_tools_none_no_strict(bedrock_client, model_id, messages): + tool_specs = [ + { + "name": "my_tool", + "description": "A tool", + "inputSchema": {"json": {"type": "object", "properties": {"x": {"type": "string"}}}}, + } + ] + model = BedrockModel(model_id=model_id, strict_tools=None) + request = model._format_request(messages, tool_specs=tool_specs) + tool_spec_result = request["toolConfig"]["tools"][0]["toolSpec"] + + assert "strict" not in tool_spec_result + + +def test_format_request_strict_tools_applies_to_all_tools(bedrock_client, model_id, messages): + tool_specs = [ + {"name": "tool_a", "description": "Tool A", "inputSchema": {"json": {"type": "object", "properties": {}}}}, + {"name": "tool_b", "description": "Tool B", "inputSchema": {"json": {"type": "object", "properties": {}}}}, + ] + model = BedrockModel(model_id=model_id, strict_tools=True) + request = model._format_request(messages, tool_specs=tool_specs) + + for tool in request["toolConfig"]["tools"]: + if "toolSpec" in tool: + assert tool["toolSpec"]["strict"] is True + assert tool["toolSpec"]["inputSchema"]["json"]["additionalProperties"] is False + + def test_format_request_tool_choice_auto(model, messages, model_id, tool_spec): tool_choice = {"auto": {}} tru_request = model._format_request(messages, [tool_spec], tool_choice=tool_choice) @@ -1515,10 +1762,34 @@ async def test_add_note_on_validation_exception_throughput(bedrock_client, model ] +@pytest.mark.parametrize( + "overflow_message", + [ + "Input is too long for requested model", + "input length and `max_tokens` exceed context limit", + "too many total text bytes", + "prompt is too long: 903884 tokens > 200000 maximum", + ], +) +@pytest.mark.asyncio +async def test_stream_context_window_overflow(overflow_message, bedrock_client, model, alist, messages): + """Test that ClientError with overflow messages raises ContextWindowOverflowException.""" + error_response = { + "Error": { + "Code": "ValidationException", + "Message": f"An error occurred (ValidationException) when calling the ConverseStream operation: " + f"The model returned the following errors: {overflow_message}", + } + } + bedrock_client.converse_stream.side_effect = ClientError(error_response, "ConverseStream") + + with pytest.raises(ContextWindowOverflowException): + await alist(model.stream(messages)) + + @pytest.mark.asyncio async def test_stream_logging(bedrock_client, model, messages, caplog, alist): """Test that stream method logs debug messages at the expected stages.""" - import logging # Set the logger to debug level to capture debug messages caplog.set_level(logging.DEBUG, logger="strands.models.bedrock") @@ -1539,53 +1810,6 @@ async def test_stream_logging(bedrock_client, model, messages, caplog, alist): assert "finished streaming response from model" in log_text -@pytest.mark.asyncio -async def test_stream_stop_reason_override_streaming(bedrock_client, model, messages, alist): - """Test that stopReason is overridden from end_turn to tool_use in streaming mode when tool use is detected.""" - bedrock_client.converse_stream.return_value = { - "stream": [ - {"messageStart": {"role": "assistant"}}, - {"contentBlockStart": {"start": {"toolUse": {"toolUseId": "123", "name": "test_tool"}}}}, - {"contentBlockDelta": {"delta": {"test": {"input": '{"param": "value"}'}}}}, - {"contentBlockStop": {}}, - {"messageStop": {"stopReason": "end_turn"}}, - ] - } - - response = model.stream(messages) - events = await alist(response) - - # Find the messageStop event - message_stop_event = next(event for event in events if "messageStop" in event) - - # Verify stopReason was overridden to tool_use - assert message_stop_event["messageStop"]["stopReason"] == "tool_use" - - -@pytest.mark.asyncio -async def test_stream_stop_reason_override_non_streaming(bedrock_client, alist, messages): - """Test that stopReason is overridden from end_turn to tool_use in non-streaming mode when tool use is detected.""" - bedrock_client.converse.return_value = { - "output": { - "message": { - "role": "assistant", - "content": [{"toolUse": {"toolUseId": "123", "name": "test_tool", "input": {"param": "value"}}}], - } - }, - "stopReason": "end_turn", - } - - model = BedrockModel(model_id="test-model", streaming=False) - response = model.stream(messages) - events = await alist(response) - - # Find the messageStop event - message_stop_event = next(event for event in events if "messageStop" in event) - - # Verify stopReason was overridden to tool_use - assert message_stop_event["messageStop"]["stopReason"] == "tool_use" - - def test_format_request_cleans_tool_result_content_blocks(model, model_id): messages = [ { @@ -1614,6 +1838,87 @@ def test_format_request_cleans_tool_result_content_blocks(model, model_id): assert "status" not in tool_result +def test_format_request_message_content_normalizes_empty_tool_result_content(model, model_id): + """Test that _format_request_message_content replaces empty toolResult content with a minimal text block. + + Some model providers (e.g., Nemotron) reject toolResult blocks with content: [] via the + Converse API, while others (e.g., Claude) accept them. The SDK should normalize empty + content arrays to ensure cross-model compatibility. + + See: https://github.com/strands-agents/sdk-python/issues/2122 + """ + messages = [ + {"role": "user", "content": [{"text": "List tables"}]}, + { + "role": "assistant", + "content": [ + {"text": "Querying...\n"}, + {"toolUse": {"toolUseId": "tool_001", "name": "run_query", "input": {"sql": "SELECT 1"}}}, + ], + }, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "tool_001", "content": []}}, + ], + }, + ] + + formatted_request = model._format_request(messages) + + tool_result = formatted_request["messages"][2]["content"][0]["toolResult"] + assert tool_result["content"] == [{"text": ""}], "Empty toolResult content should be normalized to [{'text': ''}]" + + +def test_format_request_message_content_does_not_mutate_empty_tool_result(model, model_id): + """Test that normalizing empty toolResult content does not mutate the original messages.""" + messages = [ + {"role": "user", "content": [{"text": "List tables"}]}, + { + "role": "assistant", + "content": [ + {"toolUse": {"toolUseId": "tool_001", "name": "run_query", "input": {"sql": "SELECT 1"}}}, + ], + }, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "tool_001", "content": []}}, + ], + }, + ] + + original_content = messages[2]["content"][0]["toolResult"]["content"] + model._format_request(messages) + + assert original_content == [], "Original empty content list should not be mutated" + + +def test_format_request_message_content_preserves_nonempty_tool_result_content(model, model_id): + """Test that _format_request_message_content does not modify non-empty toolResult content.""" + messages = [ + {"role": "user", "content": [{"text": "List tables"}]}, + { + "role": "assistant", + "content": [ + {"text": "Querying...\n"}, + {"toolUse": {"toolUseId": "tool_001", "name": "run_query", "input": {"sql": "SELECT 1"}}}, + ], + }, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "tool_001", "content": [{"text": "some result"}]}}, + ], + }, + ] + + formatted_request = model._format_request(messages) + + tool_result = formatted_request["messages"][2]["content"][0]["toolResult"] + assert tool_result["content"] == [{"text": "some result"}] + + def test_format_request_removes_status_field_when_configured(model, model_id): model.update_config(include_tool_result_status=False) @@ -1786,8 +2091,8 @@ def test_format_request_filters_image_content_blocks(model, model_id): assert "metadata" not in image_block -def test_format_request_filters_nested_image_s3_fields(model, model_id): - """Test that s3Location is filtered out and only bytes source is preserved.""" +def test_format_request_image_s3_location_only(model, model_id): + """Test that image with only s3Location is properly formatted.""" messages = [ { "role": "user", @@ -1796,8 +2101,7 @@ def test_format_request_filters_nested_image_s3_fields(model, model_id): "image": { "format": "png", "source": { - "bytes": b"image_data", - "s3Location": {"bucket": "my-bucket", "key": "image.png", "extraField": "filtered"}, + "location": {"type": "s3", "uri": "s3://my-bucket/image.png"}, }, } } @@ -1808,61 +2112,199 @@ def test_format_request_filters_nested_image_s3_fields(model, model_id): formatted_request = model._format_request(messages) image_source = formatted_request["messages"][0]["content"][0]["image"]["source"] - assert image_source == {"bytes": b"image_data"} - assert "s3Location" not in image_source + assert image_source == {"s3Location": {"uri": "s3://my-bucket/image.png"}} -def test_format_request_filters_document_content_blocks(model, model_id): - """Test that format_request filters extra fields from document content blocks.""" +def test_format_request_image_bytes_only(model, model_id): + """Test that image with only bytes source is properly formatted.""" messages = [ { "role": "user", "content": [ { - "document": { - "name": "test.pdf", - "source": {"bytes": b"pdf_data"}, - "format": "pdf", - "extraField": "should be removed", - "metadata": {"pages": 10}, + "image": { + "format": "png", + "source": {"bytes": b"image_data"}, } - }, + } ], } ] formatted_request = model._format_request(messages) + image_source = formatted_request["messages"][0]["content"][0]["image"]["source"] - document_block = formatted_request["messages"][0]["content"][0]["document"] - expected = {"name": "test.pdf", "source": {"bytes": b"pdf_data"}, "format": "pdf"} - assert document_block == expected - assert "extraField" not in document_block - assert "metadata" not in document_block + assert image_source == {"bytes": b"image_data"} -def test_format_request_filters_nested_reasoning_content(model, model_id): - """Test deep filtering of nested reasoningText fields.""" +def test_format_request_document_s3_location(model, model_id): + """Test that document with s3Location is properly formatted.""" messages = [ { - "role": "assistant", + "role": "user", "content": [ { - "reasoningContent": { - "reasoningText": {"text": "thinking...", "signature": "abc123", "extraField": "filtered"} + "document": { + "name": "report.pdf", + "format": "pdf", + "source": { + "location": {"type": "s3", "uri": "s3://my-bucket/report.pdf"}, + }, } - } + }, + { + "document": { + "name": "report.pdf", + "format": "pdf", + "source": { + "location": { + "type": "s3", + "uri": "s3://my-bucket/report.pdf", + "bucketOwner": "123456789012", + }, + }, + } + }, ], } ] formatted_request = model._format_request(messages) - reasoning_text = formatted_request["messages"][0]["content"][0]["reasoningContent"]["reasoningText"] + document = formatted_request["messages"][0]["content"][0]["document"] + document_with_bucket_owner = formatted_request["messages"][0]["content"][1]["document"] - assert reasoning_text == {"text": "thinking...", "signature": "abc123"} + assert document["source"] == {"s3Location": {"uri": "s3://my-bucket/report.pdf"}} + + assert document_with_bucket_owner["source"] == { + "s3Location": {"uri": "s3://my-bucket/report.pdf", "bucketOwner": "123456789012"} + } -def test_format_request_filters_video_content_blocks(model, model_id): - """Test that format_request filters extra fields from video content blocks.""" +def test_format_request_unsupported_location(model, caplog): + """Test that document with s3Location is properly formatted.""" + + caplog.set_level(logging.WARNING, logger="strands.models.bedrock") + + messages = [ + { + "role": "user", + "content": [ + {"text": "Hello!"}, + { + "document": { + "name": "report.pdf", + "format": "pdf", + "source": { + "location": { + "type": "other", + }, + }, + } + }, + { + "video": { + "format": "mp4", + "source": { + "location": { + "type": "other", + }, + }, + } + }, + { + "image": { + "format": "png", + "source": { + "location": { + "type": "other", + }, + }, + } + }, + ], + } + ] + + formatted_request = model._format_request(messages) + assert len(formatted_request["messages"][0]["content"]) == 1 + assert "Non s3 location sources are not supported by Bedrock | skipping content block" in caplog.text + + +def test_format_request_video_s3_location(model, model_id): + """Test that video with s3Location is properly formatted.""" + messages = [ + { + "role": "user", + "content": [ + { + "video": { + "format": "mp4", + "source": { + "location": {"type": "s3", "uri": "s3://my-bucket/video.mp4"}, + }, + } + }, + ], + } + ] + + formatted_request = model._format_request(messages) + video_source = formatted_request["messages"][0]["content"][0]["video"]["source"] + + assert video_source == {"s3Location": {"uri": "s3://my-bucket/video.mp4"}} + + +def test_format_request_filters_document_content_blocks(model, model_id): + """Test that format_request filters extra fields from document content blocks.""" + messages = [ + { + "role": "user", + "content": [ + { + "document": { + "name": "test.pdf", + "source": {"bytes": b"pdf_data"}, + "format": "pdf", + "extraField": "should be removed", + "metadata": {"pages": 10}, + } + }, + ], + } + ] + + formatted_request = model._format_request(messages) + + document_block = formatted_request["messages"][0]["content"][0]["document"] + expected = {"name": "test.pdf", "source": {"bytes": b"pdf_data"}, "format": "pdf"} + assert document_block == expected + assert "extraField" not in document_block + assert "metadata" not in document_block + + +def test_format_request_filters_nested_reasoning_content(model, model_id): + """Test deep filtering of nested reasoningText fields.""" + messages = [ + { + "role": "assistant", + "content": [ + { + "reasoningContent": { + "reasoningText": {"text": "thinking...", "signature": "abc123", "extraField": "filtered"} + } + } + ], + } + ] + + formatted_request = model._format_request(messages) + reasoning_text = formatted_request["messages"][0]["content"][0]["reasoningContent"]["reasoningText"] + + assert reasoning_text == {"text": "thinking...", "signature": "abc123"} + + +def test_format_request_filters_video_content_blocks(model, model_id): + """Test that format_request filters extra fields from video content blocks.""" messages = [ { "role": "user", @@ -1912,6 +2354,53 @@ def test_format_request_filters_cache_point_content_blocks(model, model_id): assert "extraField" not in cache_point_block +def test_format_request_preserves_cache_point_ttl(model, model_id): + """Test that format_request preserves the ttl field in cachePoint content blocks.""" + messages = [ + { + "role": "user", + "content": [ + { + "cachePoint": { + "type": "default", + "ttl": "1h", + } + }, + ], + } + ] + + formatted_request = model._format_request(messages) + + cache_point_block = formatted_request["messages"][0]["content"][0]["cachePoint"] + expected = {"type": "default", "ttl": "1h"} + assert cache_point_block == expected + assert cache_point_block["ttl"] == "1h" + + +def test_format_request_cache_point_without_ttl(model, model_id): + """Test that cache points work without ttl field (backward compatibility).""" + messages = [ + { + "role": "user", + "content": [ + { + "cachePoint": { + "type": "default", + } + }, + ], + } + ] + + formatted_request = model._format_request(messages) + + cache_point_block = formatted_request["messages"][0]["content"][0]["cachePoint"] + expected = {"type": "default"} + assert cache_point_block == expected + assert "ttl" not in cache_point_block + + def test_config_validation_warns_on_unknown_keys(bedrock_client, captured_warnings): """Test that unknown config keys emit a warning.""" BedrockModel(model_id="test-model", invalid_param="test") @@ -1946,42 +2435,24 @@ def test_tool_choice_none_no_warning(model, messages, captured_warnings): def test_get_default_model_with_warning_supported_regions_shows_no_warning(captured_warnings): - """Test get_model_prefix_with_warning doesn't warn for supported region prefixes.""" + """Test _get_default_model_with_warning doesn't warn for any region (global profile works everywhere).""" BedrockModel._get_default_model_with_warning("us-west-2") BedrockModel._get_default_model_with_warning("eu-west-2") - assert len(captured_warnings) == 0 - - -def test_get_default_model_for_supported_eu_region_returns_correct_model_id(captured_warnings): - model_id = BedrockModel._get_default_model_with_warning("eu-west-1") - assert model_id == "eu.anthropic.claude-sonnet-4-20250514-v1:0" - assert len(captured_warnings) == 0 - - -def test_get_default_model_for_supported_us_region_returns_correct_model_id(captured_warnings): - model_id = BedrockModel._get_default_model_with_warning("us-east-1") - assert model_id == "us.anthropic.claude-sonnet-4-20250514-v1:0" - assert len(captured_warnings) == 0 - - -def test_get_default_model_for_supported_gov_region_returns_correct_model_id(captured_warnings): - model_id = BedrockModel._get_default_model_with_warning("us-gov-west-1") - assert model_id == "us-gov.anthropic.claude-sonnet-4-20250514-v1:0" - assert len(captured_warnings) == 0 + assert all("does not support" not in str(w.message) for w in captured_warnings) -def test_get_model_prefix_for_ap_region_converts_to_apac_endpoint(captured_warnings): - """Test _get_default_model_with_warning warns for APAC regions since 'ap' is not in supported prefixes.""" - model_id = BedrockModel._get_default_model_with_warning("ap-southeast-1") - assert model_id == "apac.anthropic.claude-sonnet-4-20250514-v1:0" +def test_get_default_model_returns_global_inference_profile(captured_warnings): + """Default model id is the global inference profile regardless of region.""" + for region in ("us-east-1", "eu-west-1", "us-gov-west-1", "ap-southeast-1", "ca-central-1"): + assert BedrockModel._get_default_model_with_warning(region) == DEFAULT_BEDROCK_MODEL_ID + assert all("does not support" not in str(w.message) for w in captured_warnings) -def test_get_default_model_with_warning_unsupported_region_warns(captured_warnings): - """Test _get_default_model_with_warning warns for unsupported regions.""" +def test_get_default_model_with_warning_unsupported_region_does_not_warn(captured_warnings): + """Global inference profile works across all regions, so no region-support warning is emitted.""" BedrockModel._get_default_model_with_warning("ca-central-1") - assert len(captured_warnings) == 1 - assert "This region ca-central-1 does not support" in str(captured_warnings[0].message) - assert "our default inference endpoint" in str(captured_warnings[0].message) + region_warnings = [w for w in captured_warnings if "does not support" in str(w.message)] + assert len(region_warnings) == 0 def test_get_default_model_with_warning_no_warning_with_custom_model_id(captured_warnings): @@ -1993,12 +2464,12 @@ def test_get_default_model_with_warning_no_warning_with_custom_model_id(captured assert len(captured_warnings) == 0 -def test_init_with_unsupported_region_warns(session_cls, captured_warnings): - """Test BedrockModel initialization warns for unsupported regions.""" +def test_init_with_unsupported_region_does_not_warn(session_cls, captured_warnings): + """BedrockModel initialization does not warn for 'unsupported' regions when using the global profile.""" BedrockModel(region_name="ca-central-1") - assert len(captured_warnings) == 1 - assert "This region ca-central-1 does not support" in str(captured_warnings[0].message) + region_warnings = [w for w in captured_warnings if "does not support" in str(w.message)] + assert len(region_warnings) == 0 def test_init_with_unsupported_region_custom_model_no_warning(session_cls, captured_warnings): @@ -2013,11 +2484,35 @@ def test_override_default_model_id_uses_the_overriden_value(captured_warnings): assert model_id == "custom-overridden-model" -def test_no_override_uses_formatted_default_model_id(captured_warnings): +def test_default_model_sentinel_triggers_region_prefix_fallback(captured_warnings): + """When DEFAULT_BEDROCK_MODEL_ID matches the sentinel template, the region-prefix fallback runs.""" + sentinel = "us.anthropic.claude-sonnet-4-6" + with unittest.mock.patch("strands.models.bedrock.DEFAULT_BEDROCK_MODEL_ID", sentinel): + model_id = BedrockModel._get_default_model_with_warning("eu-west-1") + assert model_id == "eu.anthropic.claude-sonnet-4-6" + + +def test_caller_supplied_model_id_wins_over_global_default(captured_warnings): + """Caller-supplied model_id in config takes precedence over the global default.""" + model_config = {"model_id": "caller-supplied-model"} + model_id = BedrockModel._get_default_model_with_warning("us-east-1", model_config) + assert model_id == "caller-supplied-model" + + +def test_default_model_sentinel_with_unsupported_region_warns(captured_warnings): + """When the sentinel matches and the region is unknown, the region-unsupported warning fires.""" + sentinel = "us.anthropic.claude-sonnet-4-6" + with unittest.mock.patch("strands.models.bedrock.DEFAULT_BEDROCK_MODEL_ID", sentinel): + BedrockModel._get_default_model_with_warning("ca-central-1") + region_warnings = [w for w in captured_warnings if "does not support" in str(w.message)] + assert len(region_warnings) == 1 + + +def test_default_model_id_is_global_inference_profile(captured_warnings): model_id = BedrockModel._get_default_model_with_warning("us-east-1") - assert model_id == "us.anthropic.claude-sonnet-4-20250514-v1:0" - assert model_id != _DEFAULT_BEDROCK_MODEL_ID - assert len(captured_warnings) == 0 + assert model_id == "global.anthropic.claude-sonnet-4-6" + assert model_id == DEFAULT_BEDROCK_MODEL_ID + assert all("does not support" not in str(w.message) for w in captured_warnings) def test_custom_model_id_not_overridden_by_region_formatting(session_cls): @@ -2070,3 +2565,1058 @@ async def test_stream_backward_compatibility_system_prompt(bedrock_client, model "system": [{"text": system_prompt}], } bedrock_client.converse_stream.assert_called_once_with(**expected_request) + + +@pytest.mark.asyncio +async def test_citations_content_preserves_tagged_union_structure(bedrock_client, model, alist): + """Test that citationsContent preserves AWS Bedrock's required tagged union structure for citation locations. + + This test verifies that when messages contain citationsContent with tagged union CitationLocation objects, + the structure is preserved when sent to AWS Bedrock API. AWS Bedrock expects CitationLocation to be a + tagged union with exactly one wrapper key (documentChar, documentPage, documentChunk, searchResultLocation, web) + containing the location fields. + """ + # Mock the Bedrock response + bedrock_client.converse_stream.return_value = {"stream": []} + + # Messages with citationsContent using all tagged union CitationLocation types + messages = [ + {"role": "user", "content": [{"text": "Analyze multiple sources"}]}, + { + "role": "assistant", + "content": [ + { + "citationsContent": { + "citations": [ + { + "location": {"documentChar": {"documentIndex": 0, "start": 150, "end": 300}}, + "sourceContent": [ + {"text": "Employee benefits include health insurance and retirement plans"} + ], + "title": "Benefits Section", + }, + { + "location": {"documentPage": {"documentIndex": 0, "start": 2, "end": 3}}, + "sourceContent": [{"text": "Vacation policy allows 15 days per year"}], + "title": "Vacation Policy", + }, + { + "location": {"documentChunk": {"documentIndex": 1, "start": 5, "end": 8}}, + "sourceContent": [{"text": "Company culture emphasizes work-life balance"}], + "title": "Culture Section", + }, + { + "location": { + "searchResultLocation": { + "searchResultIndex": 0, + "start": 25, + "end": 150, + } + }, + "sourceContent": [{"text": "Search results show industry best practices"}], + "title": "Search Results", + }, + { + "location": { + "web": { + "url": "https://example.com/hr-policies", + "domain": "example.com", + } + }, + "sourceContent": [{"text": "External HR policy guidelines"}], + "title": "External Reference", + }, + ], + "content": [{"text": "Based on multiple sources, the company offers comprehensive benefits."}], + } + } + ], + }, + ] + + # Call the public stream method + await alist(model.stream(messages)) + + # Verify the request sent to Bedrock preserves the tagged union structure + bedrock_client.converse_stream.assert_called_once() + call_args = bedrock_client.converse_stream.call_args[1] + + # Extract the citationsContent from the formatted messages + formatted_messages = call_args["messages"] + citations_content = formatted_messages[1]["content"][0]["citationsContent"] + + # Verify the tagged union structure is preserved for all location types + expected_citations = [ + { + "location": {"documentChar": {"documentIndex": 0, "start": 150, "end": 300}}, + "sourceContent": [{"text": "Employee benefits include health insurance and retirement plans"}], + "title": "Benefits Section", + }, + { + "location": {"documentPage": {"documentIndex": 0, "start": 2, "end": 3}}, + "sourceContent": [{"text": "Vacation policy allows 15 days per year"}], + "title": "Vacation Policy", + }, + { + "location": {"documentChunk": {"documentIndex": 1, "start": 5, "end": 8}}, + "sourceContent": [{"text": "Company culture emphasizes work-life balance"}], + "title": "Culture Section", + }, + { + "location": { + "searchResultLocation": { + "searchResultIndex": 0, + "start": 25, + "end": 150, + } + }, + "sourceContent": [{"text": "Search results show industry best practices"}], + "title": "Search Results", + }, + { + "location": { + "web": { + "url": "https://example.com/hr-policies", + "domain": "example.com", + } + }, + "sourceContent": [{"text": "External HR policy guidelines"}], + "title": "External Reference", + }, + ] + + assert citations_content["citations"] == expected_citations, ( + "Citation location tagged union structure was not preserved. " + "AWS Bedrock requires CitationLocation to have exactly one wrapper key " + "(documentChar, documentPage, documentChunk, searchResultLocation, or web) " + "with the location fields nested inside." + ) + + +@pytest.mark.asyncio +async def test_format_request_with_guardrail_latest_message(model): + """Test that guardrail_latest_message wraps the latest user message with text and image.""" + model.update_config( + guardrail_id="test-guardrail", + guardrail_version="DRAFT", + guardrail_latest_message=True, + ) + + messages = [ + {"role": "user", "content": [{"text": "First message"}]}, + {"role": "assistant", "content": [{"text": "First response"}]}, + { + "role": "user", + "content": [ + {"text": "Look at this image"}, + {"image": {"format": "png", "source": {"bytes": b"fake_image_data"}}}, + ], + }, + ] + + request = model._format_request(messages) + formatted_messages = request["messages"] + + # All messages should be in the request + assert len(formatted_messages) == 3 + + # First user message should NOT be wrapped + assert "text" in formatted_messages[0]["content"][0] + assert formatted_messages[0]["content"][0]["text"] == "First message" + + # Assistant message should NOT be wrapped + assert "text" in formatted_messages[1]["content"][0] + assert formatted_messages[1]["content"][0]["text"] == "First response" + + # Latest user message text should be wrapped + assert "guardContent" in formatted_messages[2]["content"][0] + assert formatted_messages[2]["content"][0]["guardContent"]["text"]["text"] == "Look at this image" + + # Latest user message image should also be wrapped + assert "guardContent" in formatted_messages[2]["content"][1] + assert formatted_messages[2]["content"][1]["guardContent"]["image"]["format"] == "png" + + +@pytest.mark.asyncio +async def test_format_request_with_guardrail_latest_message_after_tool_use(model): + """Test that guardContent wraps the last user text message even when a toolResult follows it.""" + model.update_config( + guardrail_id="test-guardrail", + guardrail_version="DRAFT", + guardrail_latest_message=True, + ) + + messages = [ + {"role": "user", "content": [{"text": "First message"}]}, + {"role": "assistant", "content": [{"text": "First response"}]}, + {"role": "user", "content": [{"text": "what is the standard deduction?"}]}, + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "tool-1", + "name": "knowledge_base", + "input": {"query": "standard deduction"}, + } + } + ], + }, + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "tool-1", + "content": [{"text": "The standard deduction for 2024 is $14,600."}], + "status": "success", + } + } + ], + }, + ] + + request = model._format_request(messages) + formatted_messages = request["messages"] + + assert len(formatted_messages) == 5 + + # Earlier user message should NOT be wrapped + assert "text" in formatted_messages[0]["content"][0] + assert formatted_messages[0]["content"][0]["text"] == "First message" + + # Last user message with text content should be wrapped, even though a toolResult comes after + assert "guardContent" in formatted_messages[2]["content"][0] + assert formatted_messages[2]["content"][0]["guardContent"]["text"]["text"] == "what is the standard deduction?" + + # toolResult-only user message should NOT be wrapped + assert "toolResult" in formatted_messages[4]["content"][0] + assert "guardContent" not in formatted_messages[4]["content"][0] + + +@pytest.mark.asyncio +async def test_format_request_with_guardrail_latest_message_wraps_final_user_text(model): + """Test that guardContent wraps the last user message when it contains text content.""" + model.update_config( + guardrail_id="test-guardrail", + guardrail_version="DRAFT", + guardrail_latest_message=True, + ) + + messages = [ + {"role": "user", "content": [{"text": "First message"}]}, + {"role": "assistant", "content": [{"text": "First response"}]}, + {"role": "user", "content": [{"text": "Tell me about taxes"}]}, + ] + + request = model._format_request(messages) + formatted_messages = request["messages"] + + assert "guardContent" in formatted_messages[2]["content"][0] + assert formatted_messages[2]["content"][0]["guardContent"]["text"]["text"] == "Tell me about taxes" + + +@pytest.mark.asyncio +async def test_format_request_with_guardrail_multiple_sequential_tool_calls(model): + """Test guardContent with multiple tool calls in sequence (no new user input between).""" + model.update_config( + guardrail_id="test-guardrail", + guardrail_version="DRAFT", + guardrail_latest_message=True, + ) + + messages = [ + {"role": "user", "content": [{"text": "First question"}]}, + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "t1", "name": "tool1", "input": {}}}]}, + { + "role": "user", + "content": [{"toolResult": {"toolUseId": "t1", "content": [{"text": "Result 1"}], "status": "success"}}], + }, + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "t2", "name": "tool2", "input": {}}}]}, + { + "role": "user", + "content": [{"toolResult": {"toolUseId": "t2", "content": [{"text": "Result 2"}], "status": "success"}}], + }, + ] + + request = model._format_request(messages) + formatted_messages = request["messages"] + + # Should wrap the first user text message, not the toolResults + assert "guardContent" in formatted_messages[0]["content"][0] + assert formatted_messages[0]["content"][0]["guardContent"]["text"]["text"] == "First question" + + # toolResults should not be wrapped + assert "toolResult" in formatted_messages[2]["content"][0] + assert "guardContent" not in formatted_messages[2]["content"][0] + assert "toolResult" in formatted_messages[4]["content"][0] + assert "guardContent" not in formatted_messages[4]["content"][0] + + +@pytest.mark.asyncio +async def test_format_request_with_guardrail_image_before_tool_result(model): + """Test guardContent wraps image content even when toolResult follows.""" + model.update_config( + guardrail_id="test-guardrail", + guardrail_version="DRAFT", + guardrail_latest_message=True, + ) + + messages = [ + {"role": "user", "content": [{"image": {"format": "png", "source": {"bytes": b"fake"}}}]}, + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "t1", "name": "vision", "input": {}}}]}, + { + "role": "user", + "content": [{"toolResult": {"toolUseId": "t1", "content": [{"text": "I see a cat"}], "status": "success"}}], + }, + ] + + request = model._format_request(messages) + formatted_messages = request["messages"] + + # Image should be wrapped even though toolResult comes after + assert "guardContent" in formatted_messages[0]["content"][0] + assert "image" in formatted_messages[0]["content"][0]["guardContent"] + + +@pytest.mark.asyncio +async def test_format_request_with_guardrail_multiple_tool_results_same_message(model): + """Test guardContent with multiple parallel tool calls (multiple toolResults in one message).""" + model.update_config( + guardrail_id="test-guardrail", + guardrail_version="DRAFT", + guardrail_latest_message=True, + ) + + messages = [ + {"role": "user", "content": [{"text": "Question requiring multiple tools"}]}, + { + "role": "assistant", + "content": [ + {"toolUse": {"toolUseId": "t1", "name": "tool1", "input": {}}}, + {"toolUse": {"toolUseId": "t2", "name": "tool2", "input": {}}}, + ], + }, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "t1", "content": [{"text": "Result 1"}], "status": "success"}}, + {"toolResult": {"toolUseId": "t2", "content": [{"text": "Result 2"}], "status": "success"}}, + ], + }, + ] + + request = model._format_request(messages) + formatted_messages = request["messages"] + + # Should wrap the question + assert "guardContent" in formatted_messages[0]["content"][0] + assert formatted_messages[0]["content"][0]["guardContent"]["text"]["text"] == "Question requiring multiple tools" + + +def test_cache_strategy_anthropic_for_claude(bedrock_client): + """Test that _cache_strategy returns 'anthropic' for Claude models.""" + model = BedrockModel(model_id="us.anthropic.claude-sonnet-4-20250514-v1:0") + assert model._cache_strategy == "anthropic" + + model2 = BedrockModel(model_id="anthropic.claude-3-haiku-20240307-v1:0") + assert model2._cache_strategy == "anthropic" + + +def test_cache_strategy_none_for_non_claude(bedrock_client): + """Test that _cache_strategy returns None for unsupported models.""" + model = BedrockModel(model_id="amazon.nova-pro-v1:0") + assert model._cache_strategy is None + + +def test_inject_cache_point_adds_to_last_user(bedrock_client): + """Test that _inject_cache_point adds cache point to last user message.""" + model = BedrockModel( + model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", cache_config=CacheConfig(strategy="auto") + ) + + cleaned_messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there!"}]}, + {"role": "user", "content": [{"text": "How are you?"}]}, + ] + + model._inject_cache_point(cleaned_messages) + + assert len(cleaned_messages[2]["content"]) == 2 + assert "cachePoint" in cleaned_messages[2]["content"][-1] + assert cleaned_messages[2]["content"][-1]["cachePoint"]["type"] == "default" + assert len(cleaned_messages[1]["content"]) == 1 + + +def test_inject_cache_point_single_user_message(bedrock_client): + """Test that _inject_cache_point adds cache point to single user message.""" + model = BedrockModel( + model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", cache_config=CacheConfig(strategy="auto") + ) + + cleaned_messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + ] + + model._inject_cache_point(cleaned_messages) + + assert len(cleaned_messages) == 1 + assert len(cleaned_messages[0]["content"]) == 2 + assert "cachePoint" in cleaned_messages[0]["content"][-1] + + +def test_inject_cache_point_empty_messages(bedrock_client): + """Test that _inject_cache_point handles empty messages list.""" + model = BedrockModel( + model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", cache_config=CacheConfig(strategy="auto") + ) + + cleaned_messages = [] + model._inject_cache_point(cleaned_messages) + + assert cleaned_messages == [] + + +def test_inject_cache_point_with_tool_result_last_user(bedrock_client): + """Test that cache point is added to last user message even when it contains toolResult.""" + model = BedrockModel( + model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", cache_config=CacheConfig(strategy="auto") + ) + + cleaned_messages = [ + {"role": "user", "content": [{"text": "Use the tool"}]}, + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "t1", "name": "test_tool", "input": {}}}]}, + {"role": "user", "content": [{"toolResult": {"toolUseId": "t1", "content": [{"text": "Result"}]}}]}, + ] + + model._inject_cache_point(cleaned_messages) + + assert len(cleaned_messages[2]["content"]) == 2 + assert "cachePoint" in cleaned_messages[2]["content"][-1] + assert cleaned_messages[2]["content"][-1]["cachePoint"]["type"] == "default" + assert len(cleaned_messages[0]["content"]) == 1 + + +def test_inject_cache_point_skipped_for_non_claude(bedrock_client): + """Test that cache point injection is skipped for non-Claude models.""" + model = BedrockModel(model_id="amazon.nova-pro-v1:0", cache_config=CacheConfig(strategy="auto")) + + messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Response"}]}, + ] + + formatted = model._format_bedrock_messages(messages) + + assert len(formatted[0]["content"]) == 1 + assert "cachePoint" not in formatted[0]["content"][0] + assert len(formatted[1]["content"]) == 1 + assert "cachePoint" not in formatted[1]["content"][0] + + +def test_format_bedrock_messages_does_not_mutate_original(bedrock_client): + """Test that _format_bedrock_messages does not mutate original messages.""" + + model = BedrockModel( + model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", cache_config=CacheConfig(strategy="auto") + ) + + original_messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there!"}]}, + {"role": "user", "content": [{"text": "How are you?"}]}, + ] + + messages_before = copy.deepcopy(original_messages) + formatted = model._format_bedrock_messages(original_messages) + + assert original_messages == messages_before + assert "cachePoint" not in original_messages[2]["content"][-1] + assert "cachePoint" in formatted[2]["content"][-1] + + +def test_inject_cache_point_strips_existing_cache_points(bedrock_client): + """Test that _inject_cache_point strips existing cache points and adds new one at correct position.""" + model = BedrockModel( + model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", cache_config=CacheConfig(strategy="auto") + ) + + # Messages with existing cache points in various positions + cleaned_messages = [ + {"role": "user", "content": [{"text": "Hello"}, {"cachePoint": {"type": "default"}}]}, + {"role": "assistant", "content": [{"text": "First response"}, {"cachePoint": {"type": "default"}}]}, + {"role": "user", "content": [{"text": "Follow up"}]}, + {"role": "assistant", "content": [{"text": "Second response"}]}, + ] + + model._inject_cache_point(cleaned_messages) + + # All old cache points should be stripped + assert len(cleaned_messages[0]["content"]) == 1 # first user: only text + assert len(cleaned_messages[1]["content"]) == 1 # first assistant: only text + assert len(cleaned_messages[3]["content"]) == 1 # last assistant: only text + + # New cache point should be at end of last user message + assert len(cleaned_messages[2]["content"]) == 2 + assert "cachePoint" in cleaned_messages[2]["content"][-1] + + +def test_inject_cache_point_anthropic_strategy_skips_model_check(bedrock_client): + """Test that anthropic strategy injects cache point without model support check.""" + model = BedrockModel( + model_id="arn:aws:bedrock:us-east-1:123456789012:application-inference-profile/a1b2c3d4e5f6", + cache_config=CacheConfig(strategy="anthropic"), + ) + + messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Response"}]}, + ] + + formatted = model._format_bedrock_messages(messages) + + assert len(formatted[0]["content"]) == 2 + assert "cachePoint" in formatted[0]["content"][-1] + assert formatted[0]["content"][-1]["cachePoint"]["type"] == "default" + assert len(formatted[1]["content"]) == 1 + + +def test_inject_cache_point_auto_strategy_resolves_to_anthropic_for_claude(bedrock_client): + """Test that auto strategy resolves to anthropic strategy for Claude models.""" + model = BedrockModel( + model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", cache_config=CacheConfig(strategy="auto") + ) + + messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Response"}]}, + ] + + formatted = model._format_bedrock_messages(messages) + + assert len(formatted[0]["content"]) == 2 + assert "cachePoint" in formatted[0]["content"][-1] + assert len(formatted[1]["content"]) == 1 + + +def test_find_last_user_text_message_index_no_user_messages(bedrock_client): + """Test _find_last_user_text_message_index returns None when no user text messages exist.""" + model = BedrockModel(model_id="test-model") + + messages = [ + {"role": "assistant", "content": [{"text": "hello"}]}, + ] + + assert model._find_last_user_text_message_index(messages) is None + + +def test_find_last_user_text_message_index_only_tool_results(bedrock_client): + """Test _find_last_user_text_message_index returns None when user messages only have toolResult.""" + model = BedrockModel(model_id="test-model") + + messages = [ + { + "role": "user", + "content": [{"toolResult": {"toolUseId": "t1", "content": [{"text": "result"}]}}], + }, + ] + + assert model._find_last_user_text_message_index(messages) is None + + +def test_find_last_user_text_message_index_returns_last_text_message(bedrock_client): + """Test _find_last_user_text_message_index returns the index of the last user message with text.""" + model = BedrockModel(model_id="test-model") + + messages = [ + {"role": "user", "content": [{"text": "First question"}]}, + {"role": "assistant", "content": [{"text": "Response"}]}, + {"role": "user", "content": [{"text": "Second question"}]}, + ] + + assert model._find_last_user_text_message_index(messages) == 2 + + +def test_find_last_user_text_message_index_skips_tool_result_messages(bedrock_client): + """Test _find_last_user_text_message_index skips toolResult-only user messages.""" + model = BedrockModel(model_id="test-model") + + messages = [ + {"role": "user", "content": [{"text": "Question"}]}, + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "t1", "name": "tool", "input": {}}}]}, + { + "role": "user", + "content": [{"toolResult": {"toolUseId": "t1", "content": [{"text": "Result"}]}}], + }, + ] + + assert model._find_last_user_text_message_index(messages) == 0 + + +def test_find_last_user_text_message_index_finds_image_message(bedrock_client): + """Test _find_last_user_text_message_index finds user messages with image content.""" + model = BedrockModel(model_id="test-model") + + messages = [ + {"role": "user", "content": [{"image": {"format": "png", "source": {"bytes": b"fake"}}}]}, + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "t1", "name": "vision", "input": {}}}]}, + { + "role": "user", + "content": [{"toolResult": {"toolUseId": "t1", "content": [{"text": "Result"}]}}], + }, + ] + + assert model._find_last_user_text_message_index(messages) == 0 + + +def test_find_last_user_text_message_index_empty_messages(bedrock_client): + """Test _find_last_user_text_message_index returns None for empty message list.""" + model = BedrockModel(model_id="test-model") + + assert model._find_last_user_text_message_index([]) is None + + +def test_guardrail_latest_message_disabled_does_not_wrap(model): + """Test that guardContent wrapping is skipped when guardrail_latest_message is not set.""" + messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + ] + + request = model._format_request(messages) + formatted = request["messages"][0]["content"][0] + + assert "text" in formatted + assert "guardContent" not in formatted + + +@pytest.mark.asyncio +async def test_non_streaming_citations_with_missing_optional_fields(bedrock_client, model, alist): + """Test that _convert_non_streaming_to_streaming handles citations missing optional fields. + + Nova grounding returns citations with only url/domain but no title field. The conversion + should not crash with KeyError when optional fields like title, location, or sourceContent + are missing from the citation response. + """ + # Simulate a non-streaming response with citations missing the 'title' field + # This is what Nova grounding returns: url+domain in location, no title + non_streaming_response = { + "output": { + "message": { + "role": "assistant", + "content": [ + { + "citationsContent": { + "content": [{"text": "Top shoe brands include Nike and Adidas."}], + "citations": [ + { + "location": { + "web": { + "url": "https://example.com/shoes", + "domain": "example.com", + } + }, + }, + ], + } + } + ], + } + }, + "stopReason": "end_turn", + "usage": {"inputTokens": 10, "outputTokens": 20}, + } + + events = list(model._convert_non_streaming_to_streaming(non_streaming_response)) + + # Should have: messageStart, contentBlockDelta (text + citation), contentBlockStop, messageStop, metadata + citation_deltas = [ + e for e in events if "contentBlockDelta" in e and "citation" in e.get("contentBlockDelta", {}).get("delta", {}) + ] + assert len(citation_deltas) == 1 + + citation = citation_deltas[0]["contentBlockDelta"]["delta"]["citation"] + # title should NOT be present since the source didn't have it + assert "title" not in citation + # location should be present + assert "location" in citation + # sourceContent should NOT be present since the source didn't have it + assert "sourceContent" not in citation + + +@pytest.mark.asyncio +async def test_non_streaming_citations_with_all_fields_present(bedrock_client, model, alist): + """Test that _convert_non_streaming_to_streaming correctly includes all fields when present.""" + non_streaming_response = { + "output": { + "message": { + "role": "assistant", + "content": [ + { + "citationsContent": { + "content": [{"text": "Nike is a top shoe brand."}], + "citations": [ + { + "title": "Top Shoe Brands", + "location": { + "web": { + "url": "https://example.com/shoes", + "domain": "example.com", + } + }, + "sourceContent": [{"text": "Nike is a leading brand"}], + }, + ], + } + } + ], + } + }, + "stopReason": "end_turn", + "usage": {"inputTokens": 10, "outputTokens": 20}, + } + + events = list(model._convert_non_streaming_to_streaming(non_streaming_response)) + + citation_deltas = [ + e for e in events if "contentBlockDelta" in e and "citation" in e.get("contentBlockDelta", {}).get("delta", {}) + ] + assert len(citation_deltas) == 1 + + citation = citation_deltas[0]["contentBlockDelta"]["delta"]["citation"] + assert citation["title"] == "Top Shoe Brands" + assert citation["location"] == {"web": {"url": "https://example.com/shoes", "domain": "example.com"}} + assert citation["sourceContent"] == [{"text": "Nike is a leading brand"}] + + +@pytest.mark.asyncio +async def test_non_streaming_citations_with_only_location(bedrock_client, model, alist): + """Test citations with only location field (no title, no sourceContent).""" + non_streaming_response = { + "output": { + "message": { + "role": "assistant", + "content": [ + { + "citationsContent": { + "citations": [ + { + "location": { + "web": { + "url": "https://example.com", + "domain": "example.com", + } + }, + }, + ], + } + } + ], + } + }, + "stopReason": "end_turn", + "usage": {"inputTokens": 5, "outputTokens": 10}, + } + + events = list(model._convert_non_streaming_to_streaming(non_streaming_response)) + + citation_deltas = [ + e for e in events if "contentBlockDelta" in e and "citation" in e.get("contentBlockDelta", {}).get("delta", {}) + ] + assert len(citation_deltas) == 1 + + citation = citation_deltas[0]["contentBlockDelta"]["delta"]["citation"] + assert citation["location"] == {"web": {"url": "https://example.com", "domain": "example.com"}} + assert "title" not in citation + assert "sourceContent" not in citation + + +class TestCountTokens: + """Tests for BedrockModel.count_tokens native token counting.""" + + @pytest.fixture(autouse=True) + def clean_cache(self): + _clear_skip_count_tokens_cache() + yield + _clear_skip_count_tokens_cache() + + @pytest.fixture + def model_with_client(self, bedrock_client, model_id): + _ = bedrock_client + return BedrockModel(model_id=model_id, use_native_token_count=True) + + @pytest.fixture + def messages(self): + return [{"role": "user", "content": [{"text": "hello"}]}] + + @pytest.fixture + def tool_specs(self): + return [ + { + "name": "test_tool", + "description": "A test tool", + "inputSchema": {"json": {"type": "object", "properties": {}}}, + } + ] + + @pytest.mark.asyncio + async def test_native_count_tokens_success(self, model_with_client, bedrock_client, messages): + bedrock_client.count_tokens.return_value = {"inputTokens": 42} + + result = await model_with_client.count_tokens(messages=messages) + + assert result == 42 + bedrock_client.count_tokens.assert_called_once() + call_kwargs = bedrock_client.count_tokens.call_args[1] + assert "input" in call_kwargs + assert "converse" in call_kwargs["input"] + + @pytest.mark.asyncio + async def test_native_count_tokens_with_system_prompt(self, model_with_client, bedrock_client, messages): + bedrock_client.count_tokens.return_value = {"inputTokens": 55} + + result = await model_with_client.count_tokens(messages=messages, system_prompt="Be helpful.") + + assert result == 55 + call_kwargs = bedrock_client.count_tokens.call_args[1] + assert call_kwargs["input"]["converse"]["system"] == [{"text": "Be helpful."}] + assert "toolConfig" not in call_kwargs["input"]["converse"] + + @pytest.mark.asyncio + async def test_native_count_tokens_with_tool_specs(self, model_with_client, bedrock_client, messages, tool_specs): + bedrock_client.count_tokens.return_value = {"inputTokens": 100} + + result = await model_with_client.count_tokens(messages=messages, tool_specs=tool_specs) + + assert result == 100 + call_kwargs = bedrock_client.count_tokens.call_args[1] + assert "toolConfig" in call_kwargs["input"]["converse"] + + @pytest.mark.asyncio + async def test_native_count_tokens_with_system_prompt_content(self, model_with_client, bedrock_client, messages): + bedrock_client.count_tokens.return_value = {"inputTokens": 60} + + result = await model_with_client.count_tokens( + messages=messages, + system_prompt_content=[{"text": "Be helpful."}, {"text": "Be concise."}], + ) + + assert result == 60 + call_kwargs = bedrock_client.count_tokens.call_args[1] + assert call_kwargs["input"]["converse"]["system"] == [{"text": "Be helpful."}, {"text": "Be concise."}] + + @pytest.mark.asyncio + async def test_native_count_tokens_strips_inference_config(self, model_with_client, bedrock_client, messages): + bedrock_client.count_tokens.return_value = {"inputTokens": 10} + model_with_client.update_config(max_tokens=100) + + await model_with_client.count_tokens(messages=messages) + + call_kwargs = bedrock_client.count_tokens.call_args[1] + converse = call_kwargs["input"]["converse"] + assert "inferenceConfig" not in converse + assert "additionalModelRequestFields" not in converse + assert "guardrailConfig" not in converse + + @pytest.mark.asyncio + async def test_fallback_on_api_error(self, model_with_client, bedrock_client, messages): + bedrock_client.count_tokens.side_effect = ClientError( + {"Error": {"Code": "ValidationException", "Message": "Unsupported"}}, + "CountTokens", + ) + + result = await model_with_client.count_tokens(messages=messages) + + assert isinstance(result, int) + assert result >= 0 + + @pytest.mark.asyncio + async def test_fallback_on_generic_exception(self, model_with_client, bedrock_client, messages): + bedrock_client.count_tokens.side_effect = RuntimeError("Connection failed") + + result = await model_with_client.count_tokens(messages=messages) + + assert isinstance(result, int) + assert result >= 0 + + @pytest.mark.asyncio + async def test_fallback_on_none_input_tokens(self, model_with_client, bedrock_client, messages): + bedrock_client.count_tokens.return_value = {} + + result = await model_with_client.count_tokens(messages=messages) + + assert isinstance(result, int) + assert result >= 0 + + @pytest.mark.asyncio + async def test_fallback_logs_debug(self, model_with_client, bedrock_client, messages, caplog): + bedrock_client.count_tokens.side_effect = RuntimeError("API down") + + with caplog.at_level(logging.DEBUG, logger="strands.models.bedrock"): + await model_with_client.count_tokens(messages=messages) + + assert any("native token counting failed" in record.message for record in caplog.records) + + @pytest.mark.asyncio + async def test_caches_model_id_when_count_tokens_unsupported(self, bedrock_client, messages): + model = BedrockModel(model_id="unsupported-cache-test-model", use_native_token_count=True) + bedrock_client.count_tokens.side_effect = ClientError( + {"Error": {"Code": "ValidationException", "Message": "The provided model doesn't support counting tokens"}}, + "CountTokens", + ) + + # First call: hits API, gets error, caches + await model.count_tokens(messages=messages) + assert bedrock_client.count_tokens.call_count == 1 + + # Second call: skips API entirely + await model.count_tokens(messages=messages) + assert bedrock_client.count_tokens.call_count == 1 + + @pytest.mark.asyncio + async def test_caches_model_id_when_access_denied(self, bedrock_client, messages): + model = BedrockModel(model_id="access-denied-cache-test-model", use_native_token_count=True) + bedrock_client.count_tokens.side_effect = ClientError( + { + "Error": { + "Code": "AccessDeniedException", + "Message": "User: arn:aws:sts::123456789012:assumed-role/role is not authorized" + " to perform: bedrock:CountTokens", + } + }, + "CountTokens", + ) + + # First call: hits API, gets error, caches + await model.count_tokens(messages=messages) + bedrock_client.count_tokens.assert_called_once() + + # Reset mock to clearly verify second call doesn't hit the API + bedrock_client.count_tokens.reset_mock() + + # Second call: skips API entirely due to caching + result = await model.count_tokens(messages=messages) + bedrock_client.count_tokens.assert_not_called() + assert isinstance(result, int) + assert result >= 0 + + @pytest.mark.asyncio + async def test_access_denied_logs_warning_with_full_error( + self, model_with_client, bedrock_client, messages, caplog + ): + error_message = ( + "User: arn:aws:sts::123456789012:assumed-role/role is not authorized" + " to perform: bedrock:CountTokens" + ) + bedrock_client.count_tokens.side_effect = ClientError( + {"Error": {"Code": "AccessDeniedException", "Message": error_message}}, + "CountTokens", + ) + + with caplog.at_level(logging.WARNING, logger="strands.models.bedrock"): + await model_with_client.count_tokens(messages=messages) + + warning_records = [r for r in caplog.records if r.levelno == logging.WARNING] + assert len(warning_records) == 1 + assert "bedrock:CountTokens permission denied" in warning_records[0].message + assert error_message in warning_records[0].message + + @pytest.mark.asyncio + async def test_does_not_cache_model_id_for_other_errors(self, bedrock_client, messages): + model = BedrockModel(model_id="transient-error-test-model", use_native_token_count=True) + bedrock_client.count_tokens.side_effect = RuntimeError("Transient network error") + + await model.count_tokens(messages=messages) + assert bedrock_client.count_tokens.call_count == 1 + + # Second call should still attempt the API + await model.count_tokens(messages=messages) + assert bedrock_client.count_tokens.call_count == 2 + + @pytest.mark.asyncio + async def test_skip_native_api_when_use_native_token_count_false(self, bedrock_client, model_id, messages): + _ = bedrock_client + model = BedrockModel(model_id=model_id, use_native_token_count=False) + + result = await model.count_tokens(messages=messages) + + bedrock_client.count_tokens.assert_not_called() + assert isinstance(result, int) + assert result >= 0 + + @pytest.mark.asyncio + async def test_skip_native_api_by_default(self, bedrock_client, model_id, messages): + _ = bedrock_client + model = BedrockModel(model_id=model_id) + + result = await model.count_tokens(messages=messages) + + bedrock_client.count_tokens.assert_not_called() + assert isinstance(result, int) + assert result >= 0 + + +def test_inject_cache_point_with_ttl(bedrock_client): + """Test that _inject_cache_point includes TTL when cache_config has ttl set.""" + model = BedrockModel( + model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", + cache_config=CacheConfig(strategy="auto", ttl="5m"), + ) + + cleaned_messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + ] + + model._inject_cache_point(cleaned_messages) + + cache_point = cleaned_messages[0]["content"][-1]["cachePoint"] + assert cache_point["type"] == "default" + assert cache_point["ttl"] == "5m" + + +def test_inject_cache_point_without_ttl(bedrock_client): + """Test that _inject_cache_point omits TTL when cache_config has no ttl.""" + model = BedrockModel( + model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", + cache_config=CacheConfig(strategy="auto"), + ) + + cleaned_messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + ] + + model._inject_cache_point(cleaned_messages) + + cache_point = cleaned_messages[0]["content"][-1]["cachePoint"] + assert cache_point["type"] == "default" + assert "ttl" not in cache_point + + +def test_format_request_cache_tools_config_with_ttl(model, messages, model_id, tool_spec, cache_type): + """Test that CacheToolsConfig propagates type and ttl into toolConfig cachePoint.""" + model.update_config(cache_tools=CacheToolsConfig(type=cache_type, ttl="5m")) + + tru_request = model._format_request(messages, tool_specs=[tool_spec]) + + exp_cache_point = {"cachePoint": {"type": cache_type, "ttl": "5m"}} + assert tru_request["toolConfig"]["tools"][-1] == exp_cache_point + + +def test_format_request_cache_tools_config_without_ttl(model, messages, model_id, tool_spec, cache_type): + """Test that CacheToolsConfig without ttl produces a cachePoint with only type.""" + model.update_config(cache_tools=CacheToolsConfig(type=cache_type)) + + tru_request = model._format_request(messages, tool_specs=[tool_spec]) + + exp_cache_point = {"cachePoint": {"type": cache_type}} + assert tru_request["toolConfig"]["tools"][-1] == exp_cache_point + + +def test_format_request_cache_tools_string_backward_compat(model, messages, model_id, tool_spec, cache_type): + """Test that passing cache_tools as a string still produces a cachePoint with only type.""" + model.update_config(cache_tools=cache_type) + + tru_request = model._format_request(messages, tool_specs=[tool_spec]) + + exp_cache_point = {"cachePoint": {"type": cache_type}} + assert tru_request["toolConfig"]["tools"][-1] == exp_cache_point diff --git a/tests/strands/models/test_bedrock_thinking.py b/tests/strands/models/test_bedrock_thinking.py new file mode 100644 index 000000000..10b53cb03 --- /dev/null +++ b/tests/strands/models/test_bedrock_thinking.py @@ -0,0 +1,84 @@ +"""Tests for thinking mode behavior in BedrockModel.""" + +import pytest + +from strands.models.bedrock import BedrockModel + + +@pytest.fixture +def model_with_thinking(): + """Create a BedrockModel with thinking enabled.""" + return BedrockModel( + model_id="anthropic.claude-sonnet-4-20250514-v1:0", + additional_request_fields={"thinking": {"type": "enabled", "budget_tokens": 5000}}, + ) + + +@pytest.fixture +def model_without_thinking(): + """Create a BedrockModel without thinking.""" + return BedrockModel(model_id="anthropic.claude-sonnet-4-20250514-v1:0") + + +@pytest.fixture +def model_with_thinking_and_other_fields(): + """Create a BedrockModel with thinking and other additional fields.""" + return BedrockModel( + model_id="anthropic.claude-sonnet-4-20250514-v1:0", + additional_request_fields={ + "thinking": {"type": "enabled", "budget_tokens": 5000}, + "some_other_field": "value", + }, + ) + + +def test_thinking_removed_when_forcing_tool_any(model_with_thinking): + """Thinking should be removed when tool_choice forces tool use with 'any'.""" + tool_choice = {"any": {}} + result = model_with_thinking._get_additional_request_fields(tool_choice) + assert result == {} # thinking removed, no other fields + + +def test_thinking_removed_when_forcing_specific_tool(model_with_thinking): + """Thinking should be removed when tool_choice forces a specific tool.""" + tool_choice = {"tool": {"name": "structured_output_tool"}} + result = model_with_thinking._get_additional_request_fields(tool_choice) + assert result == {} # thinking removed, no other fields + + +def test_thinking_preserved_with_auto_tool_choice(model_with_thinking): + """Thinking should be preserved when tool_choice is 'auto'.""" + tool_choice = {"auto": {}} + result = model_with_thinking._get_additional_request_fields(tool_choice) + assert result == {"additionalModelRequestFields": {"thinking": {"type": "enabled", "budget_tokens": 5000}}} + + +def test_thinking_preserved_with_none_tool_choice(model_with_thinking): + """Thinking should be preserved when tool_choice is None.""" + result = model_with_thinking._get_additional_request_fields(None) + assert result == {"additionalModelRequestFields": {"thinking": {"type": "enabled", "budget_tokens": 5000}}} + + +def test_other_fields_preserved_when_thinking_removed(model_with_thinking_and_other_fields): + """Other additional fields should be preserved when thinking is removed.""" + tool_choice = {"any": {}} + result = model_with_thinking_and_other_fields._get_additional_request_fields(tool_choice) + assert result == {"additionalModelRequestFields": {"some_other_field": "value"}} + + +def test_no_fields_when_model_has_no_additional_fields(model_without_thinking): + """Should return empty dict when model has no additional_request_fields.""" + tool_choice = {"any": {}} + result = model_without_thinking._get_additional_request_fields(tool_choice) + assert result == {} + + +def test_fields_preserved_when_no_thinking_and_forcing_tool(): + """Additional fields without thinking should be preserved when forcing tool.""" + model = BedrockModel( + model_id="anthropic.claude-sonnet-4-20250514-v1:0", + additional_request_fields={"some_field": "value"}, + ) + tool_choice = {"any": {}} + result = model._get_additional_request_fields(tool_choice) + assert result == {"additionalModelRequestFields": {"some_field": "value"}} diff --git a/tests/strands/models/test_defaults.py b/tests/strands/models/test_defaults.py new file mode 100644 index 000000000..94c602fc1 --- /dev/null +++ b/tests/strands/models/test_defaults.py @@ -0,0 +1,76 @@ +"""Tests for model metadata lookup tables.""" + +from strands.models._defaults import get_context_window_limit, resolve_config_metadata + + +class TestGetContextWindowLimit: + """Tests for get_context_window_limit.""" + + def test_known_anthropic_direct_api(self): + assert get_context_window_limit("claude-sonnet-4-6") == 1_000_000 + assert get_context_window_limit("claude-opus-4-6") == 1_000_000 + assert get_context_window_limit("claude-opus-4-5") == 200_000 + assert get_context_window_limit("claude-haiku-4-5") == 200_000 + + def test_known_bedrock_anthropic(self): + assert get_context_window_limit("anthropic.claude-sonnet-4-6") == 1_000_000 + assert get_context_window_limit("anthropic.claude-haiku-4-5-20251001-v1:0") == 200_000 + + def test_known_bedrock_nova(self): + assert get_context_window_limit("amazon.nova-pro-v1:0") == 300_000 + assert get_context_window_limit("amazon.nova-micro-v1:0") == 128_000 + + def test_known_openai(self): + assert get_context_window_limit("gpt-5.4") == 1_050_000 + assert get_context_window_limit("gpt-4o") == 128_000 + assert get_context_window_limit("o3") == 200_000 + assert get_context_window_limit("o4-mini") == 200_000 + + def test_known_gemini(self): + assert get_context_window_limit("gemini-2.5-flash") == 1_048_576 + assert get_context_window_limit("gemini-2.5-pro") == 1_048_576 + + def test_strips_bedrock_cross_region_prefix(self): + assert get_context_window_limit("us.anthropic.claude-sonnet-4-6") == 1_000_000 + assert get_context_window_limit("global.anthropic.claude-sonnet-4-6") == 1_000_000 + assert get_context_window_limit("eu.anthropic.claude-sonnet-4-6") == 1_000_000 + assert get_context_window_limit("ap.anthropic.claude-sonnet-4-6") == 1_000_000 + + def test_strips_any_prefix_as_fallback(self): + # Any prefix before the first dot is stripped if direct lookup fails + assert get_context_window_limit("custom.anthropic.claude-sonnet-4-6") == 1_000_000 + + def test_unknown_model_returns_none(self): + assert get_context_window_limit("unknown-model-xyz") is None + assert get_context_window_limit("foo.unknown-model-xyz") is None + + +class TestResolveConfigMetadata: + """Tests for resolve_config_metadata.""" + + def test_resolves_context_window_limit(self): + config: dict = {"model_id": "claude-sonnet-4-6"} + result = resolve_config_metadata(config, "claude-sonnet-4-6") + assert result["context_window_limit"] == 1_000_000 + + def test_preserves_explicit_context_window_limit(self): + config: dict = {"model_id": "claude-sonnet-4-6", "context_window_limit": 100_000} + result = resolve_config_metadata(config, "claude-sonnet-4-6") + assert result["context_window_limit"] == 100_000 + + def test_returns_original_config_when_explicit(self): + config: dict = {"model_id": "claude-sonnet-4-6", "context_window_limit": 100_000} + result = resolve_config_metadata(config, "claude-sonnet-4-6") + assert result is config + + def test_returns_original_config_when_unknown_model(self): + config: dict = {"model_id": "unknown-model"} + result = resolve_config_metadata(config, "unknown-model") + assert result is config + assert "context_window_limit" not in result + + def test_returns_new_dict_when_resolved(self): + config: dict = {"model_id": "claude-sonnet-4-6"} + result = resolve_config_metadata(config, "claude-sonnet-4-6") + assert result is not config + assert "context_window_limit" not in config diff --git a/tests/strands/models/test_gemini.py b/tests/strands/models/test_gemini.py index a8f5351cc..a8ff38b99 100644 --- a/tests/strands/models/test_gemini.py +++ b/tests/strands/models/test_gemini.py @@ -70,6 +70,39 @@ def test__init__model_configs(gemini_client, model_id): assert tru_temperature == exp_temperature +def test__init__context_window_limit(gemini_client): + _ = gemini_client + + model = GeminiModel(model_id="gemini-2.5-flash", context_window_limit=1_048_576) + + assert model.get_config().get("context_window_limit") == 1_048_576 + assert model.context_window_limit == 1_048_576 + + +def test__init__auto_populates_context_window_limit(gemini_client): + _ = gemini_client + + model = GeminiModel(model_id="gemini-2.5-flash") + + assert model.get_config().get("context_window_limit") == 1_048_576 + + +def test__init__explicit_context_window_limit_not_overridden(gemini_client): + _ = gemini_client + + model = GeminiModel(model_id="gemini-2.5-flash", context_window_limit=500_000) + + assert model.get_config().get("context_window_limit") == 500_000 + + +def test__init__unknown_model_no_context_window_limit(gemini_client): + _ = gemini_client + + model = GeminiModel(model_id="unknown-model") + + assert model.get_config().get("context_window_limit") is None + + def test_update_config(model, model_id): model.update_config(model_id=model_id) @@ -203,7 +236,7 @@ async def test_stream_request_with_reasoning(gemini_client, model, model_id): { "reasoningContent": { "reasoningText": { - "signature": "abc", + "signature": "YWJj", # base64 of "abc" "text": "reasoning_text", }, }, @@ -260,6 +293,51 @@ async def test_stream_request_with_tool_spec(gemini_client, model, model_id, too @pytest.mark.asyncio async def test_stream_request_with_tool_use(gemini_client, model, model_id): + """Test toolUse with reasoningSignature is sent as function_call with thought_signature.""" + messages = [ + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "c1", + "name": "calculator", + "input": {"expression": "2+2"}, + "reasoningSignature": "YWJj", # base64 of "abc" + }, + }, + ], + }, + ] + await anext(model.stream(messages)) + + exp_request = { + "config": { + "tools": [{"function_declarations": []}], + }, + "contents": [ + { + "parts": [ + { + "function_call": { + "args": {"expression": "2+2"}, + "id": "c1", + "name": "calculator", + }, + "thought_signature": "YWJj", + }, + ], + "role": "model", + }, + ], + "model": model_id, + } + gemini_client.aio.models.generate_content_stream.assert_called_with(**exp_request) + + +@pytest.mark.asyncio +async def test_stream_request_with_tool_use_no_reasoning_signature(gemini_client, model, model_id): + """Test toolUse without reasoningSignature is sent as function_call without thought_signature.""" messages = [ { "role": "assistant", @@ -360,6 +438,71 @@ async def test_stream_request_with_tool_results(gemini_client, model, model_id): gemini_client.aio.models.generate_content_stream.assert_called_with(**exp_request) +@pytest.mark.asyncio +async def test_stream_request_with_tool_results_preserving_name(gemini_client, model, model_id): + messages = [ + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "t1", + "name": "tool_1", + "input": {}, + }, + }, + ], + }, + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "t1", + "status": "success", + "content": [{"text": "done"}], + }, + }, + ], + }, + ] + await anext(model.stream(messages)) + + exp_request = { + "config": { + "tools": [{"function_declarations": []}], + }, + "contents": [ + { + "parts": [ + { + "function_call": { + "args": {}, + "id": "t1", + "name": "tool_1", + }, + }, + ], + "role": "model", + }, + { + "parts": [ + { + "function_response": { + "id": "t1", + "name": "tool_1", + "response": {"output": [{"text": "done"}]}, + }, + }, + ], + "role": "user", + }, + ], + "model": model_id, + } + gemini_client.aio.models.generate_content_stream.assert_called_with(**exp_request) + + @pytest.mark.asyncio async def test_stream_request_with_empty_content(gemini_client, model, model_id): messages = [ @@ -458,10 +601,57 @@ async def test_stream_response_tool_use(gemini_client, model, messages, agenerat tru_chunks = await alist(model.stream(messages)) exp_chunks = [ {"messageStart": {"role": "assistant"}}, - {"contentBlockStart": {"start": {}}}, - {"contentBlockStart": {"start": {"toolUse": {"name": "calculator", "toolUseId": "calculator"}}}}, + {"contentBlockStart": {"start": {"toolUse": {"name": "calculator", "toolUseId": "c1"}}}}, {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"expression": "2+2"}'}}}}, {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "tool_use"}}, + {"metadata": {"usage": {"inputTokens": 1, "outputTokens": 2, "totalTokens": 3}, "metrics": {"latencyMs": 0}}}, + ] + assert tru_chunks == exp_chunks + + +@pytest.mark.asyncio +async def test_stream_response_tool_use_with_thought_signature(gemini_client, model, messages, agenerator, alist): + """Test that tool use responses with thought_signature include reasoningSignature.""" + gemini_client.aio.models.generate_content_stream.return_value = agenerator( + [ + genai.types.GenerateContentResponse( + candidates=[ + genai.types.Candidate( + content=genai.types.Content( + parts=[ + genai.types.Part( + function_call=genai.types.FunctionCall( + args={"expression": "2+2"}, + id="c1", + name="calculator", + ), + thought_signature=b"abc", + ), + ], + ), + finish_reason="STOP", + ), + ], + usage_metadata=genai.types.GenerateContentResponseUsageMetadata( + prompt_token_count=1, + total_token_count=3, + ), + ), + ] + ) + + tru_chunks = await alist(model.stream(messages)) + exp_chunks = [ + {"messageStart": {"role": "assistant"}}, + { + "contentBlockStart": { + "start": { + "toolUse": {"name": "calculator", "toolUseId": "c1", "reasoningSignature": "YWJj"}, + }, + }, + }, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"expression": "2+2"}'}}}}, {"contentBlockStop": {}}, {"messageStop": {"stopReason": "tool_use"}}, {"metadata": {"usage": {"inputTokens": 1, "outputTokens": 2, "totalTokens": 3}, "metrics": {"latencyMs": 0}}}, @@ -500,7 +690,7 @@ async def test_stream_response_reasoning(gemini_client, model, messages, agenera exp_chunks = [ {"messageStart": {"role": "assistant"}}, {"contentBlockStart": {"start": {}}}, - {"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "abc", "text": "test reason"}}}}, + {"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "YWJj", "text": "test reason"}}}}, {"contentBlockStop": {}}, {"messageStop": {"stopReason": "end_turn"}}, {"metadata": {"usage": {"inputTokens": 1, "outputTokens": 2, "totalTokens": 3}, "metrics": {"latencyMs": 0}}}, @@ -508,6 +698,72 @@ async def test_stream_response_reasoning(gemini_client, model, messages, agenera assert tru_chunks == exp_chunks +@pytest.mark.asyncio +async def test_stream_response_reasoning_and_text(gemini_client, model, messages, agenerator, alist): + """Test that both reasoning and text content are captured in separate blocks.""" + gemini_client.aio.models.generate_content_stream.return_value = agenerator( + [ + genai.types.GenerateContentResponse( + candidates=[ + genai.types.Candidate( + content=genai.types.Content( + parts=[ + genai.types.Part( + text="thinking about math", + thought=True, + thought_signature=b"sig1", + ), + ], + ), + finish_reason="STOP", + ), + ], + usage_metadata=genai.types.GenerateContentResponseUsageMetadata( + prompt_token_count=1, + total_token_count=3, + ), + ), + genai.types.GenerateContentResponse( + candidates=[ + genai.types.Candidate( + content=genai.types.Content( + parts=[ + genai.types.Part( + text="2 + 2 = 4", + thought=False, + ), + ], + ), + finish_reason="STOP", + ), + ], + usage_metadata=genai.types.GenerateContentResponseUsageMetadata( + prompt_token_count=1, + total_token_count=5, + ), + ), + ] + ) + + tru_chunks = await alist(model.stream(messages)) + exp_chunks = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockStart": {"start": {}}}, + { + "contentBlockDelta": { + "delta": {"reasoningContent": {"signature": "c2lnMQ==", "text": "thinking about math"}} + } + }, + {"contentBlockStop": {}}, + {"contentBlockStart": {"start": {}}}, + {"contentBlockDelta": {"delta": {"text": "2 + 2 = 4"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn"}}, + {"metadata": {"usage": {"inputTokens": 1, "outputTokens": 4, "totalTokens": 5}, "metrics": {"latencyMs": 0}}}, + ] + assert tru_chunks == exp_chunks + + @pytest.mark.asyncio async def test_stream_response_max_tokens(gemini_client, model, messages, agenerator, alist): gemini_client.aio.models.generate_content_stream.return_value = agenerator( @@ -558,14 +814,29 @@ async def test_stream_response_none_candidates(gemini_client, model, messages, a tru_chunks = await alist(model.stream(messages)) exp_chunks = [ {"messageStart": {"role": "assistant"}}, - {"contentBlockStart": {"start": {}}}, - {"contentBlockStop": {}}, {"messageStop": {"stopReason": "end_turn"}}, {"metadata": {"usage": {"inputTokens": 1, "outputTokens": 2, "totalTokens": 3}, "metrics": {"latencyMs": 0}}}, ] assert tru_chunks == exp_chunks +@pytest.mark.asyncio +async def test_stream_response_empty_stream(gemini_client, model, messages, agenerator, alist): + """Test that empty stream doesn't raise UnboundLocalError. + + When the stream yields no events, the candidate variable must be initialized + to None to avoid UnboundLocalError when referenced in message_stop chunk. + """ + gemini_client.aio.models.generate_content_stream.return_value = agenerator([]) + + tru_chunks = await alist(model.stream(messages)) + exp_chunks = [ + {"messageStart": {"role": "assistant"}}, + {"messageStop": {"stopReason": "end_turn"}}, + ] + assert tru_chunks == exp_chunks + + @pytest.mark.asyncio async def test_stream_response_throttled_exception(gemini_client, model, messages): gemini_client.aio.models.generate_content_stream.side_effect = genai.errors.ClientError( @@ -624,6 +895,89 @@ async def test_structured_output(gemini_client, model, messages, model_id, weath gemini_client.aio.models.generate_content.assert_called_with(**exp_request) +def test_gemini_tools_validation_rejects_function_declarations(model_id): + tool_with_function_declarations = genai.types.Tool( + function_declarations=[ + genai.types.FunctionDeclaration( + name="test_function", + description="A test function", + ) + ] + ) + + with pytest.raises(ValueError, match="gemini_tools should not contain FunctionDeclarations"): + GeminiModel(model_id=model_id, gemini_tools=[tool_with_function_declarations]) + + +def test_gemini_tools_validation_allows_non_function_tools(model_id): + tool_with_google_search = genai.types.Tool(google_search=genai.types.GoogleSearch()) + + model = GeminiModel(model_id=model_id, gemini_tools=[tool_with_google_search]) + assert "gemini_tools" in model.config + + +def test_gemini_tools_validation_on_update_config(model): + tool_with_function_declarations = genai.types.Tool( + function_declarations=[ + genai.types.FunctionDeclaration( + name="test_function", + description="A test function", + ) + ] + ) + + with pytest.raises(ValueError, match="gemini_tools should not contain FunctionDeclarations"): + model.update_config(gemini_tools=[tool_with_function_declarations]) + + +@pytest.mark.asyncio +async def test_stream_request_with_gemini_tools(gemini_client, messages, model_id): + google_search_tool = genai.types.Tool(google_search=genai.types.GoogleSearch()) + model = GeminiModel(model_id=model_id, gemini_tools=[google_search_tool]) + + await anext(model.stream(messages)) + + exp_request = { + "config": { + "tools": [ + {"function_declarations": []}, + {"google_search": {}}, + ] + }, + "contents": [{"parts": [{"text": "test"}], "role": "user"}], + "model": model_id, + } + gemini_client.aio.models.generate_content_stream.assert_called_with(**exp_request) + + +@pytest.mark.asyncio +async def test_stream_request_with_gemini_tools_and_function_tools(gemini_client, messages, tool_spec, model_id): + code_execution_tool = genai.types.Tool(code_execution=genai.types.ToolCodeExecution()) + model = GeminiModel(model_id=model_id, gemini_tools=[code_execution_tool]) + + await anext(model.stream(messages, tool_specs=[tool_spec])) + + exp_request = { + "config": { + "tools": [ + { + "function_declarations": [ + { + "description": tool_spec["description"], + "name": tool_spec["name"], + "parameters_json_schema": tool_spec["inputSchema"]["json"], + } + ] + }, + {"code_execution": {}}, + ] + }, + "contents": [{"parts": [{"text": "test"}], "role": "user"}], + "model": model_id, + } + gemini_client.aio.models.generate_content_stream.assert_called_with(**exp_request) + + @pytest.mark.asyncio async def test_stream_handles_non_json_error(gemini_client, model, messages, caplog, alist): error_message = "Invalid API key" @@ -637,3 +991,262 @@ async def test_stream_handles_non_json_error(gemini_client, model, messages, cap assert "Gemini API returned non-JSON error" in caplog.text assert f"error_message=<{error_message}>" in caplog.text + + +@pytest.mark.asyncio +async def test_stream_with_injected_client(model_id, agenerator, alist): + """Test that stream works with an injected client and doesn't close it.""" + # Create a mock injected client + mock_injected_client = unittest.mock.Mock() + mock_injected_client.aio = unittest.mock.AsyncMock() + + mock_injected_client.aio.models.generate_content_stream.return_value = agenerator( + [ + genai.types.GenerateContentResponse( + candidates=[ + genai.types.Candidate( + content=genai.types.Content( + parts=[genai.types.Part(text="Hello")], + ), + finish_reason="STOP", + ), + ], + usage_metadata=genai.types.GenerateContentResponseUsageMetadata( + prompt_token_count=1, + total_token_count=3, + ), + ), + ] + ) + + # Create model with injected client + model = GeminiModel(client=mock_injected_client, model_id=model_id) + + messages = [{"role": "user", "content": [{"text": "test"}]}] + response = model.stream(messages) + tru_events = await alist(response) + + # Verify events were generated + assert len(tru_events) > 0 + + # Verify the injected client was used + mock_injected_client.aio.models.generate_content_stream.assert_called_once() + + +@pytest.mark.asyncio +async def test_structured_output_with_injected_client(model_id, weather_output, alist): + """Test that structured_output works with an injected client and doesn't close it.""" + # Create a mock injected client + mock_injected_client = unittest.mock.Mock() + mock_injected_client.aio = unittest.mock.AsyncMock() + + mock_injected_client.aio.models.generate_content.return_value = unittest.mock.Mock( + parsed=weather_output.model_dump() + ) + + # Create model with injected client + model = GeminiModel(client=mock_injected_client, model_id=model_id) + + messages = [{"role": "user", "content": [{"text": "Generate weather"}]}] + stream = model.structured_output(type(weather_output), messages) + events = await alist(stream) + + # Verify output was generated + assert len(events) == 1 + assert events[0] == {"output": weather_output} + + # Verify the injected client was used + mock_injected_client.aio.models.generate_content.assert_called_once() + + +def test_init_with_both_client_and_client_args_raises_error(): + """Test that providing both client and client_args raises ValueError.""" + mock_client = unittest.mock.Mock() + + with pytest.raises(ValueError, match="Only one of 'client' or 'client_args' should be provided"): + GeminiModel(client=mock_client, client_args={"api_key": "test"}, model_id="test-model") + + +def test_format_request_filters_s3_source_image(model, caplog): + """Test that images with Location sources are filtered out with warning.""" + caplog.set_level(logging.WARNING, logger="strands.models.gemini") + + messages = [ + { + "role": "user", + "content": [ + {"text": "look at this image"}, + { + "image": { + "format": "png", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/image.png"}}, + }, + }, + ], + }, + ] + + request = model._format_request(messages, None, None, None) + + # Image with S3 source should be filtered, text should remain + formatted_content = request["contents"][0]["parts"] + assert len(formatted_content) == 1 + assert "text" in formatted_content[0] + assert "Location sources are not supported by Gemini" in caplog.text + + +def test_format_request_filters_location_source_document(model, caplog): + """Test that documents with Location sources are filtered out with warning.""" + caplog.set_level(logging.WARNING, logger="strands.models.gemini") + + messages = [ + { + "role": "user", + "content": [ + {"text": "analyze this document"}, + { + "document": { + "format": "pdf", + "name": "report.pdf", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/report.pdf"}}, + }, + }, + { + "document": { + "format": "pdf", + "name": "report.pdf", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/report.pdf"}}, + }, + }, + ], + }, + ] + + request = model._format_request(messages, None, None, None) + + # Document with S3 source should be filtered, text should remain + formatted_content = request["contents"][0]["parts"] + assert len(formatted_content) == 1 + assert "text" in formatted_content[0] + assert "Location sources are not supported by Gemini" in caplog.text + + +class TestCountTokens: + """Tests for GeminiModel.count_tokens native token counting.""" + + @pytest.fixture + def gemini_client(self): + with unittest.mock.patch.object(strands.models.gemini.genai, "Client") as mock_client_cls: + mock_client = mock_client_cls.return_value + mock_client.aio = unittest.mock.AsyncMock() + yield mock_client + + @pytest.fixture + def model(self, gemini_client): + _ = gemini_client + return GeminiModel(model_id="m1", use_native_token_count=True) + + @pytest.fixture + def messages(self): + return [{"role": "user", "content": [{"text": "hello"}]}] + + @pytest.fixture + def tool_specs(self): + return [ + { + "name": "test_tool", + "description": "A test tool", + "inputSchema": {"json": {"type": "object", "properties": {}}}, + } + ] + + @pytest.mark.asyncio + async def test_native_count_tokens_success(self, model, gemini_client, messages): + mock_response = unittest.mock.AsyncMock() + mock_response.total_tokens = 42 + gemini_client.aio.models.count_tokens.return_value = mock_response + + result = await model.count_tokens(messages=messages) + + assert result == 42 + gemini_client.aio.models.count_tokens.assert_called_once() + + @pytest.mark.asyncio + async def test_native_count_tokens_with_system_prompt(self, model, gemini_client, messages): + mock_response = unittest.mock.AsyncMock() + mock_response.total_tokens = 55 + gemini_client.aio.models.count_tokens.return_value = mock_response + + result = await model.count_tokens(messages=messages, system_prompt="Be helpful.") + + assert result > 55 # native (55) + heuristic estimate for system_prompt + + @pytest.mark.asyncio + async def test_native_count_tokens_with_tool_specs(self, model, gemini_client, messages, tool_specs): + mock_response = unittest.mock.AsyncMock() + mock_response.total_tokens = 100 + gemini_client.aio.models.count_tokens.return_value = mock_response + + result = await model.count_tokens(messages=messages, tool_specs=tool_specs) + + assert result > 100 # native (100) + heuristic estimate for tool_specs + + @pytest.mark.asyncio + async def test_fallback_on_none_total_tokens(self, model, gemini_client, messages): + mock_response = unittest.mock.AsyncMock() + mock_response.total_tokens = None + gemini_client.aio.models.count_tokens.return_value = mock_response + + result = await model.count_tokens(messages=messages) + + assert isinstance(result, int) + assert result >= 0 + + @pytest.mark.asyncio + async def test_fallback_on_api_error(self, model, gemini_client, messages): + gemini_client.aio.models.count_tokens.side_effect = genai.errors.ClientError("Unsupported", response_json={}) + + result = await model.count_tokens(messages=messages) + + assert isinstance(result, int) + assert result >= 0 + + @pytest.mark.asyncio + async def test_fallback_on_generic_exception(self, model, gemini_client, messages): + gemini_client.aio.models.count_tokens.side_effect = RuntimeError("Connection failed") + + result = await model.count_tokens(messages=messages) + + assert isinstance(result, int) + assert result >= 0 + + @pytest.mark.asyncio + async def test_fallback_logs_debug(self, model, gemini_client, messages, caplog): + gemini_client.aio.models.count_tokens.side_effect = RuntimeError("API down") + + with caplog.at_level(logging.DEBUG, logger="strands.models.gemini"): + await model.count_tokens(messages=messages) + + assert any("native token counting failed" in record.message for record in caplog.records) + + @pytest.mark.asyncio + async def test_skip_native_api_when_use_native_token_count_false(self, gemini_client, messages): + _ = gemini_client + model = GeminiModel(model_id="m1", use_native_token_count=False) + + result = await model.count_tokens(messages=messages) + + gemini_client.aio.models.count_tokens.assert_not_called() + assert isinstance(result, int) + assert result >= 0 + + @pytest.mark.asyncio + async def test_skip_native_api_by_default(self, gemini_client, messages): + _ = gemini_client + model = GeminiModel(model_id="m1") + + result = await model.count_tokens(messages=messages) + + gemini_client.aio.models.count_tokens.assert_not_called() + assert isinstance(result, int) + assert result >= 0 diff --git a/tests/strands/models/test_litellm.py b/tests/strands/models/test_litellm.py index 832b5c836..96cf561cd 100644 --- a/tests/strands/models/test_litellm.py +++ b/tests/strands/models/test_litellm.py @@ -285,7 +285,7 @@ async def test_stream_empty(litellm_acompletion, api_key, model_id, model, agene mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta)]) mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="stop", delta=mock_delta)]) - mock_event_3 = unittest.mock.Mock() + mock_event_3 = unittest.mock.Mock(usage=None) mock_event_4 = unittest.mock.Mock(usage=None) litellm_acompletion.side_effect = unittest.mock.AsyncMock( @@ -408,16 +408,6 @@ async def test_context_window_maps_to_typed_exception(litellm_acompletion, model pass -@pytest.mark.asyncio -async def test_stream_raises_error_when_stream_is_false(model): - """Test that stream raises ValueError when stream parameter is explicitly False.""" - messages = [{"role": "user", "content": [{"text": "test"}]}] - - with pytest.raises(ValueError, match="stream parameter cannot be explicitly set to False"): - async for _ in model.stream(messages, stream=False): - pass - - def test_format_request_messages_with_system_prompt_content(): """Test format_request_messages with system_prompt_content parameter.""" messages = [{"role": "user", "content": [{"text": "Hello"}]}] @@ -478,3 +468,554 @@ def test_format_request_messages_cache_point_support(): ] assert result == expected + + +@pytest.mark.asyncio +async def test_stream_non_streaming(litellm_acompletion, api_key, model_id, alist): + """Test LiteLLM model with streaming disabled (stream=False). + + This test verifies that the LiteLLM model works correctly when streaming is disabled, + which was the issue reported in GitHub issue #477. + """ + + mock_function = unittest.mock.Mock() + mock_function.name = "calculator" + mock_function.arguments = '{"expression": "123981723 + 234982734"}' + + mock_tool_call = unittest.mock.Mock(index=0, function=mock_function, id="tool_call_id_123") + + mock_message = unittest.mock.Mock() + mock_message.content = "I'll calculate that for you" + mock_message.reasoning_content = "Let me think about this calculation" + mock_message.tool_calls = [mock_tool_call] + + mock_choice = unittest.mock.Mock() + mock_choice.message = mock_message + mock_choice.finish_reason = "tool_calls" + + mock_response = unittest.mock.Mock() + mock_response.choices = [mock_choice] + + # Create a more explicit usage mock that doesn't have cache-related attributes + mock_usage = unittest.mock.Mock() + mock_usage.prompt_tokens = 10 + mock_usage.completion_tokens = 20 + mock_usage.total_tokens = 30 + mock_usage.prompt_tokens_details = None + mock_usage.cache_creation_input_tokens = None + mock_response.usage = mock_usage + + litellm_acompletion.side_effect = unittest.mock.AsyncMock(return_value=mock_response) + + model = LiteLLMModel( + client_args={"api_key": api_key}, + model_id=model_id, + params={"stream": False}, # This is the key setting that was causing the #477 isuue + ) + + messages = [{"role": "user", "content": [{"type": "text", "text": "What is 123981723 + 234982734?"}]}] + response = model.stream(messages) + + tru_events = await alist(response) + + exp_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockStart": {"start": {}}}, + {"contentBlockDelta": {"delta": {"reasoningContent": {"text": "Let me think about this calculation"}}}}, + {"contentBlockStop": {}}, + {"contentBlockStart": {"start": {}}}, + {"contentBlockDelta": {"delta": {"text": "I'll calculate that for you"}}}, + {"contentBlockStop": {}}, + { + "contentBlockStart": { + "start": {"toolUse": {"name": "calculator", "toolUseId": mock_message.tool_calls[0].id}} + } + }, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"expression": "123981723 + 234982734"}'}}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "tool_use"}}, + { + "metadata": { + "usage": { + "inputTokens": 10, + "outputTokens": 20, + "totalTokens": 30, + }, + "metrics": {"latencyMs": 0}, + } + }, + ] + + assert len(tru_events) == len(exp_events) + + for i, (tru, exp) in enumerate(zip(tru_events, exp_events, strict=False)): + assert tru == exp, f"Event {i} mismatch: {tru} != {exp}" + + expected_request = { + "api_key": api_key, + "model": model_id, + "messages": [{"role": "user", "content": [{"text": "What is 123981723 + 234982734?", "type": "text"}]}], + "stream": False, # Verify that stream=False was passed to litellm + "stream_options": {"include_usage": True}, + "tools": [], + } + litellm_acompletion.assert_called_once_with(**expected_request) + + +@pytest.mark.asyncio +async def test_stream_path_validation(litellm_acompletion, api_key, model_id, model, agenerator, alist): + """Test that we're taking the correct streaming path and validate stream parameter.""" + mock_delta = unittest.mock.Mock(content=None, tool_calls=None, reasoning_content=None) + mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="stop", delta=mock_delta)]) + mock_event_2 = unittest.mock.Mock(usage=None) + + litellm_acompletion.side_effect = unittest.mock.AsyncMock(return_value=agenerator([mock_event_1, mock_event_2])) + + messages = [{"role": "user", "content": []}] + response = model.stream(messages) + + # Consume the response + await alist(response) + + # Validate that litellm.acompletion was called with the expected parameters + call_args = litellm_acompletion.call_args + assert call_args is not None, "litellm.acompletion should have been called" + + # Check if stream parameter is being set + called_kwargs = call_args.kwargs + + # Validate we're going down the streaming path (should have stream=True) + assert called_kwargs.get("stream") is True, f"Expected stream=True, got {called_kwargs.get('stream')}" + + +def test_format_request_message_content_reasoning(): + """Test formatting reasoning content.""" + content = {"reasoningContent": {"reasoningText": {"signature": "test_sig", "text": "test_thinking"}}} + + result = LiteLLMModel.format_request_message_content(content) + expected = {"signature": "test_sig", "thinking": "test_thinking", "type": "thinking"} + + assert result == expected + + +def test_format_request_message_content_video(): + """Test formatting video content.""" + content = {"video": {"source": {"bytes": "base64videodata"}}} + + result = LiteLLMModel.format_request_message_content(content) + expected = {"type": "video_url", "video_url": {"detail": "auto", "url": "base64videodata"}} + + assert result == expected + + +def test_apply_proxy_prefix_with_use_litellm_proxy(): + """Test _apply_proxy_prefix when use_litellm_proxy is True.""" + model = LiteLLMModel(client_args={"use_litellm_proxy": True}, model_id="openai/gpt-4") + + assert model.get_config()["model_id"] == "litellm_proxy/openai/gpt-4" + + +def test_apply_proxy_prefix_already_has_prefix(): + """Test _apply_proxy_prefix when model_id already has prefix.""" + model = LiteLLMModel(client_args={"use_litellm_proxy": True}, model_id="litellm_proxy/openai/gpt-4") + + # Should not add another prefix + assert model.get_config()["model_id"] == "litellm_proxy/openai/gpt-4" + + +def test_apply_proxy_prefix_disabled(): + """Test _apply_proxy_prefix when use_litellm_proxy is False.""" + model = LiteLLMModel(client_args={"use_litellm_proxy": False}, model_id="openai/gpt-4") + + assert model.get_config()["model_id"] == "openai/gpt-4" + + +def test_format_chunk_metadata_with_cache_tokens(): + """Test format_chunk for metadata with cache tokens.""" + model = LiteLLMModel(model_id="test") + + # Mock usage data with cache tokens + mock_usage = unittest.mock.Mock() + mock_usage.prompt_tokens = 100 + mock_usage.completion_tokens = 50 + mock_usage.total_tokens = 150 + + # Mock cache-related attributes + mock_tokens_details = unittest.mock.Mock() + mock_tokens_details.cached_tokens = 25 + mock_usage.prompt_tokens_details = mock_tokens_details + mock_usage.cache_creation_input_tokens = 10 + + event = {"chunk_type": "metadata", "data": mock_usage} + + result = model.format_chunk(event) + + assert result["metadata"]["usage"]["inputTokens"] == 100 + assert result["metadata"]["usage"]["outputTokens"] == 50 + assert result["metadata"]["usage"]["totalTokens"] == 150 + assert result["metadata"]["usage"]["cacheReadInputTokens"] == 25 + assert result["metadata"]["usage"]["cacheWriteInputTokens"] == 10 + + +def test_format_chunk_metadata_without_cache_tokens(): + """Test format_chunk for metadata without cache tokens.""" + model = LiteLLMModel(model_id="test") + + # Mock usage data without cache tokens + mock_usage = unittest.mock.Mock() + mock_usage.prompt_tokens = 100 + mock_usage.completion_tokens = 50 + mock_usage.total_tokens = 150 + mock_usage.prompt_tokens_details = None + mock_usage.cache_creation_input_tokens = None + + event = {"chunk_type": "metadata", "data": mock_usage} + + result = model.format_chunk(event) + + assert result["metadata"]["usage"]["inputTokens"] == 100 + assert result["metadata"]["usage"]["outputTokens"] == 50 + assert result["metadata"]["usage"]["totalTokens"] == 150 + assert "cacheReadInputTokens" not in result["metadata"]["usage"] + assert "cacheWriteInputTokens" not in result["metadata"]["usage"] + + +def test_stream_switch_content_same_type(): + """Test _stream_switch_content when data_type is the same as prev_data_type.""" + model = LiteLLMModel(model_id="test") + + chunks, data_type = model._stream_switch_content("text", "text") + + assert chunks == [] + assert data_type == "text" + + +def test_stream_switch_content_different_type_with_prev(): + """Test _stream_switch_content when switching from one type to another.""" + model = LiteLLMModel(model_id="test") + + chunks, data_type = model._stream_switch_content("text", "reasoning_content") + + assert len(chunks) == 2 + assert chunks[0]["contentBlockStop"] == {} + assert chunks[1]["contentBlockStart"] == {"start": {}} + assert data_type == "text" + + +def test_stream_switch_content_different_type_no_prev(): + """Test _stream_switch_content when switching to a type with no previous type.""" + model = LiteLLMModel(model_id="test") + + chunks, data_type = model._stream_switch_content("text", None) + + assert len(chunks) == 1 + assert chunks[0]["contentBlockStart"] == {"start": {}} + assert data_type == "text" + + +@pytest.mark.asyncio +async def test_stream_with_events_missing_usage_attribute( + litellm_acompletion, api_key, model_id, model, agenerator, alist +): + """Test streaming handles events that don't have a usage attribute. + + This test verifies the fix for a bug where ModelResponseStream objects + (which don't have a 'usage' attribute) would cause an AttributeError + when the code tried to access event.usage directly instead of using getattr. + + The bug occurred because: + 1. ModelResponse (non-streaming) has a 'usage' attribute + 2. ModelResponseStream (streaming chunks) does NOT have a 'usage' attribute + 3. The code assumed all events would have the 'usage' attribute + + Regression test for: 'ModelResponseStream' object has no attribute 'usage' + """ + + # Use spec to ensure mock objects only have specified attributes + # This mimics the real ModelResponseStream which doesn't have 'usage' + class MockStreamChunk: + """Mock that mimics ModelResponseStream - no usage attribute.""" + + def __init__(self, choices=None): + self.choices = choices or [] + + mock_delta = unittest.mock.Mock(content="Hello", tool_calls=None, reasoning_content=None) + mock_event_1 = MockStreamChunk(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta)]) + mock_event_2 = MockStreamChunk(choices=[unittest.mock.Mock(finish_reason="stop", delta=mock_delta)]) + # After finish_reason is received, remaining events in the stream also don't have 'usage' + mock_event_3 = MockStreamChunk(choices=[]) + mock_event_4 = MockStreamChunk(choices=[]) + + litellm_acompletion.side_effect = unittest.mock.AsyncMock( + return_value=agenerator([mock_event_1, mock_event_2, mock_event_3, mock_event_4]) + ) + + messages = [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}] + response = model.stream(messages) + + # This should NOT raise AttributeError: 'MockStreamChunk' object has no attribute 'usage' + tru_events = await alist(response) + + # Verify we got the expected events (no metadata since no usage was available) + assert tru_events[0] == {"messageStart": {"role": "assistant"}} + assert {"messageStop": {"stopReason": "end_turn"}} in tru_events + # No metadata event since mock events don't have usage + assert not any("metadata" in event for event in tru_events) + + +@pytest.mark.asyncio +async def test_stream_with_usage_in_final_event(litellm_acompletion, api_key, model_id, model, agenerator, alist): + """Test streaming correctly extracts usage when it IS present in final events. + + This test ensures that when usage data IS available (e.g., with stream_options.include_usage=True), + it is correctly extracted and included in the metadata event. + """ + + class MockStreamChunkWithoutUsage: + """Mock streaming chunk without usage.""" + + def __init__(self, choices=None): + self.choices = choices or [] + + class MockStreamChunkWithUsage: + """Mock streaming chunk with usage (final event).""" + + def __init__(self, usage): + self.choices = [] + self.usage = usage + + mock_delta = unittest.mock.Mock(content="Hi", tool_calls=None, reasoning_content=None) + mock_event_1 = MockStreamChunkWithoutUsage(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta)]) + mock_event_2 = MockStreamChunkWithoutUsage(choices=[unittest.mock.Mock(finish_reason="stop", delta=mock_delta)]) + + # Final event with usage data + mock_usage = unittest.mock.Mock() + mock_usage.prompt_tokens = 10 + mock_usage.completion_tokens = 5 + mock_usage.total_tokens = 15 + mock_usage.prompt_tokens_details = None + mock_usage.cache_creation_input_tokens = None + mock_event_3 = MockStreamChunkWithUsage(usage=mock_usage) + + litellm_acompletion.side_effect = unittest.mock.AsyncMock( + return_value=agenerator([mock_event_1, mock_event_2, mock_event_3]) + ) + + messages = [{"role": "user", "content": [{"type": "text", "text": "Hi"}]}] + response = model.stream(messages) + + tru_events = await alist(response) + + # Verify metadata event is present with correct usage + metadata_events = [e for e in tru_events if "metadata" in e] + assert len(metadata_events) == 1 + assert metadata_events[0]["metadata"]["usage"]["inputTokens"] == 10 + assert metadata_events[0]["metadata"]["usage"]["outputTokens"] == 5 + assert metadata_events[0]["metadata"]["usage"]["totalTokens"] == 15 + + +def test_format_request_messages_with_tool_calls_no_content(): + """Test that assistant messages with only tool calls are included and have no content field.""" + messages = [ + {"role": "user", "content": [{"text": "Use the calculator"}]}, + { + "role": "assistant", + "content": [ + { + "toolUse": { + "input": {"expression": "2+2"}, + "name": "calculator", + "toolUseId": "c1", + }, + }, + ], + }, + ] + + tru_result = LiteLLMModel.format_request_messages(messages) + + exp_result = [ + {"role": "user", "content": [{"text": "Use the calculator", "type": "text"}]}, + { + "role": "assistant", + "tool_calls": [ + { + "function": {"arguments": '{"expression": "2+2"}', "name": "calculator"}, + "id": "c1", + "type": "function", + } + ], + }, + ] + assert tru_result == exp_result + + +# --- Thought Signature Tests --- + + +def test_format_chunk_tool_start_extracts_thought_signature_from_id(): + """Test that format_chunk extracts thought_signature from LiteLLM-encoded tool call ID.""" + model = LiteLLMModel(model_id="test") + + mock_data = unittest.mock.Mock() + mock_data.id = "call_abc123__thought__dGhpcy1pcy1hLXNpZw==" + mock_data.function = unittest.mock.Mock() + mock_data.function.name = "get_weather" + mock_data.provider_specific_fields = None + + event = {"chunk_type": "content_start", "data_type": "tool", "data": mock_data} + result = model.format_chunk(event) + + tool_use = result["contentBlockStart"]["start"]["toolUse"] + assert tool_use["reasoningSignature"] == "dGhpcy1pcy1hLXNpZw==" + # toolUseId keeps the full encoded string so tool result IDs match + assert tool_use["toolUseId"] == "call_abc123__thought__dGhpcy1pcy1hLXNpZw==" + + +def test_format_chunk_tool_start_extracts_thought_signature_from_provider_specific_fields(): + """Test that format_chunk extracts thought_signature from provider_specific_fields.""" + model = LiteLLMModel(model_id="test") + + mock_data = unittest.mock.Mock() + mock_data.id = "call_abc123" # No __thought__ in ID + mock_data.function = unittest.mock.Mock() + mock_data.function.name = "get_weather" + mock_data.function.provider_specific_fields = None + mock_data.provider_specific_fields = {"thought_signature": "cHNmLXNpZw=="} + + event = {"chunk_type": "content_start", "data_type": "tool", "data": mock_data} + result = model.format_chunk(event) + + tool_use = result["contentBlockStart"]["start"]["toolUse"] + assert tool_use["reasoningSignature"] == "cHNmLXNpZw==" + assert tool_use["toolUseId"] == "call_abc123" + + +def test_format_chunk_tool_start_no_thought_signature(): + """Test that format_chunk works normally when no thought_signature is present.""" + model = LiteLLMModel(model_id="test") + + mock_data = unittest.mock.Mock() + mock_data.id = "call_plain123" + mock_data.function = unittest.mock.Mock() + mock_data.function.name = "get_weather" + mock_data.provider_specific_fields = None + mock_data.function.provider_specific_fields = None + + event = {"chunk_type": "content_start", "data_type": "tool", "data": mock_data} + result = model.format_chunk(event) + + tool_use = result["contentBlockStart"]["start"]["toolUse"] + assert tool_use["toolUseId"] == "call_plain123" + assert "reasoningSignature" not in tool_use + + +def test_format_request_message_tool_call_encodes_thought_signature(): + """Test that format_request_message_tool_call encodes reasoningSignature into the tool call ID.""" + tool_use = { + "toolUseId": "call_abc123", + "name": "get_weather", + "input": {"city": "Seattle"}, + "reasoningSignature": "dGhpcy1pcy1hLXNpZw==", + } + + result = LiteLLMModel.format_request_message_tool_call(tool_use) + + assert result["id"] == "call_abc123__thought__dGhpcy1pcy1hLXNpZw==" + assert result["function"]["name"] == "get_weather" + assert result["function"]["arguments"] == '{"city": "Seattle"}' + + +def test_format_request_message_tool_call_skips_encoding_when_already_present(): + """Test that format_request_message_tool_call does not double-encode the signature.""" + tool_use = { + "toolUseId": "call_abc123__thought__dGhpcy1pcy1hLXNpZw==", + "name": "get_weather", + "input": {"city": "Seattle"}, + "reasoningSignature": "dGhpcy1pcy1hLXNpZw==", + } + + result = LiteLLMModel.format_request_message_tool_call(tool_use) + + # Should NOT double-encode + assert result["id"] == "call_abc123__thought__dGhpcy1pcy1hLXNpZw==" + + +def test_format_request_message_tool_call_no_reasoning_signature(): + """Test that format_request_message_tool_call works normally without reasoningSignature.""" + tool_use = { + "toolUseId": "call_plain123", + "name": "get_weather", + "input": {"city": "Seattle"}, + } + + result = LiteLLMModel.format_request_message_tool_call(tool_use) + + assert result["id"] == "call_plain123" + assert "__thought__" not in result["id"] + + +def test_format_system_messages_preserves_cache_point_ttl(): + """CachePoint with ttl="1h" should produce cache_control with ttl field.""" + result = LiteLLMModel._format_system_messages( + system_prompt_content=[ + {"text": "You are a helpful assistant."}, + {"cachePoint": {"type": "default", "ttl": "1h"}}, + ] + ) + assert result[0]["content"][0]["cache_control"] == {"type": "ephemeral", "ttl": "1h"} + + +def test_format_system_messages_cache_point_without_ttl(): + """CachePoint without ttl should produce cache_control with no ttl key (backward compat).""" + result = LiteLLMModel._format_system_messages( + system_prompt_content=[ + {"text": "You are a helpful assistant."}, + {"cachePoint": {"type": "default"}}, + ] + ) + assert result[0]["content"][0]["cache_control"] == {"type": "ephemeral"} + assert "ttl" not in result[0]["content"][0]["cache_control"] + + +def test_format_system_messages_cache_point_with_no_preceding_content(): + """CachePoint with no preceding text block should be silently ignored.""" + result = LiteLLMModel._format_system_messages( + system_prompt_content=[ + {"cachePoint": {"type": "default", "ttl": "1h"}}, + ] + ) + assert result == [] + + +def test_thought_signature_round_trip(): + """Test that thought signature is preserved through a full response -> internal -> request cycle.""" + model = LiteLLMModel(model_id="test") + signature = "R2VtaW5pVGhvdWdodFNpZw==" + tool_call_id = f"call_xyz789__thought__{signature}" + + # 1. Response path: format_chunk extracts the signature + mock_data = unittest.mock.Mock() + mock_data.id = tool_call_id + mock_data.function = unittest.mock.Mock() + mock_data.function.name = "current_time" + mock_data.provider_specific_fields = None + mock_data.function.provider_specific_fields = None + + event = {"chunk_type": "content_start", "data_type": "tool", "data": mock_data} + chunk = model.format_chunk(event) + tool_use_data = chunk["contentBlockStart"]["start"]["toolUse"] + assert tool_use_data["reasoningSignature"] == signature + + # 2. Simulate internal storage (streaming layer stores reasoningSignature) + internal_tool_use = { + "toolUseId": tool_use_data["toolUseId"], + "name": tool_use_data["name"], + "input": {"timezone": "UTC"}, + "reasoningSignature": tool_use_data["reasoningSignature"], + } + + # 3. Request path: format_request_message_tool_call re-encodes the signature + tool_call = LiteLLMModel.format_request_message_tool_call(internal_tool_use) + assert "__thought__" in tool_call["id"] + assert signature in tool_call["id"] diff --git a/tests/strands/models/test_llamaapi.py b/tests/strands/models/test_llamaapi.py index a6bbf5673..2bf12d055 100644 --- a/tests/strands/models/test_llamaapi.py +++ b/tests/strands/models/test_llamaapi.py @@ -1,4 +1,5 @@ # Copyright (c) Meta Platforms, Inc. and affiliates +import logging import unittest.mock import pytest @@ -414,3 +415,69 @@ async def test_tool_choice_none_no_warning(model, messages, captured_warnings, a await alist(response) assert len(captured_warnings) == 0 + + +def test_format_request_filters_s3_source_image(model, caplog): + """Test that images with Location sources are filtered out with warning.""" + caplog.set_level(logging.WARNING, logger="strands.models.llamaapi") + + messages = [ + { + "role": "user", + "content": [ + {"text": "look at this image"}, + { + "image": { + "format": "png", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/image.png"}}, + }, + }, + ], + }, + ] + + request = model.format_request(messages) + + # Image with S3 source should be filtered, text should remain + formatted_messages = request["messages"] + user_content = formatted_messages[0]["content"] + assert len(user_content) == 1 + assert user_content[0]["type"] == "text" + assert "Location sources are not supported by LlamaAPI" in caplog.text + + +def test_format_request_filters_location_source_document(model, caplog): + """Test that documents with Location sources are filtered out with warning.""" + caplog.set_level(logging.WARNING, logger="strands.models.llamaapi") + + messages = [ + { + "role": "user", + "content": [ + {"text": "analyze this document"}, + { + "document": { + "format": "pdf", + "name": "report.pdf", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/report.pdf"}}, + }, + }, + { + "document": { + "format": "pdf", + "name": "report.pdf", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/report.pdf"}}, + }, + }, + ], + }, + ] + + request = model.format_request(messages) + + # Document with S3 source should be filtered, text should remain + formatted_messages = request["messages"] + user_content = formatted_messages[0]["content"] + assert len(user_content) == 1 + assert user_content[0]["type"] == "text" + assert "Location sources are not supported by LlamaAPI" in caplog.text diff --git a/tests/strands/models/test_llamacpp.py b/tests/strands/models/test_llamacpp.py index e5b2614c0..6868e490b 100644 --- a/tests/strands/models/test_llamacpp.py +++ b/tests/strands/models/test_llamacpp.py @@ -2,7 +2,8 @@ import base64 import json -from unittest.mock import AsyncMock, patch +import logging +from unittest.mock import AsyncMock, MagicMock, patch import httpx import pytest @@ -247,7 +248,7 @@ async def mock_aiter_lines(): mock_response = AsyncMock() mock_response.aiter_lines = mock_aiter_lines - mock_response.raise_for_status = AsyncMock() + mock_response.raise_for_status = MagicMock() with patch.object(model.client, "post", return_value=mock_response): messages = [{"role": "user", "content": [{"text": "Hi"}]}] @@ -637,3 +638,190 @@ def test_format_messages_with_mixed_content() -> None: assert result[0]["content"][2]["type"] == "image_url" assert "image_url" in result[0]["content"][2] assert result[0]["content"][2]["image_url"]["url"].startswith("data:image/jpeg;base64,") + + +def test_format_request_filters_s3_source_image(caplog) -> None: + """Test that images with Location sources are filtered out with warning.""" + model = LlamaCppModel() + caplog.set_level(logging.WARNING, logger="strands.models.llamacpp") + + messages = [ + { + "role": "user", + "content": [ + {"text": "look at this image"}, + { + "image": { + "format": "png", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/image.png"}}, + }, + }, + ], + }, + ] + + request = model._format_request(messages) + + # Image with S3 source should be filtered, text should remain + formatted_messages = request["messages"] + user_content = formatted_messages[0]["content"] + assert len(user_content) == 1 + assert user_content[0]["type"] == "text" + assert "Location sources are not supported by llama.cpp" in caplog.text + + +def test_format_request_filters_location_source_document(caplog) -> None: + """Test that documents with Location sources are filtered out with warning.""" + model = LlamaCppModel() + caplog.set_level(logging.WARNING, logger="strands.models.llamacpp") + + messages = [ + { + "role": "user", + "content": [ + {"text": "analyze this document"}, + { + "document": { + "format": "pdf", + "name": "report.pdf", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/report.pdf"}}, + }, + }, + { + "document": { + "format": "pdf", + "name": "report.pdf", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/report.pdf"}}, + }, + }, + ], + }, + ] + + request = model._format_request(messages) + + # Document with S3 source should be filtered, text should remain + formatted_messages = request["messages"] + user_content = formatted_messages[0]["content"] + assert len(user_content) == 1 + assert user_content[0]["type"] == "text" + assert "Location sources are not supported by llama.cpp" in caplog.text + + +class TestCountTokens: + """Tests for LlamaCppModel.count_tokens native token counting.""" + + @pytest.fixture + def model(self): + return LlamaCppModel(base_url="http://localhost:8080", use_native_token_count=True) + + @pytest.fixture + def messages(self): + return [{"role": "user", "content": [{"text": "hello"}]}] + + @pytest.fixture + def tool_specs(self): + return [ + { + "name": "test_tool", + "description": "A test tool", + "inputSchema": {"json": {"type": "object", "properties": {}}}, + } + ] + + @pytest.mark.asyncio + async def test_native_count_tokens_success(self, model, messages): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"tokens": [1, 2, 3, 4, 5]} + mock_response.raise_for_status = MagicMock() + model.client.post = AsyncMock(return_value=mock_response) + + result = await model.count_tokens(messages=messages) + + assert result == 5 + model.client.post.assert_called_once() + call_args = model.client.post.call_args + assert call_args[0][0] == "/tokenize" + + @pytest.mark.asyncio + async def test_native_count_tokens_with_system_prompt(self, model, messages): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"tokens": list(range(10))} + mock_response.raise_for_status = MagicMock() + model.client.post = AsyncMock(return_value=mock_response) + + result = await model.count_tokens(messages=messages, system_prompt="Be helpful.") + + assert result == 10 + call_kwargs = model.client.post.call_args[1] + payload = call_kwargs["json"] + assert payload["messages"][0]["role"] == "system" + assert payload["messages"][0]["content"] == "Be helpful." + + @pytest.mark.asyncio + async def test_native_count_tokens_with_tool_specs(self, model, messages, tool_specs): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"tokens": list(range(20))} + mock_response.raise_for_status = MagicMock() + model.client.post = AsyncMock(return_value=mock_response) + + result = await model.count_tokens(messages=messages, tool_specs=tool_specs) + + assert result == 20 + call_kwargs = model.client.post.call_args[1] + payload = call_kwargs["json"] + assert "tools" in payload + + @pytest.mark.asyncio + async def test_fallback_on_http_error(self, model, messages): + model.client.post = AsyncMock( + side_effect=httpx.HTTPStatusError("Server error", request=MagicMock(), response=MagicMock(status_code=500)) + ) + + result = await model.count_tokens(messages=messages) + + assert isinstance(result, int) + assert result >= 0 + + @pytest.mark.asyncio + async def test_fallback_on_connection_error(self, model, messages): + model.client.post = AsyncMock(side_effect=httpx.ConnectError("Connection refused")) + + result = await model.count_tokens(messages=messages) + + assert isinstance(result, int) + assert result >= 0 + + @pytest.mark.asyncio + async def test_fallback_logs_debug(self, model, messages, caplog): + model.client.post = AsyncMock(side_effect=RuntimeError("Server down")) + + with caplog.at_level(logging.DEBUG, logger="strands.models.llamacpp"): + await model.count_tokens(messages=messages) + + assert any("native token counting failed" in record.message for record in caplog.records) + + @pytest.mark.asyncio + async def test_skip_native_api_when_use_native_token_count_false(self, messages): + model = LlamaCppModel(base_url="http://localhost:8080", use_native_token_count=False) + model.client.post = AsyncMock() + + result = await model.count_tokens(messages=messages) + + model.client.post.assert_not_called() + assert isinstance(result, int) + assert result >= 0 + + @pytest.mark.asyncio + async def test_skip_native_api_by_default(self, messages): + model = LlamaCppModel(base_url="http://localhost:8080") + model.client.post = AsyncMock() + + result = await model.count_tokens(messages=messages) + + model.client.post.assert_not_called() + assert isinstance(result, int) + assert result >= 0 diff --git a/tests/strands/models/test_mistral.py b/tests/strands/models/test_mistral.py index 7808336f2..dd2728785 100644 --- a/tests/strands/models/test_mistral.py +++ b/tests/strands/models/test_mistral.py @@ -1,3 +1,4 @@ +import logging import unittest.mock import pydantic @@ -79,6 +80,30 @@ def test__init__model_configs(mistral_client, model_id, max_tokens): assert actual_temperature == exp_temperature +def test__init__auto_populates_context_window_limit(mistral_client): + _ = mistral_client + + model = MistralModel(model_id="mistral-large-latest", max_tokens=1) + + assert model.get_config().get("context_window_limit") == 262_144 + + +def test__init__explicit_context_window_limit_not_overridden(mistral_client): + _ = mistral_client + + model = MistralModel(model_id="mistral-large-latest", max_tokens=1, context_window_limit=100_000) + + assert model.get_config().get("context_window_limit") == 100_000 + + +def test__init__unknown_model_no_context_window_limit(mistral_client): + _ = mistral_client + + model = MistralModel(model_id="unknown-model", max_tokens=1) + + assert model.get_config().get("context_window_limit") is None + + def test_update_config(model, model_id): model.update_config(model_id=model_id) @@ -450,9 +475,9 @@ async def test_stream(mistral_client, model, agenerator, alist, captured_warning delta=unittest.mock.Mock(content="test stream", tool_calls=None), finish_reason="end_turn", ) - ] + ], + usage=mock_usage, ), - usage=mock_usage, ) mistral_client.chat.stream_async = unittest.mock.AsyncMock(return_value=agenerator([mock_event])) @@ -475,6 +500,30 @@ async def test_stream(mistral_client, model, agenerator, alist, captured_warning assert len(captured_warnings) == 0 +@pytest.mark.asyncio +async def test_stream_no_usage(mistral_client, model, agenerator, alist): + mock_event = unittest.mock.Mock( + data=unittest.mock.Mock( + choices=[ + unittest.mock.Mock( + delta=unittest.mock.Mock(content="test stream", tool_calls=None), + finish_reason="end_turn", + ) + ], + usage=None, + ), + ) + + mistral_client.chat.stream_async = unittest.mock.AsyncMock(return_value=agenerator([mock_event])) + + messages = [{"role": "user", "content": [{"text": "test"}]}] + response = model.stream(messages, None, None) + + # Should complete without error and not yield a metadata chunk + chunks = await alist(response) + assert not any("metadata" in c for c in chunks if isinstance(c, dict)) + + @pytest.mark.asyncio async def test_tool_choice_not_supported_warns(mistral_client, model, agenerator, alist, captured_warnings): tool_choice = {"auto": {}} @@ -491,9 +540,9 @@ async def test_tool_choice_not_supported_warns(mistral_client, model, agenerator delta=unittest.mock.Mock(content="test stream", tool_calls=None), finish_reason="end_turn", ) - ] + ], + usage=mock_usage, ), - usage=mock_usage, ) mistral_client.chat.stream_async = unittest.mock.AsyncMock(return_value=agenerator([mock_event])) @@ -592,3 +641,65 @@ def test_update_config_validation_warns_on_unknown_keys(model, captured_warnings assert len(captured_warnings) == 1 assert "Invalid configuration parameters" in str(captured_warnings[0].message) assert "wrong_param" in str(captured_warnings[0].message) + + +def test_format_request_filters_s3_source_image(model, caplog): + """Test that images with Location sources are filtered out with warning.""" + caplog.set_level(logging.WARNING, logger="strands.models.mistral") + + messages = [ + { + "role": "user", + "content": [ + {"text": "look at this image"}, + { + "image": { + "format": "png", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/image.png"}}, + }, + }, + ], + }, + ] + + formatted_messages = model._format_request_messages(messages) + + # Image with S3 source should be filtered, text should remain + user_content = formatted_messages[0]["content"] + assert user_content == "look at this image" + assert "Location sources are not supported by Mistral" in caplog.text + + +def test_format_request_filters_location_source_document(model, caplog): + """Test that documents with Location sources are filtered out with warning.""" + caplog.set_level(logging.WARNING, logger="strands.models.mistral") + + messages = [ + { + "role": "user", + "content": [ + {"text": "analyze this document"}, + { + "document": { + "format": "pdf", + "name": "report.pdf", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/report.pdf"}}, + }, + }, + { + "document": { + "format": "pdf", + "name": "report.pdf", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/report.pdf"}}, + }, + }, + ], + }, + ] + + formatted_messages = model._format_request_messages(messages) + + # Document with S3 source should be filtered, text should remain + user_content = formatted_messages[0]["content"] + assert user_content == "analyze this document" + assert "Location sources are not supported by Mistral" in caplog.text diff --git a/tests/strands/models/test_model.py b/tests/strands/models/test_model.py index b8249f504..34f4ef328 100644 --- a/tests/strands/models/test_model.py +++ b/tests/strands/models/test_model.py @@ -1,7 +1,11 @@ +from unittest.mock import MagicMock + import pytest from pydantic import BaseModel +from strands.hooks.events import AfterInvocationEvent from strands.models import Model as SAModel +from strands.models.model import _ModelPlugin class Person(BaseModel): @@ -67,6 +71,11 @@ def tool_specs(): ] +@pytest.fixture +def model_plugin(): + return _ModelPlugin() + + @pytest.fixture def system_prompt(): return "s1" @@ -173,3 +182,385 @@ async def stream(self, messages, tool_specs=None, system_prompt=None, *, tool_ch response = model.stream(messages, tool_specs, system_prompt) events = await alist(response) assert events[1]["contentBlockDelta"]["delta"]["text"] == "No tool choice" + + +def test_context_window_limit_from_dict_config(): + class DictConfigModel(SAModel): + def update_config(self, **model_config): + pass + + def get_config(self): + return {"context_window_limit": 200_000} + + async def structured_output(self, output_model, prompt=None, system_prompt=None, **kwargs): + yield {} + + async def stream(self, messages, tool_specs=None, system_prompt=None): + yield {} + + assert DictConfigModel().context_window_limit == 200_000 + + +def test_context_window_limit_none_when_not_configured(model): + assert model.context_window_limit is None + + +def test_stateful_false(model): + """Model.stateful defaults to False.""" + assert not model.stateful + + +def test_model_plugin_clears_messages_when_stateful(model_plugin): + """Messages are cleared when model is stateful.""" + agent = MagicMock() + agent.model.stateful = True + agent._model_state = {"response_id": "resp_123"} + agent.messages = [{"role": "user", "content": [{"text": "hello"}]}] + + event = AfterInvocationEvent(agent=agent, invocation_state={}) + model_plugin._on_after_invocation(event) + + assert agent.messages == [] + + +def test_model_plugin_preserves_messages_when_not_stateful(model_plugin): + """Messages are preserved when model is not stateful.""" + agent = MagicMock() + agent.model.stateful = False + agent._model_state = {} + agent.messages = [{"role": "user", "content": [{"text": "hello"}]}] + + event = AfterInvocationEvent(agent=agent, invocation_state={}) + model_plugin._on_after_invocation(event) + + assert len(agent.messages) == 1 + + +@pytest.mark.asyncio +async def test_count_tokens_empty_messages(model): + assert await model.count_tokens(messages=[]) == 0 + + +@pytest.mark.asyncio +async def test_count_tokens_system_prompt_only(model): + result = await model.count_tokens(messages=[], system_prompt="You are a helpful assistant.") + assert result == 7 # ceil(28/4) + + +@pytest.mark.asyncio +async def test_count_tokens_text_messages(model, messages): + result = await model.count_tokens(messages=messages) + assert result == 2 # ceil(5/4) + + +@pytest.mark.asyncio +async def test_count_tokens_with_tool_specs(model, messages, tool_specs): + without_tools = await model.count_tokens(messages=messages) + with_tools = await model.count_tokens(messages=messages, tool_specs=tool_specs) + assert without_tools == 2 # ceil(5/4) + assert with_tools == 84 # ceil(5/4) + ceil(164/2) + + +@pytest.mark.asyncio +async def test_count_tokens_with_system_prompt(model, messages, system_prompt): + without_prompt = await model.count_tokens(messages=messages) + with_prompt = await model.count_tokens(messages=messages, system_prompt=system_prompt) + assert without_prompt == 2 # ceil(5/4) + assert with_prompt == 3 # ceil(5/4) + ceil(2/4) + + +@pytest.mark.asyncio +async def test_count_tokens_combined(model, messages, tool_specs, system_prompt): + result = await model.count_tokens(messages=messages, tool_specs=tool_specs, system_prompt=system_prompt) + assert result == 85 # ceil(5/4) + ceil(164/2) + ceil(2/4) + + +@pytest.mark.asyncio +async def test_count_tokens_tool_use_block(model): + messages = [ + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "123", + "name": "my_tool", + "input": {"query": "test"}, + } + } + ], + } + ] + result = await model.count_tokens(messages=messages) + # name "my_tool" ceil(7/4)=2 + json.dumps(input) ceil(17/2)=9 = 11 + assert result == 11 + + +@pytest.mark.asyncio +async def test_count_tokens_tool_result_block(model): + messages = [ + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "123", + "content": [{"text": "tool output here"}], + "status": "success", + } + } + ], + } + ] + result = await model.count_tokens(messages=messages) + assert result == 4 # ceil(16/4) + + +@pytest.mark.asyncio +async def test_count_tokens_reasoning_block(model): + messages = [ + { + "role": "assistant", + "content": [ + { + "reasoningContent": { + "reasoningText": { + "text": "Let me think about this step by step.", + } + } + } + ], + } + ] + result = await model.count_tokens(messages=messages) + assert result == 10 # ceil(37/4) + + +@pytest.mark.asyncio +async def test_count_tokens_skips_binary_content(model): + messages = [ + { + "role": "user", + "content": [{"image": {"format": "png", "source": {"bytes": b"fake image data"}}}], + } + ] + assert await model.count_tokens(messages=messages) == 0 + + +@pytest.mark.asyncio +async def test_count_tokens_tool_result_with_bytes_only(model): + messages = [ + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "123", + "content": [{"image": {"format": "png", "source": {"bytes": b"image data"}}}], + "status": "success", + } + } + ], + } + ] + result = await model.count_tokens(messages=messages) + assert result == 0 + + +@pytest.mark.asyncio +async def test_count_tokens_tool_result_with_text_and_bytes(model): + messages = [ + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "123", + "content": [ + {"text": "Here is the screenshot"}, + {"image": {"format": "png", "source": {"bytes": b"image data"}}}, + ], + "status": "success", + } + } + ], + } + ] + result = await model.count_tokens(messages=messages) + assert result > 0 + + +@pytest.mark.asyncio +async def test_count_tokens_guard_content_block(model): + messages = [ + { + "role": "assistant", + "content": [{"guardContent": {"text": {"text": "This content was filtered by guardrails."}}}], + } + ] + result = await model.count_tokens(messages=messages) + assert result == 10 # ceil(40/4) + + +@pytest.mark.asyncio +async def test_count_tokens_tool_use_with_bytes(model): + messages = [ + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "123", + "name": "my_tool", + "input": {"data": b"binary data"}, + } + } + ], + } + ] + result = await model.count_tokens(messages=messages) + # Should still count the tool name even though input has non-serializable bytes + assert result == 2 # ceil(7/4) name only + + +@pytest.mark.asyncio +async def test_count_tokens_non_serializable_tool_spec(model, messages): + tool_specs = [ + { + "name": "test", + "description": "a tool", + "inputSchema": {"json": {"default": b"bytes"}}, + } + ] + result = await model.count_tokens(messages=messages, tool_specs=tool_specs) + # Should still count the message tokens even though tool spec fails + assert result == 2 # ceil(5/4) only, tool spec skipped + + +@pytest.mark.asyncio +async def test_count_tokens_citations_block(model): + messages = [ + { + "role": "assistant", + "content": [ + { + "citationsContent": { + "content": [{"text": "According to the document, the answer is 42."}], + "citations": [], + } + } + ], + } + ] + result = await model.count_tokens(messages=messages) + assert result == 11 # ceil(44/4) + + +@pytest.mark.asyncio +async def test_count_tokens_system_prompt_content(model): + result = await model.count_tokens( + messages=[], + system_prompt_content=[{"text": "You are a helpful assistant."}], + ) + assert result == 7 # ceil(28/4) + + +@pytest.mark.asyncio +async def test_count_tokens_system_prompt_content_with_cache_point(model): + result = await model.count_tokens( + messages=[], + system_prompt_content=[ + {"text": "You are a helpful assistant."}, + {"cachePoint": {"type": "default"}}, + ], + ) + assert result == 7 # ceil(28/4), cachePoint adds 0 + + +@pytest.mark.asyncio +async def test_count_tokens_system_prompt_content_takes_priority(model): + content_only = await model.count_tokens( + messages=[], + system_prompt_content=[{"text": "Short."}], + ) + # When both are provided, system_prompt_content wins — system_prompt is ignored + both = await model.count_tokens( + messages=[], + system_prompt="This is a much longer system prompt that should have more tokens.", + system_prompt_content=[{"text": "Short."}], + ) + assert content_only == 2 # ceil(6/4) + assert content_only == both + + +@pytest.mark.asyncio +async def test_count_tokens_all_inputs(model): + messages = [ + {"role": "user", "content": [{"text": "hello world"}]}, + {"role": "assistant", "content": [{"text": "hi there"}]}, + ] + result = await model.count_tokens( + messages=messages, + tool_specs=[{"name": "test", "description": "a test tool", "inputSchema": {"json": {}}}], + system_prompt="Be helpful.", + system_prompt_content=[{"text": "Additional system context."}], + ) + # system_prompt_content (7) + "hello world" (3) + "hi there" (2) + tool_spec (38) = 50 + assert result == 50 + + +class TestHeuristicEstimation: + """Tests for _estimate_tokens_with_heuristic.""" + + def test_all_content_types(self): + """One call covering text, toolUse, toolResult, reasoning, guard, citations, system prompt, tool specs.""" + from strands.models.model import _estimate_tokens_with_heuristic + + messages = [ + {"role": "user", "content": [{"text": "hello world!"}]}, + { + "role": "assistant", + "content": [ + {"toolUse": {"toolUseId": "1", "name": "my_tool", "input": {"q": "test"}}}, + {"reasoningContent": {"reasoningText": {"text": "Let me think."}}}, + {"guardContent": {"text": {"text": "Filtered."}}}, + {"citationsContent": {"content": [{"text": "Citation."}]}}, + ], + }, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "1", "content": [{"text": "tool output here"}]}}, + ], + }, + ] + result = _estimate_tokens_with_heuristic( + messages=messages, + tool_specs=[{"name": "test", "description": "a tool"}], + system_prompt="ignored", + system_prompt_content=[{"text": "Be helpful."}], + ) + assert result > 0 + + def test_non_serializable_inputs(self): + """Heuristic gracefully handles non-serializable tool input and tool specs.""" + from strands.models.model import _estimate_tokens_with_heuristic + + result = _estimate_tokens_with_heuristic( + messages=[ + { + "role": "assistant", + "content": [ + {"toolUse": {"toolUseId": "1", "name": "my_tool", "input": {"data": b"bytes"}}}, + ], + }, + ], + tool_specs=[{"name": "t", "inputSchema": {"json": {"default": b"bytes"}}}], + ) + assert result == 2 # only tool name counted: ceil(len("my_tool") / 4) + + @pytest.mark.asyncio + async def test_model_uses_heuristic(self, model): + """Model.count_tokens uses heuristic estimation.""" + result = await model.count_tokens(messages=[{"role": "user", "content": [{"text": "hello world!"}]}]) + assert result == 3 # ceil(12 / 4) diff --git a/tests/strands/models/test_ollama.py b/tests/strands/models/test_ollama.py index 14db63a24..360683d08 100644 --- a/tests/strands/models/test_ollama.py +++ b/tests/strands/models/test_ollama.py @@ -1,4 +1,6 @@ import json +import logging +import re import unittest.mock import pydantic @@ -126,7 +128,12 @@ def test_format_request_with_image(model, model_id): def test_format_request_with_tool_use(model, model_id): messages = [ - {"role": "assistant", "content": [{"toolUse": {"toolUseId": "calculator", "input": '{"expression": "2+2"}'}}]} + { + "role": "assistant", + "content": [ + {"toolUse": {"toolUseId": "tool-use-id-123", "name": "calculator", "input": '{"expression": "2+2"}'}} + ], + } ] tru_request = model.format_request(messages) @@ -320,9 +327,11 @@ def test_format_chunk_content_start_tool(model): event = {"chunk_type": "content_start", "data_type": "tool", "data": mock_function} tru_chunk = model.format_chunk(event) - exp_chunk = {"contentBlockStart": {"start": {"toolUse": {"name": "calculator", "toolUseId": "calculator"}}}} + tool_use = tru_chunk["contentBlockStart"]["start"]["toolUse"] - assert tru_chunk == exp_chunk + assert tool_use["name"] == "calculator" + assert tool_use["toolUseId"] != "calculator" + assert len(tool_use["toolUseId"]) > 0 def test_format_chunk_content_delta_text(model): @@ -393,12 +402,12 @@ def test_format_chunk_metadata(model): exp_chunk = { "metadata": { "usage": { - "inputTokens": 100, - "outputTokens": 50, + "inputTokens": 50, + "outputTokens": 100, "totalTokens": 150, }, "metrics": { - "latencyMs": 1.0, + "latencyMs": 1, }, }, } @@ -437,8 +446,8 @@ async def test_stream(ollama_client, model, agenerator, alist, captured_warnings {"messageStop": {"stopReason": "end_turn"}}, { "metadata": { - "usage": {"inputTokens": 10, "outputTokens": 5, "totalTokens": 15}, - "metrics": {"latencyMs": 1.0}, + "usage": {"inputTokens": 5, "outputTokens": 10, "totalTokens": 15}, + "metrics": {"latencyMs": 1}, } }, ] @@ -498,24 +507,27 @@ async def test_stream_with_tool_calls(ollama_client, model, agenerator, alist): response = model.stream(messages) tru_events = await alist(response) - exp_events = [ - {"messageStart": {"role": "assistant"}}, - {"contentBlockStart": {"start": {}}}, - {"contentBlockStart": {"start": {"toolUse": {"name": "calculator", "toolUseId": "calculator"}}}}, - {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"expression": "2+2"}'}}}}, - {"contentBlockStop": {}}, - {"contentBlockDelta": {"delta": {"text": "I'll calculate that for you"}}}, - {"contentBlockStop": {}}, - {"messageStop": {"stopReason": "tool_use"}}, - { - "metadata": { - "usage": {"inputTokens": 15, "outputTokens": 8, "totalTokens": 23}, - "metrics": {"latencyMs": 2.0}, - } - }, - ] - assert tru_events == exp_events + # Verify the tool use event has a unique ID (not equal to the tool name) + tool_start_event = tru_events[2] + tool_use = tool_start_event["contentBlockStart"]["start"]["toolUse"] + assert tool_use["name"] == "calculator" + assert tool_use["toolUseId"] != "calculator" + + # Verify all other events + assert tru_events[0] == {"messageStart": {"role": "assistant"}} + assert tru_events[1] == {"contentBlockStart": {"start": {}}} + assert tru_events[3] == {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"expression": "2+2"}'}}}} + assert tru_events[4] == {"contentBlockStop": {}} + assert tru_events[5] == {"contentBlockDelta": {"delta": {"text": "I'll calculate that for you"}}} + assert tru_events[6] == {"contentBlockStop": {}} + assert tru_events[7] == {"messageStop": {"stopReason": "tool_use"}} + assert tru_events[8] == { + "metadata": { + "usage": {"inputTokens": 8, "outputTokens": 15, "totalTokens": 23}, + "metrics": {"latencyMs": 2}, + } + } expected_request = { "model": "m1", "messages": [{"role": "user", "content": "Calculate 2+2"}], @@ -559,3 +571,122 @@ def test_update_config_validation_warns_on_unknown_keys(model, captured_warnings assert len(captured_warnings) == 1 assert "Invalid configuration parameters" in str(captured_warnings[0].message) assert "wrong_param" in str(captured_warnings[0].message) + + +def test_format_request_filters_s3_source_image(model, caplog): + """Test that images with Location sources are filtered out with warning.""" + caplog.set_level(logging.WARNING, logger="strands.models.ollama") + + messages = [ + { + "role": "user", + "content": [ + {"text": "look at this image"}, + { + "image": { + "format": "png", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/image.png"}}, + }, + }, + ], + }, + ] + + request = model.format_request(messages) + + # Image with S3 source should be filtered, text should remain + formatted_messages = request["messages"] + user_message = formatted_messages[0] + assert user_message["content"] == "look at this image" + assert "images" not in user_message or user_message.get("images") == [] + assert "Location sources are not supported by Ollama" in caplog.text + + +def test_format_request_filters_location_source_document(model, caplog): + """Test that documents with Location sources are filtered out with warning.""" + caplog.set_level(logging.WARNING, logger="strands.models.ollama") + + messages = [ + { + "role": "user", + "content": [ + {"text": "analyze this document"}, + { + "document": { + "format": "pdf", + "name": "report.pdf", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/report.pdf"}}, + }, + }, + { + "document": { + "format": "pdf", + "name": "report.pdf", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/report.pdf"}}, + }, + }, + ], + }, + ] + + request = model.format_request(messages) + + # Document with S3 source should be filtered, text should remain + formatted_messages = request["messages"] + user_message = formatted_messages[0] + assert user_message["content"] == "analyze this document" + assert "Location sources are not supported by Ollama" in caplog.text + + +def test_tool_use_id_is_unique_and_not_tool_name(model): + """Test that toolUseId is a unique UUID, not the tool name.""" + mock_function = unittest.mock.Mock() + mock_function.function.name = "calculator" + + event = {"chunk_type": "content_start", "data_type": "tool", "data": mock_function} + + chunk1 = model.format_chunk(event) + chunk2 = model.format_chunk(event) + + tool_use1 = chunk1["contentBlockStart"]["start"]["toolUse"] + tool_use2 = chunk2["contentBlockStart"]["start"]["toolUse"] + + # toolUseId should not equal the tool name + assert tool_use1["toolUseId"] != "calculator" + assert tool_use2["toolUseId"] != "calculator" + + # toolUseId should be unique across calls + assert tool_use1["toolUseId"] != tool_use2["toolUseId"] + + # toolUseId should follow the tooluse_<24-hex> convention used by other providers + assert re.fullmatch(r"tooluse_[0-9a-f]{24}", tool_use1["toolUseId"]) + assert re.fullmatch(r"tooluse_[0-9a-f]{24}", tool_use2["toolUseId"]) + + # name should still be correct + assert tool_use1["name"] == "calculator" + assert tool_use2["name"] == "calculator" + + +def test_format_request_uses_tool_name_not_tool_use_id(model, model_id): + """Test that format_request uses the 'name' field, not 'toolUseId', for the function name.""" + messages = [ + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "unique-id-abc-123", + "name": "calculator", + "input": '{"expression": "1+1"}', + } + } + ], + } + ] + + request = model.format_request(messages) + tool_call = request["messages"][0]["tool_calls"][0] + + # The function name in the request must come from "name", not "toolUseId" + assert tool_call["function"]["name"] == "calculator" + assert tool_call["function"]["name"] != "unique-id-abc-123" diff --git a/tests/strands/models/test_openai.py b/tests/strands/models/test_openai.py index 0de0c4ebc..613acd163 100644 --- a/tests/strands/models/test_openai.py +++ b/tests/strands/models/test_openai.py @@ -1,3 +1,5 @@ +import logging +import os import unittest.mock import openai @@ -13,7 +15,10 @@ def openai_client(): with unittest.mock.patch.object(strands.models.openai.openai, "AsyncOpenAI") as mock_client_cls: mock_client = unittest.mock.AsyncMock() - mock_client_cls.return_value.__aenter__.return_value = mock_client + # Make the mock client work as an async context manager + mock_client.__aenter__ = unittest.mock.AsyncMock(return_value=mock_client) + mock_client.__aexit__ = unittest.mock.AsyncMock(return_value=None) + mock_client_cls.return_value = mock_client yield mock_client @@ -85,6 +90,39 @@ def test_update_config(model, model_id): assert tru_model_id == exp_model_id +def test__init__context_window_limit(openai_client): + _ = openai_client + + model = OpenAIModel(model_id="gpt-4o", context_window_limit=128_000) + + assert model.get_config().get("context_window_limit") == 128_000 + assert model.context_window_limit == 128_000 + + +def test__init__auto_populates_context_window_limit(openai_client): + _ = openai_client + + model = OpenAIModel(model_id="gpt-4o") + + assert model.get_config().get("context_window_limit") == 128_000 + + +def test__init__explicit_context_window_limit_not_overridden(openai_client): + _ = openai_client + + model = OpenAIModel(model_id="gpt-4o", context_window_limit=50_000) + + assert model.get_config().get("context_window_limit") == 50_000 + + +def test__init__unknown_model_no_context_window_limit(openai_client): + _ = openai_client + + model = OpenAIModel(model_id="unknown-model") + + assert model.get_config().get("context_window_limit") is None + + @pytest.mark.parametrize( "content, exp_result", [ @@ -169,13 +207,337 @@ def test_format_request_tool_message(): tru_result = OpenAIModel.format_request_tool_message(tool_result) exp_result = { - "content": [{"text": "4", "type": "text"}, {"text": '["4"]', "type": "text"}], + "content": '4\n["4"]', + "role": "tool", + "tool_call_id": "c1", + } + assert tru_result == exp_result + + +def test_format_request_tool_message_single_text_returns_string(): + """Test that single text content is returned as string for model compatibility.""" + tool_result = { + "content": [{"text": '{"result": "success"}'}], + "status": "success", + "toolUseId": "c1", + } + + tru_result = OpenAIModel.format_request_tool_message(tool_result) + exp_result = { + "content": '{"result": "success"}', + "role": "tool", + "tool_call_id": "c1", + } + assert tru_result == exp_result + + +def test_format_request_tool_message_multi_text_returns_joined_string(): + """Test that multi-content text results are joined into a single string. + + Regression test for https://github.com/strands-agents/sdk-python/issues/1696. + OpenAI-compatible endpoints (e.g., Kimi K2.5, vLLM, Ollama) only correctly + parse string content for tool messages; array format causes hallucinated results. + """ + tool_result = { + "content": [ + {"text": "Temperature: 72°F"}, + {"json": {"humidity": 45, "unit": "%"}}, + {"text": "Wind: 5 mph"}, + ], + "status": "success", + "toolUseId": "c1", + } + + tru_result = OpenAIModel.format_request_tool_message(tool_result) + exp_result = { + "content": 'Temperature: 72°F\n{"humidity": 45, "unit": "%"}\nWind: 5 mph', "role": "tool", "tool_call_id": "c1", } assert tru_result == exp_result +def test_format_request_tool_message_mixed_text_image_preserves_order(): + """Test that text and image content blocks preserve their original order.""" + tool_result = { + "content": [ + {"text": "Before image"}, + {"image": {"format": "png", "source": {"bytes": b"PNG"}}}, + {"text": "After image"}, + ], + "status": "success", + "toolUseId": "c1", + } + + tru_result = OpenAIModel.format_request_tool_message(tool_result) + content = tru_result["content"] + # Array format since images are present + assert isinstance(content, list) + assert len(content) == 3 + # Order preserved: text, image, text + assert content[0] == {"type": "text", "text": "Before image"} + assert content[1]["type"] == "image_url" + assert content[2] == {"type": "text", "text": "After image"} + + +def test_format_request_tool_message_merges_adjacent_text(): + """Test that adjacent text blocks are merged while non-text order is preserved.""" + tool_result = { + "content": [ + {"text": "Line 1"}, + {"text": "Line 2"}, + {"image": {"format": "png", "source": {"bytes": b"PNG"}}}, + {"text": "Line 3"}, + ], + "status": "success", + "toolUseId": "c1", + } + + tru_result = OpenAIModel.format_request_tool_message(tool_result) + content = tru_result["content"] + assert isinstance(content, list) + assert len(content) == 3 + # Adjacent text merged, image order preserved + assert content[0] == {"type": "text", "text": "Line 1\nLine 2"} + assert content[1]["type"] == "image_url" + assert content[2] == {"type": "text", "text": "Line 3"} + + +def test_format_request_tool_message_image_only(): + """Test tool message with only non-text content.""" + tool_result = { + "content": [ + {"image": {"format": "png", "source": {"bytes": b"PNG"}}}, + ], + "status": "success", + "toolUseId": "c1", + } + + tru_result = OpenAIModel.format_request_tool_message(tool_result) + content = tru_result["content"] + assert isinstance(content, list) + assert len(content) == 1 + assert content[0]["type"] == "image_url" + + +def test_format_request_tool_message_document_mixed(): + """Test tool message with document content mixed with text.""" + tool_result = { + "content": [ + {"text": "Summary"}, + {"document": {"format": "pdf", "name": "report.pdf", "source": {"bytes": b"PDF"}}}, + {"text": "Footer"}, + ], + "status": "success", + "toolUseId": "c1", + } + + tru_result = OpenAIModel.format_request_tool_message(tool_result) + content = tru_result["content"] + assert isinstance(content, list) + assert len(content) == 3 + assert content[0] == {"type": "text", "text": "Summary"} + assert content[1]["type"] == "file" + assert content[2] == {"type": "text", "text": "Footer"} + + +def test_format_request_tool_message_empty_content(): + """Test tool message with empty content list returns empty string.""" + tool_result = { + "content": [], + "status": "success", + "toolUseId": "c1", + } + + tru_result = OpenAIModel.format_request_tool_message(tool_result) + assert tru_result["content"] == "" + assert tru_result["role"] == "tool" + assert tru_result["tool_call_id"] == "c1" + + +def test_split_tool_message_images_with_image(): + """Test that images are extracted from tool messages.""" + tool_message = { + "role": "tool", + "tool_call_id": "c1", + "content": [ + {"type": "text", "text": "Result"}, + { + "type": "image_url", + "image_url": {"url": "data:image/png;base64,iVBORw0KGgo=", "detail": "auto", "format": "image/png"}, + }, + ], + } + + tool_clean, user_with_image = OpenAIModel._split_tool_message_images(tool_message) + + # Tool message should now have the original text plus the appended informational text + assert tool_clean["role"] == "tool" + assert tool_clean["tool_call_id"] == "c1" + assert len(tool_clean["content"]) == 2 + assert tool_clean["content"][0]["type"] == "text" + assert tool_clean["content"][0]["text"] == "Result" + assert "Tool successfully returned an image" in tool_clean["content"][1]["text"] + + # User message should have the image + assert user_with_image is not None + assert user_with_image["role"] == "user" + assert len(user_with_image["content"]) == 1 + assert user_with_image["content"][0]["type"] == "image_url" + + +def test_split_tool_message_images_without_image(): + """Test that tool messages without images are unchanged.""" + tool_message = {"role": "tool", "tool_call_id": "c1", "content": [{"type": "text", "text": "Result"}]} + + tool_clean, user_with_image = OpenAIModel._split_tool_message_images(tool_message) + + assert tool_clean == tool_message + assert user_with_image is None + + +def test_split_tool_message_images_only_image(): + """Test tool message with only image content.""" + tool_message = { + "role": "tool", + "tool_call_id": "c1", + "content": [{"type": "image_url", "image_url": {"url": "data:image/png;base64,iVBORw0KGgo="}}], + } + + tool_clean, user_with_image = OpenAIModel._split_tool_message_images(tool_message) + + # Tool message should have default text + assert tool_clean["role"] == "tool" + assert len(tool_clean["content"]) == 1 + assert "successfully" in tool_clean["content"][0]["text"].lower() + + # User message should have the image + assert user_with_image is not None + assert user_with_image["role"] == "user" + assert len(user_with_image["content"]) == 1 + + +def test_split_tool_message_images_non_tool_role(): + """Test that messages with roles other than 'tool' are ignored.""" + user_msg = {"role": "user", "content": [{"type": "text", "text": "hello"}]} + clean, extra = OpenAIModel._split_tool_message_images(user_msg) + assert clean == user_msg + assert extra is None + + +def test_split_tool_message_images_invalid_content_type(): + """Test that messages with non-list content are ignored.""" + invalid_msg = {"role": "tool", "content": "not a list"} + clean, extra = OpenAIModel._split_tool_message_images(invalid_msg) + assert clean == invalid_msg + assert extra is None + + +def test_format_request_messages_with_tool_result_containing_image(): + """Test that tool results with images are properly split into tool and user messages.""" + messages = [ + { + "content": [{"text": "Run the tool"}], + "role": "user", + }, + { + "content": [ + { + "toolUse": { + "input": {}, + "name": "image_tool", + "toolUseId": "t1", + }, + }, + ], + "role": "assistant", + }, + { + "content": [ + { + "toolResult": { + "toolUseId": "t1", + "status": "success", + "content": [ + {"text": "Image generated"}, + { + "image": { + "format": "png", + "source": {"bytes": b"fake_image_data"}, + } + }, + ], + } + } + ], + "role": "user", + }, + ] + + formatted = OpenAIModel.format_request_messages(messages) + + # Find the tool message + tool_messages = [msg for msg in formatted if msg.get("role") == "tool"] + assert len(tool_messages) == 1 + + # Tool message should only have text content + tool_msg = tool_messages[0] + assert all(c.get("type") != "image_url" for c in tool_msg["content"]) + + # There should be a user message right after the tool message with the image + tool_msg_idx = formatted.index(tool_msg) + assert tool_msg_idx + 1 < len(formatted) + user_msg = formatted[tool_msg_idx + 1] + assert user_msg["role"] == "user" + assert any(c.get("type") == "image_url" for c in user_msg["content"]) + + +def test_format_request_messages_with_multiple_images_in_tool_result(): + """Test tool result with multiple images.""" + messages = [ + { + "content": [ + { + "toolResult": { + "toolUseId": "t1", + "status": "success", + "content": [ + {"text": "Two images generated"}, + { + "image": { + "format": "png", + "source": {"bytes": b"image1"}, + } + }, + { + "image": { + "format": "jpg", + "source": {"bytes": b"image2"}, + } + }, + ], + } + } + ], + "role": "user", + }, + ] + + formatted = OpenAIModel.format_request_messages(messages) + + # Find user message with images + user_image_msgs = [ + msg + for msg in formatted + if msg.get("role") == "user" and any(c.get("type") == "image_url" for c in msg.get("content", [])) + ] + assert len(user_image_msgs) == 1 + + # Should have both images + image_contents = [c for c in user_image_msgs[0]["content"] if c.get("type") == "image_url"] + assert len(image_contents) == 2 + + def test_format_request_tool_choice_auto(): tool_choice = {"auto": {}} @@ -254,7 +616,7 @@ def test_format_request_messages(system_prompt): ], }, { - "content": [{"text": "4", "type": "text"}], + "content": "4", "role": "tool", "tool_call_id": "c1", }, @@ -490,7 +852,12 @@ def test_format_request_with_tool_choice_tool(model, messages, tool_specs, syste ( { "chunk_type": "metadata", - "data": unittest.mock.Mock(prompt_tokens=100, completion_tokens=50, total_tokens=150), + "data": unittest.mock.Mock( + prompt_tokens=100, + completion_tokens=50, + total_tokens=150, + prompt_tokens_details=None, + ), }, { "metadata": { @@ -519,6 +886,45 @@ def test_format_chunk_unknown_type(model): model.format_chunk(event) +def test_format_chunk_metadata_with_cache_tokens(model): + """Test format_chunk for metadata with cache tokens present.""" + mock_usage = unittest.mock.Mock() + mock_usage.prompt_tokens = 100 + mock_usage.completion_tokens = 50 + mock_usage.total_tokens = 150 + + mock_tokens_details = unittest.mock.Mock() + mock_tokens_details.cached_tokens = 25 + mock_usage.prompt_tokens_details = mock_tokens_details + + event = {"chunk_type": "metadata", "data": mock_usage} + + result = model.format_chunk(event) + + assert result["metadata"]["usage"]["inputTokens"] == 100 + assert result["metadata"]["usage"]["outputTokens"] == 50 + assert result["metadata"]["usage"]["totalTokens"] == 150 + assert result["metadata"]["usage"]["cacheReadInputTokens"] == 25 + + +def test_format_chunk_metadata_with_zero_cached_tokens(model): + """Test format_chunk for metadata when cached_tokens is 0.""" + mock_usage = unittest.mock.Mock() + mock_usage.prompt_tokens = 100 + mock_usage.completion_tokens = 50 + mock_usage.total_tokens = 150 + + mock_tokens_details = unittest.mock.Mock() + mock_tokens_details.cached_tokens = 0 + mock_usage.prompt_tokens_details = mock_tokens_details + + event = {"chunk_type": "metadata", "data": mock_usage} + + result = model.format_chunk(event) + + assert "cacheReadInputTokens" not in result["metadata"]["usage"] + + @pytest.mark.asyncio async def test_stream(openai_client, model_id, model, agenerator, alist): mock_tool_call_1_part_1 = unittest.mock.Mock(index=0) @@ -848,6 +1254,92 @@ async def test_stream_context_overflow_exception(openai_client, model, messages) assert exc_info.value.__cause__ == mock_error +@pytest.mark.asyncio +@pytest.mark.parametrize( + "error_message", + [ + "Input is too long for requested model", + "input length and `max_tokens` exceed context limit", + "too many total text bytes", + ], +) +async def test_stream_alternative_context_overflow_messages(openai_client, model, messages, error_message): + """Test that alternative context overflow messages in APIError are properly converted.""" + # Create a mock OpenAI APIError with alternative context overflow message + mock_error = openai.APIError( + message=error_message, + request=unittest.mock.MagicMock(), + body={"error": {"message": error_message}}, + ) + + # Configure the mock client to raise the APIError + openai_client.chat.completions.create.side_effect = mock_error + + # Test that the stream method converts the error properly + with pytest.raises(ContextWindowOverflowException) as exc_info: + async for _ in model.stream(messages): + pass + + # Verify the exception message contains the original error + assert error_message in str(exc_info.value) + assert exc_info.value.__cause__ == mock_error + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "error_message", + [ + "Input is too long for requested model", + "input length and `max_tokens` exceed context limit", + "too many total text bytes", + ], +) +async def test_structured_output_alternative_context_overflow_messages( + openai_client, model, messages, test_output_model_cls, error_message +): + """Test that alternative context overflow messages in APIError are properly converted in structured output.""" + # Create a mock OpenAI APIError with alternative context overflow message + mock_error = openai.APIError( + message=error_message, + request=unittest.mock.MagicMock(), + body={"error": {"message": error_message}}, + ) + + # Configure the mock client to raise the APIError + openai_client.beta.chat.completions.parse.side_effect = mock_error + + # Test that the structured_output method converts the error properly + with pytest.raises(ContextWindowOverflowException) as exc_info: + async for _ in model.structured_output(test_output_model_cls, messages): + pass + + # Verify the exception message contains the original error + assert error_message in str(exc_info.value) + assert exc_info.value.__cause__ == mock_error + + +@pytest.mark.asyncio +async def test_stream_api_error_passthrough(openai_client, model, messages): + """Test that APIError without overflow messages passes through unchanged.""" + # Create a mock OpenAI APIError without overflow message + mock_error = openai.APIError( + message="Some other API error", + request=unittest.mock.MagicMock(), + body={"error": {"message": "Some other API error"}}, + ) + + # Configure the mock client to raise the APIError + openai_client.chat.completions.create.side_effect = mock_error + + # Test that APIError without overflow messages passes through + with pytest.raises(openai.APIError) as exc_info: + async for _ in model.stream(messages): + pass + + # Verify the original exception is raised, not ContextWindowOverflowException + assert exc_info.value == mock_error + + @pytest.mark.asyncio async def test_stream_other_bad_request_errors_passthrough(openai_client, model, messages): """Test that other BadRequestError exceptions are not converted to ContextWindowOverflowException.""" @@ -986,3 +1478,416 @@ def test_format_request_messages_drops_cache_points(): ] assert result == expected + + +@pytest.mark.asyncio +async def test_stream_with_injected_client(model_id, agenerator, alist): + """Test that stream works with an injected client and doesn't close it.""" + # Create a mock injected client + mock_injected_client = unittest.mock.AsyncMock() + mock_injected_client.close = unittest.mock.AsyncMock() + + mock_delta = unittest.mock.Mock(content="Hello", tool_calls=None, reasoning_content=None) + mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta)]) + mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="stop", delta=mock_delta)]) + mock_event_3 = unittest.mock.Mock() + + mock_injected_client.chat.completions.create = unittest.mock.AsyncMock( + return_value=agenerator([mock_event_1, mock_event_2, mock_event_3]) + ) + + # Create model with injected client + model = OpenAIModel(client=mock_injected_client, model_id=model_id, params={"max_tokens": 1}) + + messages = [{"role": "user", "content": [{"text": "test"}]}] + response = model.stream(messages) + tru_events = await alist(response) + + # Verify events were generated + assert len(tru_events) > 0 + + # Verify the injected client was used + mock_injected_client.chat.completions.create.assert_called_once() + + # Verify the injected client was NOT closed + mock_injected_client.close.assert_not_called() + + +@pytest.mark.asyncio +async def test_structured_output_with_injected_client(model_id, test_output_model_cls, alist): + """Test that structured_output works with an injected client and doesn't close it.""" + # Create a mock injected client + mock_injected_client = unittest.mock.AsyncMock() + mock_injected_client.close = unittest.mock.AsyncMock() + + mock_parsed_instance = test_output_model_cls(name="John", age=30) + mock_choice = unittest.mock.Mock() + mock_choice.message.parsed = mock_parsed_instance + mock_response = unittest.mock.Mock() + mock_response.choices = [mock_choice] + + mock_injected_client.beta.chat.completions.parse = unittest.mock.AsyncMock(return_value=mock_response) + + # Create model with injected client + model = OpenAIModel(client=mock_injected_client, model_id=model_id, params={"max_tokens": 1}) + + messages = [{"role": "user", "content": [{"text": "Generate a person"}]}] + stream = model.structured_output(test_output_model_cls, messages) + events = await alist(stream) + + # Verify output was generated + assert len(events) == 1 + assert events[0] == {"output": test_output_model_cls(name="John", age=30)} + + # Verify the injected client was used + mock_injected_client.beta.chat.completions.parse.assert_called_once() + + # Verify the injected client was NOT closed + mock_injected_client.close.assert_not_called() + + +def test_init_with_both_client_and_client_args_raises_error(): + """Test that providing both client and client_args raises ValueError.""" + mock_client = unittest.mock.AsyncMock() + + with pytest.raises(ValueError, match="Only one of 'client' or 'client_args' should be provided"): + OpenAIModel(client=mock_client, client_args={"api_key": "test"}, model_id="test-model") + + +def test_format_request_filters_s3_source_image(model, caplog): + """Test that images with Location sources are filtered out with warning.""" + caplog.set_level(logging.WARNING, logger="strands.models.openai") + + messages = [ + { + "role": "user", + "content": [ + {"text": "look at this image"}, + { + "image": { + "format": "png", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/image.png"}}, + }, + }, + ], + }, + ] + + request = model.format_request(messages) + + # Image with S3 source should be filtered, text should remain + formatted_content = request["messages"][0]["content"] + assert len(formatted_content) == 1 + assert formatted_content[0]["type"] == "text" + assert "Location sources are not supported by OpenAI" in caplog.text + + +def test_format_request_filters_location_source_document(model, caplog): + """Test that documents with Location sources are filtered out with warning.""" + caplog.set_level(logging.WARNING, logger="strands.models.openai") + + messages = [ + { + "role": "user", + "content": [ + {"text": "analyze this document"}, + { + "document": { + "format": "pdf", + "name": "report.pdf", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/report.pdf"}}, + }, + }, + { + "document": { + "format": "pdf", + "name": "report.pdf", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/report.pdf"}}, + }, + }, + ], + }, + ] + + request = model.format_request(messages) + + # Document with S3 source should be filtered, text should remain + formatted_content = request["messages"][0]["content"] + assert len(formatted_content) == 1 + assert formatted_content[0]["type"] == "text" + assert "Location sources are not supported by OpenAI" in caplog.text + + +def test_format_request_messages_with_tool_calls_no_content(): + """Test that assistant messages with only tool calls are included and have no content field.""" + messages = [ + {"role": "user", "content": [{"text": "Use the calculator"}]}, + { + "role": "assistant", + "content": [ + { + "toolUse": { + "input": {"expression": "2+2"}, + "name": "calculator", + "toolUseId": "c1", + }, + }, + ], + }, + ] + + tru_result = OpenAIModel.format_request_messages(messages) + + exp_result = [ + {"role": "user", "content": [{"text": "Use the calculator", "type": "text"}]}, + { + "role": "assistant", + "tool_calls": [ + { + "function": {"arguments": '{"expression": "2+2"}', "name": "calculator"}, + "id": "c1", + "type": "function", + } + ], + }, + ] + assert tru_result == exp_result + + +def test_format_request_messages_multiple_tool_calls_with_images(): + """Test that multiple tool calls with image results are formatted correctly. + + OpenAI requires all tool response messages to immediately follow the assistant + message with tool_calls, before any other messages. When tools return images, + the images are moved to user messages, but these must come after ALL tool messages. + """ + messages = [ + {"role": "user", "content": [{"text": "Run the tools"}]}, + { + "role": "assistant", + "content": [ + {"toolUse": {"input": {}, "name": "tool1", "toolUseId": "call_1"}}, + {"toolUse": {"input": {}, "name": "tool2", "toolUseId": "call_2"}}, + ], + }, + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "call_1", + "content": [{"image": {"format": "png", "source": {"bytes": b"img1"}}}], + "status": "success", + } + }, + { + "toolResult": { + "toolUseId": "call_2", + "content": [{"image": {"format": "png", "source": {"bytes": b"img2"}}}], + "status": "success", + } + }, + ], + }, + ] + + tru_result = OpenAIModel.format_request_messages(messages) + + image_placeholder = ( + "Tool successfully returned an image. The image is being provided in the following user message." + ) + exp_result = [ + {"role": "user", "content": [{"text": "Run the tools", "type": "text"}]}, + { + "role": "assistant", + "tool_calls": [ + {"function": {"arguments": "{}", "name": "tool1"}, "id": "call_1", "type": "function"}, + {"function": {"arguments": "{}", "name": "tool2"}, "id": "call_2", "type": "function"}, + ], + }, + { + "role": "tool", + "tool_call_id": "call_1", + "content": [{"type": "text", "text": image_placeholder}], + }, + { + "role": "tool", + "tool_call_id": "call_2", + "content": [{"type": "text", "text": image_placeholder}], + }, + { + "role": "user", + "content": [ + { + "image_url": {"detail": "auto", "format": "image/png", "url": "data:image/png;base64,aW1nMQ=="}, + "type": "image_url", + } + ], + }, + { + "role": "user", + "content": [ + { + "image_url": {"detail": "auto", "format": "image/png", "url": "data:image/png;base64,aW1nMg=="}, + "type": "image_url", + } + ], + }, + ] + assert tru_result == exp_result + + +# ============================================================================= +# Bedrock Mantle (bedrock_mantle_config) integration with OpenAIModel +# ============================================================================= + + +class TestOpenAIModelBedrockMantleConfig: + @pytest.fixture + def mock_provide_token(self): + with unittest.mock.patch("aws_bedrock_token_generator.provide_token") as mock: + mock.return_value = "bedrock-api-key-deadbeef&Version=1" + yield mock + + def test_bedrock_mantle_config_sets_base_url_and_api_key(self, openai_client, mock_provide_token): + """bedrock_mantle_config produces the Mantle base_url and a minted bearer token as api_key.""" + _ = openai_client + model = OpenAIModel(model_id="openai.gpt-oss-120b", bedrock_mantle_config={"region": "us-east-1"}) + + # Token is minted lazily per request, so inspect the resolved kwargs. + resolved = model._resolve_client_args() + assert resolved["base_url"] == "https://bedrock-mantle.us-east-1.api.aws/v1" + assert resolved["api_key"] == "bedrock-api-key-deadbeef&Version=1" + # Optional kwargs aren't forwarded so provide_token's own defaults apply. + mock_provide_token.assert_called_once_with(region="us-east-1") + + def test_bedrock_mantle_config_forwards_credentials_provider_and_expiry(self, openai_client, mock_provide_token): + """Optional credentials_provider and expiry are forwarded to provide_token.""" + _ = openai_client + from datetime import timedelta + + provider = unittest.mock.Mock() + model = OpenAIModel( + model_id="openai.gpt-oss-120b", + bedrock_mantle_config={ + "region": "us-west-2", + "credentials_provider": provider, + "expiry": timedelta(minutes=15), + }, + ) + model._resolve_client_args() + mock_provide_token.assert_called_once_with( + region="us-west-2", + aws_credentials_provider=provider, + expiry=timedelta(minutes=15), + ) + + def test_bedrock_mantle_config_mints_token_per_request(self, openai_client, mock_provide_token): + """Each call to _resolve_client_args mints a fresh token (long-lived processes).""" + _ = openai_client + model = OpenAIModel(model_id="openai.gpt-oss-120b", bedrock_mantle_config={"region": "us-east-1"}) + model._resolve_client_args() + model._resolve_client_args() + model._resolve_client_args() + assert mock_provide_token.call_count == 3 + + def test_bedrock_mantle_config_conflicts_with_custom_client(self, openai_client): + """Cannot pass both bedrock_mantle_config and a pre-built client.""" + _ = openai_client + custom_client = unittest.mock.Mock() + with pytest.raises(ValueError, match="bedrock_mantle_config"): + OpenAIModel( + model_id="openai.gpt-oss-120b", + client=custom_client, + bedrock_mantle_config={"region": "us-east-1"}, + ) + + def test_bedrock_mantle_config_merges_with_client_args(self, openai_client, mock_provide_token): + """bedrock_mantle_config composes with client_args; transport options are preserved.""" + _ = openai_client + sentinel_http_client = unittest.mock.Mock() + model = OpenAIModel( + model_id="openai.gpt-oss-120b", + client_args={ + "timeout": 42, + "http_client": sentinel_http_client, + "default_headers": {"X-Trace-Id": "abc"}, + }, + bedrock_mantle_config={"region": "us-east-1"}, + ) + resolved = model._resolve_client_args() + assert resolved["base_url"] == "https://bedrock-mantle.us-east-1.api.aws/v1" + assert resolved["api_key"] == "bedrock-api-key-deadbeef&Version=1" + assert resolved["timeout"] == 42 + assert resolved["http_client"] is sentinel_http_client + assert resolved["default_headers"] == {"X-Trace-Id": "abc"} + + def test_bedrock_mantle_config_rejects_base_url_in_client_args(self, openai_client): + """client_args must not contain base_url or api_key when bedrock_mantle_config is set.""" + _ = openai_client + with pytest.raises(ValueError, match="client_args must not contain"): + OpenAIModel( + model_id="openai.gpt-oss-120b", + client_args={"base_url": "https://custom.example.com"}, + bedrock_mantle_config={"region": "us-east-1"}, + ) + + def test_bedrock_mantle_config_requires_region(self, openai_client): + """bedrock_mantle_config raises when no region can be resolved from config, session, or env.""" + _ = openai_client + with ( + unittest.mock.patch("boto3.Session") as mock_session_cls, + unittest.mock.patch.dict(os.environ, {}, clear=True), + ): + mock_session_cls.return_value.region_name = None + model = OpenAIModel(model_id="openai.gpt-oss-120b", bedrock_mantle_config={}) + with pytest.raises(ValueError, match="Could not resolve an AWS region"): + model._resolve_client_args() + + def test_bedrock_mantle_config_region_resolved_from_boto3_default(self, openai_client, mock_provide_token): + """When region is omitted, the default boto3 session chain resolves it.""" + _ = openai_client + with unittest.mock.patch("boto3.Session") as mock_session_cls: + mock_session_cls.return_value.region_name = "eu-west-1" + model = OpenAIModel(model_id="openai.gpt-oss-120b", bedrock_mantle_config={}) + resolved = model._resolve_client_args() + + assert resolved["base_url"] == "https://bedrock-mantle.eu-west-1.api.aws/v1" + mock_provide_token.assert_called_once_with(region="eu-west-1") + + def test_bedrock_mantle_config_region_resolved_from_boto_session(self, openai_client, mock_provide_token): + """An explicit ``boto_session`` supplies the region when ``region`` is omitted.""" + _ = openai_client + session = unittest.mock.Mock() + session.region_name = "ap-southeast-2" + model = OpenAIModel( + model_id="openai.gpt-oss-120b", + bedrock_mantle_config={"boto_session": session}, + ) + + resolved = model._resolve_client_args() + + assert resolved["base_url"] == "https://bedrock-mantle.ap-southeast-2.api.aws/v1" + mock_provide_token.assert_called_once_with(region="ap-southeast-2") + + def test_bedrock_mantle_config_explicit_region_wins_over_boto_session(self, openai_client, mock_provide_token): + """``region`` takes precedence over a session's region.""" + _ = openai_client + session = unittest.mock.Mock() + session.region_name = "ap-southeast-2" + model = OpenAIModel( + model_id="openai.gpt-oss-120b", + bedrock_mantle_config={"region": "us-east-1", "boto_session": session}, + ) + + model._resolve_client_args() + + mock_provide_token.assert_called_once_with(region="us-east-1") + + def test_bedrock_mantle_config_wraps_token_failures_with_context(self, openai_client, mock_provide_token): + """provide_token failures are wrapped in a RuntimeError with actionable context.""" + _ = openai_client + mock_provide_token.side_effect = RuntimeError("no credentials in chain") + model = OpenAIModel(model_id="openai.gpt-oss-120b", bedrock_mantle_config={"region": "us-east-1"}) + with pytest.raises(RuntimeError, match="Bedrock Mantle bearer token.*us-east-1"): + model._resolve_client_args() diff --git a/tests/strands/models/test_openai_responses.py b/tests/strands/models/test_openai_responses.py new file mode 100644 index 000000000..697508339 --- /dev/null +++ b/tests/strands/models/test_openai_responses.py @@ -0,0 +1,1463 @@ +import os +import unittest.mock + +import openai +import pydantic +import pytest + +import strands +from strands.models.openai_responses import _MAX_MEDIA_SIZE_BYTES, OpenAIResponsesModel +from strands.types.exceptions import ContextWindowOverflowException, ModelThrottledException + + +@pytest.fixture +def openai_client(): + with unittest.mock.patch.object(strands.models.openai_responses.openai, "AsyncOpenAI") as mock_client_cls: + mock_client = unittest.mock.AsyncMock() + mock_client_cls.return_value.__aenter__.return_value = mock_client + yield mock_client + + +@pytest.fixture +def model_id(): + return "gpt-4o" + + +@pytest.fixture +def model(openai_client, model_id): + _ = openai_client + return OpenAIResponsesModel(model_id=model_id, params={"max_output_tokens": 100}) + + +@pytest.fixture +def messages(): + return [{"role": "user", "content": [{"text": "test"}]}] + + +@pytest.fixture +def tool_specs(): + return [ + { + "name": "test_tool", + "description": "A test tool", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "input": {"type": "string"}, + }, + "required": ["input"], + }, + }, + }, + ] + + +@pytest.fixture +def system_prompt(): + return "s1" + + +@pytest.fixture +def test_output_model_cls(): + class TestOutputModel(pydantic.BaseModel): + name: str + age: int + + return TestOutputModel + + +def test__init__(model_id): + model = OpenAIResponsesModel(model_id=model_id, params={"max_output_tokens": 100}) + + tru_config = model.get_config() + exp_config = {"model_id": "gpt-4o", "params": {"max_output_tokens": 100}, "context_window_limit": 128_000} + + assert tru_config == exp_config + + +def test__init__auto_populates_context_window_limit(): + model = OpenAIResponsesModel(model_id="gpt-4o") + + assert model.get_config().get("context_window_limit") == 128_000 + + +def test__init__explicit_context_window_limit_not_overridden(): + model = OpenAIResponsesModel(model_id="gpt-4o", context_window_limit=50_000) + + assert model.get_config().get("context_window_limit") == 50_000 + + +def test__init__unknown_model_no_context_window_limit(): + model = OpenAIResponsesModel(model_id="unknown-model") + + assert model.get_config().get("context_window_limit") is None + + +def test_update_config(model, model_id): + model.update_config(model_id=model_id) + + tru_model_id = model.get_config().get("model_id") + exp_model_id = model_id + + assert tru_model_id == exp_model_id + + +@pytest.mark.parametrize( + "content, exp_result", + [ + # Document + ( + { + "document": { + "format": "pdf", + "name": "test doc", + "source": {"bytes": b"document"}, + }, + }, + { + "type": "input_file", + "file_url": "data:application/pdf;base64,ZG9jdW1lbnQ=", + }, + ), + # Image + ( + { + "image": { + "format": "jpg", + "source": {"bytes": b"image"}, + }, + }, + { + "type": "input_image", + "image_url": "data:image/jpeg;base64,aW1hZ2U=", + }, + ), + # Text + ( + {"text": "hello"}, + {"type": "input_text", "text": "hello"}, + ), + ], +) +def test_format_request_message_content(content, exp_result): + tru_result = OpenAIResponsesModel._format_request_message_content(content) + assert tru_result == exp_result + + +def test_format_request_message_content_unsupported_type(): + content = {"unsupported": {}} + + with pytest.raises(TypeError, match="content_type= | unsupported type"): + OpenAIResponsesModel._format_request_message_content(content) + + +def test_format_request_message_tool_call(): + tool_use = { + "input": {"expression": "2+2"}, + "name": "calculator", + "toolUseId": "c1", + } + + tru_result = OpenAIResponsesModel._format_request_message_tool_call(tool_use) + exp_result = { + "type": "function_call", + "call_id": "c1", + "name": "calculator", + "arguments": '{"expression": "2+2"}', + } + assert tru_result == exp_result + + +def test_format_request_tool_message(): + tool_result = { + "content": [{"text": "4"}, {"json": ["4"]}], + "status": "success", + "toolUseId": "c1", + } + + tru_result = OpenAIResponsesModel._format_request_tool_message(tool_result) + exp_result = { + "type": "function_call_output", + "call_id": "c1", + "output": '4\n["4"]', + } + assert tru_result == exp_result + + +def test_format_request_tool_message_with_image(): + """Test that tool results with images return an array output.""" + tool_result = { + "content": [ + {"text": "Here is the image:"}, + {"image": {"format": "png", "source": {"bytes": b"fake_image_data"}}}, + ], + "status": "success", + "toolUseId": "c2", + } + + tru_result = OpenAIResponsesModel._format_request_tool_message(tool_result) + + assert tru_result["type"] == "function_call_output" + assert tru_result["call_id"] == "c2" + # When images are present, output should be an array + assert isinstance(tru_result["output"], list) + assert len(tru_result["output"]) == 2 + assert tru_result["output"][0]["type"] == "input_text" + assert tru_result["output"][0]["text"] == "Here is the image:" + assert tru_result["output"][1]["type"] == "input_image" + assert "image_url" in tru_result["output"][1] + + +def test_format_request_tool_message_with_document(): + """Test that tool results with documents return an array output.""" + tool_result = { + "content": [ + {"document": {"format": "pdf", "name": "test.pdf", "source": {"bytes": b"fake_pdf_data"}}}, + ], + "status": "success", + "toolUseId": "c3", + } + + tru_result = OpenAIResponsesModel._format_request_tool_message(tool_result) + + assert tru_result["type"] == "function_call_output" + assert tru_result["call_id"] == "c3" + # When documents are present, output should be an array + assert isinstance(tru_result["output"], list) + assert len(tru_result["output"]) == 1 + assert tru_result["output"][0]["type"] == "input_file" + assert "file_url" in tru_result["output"][0] + + +def test_format_request_messages(system_prompt): + messages = [ + { + "content": [], + "role": "user", + }, + { + "content": [{"text": "hello"}], + "role": "user", + }, + { + "content": [ + {"text": "call tool"}, + { + "toolUse": { + "input": {"expression": "2+2"}, + "name": "calculator", + "toolUseId": "c1", + }, + }, + ], + "role": "assistant", + }, + { + "content": [{"toolResult": {"toolUseId": "c1", "status": "success", "content": [{"text": "4"}]}}], + "role": "user", + }, + ] + + tru_result = OpenAIResponsesModel._format_request_messages(messages) + exp_result = [ + { + "role": "user", + "content": [{"type": "input_text", "text": "hello"}], + }, + { + "role": "assistant", + "content": [{"type": "output_text", "text": "call tool"}], + }, + { + "type": "function_call", + "call_id": "c1", + "name": "calculator", + "arguments": '{"expression": "2+2"}', + }, + { + "type": "function_call_output", + "call_id": "c1", + "output": "4", + }, + ] + assert tru_result == exp_result + + +def test_format_request_messages_assistant_text_uses_output_text(): + """Assistant text content must use output_text, not input_text. + + Regression test for multi-turn conversations failing because the OpenAI + Responses API rejects input_text in assistant messages. + See: https://github.com/strands-agents/sdk-python/issues/1850 + """ + messages = [ + { + "content": [{"text": "Say hello"}], + "role": "user", + }, + { + "content": [{"text": "Hello!"}], + "role": "assistant", + }, + { + "content": [{"text": "Say goodbye"}], + "role": "user", + }, + ] + + result = OpenAIResponsesModel._format_request_messages(messages) + + assert result[0] == { + "role": "user", + "content": [{"type": "input_text", "text": "Say hello"}], + } + assert result[1] == { + "role": "assistant", + "content": [{"type": "output_text", "text": "Hello!"}], + } + assert result[2] == { + "role": "user", + "content": [{"type": "input_text", "text": "Say goodbye"}], + } + + +def test_format_request_message_content_role_assistant(): + """_format_request_message_content uses output_text for assistant role.""" + content = {"text": "response text"} + result = OpenAIResponsesModel._format_request_message_content(content, role="assistant") + assert result == {"type": "output_text", "text": "response text"} + + +def test_format_request_message_content_role_user(): + """_format_request_message_content uses input_text for user role (default).""" + content = {"text": "question"} + result = OpenAIResponsesModel._format_request_message_content(content, role="user") + assert result == {"type": "input_text", "text": "question"} + + +def test_format_request(model, messages, tool_specs, system_prompt): + tru_request = model._format_request(messages, tool_specs, system_prompt) + exp_request = { + "model": "gpt-4o", + "input": [ + { + "role": "user", + "content": [{"type": "input_text", "text": "test"}], + } + ], + "stream": True, + "store": False, + "instructions": system_prompt, + "tools": [ + { + "type": "function", + "name": "test_tool", + "description": "A test tool", + "parameters": { + "type": "object", + "properties": { + "input": {"type": "string"}, + }, + "required": ["input"], + }, + }, + ], + "max_output_tokens": 100, + } + assert tru_request == exp_request + + +@pytest.mark.parametrize( + ("event", "exp_chunk"), + [ + # Message start + ( + {"chunk_type": "message_start"}, + {"messageStart": {"role": "assistant"}}, + ), + # Content Start - Tool Use + ( + { + "chunk_type": "content_start", + "data_type": "tool", + "data": unittest.mock.Mock(**{"function.name": "calculator", "id": "c1"}), + }, + {"contentBlockStart": {"start": {"toolUse": {"name": "calculator", "toolUseId": "c1"}}}}, + ), + # Content Start - Text + ( + {"chunk_type": "content_start", "data_type": "text"}, + {"contentBlockStart": {"start": {}}}, + ), + # Content Delta - Tool Use + ( + { + "chunk_type": "content_delta", + "data_type": "tool", + "data": unittest.mock.Mock(function=unittest.mock.Mock(arguments='{"expression": "2+2"}')), + }, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"expression": "2+2"}'}}}}, + ), + # Content Delta - Tool Use - None + ( + { + "chunk_type": "content_delta", + "data_type": "tool", + "data": unittest.mock.Mock(function=unittest.mock.Mock(arguments=None)), + }, + {"contentBlockDelta": {"delta": {"toolUse": {"input": ""}}}}, + ), + # Content Delta - Reasoning Text + ( + {"chunk_type": "content_delta", "data_type": "reasoning_content", "data": "I'm thinking"}, + {"contentBlockDelta": {"delta": {"reasoningContent": {"text": "I'm thinking"}}}}, + ), + # Content Delta - Citation + ( + { + "chunk_type": "content_delta", + "data_type": "citation", + "data": {"type": "url_citation", "title": "Example", "url": "https://example.com"}, + }, + { + "contentBlockDelta": { + "delta": {"citation": {"title": "Example", "location": {"web": {"url": "https://example.com"}}}} + } + }, + ), + # Content Delta - Text + ( + {"chunk_type": "content_delta", "data_type": "text", "data": "hello"}, + {"contentBlockDelta": {"delta": {"text": "hello"}}}, + ), + # Content Stop + ( + {"chunk_type": "content_stop"}, + {"contentBlockStop": {}}, + ), + # Message Stop - Tool Use + ( + {"chunk_type": "message_stop", "data": "tool_calls"}, + {"messageStop": {"stopReason": "tool_use"}}, + ), + # Message Stop - Max Tokens + ( + {"chunk_type": "message_stop", "data": "length"}, + {"messageStop": {"stopReason": "max_tokens"}}, + ), + # Message Stop - End Turn + ( + {"chunk_type": "message_stop", "data": "stop"}, + {"messageStop": {"stopReason": "end_turn"}}, + ), + # Metadata + ( + { + "chunk_type": "metadata", + "data": unittest.mock.Mock(input_tokens=100, output_tokens=50, total_tokens=150), + }, + { + "metadata": { + "usage": { + "inputTokens": 100, + "outputTokens": 50, + "totalTokens": 150, + }, + "metrics": { + "latencyMs": 0, + }, + }, + }, + ), + ], +) +def test_format_chunk(event, exp_chunk, model): + tru_chunk = model._format_chunk(event) + assert tru_chunk == exp_chunk + + +def test_format_chunk_unknown_type(model): + event = {"chunk_type": "unknown"} + + with pytest.raises(RuntimeError, match="chunk_type= | unknown type"): + model._format_chunk(event) + + +@pytest.mark.asyncio +async def test_stream(openai_client, model_id, model, agenerator, alist): + # Mock response events + mock_text_event = unittest.mock.Mock(type="response.output_text.delta", delta="Hello") + mock_complete_event = unittest.mock.Mock( + type="response.completed", + response=unittest.mock.Mock(usage=unittest.mock.Mock(input_tokens=10, output_tokens=5, total_tokens=15)), + ) + + openai_client.responses.create = unittest.mock.AsyncMock( + return_value=agenerator([mock_text_event, mock_complete_event]) + ) + + messages = [{"role": "user", "content": [{"text": "test"}]}] + response = model.stream(messages) + tru_events = await alist(response) + + exp_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockStart": {"start": {}}}, + {"contentBlockDelta": {"delta": {"text": "Hello"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn"}}, + { + "metadata": { + "usage": {"inputTokens": 10, "outputTokens": 5, "totalTokens": 15}, + "metrics": {"latencyMs": 0}, + } + }, + ] + + assert len(tru_events) == len(exp_events) + expected_request = { + "model": model_id, + "input": [{"role": "user", "content": [{"type": "input_text", "text": "test"}]}], + "stream": True, + "store": False, + "max_output_tokens": 100, + } + openai_client.responses.create.assert_called_once_with(**expected_request) + + +@pytest.mark.asyncio +async def test_stream_with_tool_calls(openai_client, model, agenerator, alist): + # Mock tool call events + mock_tool_event = unittest.mock.Mock( + type="response.output_item.added", + item=unittest.mock.Mock(type="function_call", call_id="call_123", name="calculator", id="item_456"), + ) + mock_args_event = unittest.mock.Mock( + type="response.function_call_arguments.delta", delta='{"expression": "2+2"}', item_id="item_456" + ) + mock_complete_event = unittest.mock.Mock( + type="response.completed", + response=unittest.mock.Mock(usage=unittest.mock.Mock(input_tokens=10, output_tokens=5, total_tokens=15)), + ) + + openai_client.responses.create = unittest.mock.AsyncMock( + return_value=agenerator([mock_tool_event, mock_args_event, mock_complete_event]) + ) + + messages = [{"role": "user", "content": [{"text": "calculate 2+2"}]}] + response = model.stream(messages) + tru_events = await alist(response) + + # Should include tool call events + assert any("toolUse" in str(event) for event in tru_events) + assert {"messageStop": {"stopReason": "tool_use"}} in tru_events + + +@pytest.mark.asyncio +async def test_stream_with_tool_calls_done_event(openai_client, model, agenerator, alist): + """Test that response.function_call_arguments.done overwrites accumulated deltas.""" + mock_tool_event = unittest.mock.Mock( + type="response.output_item.added", + item=unittest.mock.Mock(type="function_call", call_id="call_1", name="calculator", id="item_1"), + ) + # Simulate partial delta that would produce incomplete JSON + mock_args_delta = unittest.mock.Mock( + type="response.function_call_arguments.delta", delta='{"expr', item_id="item_1" + ) + # The done event provides the complete, correct arguments + mock_args_done = unittest.mock.Mock( + type="response.function_call_arguments.done", arguments='{"expression": "2+2"}', item_id="item_1" + ) + mock_complete_event = unittest.mock.Mock( + type="response.completed", + response=unittest.mock.Mock(usage=unittest.mock.Mock(input_tokens=10, output_tokens=5, total_tokens=15)), + ) + + openai_client.responses.create = unittest.mock.AsyncMock( + return_value=agenerator([mock_tool_event, mock_args_delta, mock_args_done, mock_complete_event]) + ) + + messages = [{"role": "user", "content": [{"text": "calculate 2+2"}]}] + tru_events = await alist(model.stream(messages)) + + # Find the tool use delta event and verify it has the final (done) arguments, not the partial delta + tool_deltas = [e for e in tru_events if "contentBlockDelta" in e and "toolUse" in e["contentBlockDelta"]["delta"]] + assert len(tool_deltas) == 1 + assert tool_deltas[0]["contentBlockDelta"]["delta"]["toolUse"]["input"] == '{"expression": "2+2"}' + + +@pytest.mark.asyncio +async def test_stream_response_incomplete(openai_client, model, agenerator, alist): + """Test that response.incomplete sets stop_reason to length when max_output_tokens is reached.""" + mock_text_event = unittest.mock.Mock(type="response.output_text.delta", delta="Truncated resp") + mock_incomplete_event = unittest.mock.Mock( + type="response.incomplete", + response=unittest.mock.Mock( + usage=unittest.mock.Mock(input_tokens=10, output_tokens=100, total_tokens=110), + incomplete_details=unittest.mock.Mock(reason="max_output_tokens"), + ), + ) + + openai_client.responses.create = unittest.mock.AsyncMock( + return_value=agenerator([mock_text_event, mock_incomplete_event]) + ) + + messages = [{"role": "user", "content": [{"text": "write a long essay"}]}] + tru_events = await alist(model.stream(messages)) + + assert {"messageStop": {"stopReason": "max_tokens"}} in tru_events + # Verify usage was still captured + metadata_events = [e for e in tru_events if "metadata" in e] + assert len(metadata_events) == 1 + assert metadata_events[0]["metadata"]["usage"]["inputTokens"] == 10 + assert metadata_events[0]["metadata"]["usage"]["outputTokens"] == 100 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "event_type", + [ + "response.reasoning_text.delta", + "response.reasoning_summary_text.delta", + ], +) +async def test_stream_reasoning_content(openai_client, model, agenerator, alist, event_type): + """Test that reasoning content is streamed correctly for both full and summary reasoning events.""" + mock_reasoning_event = unittest.mock.Mock(type=event_type, delta="Let me think...") + mock_text_event = unittest.mock.Mock(type="response.output_text.delta", delta="The answer is 42") + mock_complete_event = unittest.mock.Mock( + type="response.completed", + response=unittest.mock.Mock(usage=unittest.mock.Mock(input_tokens=10, output_tokens=20, total_tokens=30)), + ) + + openai_client.responses.create = unittest.mock.AsyncMock( + return_value=agenerator([mock_reasoning_event, mock_text_event, mock_complete_event]) + ) + + messages = [{"role": "user", "content": [{"text": "think step by step"}]}] + tru_events = await alist(model.stream(messages)) + + # Verify reasoning content block was emitted + reasoning_deltas = [ + e for e in tru_events if "contentBlockDelta" in e and "reasoningContent" in e["contentBlockDelta"]["delta"] + ] + assert len(reasoning_deltas) == 1 + assert reasoning_deltas[0]["contentBlockDelta"]["delta"]["reasoningContent"]["text"] == "Let me think..." + + # Verify text content block was also emitted + text_deltas = [e for e in tru_events if "contentBlockDelta" in e and "text" in e["contentBlockDelta"]["delta"]] + assert len(text_deltas) == 1 + assert text_deltas[0]["contentBlockDelta"]["delta"]["text"] == "The answer is 42" + + # Verify content blocks were properly opened and closed (reasoning start/stop, then text start/stop) + content_starts = [e for e in tru_events if "contentBlockStart" in e] + content_stops = [e for e in tru_events if "contentBlockStop" in e] + assert len(content_starts) == 2 # one for reasoning, one for text + assert len(content_stops) == 2 + + +@pytest.mark.asyncio +async def test_stream_citation_annotations(openai_client, model, agenerator, alist): + """Test that web search citation annotations are streamed as CitationsDelta events.""" + mock_text_event1 = unittest.mock.Mock(type="response.output_text.delta", delta="The answer is here. ") + mock_text_event2 = unittest.mock.Mock(type="response.output_text.delta", delta="(example.com)") + mock_annotation_event = unittest.mock.Mock( + type="response.output_text.annotation.added", + annotation={ + "type": "url_citation", + "title": "Example Source", + "url": "https://example.com/article", + }, + ) + mock_complete_event = unittest.mock.Mock( + type="response.completed", + response=unittest.mock.Mock(usage=unittest.mock.Mock(input_tokens=10, output_tokens=5, total_tokens=15)), + ) + + openai_client.responses.create = unittest.mock.AsyncMock( + return_value=agenerator([mock_text_event1, mock_text_event2, mock_annotation_event, mock_complete_event]) + ) + + messages = [{"role": "user", "content": [{"text": "search something"}]}] + tru_events = await alist(model.stream(messages)) + + citation_deltas = [ + e for e in tru_events if "contentBlockDelta" in e and "citation" in e["contentBlockDelta"]["delta"] + ] + assert len(citation_deltas) == 1 + assert citation_deltas[0] == { + "contentBlockDelta": { + "delta": { + "citation": { + "title": "Example Source", + "location": {"web": {"url": "https://example.com/article"}}, + } + } + } + } + + +@pytest.mark.asyncio +async def test_stream_unsupported_annotation_type(openai_client, model, agenerator, alist, caplog): + """Test that unsupported annotation types log a warning and are not emitted.""" + mock_text_event = unittest.mock.Mock(type="response.output_text.delta", delta="Some text") + mock_annotation_event = unittest.mock.Mock( + type="response.output_text.annotation.added", + annotation={"type": "file_citation", "file_id": "file-123", "filename": "doc.pdf"}, + ) + mock_complete_event = unittest.mock.Mock( + type="response.completed", + response=unittest.mock.Mock(usage=unittest.mock.Mock(input_tokens=10, output_tokens=5, total_tokens=15)), + ) + + openai_client.responses.create = unittest.mock.AsyncMock( + return_value=agenerator([mock_text_event, mock_annotation_event, mock_complete_event]) + ) + + messages = [{"role": "user", "content": [{"text": "search files"}]}] + tru_events = await alist(model.stream(messages)) + + citation_deltas = [ + e for e in tru_events if "contentBlockDelta" in e and "citation" in e["contentBlockDelta"]["delta"] + ] + assert len(citation_deltas) == 0 + assert "annotation_type= | unsupported annotation type" in caplog.text + + +@pytest.mark.asyncio +async def test_structured_output(openai_client, model, test_output_model_cls, alist): + messages = [{"role": "user", "content": [{"text": "Generate a person"}]}] + + mock_parsed_instance = test_output_model_cls(name="John", age=30) + mock_response = unittest.mock.Mock(output_parsed=mock_parsed_instance) + + openai_client.responses.parse = unittest.mock.AsyncMock(return_value=mock_response) + + stream = model.structured_output(test_output_model_cls, messages) + events = await alist(stream) + + tru_result = events[-1] + exp_result = {"output": test_output_model_cls(name="John", age=30)} + assert tru_result == exp_result + + +@pytest.mark.asyncio +async def test_stream_context_overflow_exception(openai_client, model, messages): + """Test that OpenAI context overflow errors are properly converted to ContextWindowOverflowException.""" + mock_error = openai.BadRequestError( + message="This model's maximum context length is 4096 tokens.", + response=unittest.mock.MagicMock(), + body={"error": {"code": "context_length_exceeded"}}, + ) + mock_error.code = "context_length_exceeded" + + openai_client.responses.create.side_effect = mock_error + + with pytest.raises(ContextWindowOverflowException) as exc_info: + async for _ in model.stream(messages): + pass + + assert "maximum context length" in str(exc_info.value) + assert exc_info.value.__cause__ == mock_error + + +@pytest.mark.asyncio +async def test_stream_context_overflow_exception_api_error_type(openai_client, model, messages): + """Test that OpenAI context overflow errors are properly converted to ContextWindowOverflowException.""" + mock_error = openai.APIError( + message="This model's maximum context length is 4096 tokens.", + request=unittest.mock.MagicMock(), + body={"error": {"code": "context_length_exceeded"}}, + ) + mock_error.code = "context_length_exceeded" + + openai_client.responses.create.side_effect = mock_error + + with pytest.raises(ContextWindowOverflowException) as exc_info: + async for _ in model.stream(messages): + pass + + assert "maximum context length" in str(exc_info.value) + assert exc_info.value.__cause__ == mock_error + + +@pytest.mark.asyncio +async def test_stream_rate_limit_as_throttle(openai_client, model, messages): + """Test that rate limit errors are converted to ModelThrottledException.""" + mock_error = openai.RateLimitError( + message="Rate limit exceeded", + response=unittest.mock.MagicMock(), + body={"error": {"code": "rate_limit_exceeded"}}, + ) + mock_error.code = "rate_limit_exceeded" + + openai_client.responses.create.side_effect = mock_error + + with pytest.raises(ModelThrottledException) as exc_info: + async for _ in model.stream(messages): + pass + + assert "Rate limit exceeded" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_stream_bad_request_non_context_overflow(openai_client, model, messages): + """Test that non-context-overflow BadRequestErrors are re-raised.""" + mock_error = openai.BadRequestError( + message="Invalid request format", + response=unittest.mock.MagicMock(), + body={"error": {"code": "invalid_request"}}, + ) + mock_error.code = "invalid_request" + + openai_client.responses.create.side_effect = mock_error + + with pytest.raises(openai.BadRequestError) as exc_info: + async for _ in model.stream(messages): + pass + + assert exc_info.value == mock_error + + +@pytest.mark.asyncio +async def test_stream_error_during_iteration(openai_client, model, messages, agenerator): + """Test that errors during streaming iteration are properly handled.""" + mock_text_event = unittest.mock.Mock(type="response.output_text.delta", delta="Hello") + + async def error_generator(): + yield mock_text_event + raise openai.RateLimitError( + message="Rate limit during stream", + response=unittest.mock.MagicMock(), + body={"error": {"code": "rate_limit_exceeded"}}, + ) + + openai_client.responses.create = unittest.mock.AsyncMock(return_value=error_generator()) + + with pytest.raises(ModelThrottledException) as exc_info: + async for _ in model.stream(messages): + pass + + assert "Rate limit during stream" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_stream_context_overflow_during_iteration(openai_client, model, messages): + """Test that context overflow during streaming iteration is properly handled.""" + mock_text_event = unittest.mock.Mock(type="response.output_text.delta", delta="Hello") + + async def error_generator(): + yield mock_text_event + error = openai.BadRequestError( + message="Context length exceeded during stream", + response=unittest.mock.MagicMock(), + body={"error": {"code": "context_length_exceeded"}}, + ) + error.code = "context_length_exceeded" + raise error + + openai_client.responses.create = unittest.mock.AsyncMock(return_value=error_generator()) + + with pytest.raises(ContextWindowOverflowException) as exc_info: + async for _ in model.stream(messages): + pass + + assert "Context length exceeded" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_structured_output_context_overflow_exception(openai_client, model, messages, test_output_model_cls): + """Test that structured output handles context overflow properly.""" + mock_error = openai.BadRequestError( + message="This model's maximum context length is 4096 tokens.", + response=unittest.mock.MagicMock(), + body={"error": {"code": "context_length_exceeded"}}, + ) + mock_error.code = "context_length_exceeded" + + openai_client.responses.parse.side_effect = mock_error + + with pytest.raises(ContextWindowOverflowException) as exc_info: + async for _ in model.structured_output(test_output_model_cls, messages): + pass + + assert "maximum context length" in str(exc_info.value) + assert exc_info.value.__cause__ == mock_error + + +@pytest.mark.asyncio +async def test_structured_output_rate_limit_as_throttle(openai_client, model, messages, test_output_model_cls): + """Test that structured output handles rate limit errors properly.""" + mock_error = openai.RateLimitError( + message="Rate limit exceeded", + response=unittest.mock.MagicMock(), + body={"error": {"code": "rate_limit_exceeded"}}, + ) + mock_error.code = "rate_limit_exceeded" + + openai_client.responses.parse.side_effect = mock_error + + with pytest.raises(ModelThrottledException) as exc_info: + async for _ in model.structured_output(test_output_model_cls, messages): + pass + + assert "Rate limit exceeded" in str(exc_info.value) + assert exc_info.value.__cause__ == mock_error + + +@pytest.mark.asyncio +async def test_structured_output_bad_request_non_context_overflow( + openai_client, model, messages, test_output_model_cls +): + """Test that structured output re-raises non-context-overflow BadRequestErrors.""" + mock_error = openai.BadRequestError( + message="Invalid request format", + response=unittest.mock.MagicMock(), + body={"error": {"code": "invalid_request"}}, + ) + mock_error.code = "invalid_request" + + openai_client.responses.parse.side_effect = mock_error + + with pytest.raises(openai.BadRequestError) as exc_info: + async for _ in model.structured_output(test_output_model_cls, messages): + pass + + assert exc_info.value == mock_error + + +@pytest.mark.asyncio +async def test_structured_output_no_parsed_output(openai_client, model, messages, test_output_model_cls, alist): + """Test that structured output raises ValueError when output_parsed is None.""" + mock_response = unittest.mock.Mock(output_parsed=None) + openai_client.responses.parse = unittest.mock.AsyncMock(return_value=mock_response) + + with pytest.raises(ValueError, match="No valid parsed output"): + await alist(model.structured_output(test_output_model_cls, messages)) + + +@pytest.mark.asyncio +async def test_stream_with_empty_tool_result_content(model): + """Test formatting tool result with empty content list.""" + tool_result = { + "content": [], + "status": "success", + "toolUseId": "c1", + } + + result = OpenAIResponsesModel._format_request_tool_message(tool_result) + assert result["output"] == "" + + +def test_config_validation_warns_on_unknown_keys(openai_client, captured_warnings): + """Test that unknown config keys emit a warning.""" + OpenAIResponsesModel({"api_key": "test"}, model_id="test-model", invalid_param="test") + + assert len(captured_warnings) == 1 + assert "Invalid configuration parameters" in str(captured_warnings[0].message) + assert "invalid_param" in str(captured_warnings[0].message) + + +def test_update_config_validation_warns_on_unknown_keys(model, captured_warnings): + """Test that update_config warns on unknown keys.""" + model.update_config(wrong_param="test") + + assert len(captured_warnings) == 1 + assert "Invalid configuration parameters" in str(captured_warnings[0].message) + assert "wrong_param" in str(captured_warnings[0].message) + + +@pytest.mark.parametrize( + ("tool_choice", "expected"), + [ + (None, {}), + ({"auto": {}}, {"tool_choice": "auto"}), + ({"any": {}}, {"tool_choice": "required"}), + ({"tool": {"name": "calculator"}}, {"tool_choice": {"type": "function", "name": "calculator"}}), + ({"unknown": {}}, {"tool_choice": "auto"}), # Test default fallback + ], +) +def test_format_request_tool_choice(tool_choice, expected): + """Test that tool_choice is properly formatted for the Responses API.""" + result = OpenAIResponsesModel._format_request_tool_choice(tool_choice) + assert result == expected + + +def test_format_request_with_tool_choice(model, messages, tool_specs): + """Test that tool_choice is properly included in the request.""" + tool_choice = {"tool": {"name": "test_tool"}} + request = model._format_request(messages, tool_specs, tool_choice=tool_choice) + + assert "tool_choice" in request + assert request["tool_choice"] == {"type": "function", "name": "test_tool"} + + +def test_format_request_merges_builtin_tools_with_function_tools(messages, tool_specs): + """Test that built-in tools from params are merged with function tools.""" + model = OpenAIResponsesModel( + model_id="gpt-4o", + params={"tools": [{"type": "web_search"}]}, + ) + request = model._format_request(messages, tool_specs) + + assert request["tools"] == [ + {"type": "web_search"}, + { + "type": "function", + "name": "test_tool", + "description": "A test tool", + "parameters": { + "type": "object", + "properties": {"input": {"type": "string"}}, + "required": ["input"], + }, + }, + ] + + +def test_format_request_builtin_tools_without_function_tools(messages): + """Test that built-in tools from params are preserved when no function tools are provided.""" + model = OpenAIResponsesModel( + model_id="gpt-4o", + params={"tools": [{"type": "web_search"}]}, + ) + request = model._format_request(messages) + + assert request["tools"] == [{"type": "web_search"}] + + +def test_format_request_messages_with_citations_content(): + """Test that citationsContent blocks are converted to text in the request.""" + messages = [ + {"role": "user", "content": [{"text": "search something"}]}, + { + "role": "assistant", + "content": [ + { + "citationsContent": { + "citations": [ + { + "title": "Example", + "location": {"web": {"url": "https://example.com", "domain": "example.com"}}, + "sourceContent": [{"text": "cited text"}], + } + ], + "content": [{"text": "The answer with citations."}], + } + } + ], + }, + ] + formatted = OpenAIResponsesModel._format_request_messages(messages) + + assistant_msg = [m for m in formatted if m.get("role") == "assistant"][0] + assert assistant_msg == { + "role": "assistant", + "content": [{"type": "output_text", "text": "The answer with citations."}], + } + + +def test_format_request_message_content_image_size_limit(): + """Test that oversized images raise ValueError.""" + oversized_data = b"x" * (_MAX_MEDIA_SIZE_BYTES + 1) + content = {"image": {"format": "png", "source": {"bytes": oversized_data}}} + + with pytest.raises(ValueError, match="Image size .* exceeds maximum"): + OpenAIResponsesModel._format_request_message_content(content) + + +def test_format_request_message_content_document_size_limit(): + """Test that oversized documents raise ValueError.""" + oversized_data = b"x" * (_MAX_MEDIA_SIZE_BYTES + 1) + content = {"document": {"format": "pdf", "name": "large.pdf", "source": {"bytes": oversized_data}}} + + with pytest.raises(ValueError, match="Document size .* exceeds maximum"): + OpenAIResponsesModel._format_request_message_content(content) + + +def test_format_request_tool_message_image_size_limit(): + """Test that oversized images in tool results raise ValueError.""" + oversized_data = b"x" * (_MAX_MEDIA_SIZE_BYTES + 1) + tool_result = { + "content": [{"image": {"format": "png", "source": {"bytes": oversized_data}}}], + "status": "success", + "toolUseId": "c1", + } + + with pytest.raises(ValueError, match="Image size .* exceeds maximum"): + OpenAIResponsesModel._format_request_tool_message(tool_result) + + +def test_format_request_tool_message_document_size_limit(): + """Test that oversized documents in tool results raise ValueError.""" + oversized_data = b"x" * (_MAX_MEDIA_SIZE_BYTES + 1) + tool_result = { + "content": [{"document": {"format": "pdf", "name": "large.pdf", "source": {"bytes": oversized_data}}}], + "status": "success", + "toolUseId": "c1", + } + + with pytest.raises(ValueError, match="Document size .* exceeds maximum"): + OpenAIResponsesModel._format_request_tool_message(tool_result) + + +def test_openai_version_check(): + """Test that module import fails with old OpenAI SDK version.""" + import importlib + + import strands.models.openai_responses as openai_responses_module + + def mock_old_version(package_name: str) -> str: + if package_name == "openai": + return "1.99.0" + from importlib.metadata import version + + return version(package_name) + + def mock_valid_version(package_name: str) -> str: + if package_name == "openai": + return "2.0.0" + from importlib.metadata import version + + return version(package_name) + + with unittest.mock.patch("importlib.metadata.version", mock_old_version): + with pytest.raises(ImportError, match="OpenAIResponsesModel requires openai>=2.0.0"): + importlib.reload(openai_responses_module) + + # Reload with valid version to restore module state + with unittest.mock.patch("importlib.metadata.version", mock_valid_version): + importlib.reload(openai_responses_module) + + +@pytest.mark.parametrize("stateful", [True, False]) +def test_stateful(model_id, stateful): + """Model.stateful reflects the stateful config option.""" + model = OpenAIResponsesModel(model_id=model_id, stateful=stateful) + assert model.stateful is stateful + + +@pytest.mark.asyncio +async def test_stream_stateful(openai_client, model_id, agenerator, alist): + """When stateful is enabled, model writes response_id to model_state from response.created.""" + model = OpenAIResponsesModel(model_id=model_id, stateful=True) + mock_events = [ + unittest.mock.Mock( + type="response.created", + response=unittest.mock.Mock(id="resp_abc123"), + ), + unittest.mock.Mock(type="response.output_text.delta", delta="Hi"), + unittest.mock.Mock( + type="response.completed", + response=unittest.mock.Mock( + id="resp_abc123", + usage=unittest.mock.Mock(input_tokens=10, output_tokens=5, total_tokens=15), + ), + ), + ] + + openai_client.responses.create = unittest.mock.AsyncMock(return_value=agenerator(mock_events)) + + model_state = {"response_id": "resp_previous"} + events = await alist( + model.stream( + [{"role": "user", "content": [{"text": "Hello"}]}], + model_state=model_state, + ) + ) + + call_kwargs = openai_client.responses.create.call_args[1] + assert call_kwargs["previous_response_id"] == "resp_previous" + + assert model_state["response_id"] == "resp_abc123" + + metadata_events = [e for e in events if "metadata" in e] + assert len(metadata_events) == 1 + assert metadata_events[0]["metadata"] == { + "usage": {"inputTokens": 10, "outputTokens": 5, "totalTokens": 15}, + "metrics": {"latencyMs": 0}, + } + + +def test_format_request_messages_excludes_reasoning_content(caplog): + """Test that reasoningContent blocks are filtered from messages with a warning.""" + messages = [ + { + "content": [{"text": "Hello"}], + "role": "user", + }, + { + "content": [ + {"reasoningContent": {"reasoningText": {"text": "Let me think..."}}}, + {"text": "The answer is 42"}, + ], + "role": "assistant", + }, + { + "content": [{"text": "Thanks"}], + "role": "user", + }, + ] + + with caplog.at_level("WARNING"): + result = OpenAIResponsesModel._format_request_messages(messages) + + assert result == [ + {"role": "user", "content": [{"type": "input_text", "text": "Hello"}]}, + {"role": "assistant", "content": [{"type": "output_text", "text": "The answer is 42"}]}, + {"role": "user", "content": [{"type": "input_text", "text": "Thanks"}]}, + ] + assert "reasoningContent is not yet supported" in caplog.text + + +class TestCountTokens: + """Tests for OpenAIResponsesModel.count_tokens native token counting.""" + + @pytest.fixture + def openai_client(self): + with unittest.mock.patch.object(strands.models.openai_responses.openai, "AsyncOpenAI") as mock_client_cls: + mock_client = unittest.mock.AsyncMock() + mock_client_cls.return_value.__aenter__.return_value = mock_client + yield mock_client + + @pytest.fixture + def model(self, openai_client): + _ = openai_client + return OpenAIResponsesModel(model_id="gpt-4o", use_native_token_count=True) + + @pytest.fixture + def messages(self): + return [{"role": "user", "content": [{"text": "hello"}]}] + + @pytest.fixture + def tool_specs(self): + return [ + { + "name": "test_tool", + "description": "A test tool", + "inputSchema": {"json": {"type": "object", "properties": {}}}, + } + ] + + @pytest.mark.asyncio + async def test_native_count_tokens_success(self, model, openai_client, messages): + mock_response = unittest.mock.AsyncMock() + mock_response.input_tokens = 42 + openai_client.responses.input_tokens.count.return_value = mock_response + + result = await model.count_tokens(messages=messages) + + assert result == 42 + openai_client.responses.input_tokens.count.assert_called_once() + + @pytest.mark.asyncio + async def test_native_count_tokens_with_system_prompt(self, model, openai_client, messages): + mock_response = unittest.mock.AsyncMock() + mock_response.input_tokens = 55 + openai_client.responses.input_tokens.count.return_value = mock_response + + result = await model.count_tokens(messages=messages, system_prompt="Be helpful.") + + assert result == 55 + call_kwargs = openai_client.responses.input_tokens.count.call_args[1] + assert call_kwargs["instructions"] == "Be helpful." + + @pytest.mark.asyncio + async def test_native_count_tokens_with_tool_specs(self, model, openai_client, messages, tool_specs): + mock_response = unittest.mock.AsyncMock() + mock_response.input_tokens = 100 + openai_client.responses.input_tokens.count.return_value = mock_response + + result = await model.count_tokens(messages=messages, tool_specs=tool_specs) + + assert result == 100 + call_kwargs = openai_client.responses.input_tokens.count.call_args[1] + assert "tools" in call_kwargs + + @pytest.mark.asyncio + async def test_stream_and_store_stripped(self, model, openai_client, messages): + mock_response = unittest.mock.AsyncMock() + mock_response.input_tokens = 10 + openai_client.responses.input_tokens.count.return_value = mock_response + + await model.count_tokens(messages=messages) + + call_kwargs = openai_client.responses.input_tokens.count.call_args[1] + assert "stream" not in call_kwargs + assert "store" not in call_kwargs + + @pytest.mark.asyncio + async def test_fallback_on_api_error(self, model, openai_client, messages): + openai_client.responses.input_tokens.count.side_effect = openai.APIError( + message="Unsupported", request=unittest.mock.MagicMock(), body=None + ) + + result = await model.count_tokens(messages=messages) + + assert isinstance(result, int) + assert result >= 0 + + @pytest.mark.asyncio + async def test_fallback_on_generic_exception(self, model, openai_client, messages): + openai_client.responses.input_tokens.count.side_effect = RuntimeError("Connection failed") + + result = await model.count_tokens(messages=messages) + + assert isinstance(result, int) + assert result >= 0 + + @pytest.mark.asyncio + async def test_fallback_logs_debug(self, model, openai_client, messages, caplog): + import logging + + openai_client.responses.input_tokens.count.side_effect = RuntimeError("API down") + + with caplog.at_level(logging.DEBUG, logger="strands.models.openai_responses"): + await model.count_tokens(messages=messages) + + assert any("native token counting failed" in record.message for record in caplog.records) + + @pytest.mark.asyncio + async def test_skip_native_api_when_use_native_token_count_false(self, openai_client, messages): + _ = openai_client + model = OpenAIResponsesModel(model_id="gpt-4o", use_native_token_count=False) + + result = await model.count_tokens(messages=messages) + + openai_client.responses.input_tokens.count.assert_not_called() + assert isinstance(result, int) + assert result >= 0 + + @pytest.mark.asyncio + async def test_skip_native_api_by_default(self, openai_client, messages): + _ = openai_client + model = OpenAIResponsesModel(model_id="gpt-4o") + + result = await model.count_tokens(messages=messages) + + openai_client.responses.input_tokens.count.assert_not_called() + assert isinstance(result, int) + assert result >= 0 + + +# ============================================================================= +# Bedrock Mantle (bedrock_mantle_config) integration with OpenAIResponsesModel +# ============================================================================= + + +class TestOpenAIResponsesModelBedrockMantleConfig: + @pytest.fixture + def mock_provide_token(self): + with unittest.mock.patch("aws_bedrock_token_generator.provide_token") as mock: + mock.return_value = "bedrock-api-key-deadbeef&Version=1" + yield mock + + def test_bedrock_mantle_config_sets_base_url_and_api_key(self, openai_client, mock_provide_token): + _ = openai_client + model = OpenAIResponsesModel(model_id="openai.gpt-oss-120b", bedrock_mantle_config={"region": "us-east-1"}) + resolved = model._resolve_client_args() + assert resolved["base_url"] == "https://bedrock-mantle.us-east-1.api.aws/v1" + assert resolved["api_key"] == "bedrock-api-key-deadbeef&Version=1" + mock_provide_token.assert_called_once_with(region="us-east-1") + + def test_bedrock_mantle_config_forwards_credentials_provider_and_expiry(self, openai_client, mock_provide_token): + _ = openai_client + from datetime import timedelta + + provider = unittest.mock.Mock() + model = OpenAIResponsesModel( + model_id="openai.gpt-oss-120b", + bedrock_mantle_config={ + "region": "us-west-2", + "credentials_provider": provider, + "expiry": timedelta(minutes=15), + }, + ) + model._resolve_client_args() + mock_provide_token.assert_called_once_with( + region="us-west-2", + aws_credentials_provider=provider, + expiry=timedelta(minutes=15), + ) + + def test_bedrock_mantle_config_mints_token_per_request(self, openai_client, mock_provide_token): + _ = openai_client + model = OpenAIResponsesModel(model_id="openai.gpt-oss-120b", bedrock_mantle_config={"region": "us-east-1"}) + model._resolve_client_args() + model._resolve_client_args() + assert mock_provide_token.call_count == 2 + + def test_bedrock_mantle_config_merges_with_client_args(self, openai_client, mock_provide_token): + """bedrock_mantle_config composes with client_args; transport options are preserved.""" + _ = openai_client + sentinel_http_client = unittest.mock.Mock() + model = OpenAIResponsesModel( + model_id="openai.gpt-oss-120b", + client_args={ + "timeout": 42, + "http_client": sentinel_http_client, + }, + bedrock_mantle_config={"region": "us-east-1"}, + ) + resolved = model._resolve_client_args() + assert resolved["base_url"] == "https://bedrock-mantle.us-east-1.api.aws/v1" + assert resolved["api_key"] == "bedrock-api-key-deadbeef&Version=1" + assert resolved["timeout"] == 42 + assert resolved["http_client"] is sentinel_http_client + + def test_bedrock_mantle_config_rejects_base_url_in_client_args(self, openai_client): + """client_args must not contain base_url or api_key when bedrock_mantle_config is set.""" + _ = openai_client + with pytest.raises(ValueError, match="client_args must not contain"): + OpenAIResponsesModel( + model_id="openai.gpt-oss-120b", + client_args={"api_key": "should-not-be-here"}, + bedrock_mantle_config={"region": "us-east-1"}, + ) + + def test_bedrock_mantle_config_requires_region(self, openai_client): + """bedrock_mantle_config raises when no region can be resolved from config, session, or env.""" + _ = openai_client + with ( + unittest.mock.patch("boto3.Session") as mock_session_cls, + unittest.mock.patch.dict(os.environ, {}, clear=True), + ): + mock_session_cls.return_value.region_name = None + model = OpenAIResponsesModel(model_id="openai.gpt-oss-120b", bedrock_mantle_config={}) + with pytest.raises(ValueError, match="Could not resolve an AWS region"): + model._resolve_client_args() + + def test_bedrock_mantle_config_region_resolved_from_boto3_default(self, openai_client, mock_provide_token): + """When region is omitted, the default boto3 session chain resolves it.""" + _ = openai_client + with unittest.mock.patch("boto3.Session") as mock_session_cls: + mock_session_cls.return_value.region_name = "eu-west-1" + model = OpenAIResponsesModel(model_id="openai.gpt-oss-120b", bedrock_mantle_config={}) + resolved = model._resolve_client_args() + + assert resolved["base_url"] == "https://bedrock-mantle.eu-west-1.api.aws/v1" + mock_provide_token.assert_called_once_with(region="eu-west-1") + + def test_bedrock_mantle_config_region_resolved_from_boto_session(self, openai_client, mock_provide_token): + """An explicit ``boto_session`` supplies the region when ``region`` is omitted.""" + _ = openai_client + session = unittest.mock.Mock() + session.region_name = "ap-southeast-2" + model = OpenAIResponsesModel( + model_id="openai.gpt-oss-120b", + bedrock_mantle_config={"boto_session": session}, + ) + + resolved = model._resolve_client_args() + + assert resolved["base_url"] == "https://bedrock-mantle.ap-southeast-2.api.aws/v1" + mock_provide_token.assert_called_once_with(region="ap-southeast-2") + + def test_bedrock_mantle_config_wraps_token_failures_with_context(self, openai_client, mock_provide_token): + """provide_token failures are wrapped in a RuntimeError with actionable context.""" + _ = openai_client + mock_provide_token.side_effect = RuntimeError("no credentials in chain") + model = OpenAIResponsesModel(model_id="openai.gpt-oss-120b", bedrock_mantle_config={"region": "us-east-1"}) + with pytest.raises(RuntimeError, match="Bedrock Mantle bearer token.*us-east-1"): + model._resolve_client_args() diff --git a/tests/strands/models/test_sagemaker.py b/tests/strands/models/test_sagemaker.py index 72ebf01c6..5d6d6869a 100644 --- a/tests/strands/models/test_sagemaker.py +++ b/tests/strands/models/test_sagemaker.py @@ -2,7 +2,7 @@ import json import unittest.mock -from typing import Any, Dict, List +from typing import Any import boto3 import pytest @@ -32,7 +32,7 @@ def sagemaker_client(boto_session): @pytest.fixture -def endpoint_config() -> Dict[str, Any]: +def endpoint_config() -> dict[str, Any]: """Default endpoint configuration for tests.""" return { "endpoint_name": "test-endpoint", @@ -42,7 +42,7 @@ def endpoint_config() -> Dict[str, Any]: @pytest.fixture -def payload_config() -> Dict[str, Any]: +def payload_config() -> dict[str, Any]: """Default payload configuration for tests.""" return { "max_tokens": 1024, @@ -64,7 +64,7 @@ def messages() -> Messages: @pytest.fixture -def tool_specs() -> List[ToolSpec]: +def tool_specs() -> list[ToolSpec]: """Sample tool specifications for testing.""" return [ { @@ -405,8 +405,8 @@ async def test_stream_with_partial_json(self, sagemaker_client, model, messages, # Mock the response from SageMaker with split JSON mock_response = { "Body": [ - {"PayloadPart": {"Bytes": '{"choices": [{"delta": {"content": "Paris is'.encode("utf-8")}}, - {"PayloadPart": {"Bytes": ' the capital of France."}, "finish_reason": "stop"}]}'.encode("utf-8")}}, + {"PayloadPart": {"Bytes": b'{"choices": [{"delta": {"content": "Paris is'}}, + {"PayloadPart": {"Bytes": b' the capital of France."}, "finish_reason": "stop"}]}'}}, ] } sagemaker_client.invoke_endpoint_with_response_stream.return_value = mock_response @@ -444,8 +444,8 @@ async def test_tool_choice_not_supported_warns(self, sagemaker_client, model, me # Mock the response from SageMaker with split JSON mock_response = { "Body": [ - {"PayloadPart": {"Bytes": '{"choices": [{"delta": {"content": "Paris is'.encode("utf-8")}}, - {"PayloadPart": {"Bytes": ' the capital of France."}, "finish_reason": "stop"}]}'.encode("utf-8")}}, + {"PayloadPart": {"Bytes": b'{"choices": [{"delta": {"content": "Paris is'}}, + {"PayloadPart": {"Bytes": b' the capital of France."}, "finish_reason": "stop"}]}'}}, ] } sagemaker_client.invoke_endpoint_with_response_stream.return_value = mock_response diff --git a/tests/strands/models/test_strict_schema.py b/tests/strands/models/test_strict_schema.py new file mode 100644 index 000000000..4e69f767d --- /dev/null +++ b/tests/strands/models/test_strict_schema.py @@ -0,0 +1,302 @@ +from strands.models._strict_schema import ensure_strict_json_schema + + +def test_basic_object(): + schema = { + "type": "object", + "properties": {"x": {"type": "string"}}, + } + result = ensure_strict_json_schema(schema) + + assert result == { + "type": "object", + "properties": {"x": {"type": "string"}}, + "additionalProperties": False, + } + assert "additionalProperties" not in schema + + +def test_nested_objects(): + schema = { + "type": "object", + "properties": { + "outer": { + "type": "object", + "properties": {"inner": {"type": "integer"}}, + } + }, + } + result = ensure_strict_json_schema(schema) + + assert result == { + "type": "object", + "properties": { + "outer": { + "type": "object", + "properties": {"inner": {"type": "integer"}}, + "additionalProperties": False, + } + }, + "additionalProperties": False, + } + + +def test_defs(): + schema = { + "type": "object", + "properties": {"item": {"$ref": "#/$defs/MyItem"}}, + "$defs": { + "MyItem": { + "type": "object", + "properties": {"name": {"type": "string"}}, + } + }, + } + result = ensure_strict_json_schema(schema) + + assert result["additionalProperties"] is False + assert result["$defs"]["MyItem"] == { + "type": "object", + "properties": {"name": {"type": "string"}}, + "additionalProperties": False, + } + + +def test_definitions(): + schema = { + "type": "object", + "properties": {"item": {"$ref": "#/definitions/MyItem"}}, + "definitions": { + "MyItem": { + "type": "object", + "properties": {"name": {"type": "string"}}, + } + }, + } + result = ensure_strict_json_schema(schema) + + assert result["additionalProperties"] is False + assert result["definitions"]["MyItem"] == { + "type": "object", + "properties": {"name": {"type": "string"}}, + "additionalProperties": False, + } + + +def test_ref_inline(): + schema = { + "type": "object", + "properties": { + "item": { + "$ref": "#/$defs/MyItem", + "description": "An item", + } + }, + "$defs": { + "MyItem": { + "type": "object", + "properties": {"name": {"type": "string"}}, + } + }, + } + result = ensure_strict_json_schema(schema) + + assert result["properties"]["item"] == { + "type": "object", + "properties": {"name": {"type": "string"}}, + "description": "An item", + "additionalProperties": False, + } + + +def test_ref_inline_uses_deep_copy(): + """Two properties referencing the same $def get independent copies.""" + schema = { + "type": "object", + "properties": { + "a": {"$ref": "#/$defs/Shared", "description": "first"}, + "b": {"$ref": "#/$defs/Shared", "description": "second"}, + }, + "$defs": { + "Shared": { + "type": "object", + "properties": {"val": {"type": "string"}}, + } + }, + } + result = ensure_strict_json_schema(schema) + + assert result["properties"]["a"]["description"] == "first" + assert result["properties"]["b"]["description"] == "second" + assert result["properties"]["a"] is not result["properties"]["b"] + + +def test_arrays_anyof_allof(): + schema = { + "type": "object", + "properties": { + "items": { + "type": "array", + "items": {"type": "object", "properties": {"a": {"type": "string"}}}, + }, + "union": { + "anyOf": [ + {"type": "object", "properties": {"b": {"type": "string"}}}, + {"type": "null"}, + ] + }, + "intersection": { + "allOf": [ + {"type": "object", "properties": {"c": {"type": "string"}}}, + ] + }, + }, + } + result = ensure_strict_json_schema(schema) + + assert result == { + "type": "object", + "properties": { + "items": { + "type": "array", + "items": { + "type": "object", + "properties": {"a": {"type": "string"}}, + "additionalProperties": False, + }, + }, + "union": { + "anyOf": [ + { + "type": "object", + "properties": {"b": {"type": "string"}}, + "additionalProperties": False, + }, + {"type": "null"}, + ] + }, + "intersection": { + "allOf": [ + { + "type": "object", + "properties": {"c": {"type": "string"}}, + "additionalProperties": False, + }, + ] + }, + }, + "additionalProperties": False, + } + + +def test_oneof(): + schema = { + "type": "object", + "properties": { + "value": { + "oneOf": [ + {"type": "object", "properties": {"a": {"type": "string"}}}, + {"type": "object", "properties": {"b": {"type": "integer"}}}, + ] + } + }, + } + result = ensure_strict_json_schema(schema) + + assert result == { + "type": "object", + "properties": { + "value": { + "oneOf": [ + {"type": "object", "properties": {"a": {"type": "string"}}, "additionalProperties": False}, + {"type": "object", "properties": {"b": {"type": "integer"}}, "additionalProperties": False}, + ] + } + }, + "additionalProperties": False, + } + + +def test_require_all_properties(): + schema = { + "type": "object", + "properties": { + "required_field": {"type": "string"}, + "optional_field": {"type": "string"}, + }, + "required": ["required_field"], + } + + without = ensure_strict_json_schema(schema) + assert without["required"] == ["required_field"] + + with_all = ensure_strict_json_schema(schema, require_all_properties=True) + assert set(with_all["required"]) == {"required_field", "optional_field"} + + +def test_preserves_additional_properties_true(): + schema = { + "type": "object", + "properties": {"x": {"type": "string"}}, + "additionalProperties": True, + } + result = ensure_strict_json_schema(schema) + + assert result == { + "type": "object", + "properties": {"x": {"type": "string"}}, + "additionalProperties": True, + } + + +def test_preserves_additional_properties_false(): + schema = { + "type": "object", + "properties": {"x": {"type": "string"}}, + "additionalProperties": False, + } + result = ensure_strict_json_schema(schema) + + assert result == { + "type": "object", + "properties": {"x": {"type": "string"}}, + "additionalProperties": False, + } + + +def test_non_object_type_unchanged(): + schema = {"type": "string"} + result = ensure_strict_json_schema(schema) + + assert result == {"type": "string"} + + +def test_ref_with_invalid_format_is_ignored(): + """A $ref that doesn't start with #/ is silently skipped.""" + schema = { + "type": "object", + "properties": { + "item": {"$ref": "external.json#/Foo", "description": "ext"}, + }, + } + result = ensure_strict_json_schema(schema) + + # $ref is not resolved, but additionalProperties is still added to root + assert result["additionalProperties"] is False + assert result["properties"]["item"]["$ref"] == "external.json#/Foo" + + +def test_ref_with_missing_path_is_ignored(): + """A $ref pointing to a non-existent path is silently skipped.""" + schema = { + "type": "object", + "properties": { + "item": {"$ref": "#/$defs/Missing", "description": "gone"}, + }, + "$defs": {}, + } + result = ensure_strict_json_schema(schema) + + assert result["additionalProperties"] is False + # $ref stays because resolution failed + assert "$ref" in result["properties"]["item"] diff --git a/tests/strands/models/test_writer.py b/tests/strands/models/test_writer.py index 8cf64a39a..81745f412 100644 --- a/tests/strands/models/test_writer.py +++ b/tests/strands/models/test_writer.py @@ -1,5 +1,6 @@ +import logging import unittest.mock -from typing import Any, List +from typing import Any import pytest @@ -266,7 +267,7 @@ def test_format_request_with_unsupported_type(model, content, content_type): class AsyncStreamWrapper: - def __init__(self, items: List[Any]): + def __init__(self, items: list[Any]): self.items = items def __aiter__(self): @@ -277,7 +278,7 @@ async def _generator(self): yield item -async def mock_streaming_response(items: List[Any]): +async def mock_streaming_response(items: list[Any]): return AsyncStreamWrapper(items) @@ -435,3 +436,69 @@ def test_update_config_validation_warns_on_unknown_keys(model, captured_warnings assert len(captured_warnings) == 1 assert "Invalid configuration parameters" in str(captured_warnings[0].message) assert "wrong_param" in str(captured_warnings[0].message) + + +def test_format_request_filters_s3_source_image(model, caplog): + """Test that images with Location sources are filtered out with warning.""" + caplog.set_level(logging.WARNING, logger="strands.models.writer") + + messages = [ + { + "role": "user", + "content": [ + {"text": "look at this image"}, + { + "image": { + "format": "png", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/image.png"}}, + }, + }, + ], + }, + ] + + request = model.format_request(messages) + + # Image with S3 source should be filtered, text should remain + formatted_messages = request["messages"] + user_content = formatted_messages[0]["content"] + assert len(user_content) == 1 + assert user_content[0]["type"] == "text" + assert "Location sources are not supported by Writer" in caplog.text + + +def test_format_request_filters_location_source_document(model, caplog): + """Test that documents with Location sources are filtered out with warning.""" + caplog.set_level(logging.WARNING, logger="strands.models.writer") + + messages = [ + { + "role": "user", + "content": [ + {"text": "analyze this document"}, + { + "document": { + "format": "pdf", + "name": "report.pdf", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/report.pdf"}}, + }, + }, + { + "document": { + "format": "pdf", + "name": "report.pdf", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/report.pdf"}}, + }, + }, + ], + }, + ] + + request = model.format_request(messages) + + # Document with S3 source should be filtered, text should remain + formatted_messages = request["messages"] + user_content = formatted_messages[0]["content"] + assert len(user_content) == 1 + assert user_content[0]["type"] == "text" + assert "Location sources are not supported by Writer" in caplog.text diff --git a/tests/strands/multiagent/a2a/test_converters.py b/tests/strands/multiagent/a2a/test_converters.py new file mode 100644 index 000000000..fff48653b --- /dev/null +++ b/tests/strands/multiagent/a2a/test_converters.py @@ -0,0 +1,528 @@ +"""Tests for A2A converter functions.""" + +from unittest.mock import MagicMock +from uuid import uuid4 + +import pytest +from a2a.types import Message as A2AMessage +from a2a.types import Part, Role, TaskArtifactUpdateEvent, TaskStatusUpdateEvent, TextPart + +from strands.agent.agent_result import AgentResult +from strands.multiagent.a2a._converters import ( + convert_content_blocks_to_parts, + convert_input_to_message, + convert_response_to_agent_result, +) + + +def test_convert_string_input(): + """Test converting string input to A2A message.""" + message = convert_input_to_message("Hello") + + assert isinstance(message, A2AMessage) + assert message.role == Role.user + assert len(message.parts) == 1 + assert message.parts[0].root.text == "Hello" + + +def test_convert_message_list_input(): + """Test converting message list input to A2A message.""" + messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + ] + + message = convert_input_to_message(messages) + + assert isinstance(message, A2AMessage) + assert message.role == Role.user + assert len(message.parts) == 1 + + +def test_convert_content_blocks_input(): + """Test converting content blocks input to A2A message.""" + content_blocks = [{"text": "Hello"}, {"text": "World"}] + + message = convert_input_to_message(content_blocks) + + assert isinstance(message, A2AMessage) + assert len(message.parts) == 2 + + +def test_convert_unsupported_input(): + """Test that unsupported input types raise ValueError.""" + with pytest.raises(ValueError, match="Unsupported input type"): + convert_input_to_message(123) + + +def test_convert_interrupt_response_raises_error(): + """Test that InterruptResponseContent raises explicit error.""" + interrupt_responses = [{"interruptResponse": {"interruptId": "123", "response": "A"}}] + + with pytest.raises(ValueError, match="InterruptResponseContent is not supported for A2AAgent"): + convert_input_to_message(interrupt_responses) + + +def test_convert_content_blocks_to_parts(): + """Test converting content blocks to A2A parts.""" + content_blocks = [{"text": "Hello"}, {"text": "World"}] + + parts = convert_content_blocks_to_parts(content_blocks) + + assert len(parts) == 2 + assert parts[0].root.text == "Hello" + assert parts[1].root.text == "World" + + +def test_convert_a2a_message_response(): + """Test converting A2A message response to AgentResult.""" + a2a_message = A2AMessage( + message_id=uuid4().hex, + role=Role.agent, + parts=[Part(TextPart(kind="text", text="Response"))], + ) + + result = convert_response_to_agent_result(a2a_message) + + assert isinstance(result, AgentResult) + assert result.message["role"] == "assistant" + assert len(result.message["content"]) == 1 + assert result.message["content"][0]["text"] == "Response" + + +def test_convert_task_response(): + """Test converting task response to AgentResult.""" + mock_task = MagicMock() + mock_artifact = MagicMock() + mock_part = MagicMock() + mock_part.root.text = "Task response" + mock_artifact.parts = [mock_part] + mock_task.artifacts = [mock_artifact] + + result = convert_response_to_agent_result((mock_task, None)) + + assert isinstance(result, AgentResult) + assert len(result.message["content"]) == 1 + assert result.message["content"][0]["text"] == "Task response" + + +def test_convert_multiple_parts_response(): + """Test converting response with multiple parts to separate content blocks.""" + a2a_message = A2AMessage( + message_id=uuid4().hex, + role=Role.agent, + parts=[ + Part(TextPart(kind="text", text="First")), + Part(TextPart(kind="text", text="Second")), + ], + ) + + result = convert_response_to_agent_result(a2a_message) + + assert len(result.message["content"]) == 2 + assert result.message["content"][0]["text"] == "First" + assert result.message["content"][1]["text"] == "Second" + + +# --- New tests for coverage --- + + +def test_convert_message_list_finds_last_user_message(): + """Test that message list conversion finds the last user message.""" + messages = [ + {"role": "user", "content": [{"text": "First"}]}, + {"role": "assistant", "content": [{"text": "Response"}]}, + {"role": "user", "content": [{"text": "Second"}]}, + ] + + message = convert_input_to_message(messages) + + assert message.parts[0].root.text == "Second" + + +def test_convert_content_blocks_skips_non_text(): + """Test that non-text content blocks are skipped.""" + content_blocks = [{"text": "Hello"}, {"image": "data"}, {"text": "World"}] + + parts = convert_content_blocks_to_parts(content_blocks) + + assert len(parts) == 2 + + +def test_convert_task_artifact_update_event(): + """Test converting TaskArtifactUpdateEvent to AgentResult.""" + mock_task = MagicMock() + mock_part = MagicMock() + mock_part.root.text = "Streamed artifact" + mock_artifact = MagicMock() + mock_artifact.parts = [mock_part] + + mock_event = MagicMock(spec=TaskArtifactUpdateEvent) + mock_event.artifact = mock_artifact + + result = convert_response_to_agent_result((mock_task, mock_event)) + + assert result.message["content"][0]["text"] == "Streamed artifact" + + +def test_convert_task_status_update_event(): + """Test converting TaskStatusUpdateEvent to AgentResult.""" + mock_task = MagicMock() + mock_part = MagicMock() + mock_part.root.text = "Status message" + mock_message = MagicMock() + mock_message.parts = [mock_part] + mock_status = MagicMock() + mock_status.message = mock_message + + mock_event = MagicMock(spec=TaskStatusUpdateEvent) + mock_event.status = mock_status + + result = convert_response_to_agent_result((mock_task, mock_event)) + + assert result.message["content"][0]["text"] == "Status message" + + +def test_convert_task_status_update_event_no_message_falls_back_to_task_artifacts(): + """Test that TaskStatusUpdateEvent with no message falls back to task.artifacts.""" + mock_task = MagicMock() + mock_part = MagicMock() + mock_part.root.text = "Artifact content" + mock_artifact = MagicMock() + mock_artifact.parts = [mock_part] + mock_task.artifacts = [mock_artifact] + + mock_event = MagicMock(spec=TaskStatusUpdateEvent) + mock_status = MagicMock() + mock_status.message = None + mock_event.status = mock_status + + result = convert_response_to_agent_result((mock_task, mock_event)) + + assert len(result.message["content"]) == 1 + assert result.message["content"][0]["text"] == "Artifact content" + + +def test_convert_task_artifact_update_event_empty_parts_falls_back_to_task_artifacts(): + """Test that TaskArtifactUpdateEvent with empty parts falls back to task.artifacts.""" + mock_task = MagicMock() + mock_part = MagicMock() + mock_part.root.text = "Full artifact content" + mock_artifact = MagicMock() + mock_artifact.parts = [mock_part] + mock_task.artifacts = [mock_artifact] + + mock_event = MagicMock(spec=TaskArtifactUpdateEvent) + mock_event_artifact = MagicMock() + mock_event_artifact.parts = [] + mock_event.artifact = mock_event_artifact + + result = convert_response_to_agent_result((mock_task, mock_event)) + + assert len(result.message["content"]) == 1 + assert result.message["content"][0]["text"] == "Full artifact content" + + +def test_convert_response_handles_missing_data(): + """Test that response conversion handles missing/malformed data gracefully.""" + # TaskArtifactUpdateEvent with no artifact + mock_event = MagicMock(spec=TaskArtifactUpdateEvent) + mock_event.artifact = None + result = convert_response_to_agent_result((MagicMock(), mock_event)) + assert len(result.message["content"]) == 0 + + # TaskStatusUpdateEvent with no status + mock_event = MagicMock(spec=TaskStatusUpdateEvent) + mock_event.status = None + result = convert_response_to_agent_result((MagicMock(), mock_event)) + assert len(result.message["content"]) == 0 + + # Task artifact without parts attribute + mock_task = MagicMock() + mock_artifact = MagicMock(spec=[]) + del mock_artifact.parts + mock_task.artifacts = [mock_artifact] + result = convert_response_to_agent_result((mock_task, None)) + assert len(result.message["content"]) == 0 + + +# ========================================================================= +# NEW TESTS: Lifecycle State Mapping +# ========================================================================= + + +def test_convert_response_completed_state_maps_to_end_turn(): + """Test that completed state maps to end_turn stop_reason.""" + from unittest.mock import MagicMock + + from a2a.types import TaskState, TaskStatus, TaskStatusUpdateEvent + + task = MagicMock() + task.artifacts = None + + status = TaskStatus(state=TaskState.completed, message=None) + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + result = convert_response_to_agent_result((task, update_event)) + assert result.stop_reason == "end_turn" + + +def test_convert_response_failed_state_maps_to_end_turn(): + """Test that failed state maps to end_turn stop_reason with error content.""" + from unittest.mock import MagicMock + + from a2a.types import Message, TaskState, TaskStatus, TaskStatusUpdateEvent + + task = MagicMock() + task.artifacts = None + + # Create a status message with error info + error_part = MagicMock() + error_part.root = MagicMock() + error_part.root.text = "Agent execution failed: timeout" + + error_message = MagicMock(spec=Message) + error_message.parts = [error_part] + + status = TaskStatus(state=TaskState.failed, message=error_message) + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + result = convert_response_to_agent_result((task, update_event)) + assert result.stop_reason == "end_turn" + assert result.state.get("a2a_task_state") == "failed" + assert "Agent execution failed" in result.message["content"][0]["text"] + + +def test_convert_response_input_required_maps_to_interrupt(): + """Test that input_required state maps to interrupt stop_reason.""" + from unittest.mock import MagicMock + + from a2a.types import Message, TaskState, TaskStatus, TaskStatusUpdateEvent + + task = MagicMock() + task.artifacts = None + + input_part = MagicMock() + input_part.root = MagicMock() + input_part.root.text = "Agent requires input:\n- approval: Need confirmation" + + input_message = MagicMock(spec=Message) + input_message.parts = [input_part] + + status = TaskStatus(state=TaskState.input_required, message=input_message) + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + result = convert_response_to_agent_result((task, update_event)) + assert result.stop_reason == "interrupt" + assert result.state.get("a2a_task_state") == "input-required" + assert "approval" in result.message["content"][0]["text"] + + +def test_convert_response_canceled_state_maps_to_end_turn(): + """Test that canceled state maps to end_turn stop_reason.""" + from unittest.mock import MagicMock + + from a2a.types import TaskState, TaskStatus, TaskStatusUpdateEvent + + task = MagicMock() + task.artifacts = None + + status = TaskStatus(state=TaskState.canceled, message=None) + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + result = convert_response_to_agent_result((task, update_event)) + assert result.stop_reason == "end_turn" + assert result.state.get("a2a_task_state") == "canceled" + + +def test_convert_response_rejected_state_maps_to_end_turn(): + """Test that rejected state maps to end_turn stop_reason.""" + from unittest.mock import MagicMock + + from a2a.types import TaskState, TaskStatus, TaskStatusUpdateEvent + + task = MagicMock() + task.artifacts = None + + status = TaskStatus(state=TaskState.rejected, message=None) + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + result = convert_response_to_agent_result((task, update_event)) + assert result.stop_reason == "end_turn" + assert result.state.get("a2a_task_state") == "rejected" + + +def test_convert_response_auth_required_maps_to_interrupt(): + """Test that auth_required state maps to interrupt stop_reason.""" + from unittest.mock import MagicMock + + from a2a.types import TaskState, TaskStatus, TaskStatusUpdateEvent + + task = MagicMock() + task.artifacts = None + + status = TaskStatus(state=TaskState.auth_required, message=None) + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + result = convert_response_to_agent_result((task, update_event)) + assert result.stop_reason == "interrupt" + assert result.state.get("a2a_task_state") == "auth-required" + + +def test_extract_task_state_from_status_update(): + """Test _extract_task_state helper.""" + from unittest.mock import MagicMock + + from a2a.types import TaskState, TaskStatus, TaskStatusUpdateEvent + + from strands.multiagent.a2a._converters import _extract_task_state + + task = MagicMock() + status = TaskStatus(state=TaskState.failed, message=None) + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + state = _extract_task_state((task, update_event)) + assert state == TaskState.failed + + +def test_extract_task_state_from_message_returns_none(): + """Test _extract_task_state returns None for Message responses.""" + from unittest.mock import MagicMock + + from a2a.types import Message + + from strands.multiagent.a2a._converters import _extract_task_state + + message = MagicMock(spec=Message) + state = _extract_task_state(message) + assert state is None + + +# ========================================================================= +# DEVIL'S ADVOCATE FINDINGS — Tests addressing review gaps +# ========================================================================= + + +def test_convert_response_completed_state_includes_state_metadata(): + """Major Finding 3: The completed state test was missing state assertion. + + Every other state test asserts both stop_reason AND result.state, but the most + important one (completed — the happy path) was missing the state check. This ensures + downstream consumers relying on result.state["a2a_task_state"] won't break silently. + """ + from unittest.mock import MagicMock + + from a2a.types import TaskState, TaskStatus, TaskStatusUpdateEvent + + task = MagicMock() + task.artifacts = None + + status = TaskStatus(state=TaskState.completed, message=None) + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + result = convert_response_to_agent_result((task, update_event)) + assert result.stop_reason == "end_turn" + assert result.state.get("a2a_task_state") == "completed" # THIS WAS MISSING + + +def test_convert_response_unknown_state_defaults_to_end_turn(): + """Major Finding 4: TaskState.unknown should default to end_turn. + + The a2a-sdk has a TaskState.unknown value. Our code handles it via the .get() + default ("end_turn"). This test documents that this is an intentional design + decision: unknown states are treated as terminal completions rather than errors. + + Rationale: An unknown state from a remote server is ambiguous. Treating it as + end_turn (completed) is the safest default — the client won't hang waiting for + more events, and the result content (if any) is still accessible. + """ + from unittest.mock import MagicMock + + from a2a.types import TaskState, TaskStatus, TaskStatusUpdateEvent + + task = MagicMock() + task.artifacts = None + + status = TaskStatus(state=TaskState.unknown, message=None) + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + result = convert_response_to_agent_result((task, update_event)) + # unknown is NOT in _STATE_TO_STOP_REASON, so defaults to "end_turn" + assert result.stop_reason == "end_turn" + # state metadata should reflect the actual state value + assert result.state.get("a2a_task_state") == "unknown" + + +def test_convert_response_working_state_defaults_to_end_turn(): + """Test that working state (not in mapping) defaults to end_turn. + + This covers the edge case where a TaskStatusUpdateEvent with state=working + somehow reaches the converter (shouldn't normally happen since _is_complete_event + filters these out, but defense-in-depth). + """ + from unittest.mock import MagicMock + + from a2a.types import TaskState, TaskStatus, TaskStatusUpdateEvent + + task = MagicMock() + task.artifacts = None + + status = TaskStatus(state=TaskState.working, message=None) + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + result = convert_response_to_agent_result((task, update_event)) + assert result.stop_reason == "end_turn" + assert result.state.get("a2a_task_state") == "working" + + +def test_extract_task_state_from_artifact_update_returns_none(): + """Minor Finding 5: _extract_task_state with TaskArtifactUpdateEvent returns None. + + This is the untested path where the update event is an artifact (not status). + """ + from unittest.mock import MagicMock + + from a2a.types import TaskArtifactUpdateEvent + + from strands.multiagent.a2a._converters import _extract_task_state + + task = MagicMock() + mock_event = MagicMock(spec=TaskArtifactUpdateEvent) + + state = _extract_task_state((task, mock_event)) + assert state is None + + +def test_state_to_stop_reason_covers_all_lifecycle_states(): + """Verify _STATE_TO_STOP_REASON has mappings for all documented lifecycle states. + + Guards against future additions to the a2a-sdk that we miss. + """ + from a2a.types import TaskState + + from strands.multiagent.a2a._converters import _STATE_TO_STOP_REASON + + # These are the states we explicitly handle + expected_mapped = { + TaskState.completed, + TaskState.failed, + TaskState.canceled, + TaskState.rejected, + TaskState.input_required, + TaskState.auth_required, + } + assert set(_STATE_TO_STOP_REASON.keys()) == expected_mapped + + # These should NOT be in the mapping (they're non-terminal progress states) + assert TaskState.working not in _STATE_TO_STOP_REASON + assert TaskState.submitted not in _STATE_TO_STOP_REASON + assert TaskState.unknown not in _STATE_TO_STOP_REASON diff --git a/tests/strands/multiagent/a2a/test_executor.py b/tests/strands/multiagent/a2a/test_executor.py index 1463d3f48..940d26f8c 100644 --- a/tests/strands/multiagent/a2a/test_executor.py +++ b/tests/strands/multiagent/a2a/test_executor.py @@ -1,6 +1,7 @@ """Tests for the StrandsA2AExecutor class.""" import base64 +from typing import Any from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -11,6 +12,9 @@ from strands.multiagent.a2a.executor import StrandsA2AExecutor from strands.types.content import ContentBlock +# Suppress A2A compliance warnings for legacy streaming mode tests +pytestmark = pytest.mark.filterwarnings("ignore:The default A2A response stream.*:UserWarning") + # Test data constants VALID_PNG_BYTES = b"fake_png_data" VALID_MP4_BYTES = b"fake_mp4_data" @@ -579,7 +583,7 @@ async def mock_stream(content_blocks): async def test_execute_streaming_mode_handles_agent_exception( mock_strands_agent, mock_request_context, mock_event_queue ): - """Test that execute handles agent exceptions correctly in streaming mode.""" + """Test that execute transitions to failed state when agent raises exception.""" # Setup mock agent to raise exception when stream_async is called mock_strands_agent.stream_async = MagicMock(side_effect=Exception("Agent error")) @@ -604,18 +608,25 @@ async def test_execute_streaming_mode_handles_agent_exception( mock_message.parts = [part] mock_request_context.message = mock_message - with pytest.raises(ServerError): - await executor.execute(mock_request_context, mock_event_queue) + # Should NOT raise - instead transitions to failed state + await executor.execute(mock_request_context, mock_event_queue) # Verify agent was called mock_strands_agent.stream_async.assert_called_once() + # Verify a failed status event was enqueued + enqueued_events = [call[0][0] for call in mock_event_queue.enqueue_event.call_args_list] + from a2a.types import TaskState, TaskStatusUpdateEvent -@pytest.mark.asyncio -async def test_cancel_raises_unsupported_operation_error(mock_strands_agent, mock_request_context, mock_event_queue): - """Test that cancel raises UnsupportedOperationError.""" + failed_events = [ + e for e in enqueued_events if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.failed + ] + assert len(failed_events) == 1 + assert "Agent execution failed" in failed_events[0].status.message.parts[0].root.text executor = StrandsA2AExecutor(mock_strands_agent) + # Cancel with no current_task raises UnsupportedOperationError + mock_request_context.current_task = None with pytest.raises(ServerError) as excinfo: await executor.cancel(mock_request_context, mock_event_queue) @@ -1020,3 +1031,913 @@ def test_default_formats_modularization(): assert executor._get_file_format_from_mime_type("", "document") == "txt" assert executor._get_file_format_from_mime_type("", "image") == "png" assert executor._get_file_format_from_mime_type("", "video") == "mp4" + + +# Tests for enable_a2a_compliant_streaming parameter + + +@pytest.mark.asyncio +async def test_legacy_mode_emits_deprecation_warning(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that legacy streaming (default) emits deprecation warning.""" + from a2a.types import TextPart + + executor = StrandsA2AExecutor(mock_strands_agent) # Default is False + + # Mock stream_async + async def mock_stream(content_blocks): + yield {"result": None} + + mock_strands_agent.stream_async = MagicMock(return_value=mock_stream([])) + + # Mock task + mock_task = MagicMock() + mock_task.id = "test-task-id" + mock_task.context_id = "test-context-id" + mock_request_context.current_task = mock_task + + # Mock message + mock_text_part = MagicMock(spec=TextPart) + mock_text_part.text = "test" + mock_part = MagicMock() + mock_part.root = mock_text_part + mock_message = MagicMock() + mock_message.parts = [mock_part] + mock_request_context.message = mock_message + + with pytest.warns(UserWarning, match="does not conform to what is expected in the A2A spec"): + await executor.execute(mock_request_context, mock_event_queue) + + +@pytest.mark.asyncio +async def test_a2a_compliant_mode_no_warning(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that A2A-compliant mode does not emit warning.""" + import warnings + + from a2a.types import TextPart + + executor = StrandsA2AExecutor(mock_strands_agent, enable_a2a_compliant_streaming=True) + + # Mock stream_async + async def mock_stream(content_blocks): + yield {"result": None} + + mock_strands_agent.stream_async = MagicMock(return_value=mock_stream([])) + + # Mock task + mock_task = MagicMock() + mock_task.id = "test-task-id" + mock_task.context_id = "test-context-id" + mock_request_context.current_task = mock_task + + # Mock message + mock_text_part = MagicMock(spec=TextPart) + mock_text_part.text = "test" + mock_part = MagicMock() + mock_part.root = mock_text_part + mock_message = MagicMock() + mock_message.parts = [mock_part] + mock_request_context.message = mock_message + + with warnings.catch_warnings(): + warnings.simplefilter("error") + try: + await executor.execute(mock_request_context, mock_event_queue) + except UserWarning: + pytest.fail("Should not emit warning") + + +@pytest.mark.asyncio +async def test_a2a_compliant_mode_uses_add_artifact(mock_strands_agent): + """Test that A2A-compliant mode uses add_artifact with artifact_id.""" + executor = StrandsA2AExecutor(mock_strands_agent, enable_a2a_compliant_streaming=True) + executor._current_artifact_id = "artifact-123" + executor._is_first_chunk = True + + mock_updater = MagicMock() + mock_updater.add_artifact = AsyncMock() + mock_updater.update_status = AsyncMock() + + event = {"data": "content"} + await executor._handle_streaming_event(event, mock_updater) + + mock_updater.add_artifact.assert_called_once() + assert mock_updater.add_artifact.call_args[1]["artifact_id"] == "artifact-123" + assert mock_updater.add_artifact.call_args[1]["append"] is False + mock_updater.update_status.assert_not_called() + + +@pytest.mark.asyncio +async def test_a2a_compliant_handle_result_first_chunk_with_content(mock_strands_agent): + """Test that A2A-compliant mode sends a TextPart with content when first chunk and result has content.""" + executor = StrandsA2AExecutor(mock_strands_agent, enable_a2a_compliant_streaming=True) + executor._current_artifact_id = "artifact-456" + executor._is_first_chunk = True + + mock_updater = MagicMock() + mock_updater.add_artifact = AsyncMock() + mock_updater.complete = AsyncMock() + + mock_result = MagicMock(spec=SAAgentResult) + mock_result.__str__ = MagicMock(return_value="Final response") + + await executor._handle_agent_result(mock_result, mock_updater) + + mock_updater.add_artifact.assert_called_once() + parts = mock_updater.add_artifact.call_args[0][0] + assert len(parts) == 1 + assert parts[0].root.text == "Final response" + assert mock_updater.add_artifact.call_args[1]["artifact_id"] == "artifact-456" + assert mock_updater.add_artifact.call_args[1]["last_chunk"] is True + mock_updater.complete.assert_called_once() + + +@pytest.mark.asyncio +async def test_a2a_compliant_handle_result_first_chunk_with_none_result(mock_strands_agent): + """Test that A2A-compliant mode sends a TextPart with empty string when first chunk and result is None. + + Per the A2A spec, parts must contain at least one part, so even with no result + we should send a TextPart with an empty string rather than an empty list. + """ + executor = StrandsA2AExecutor(mock_strands_agent, enable_a2a_compliant_streaming=True) + executor._current_artifact_id = "artifact-789" + executor._is_first_chunk = True + + mock_updater = MagicMock() + mock_updater.add_artifact = AsyncMock() + mock_updater.complete = AsyncMock() + + await executor._handle_agent_result(None, mock_updater) + + mock_updater.add_artifact.assert_called_once() + parts = mock_updater.add_artifact.call_args[0][0] + assert len(parts) == 1 + assert parts[0].root.text == "" + assert mock_updater.add_artifact.call_args[1]["artifact_id"] == "artifact-789" + assert mock_updater.add_artifact.call_args[1]["last_chunk"] is True + mock_updater.complete.assert_called_once() + + +@pytest.mark.asyncio +async def test_a2a_compliant_handle_result_not_first_chunk(mock_strands_agent): + """Test that A2A-compliant mode sends a TextPart with empty string when not the first chunk. + + Per the A2A spec, parts must contain at least one part, so the final marker + chunk should include a TextPart with an empty string rather than an empty list. + """ + executor = StrandsA2AExecutor(mock_strands_agent, enable_a2a_compliant_streaming=True) + executor._current_artifact_id = "artifact-abc" + executor._is_first_chunk = False + + mock_updater = MagicMock() + mock_updater.add_artifact = AsyncMock() + mock_updater.complete = AsyncMock() + + mock_result = MagicMock(spec=SAAgentResult) + mock_result.__str__ = MagicMock(return_value="Some content") + + await executor._handle_agent_result(mock_result, mock_updater) + + mock_updater.add_artifact.assert_called_once() + parts = mock_updater.add_artifact.call_args[0][0] + assert len(parts) == 1 + assert parts[0].root.text == "" + assert mock_updater.add_artifact.call_args[1]["artifact_id"] == "artifact-abc" + assert mock_updater.add_artifact.call_args[1]["append"] is True + assert mock_updater.add_artifact.call_args[1]["last_chunk"] is True + + +# Tests for invocation state propagation from A2A request context + + +def _setup_streaming_context( + mock_strands_agent: MagicMock, + mock_request_context: MagicMock, +) -> None: + """Set up common mocks for invocation state streaming tests. + + Args: + mock_strands_agent: The mock Strands Agent. + mock_request_context: The mock RequestContext. + """ + + async def mock_stream(content_blocks: list, **kwargs: Any) -> Any: + yield {"result": MagicMock(spec=SAAgentResult)} + + mock_strands_agent.stream_async = MagicMock(side_effect=mock_stream) + + # Set up message with a text part + mock_text_part = MagicMock(spec=TextPart) + mock_text_part.text = "test input" + mock_part = MagicMock() + mock_part.root = mock_text_part + mock_message = MagicMock() + mock_message.parts = [mock_part] + mock_request_context.message = mock_message + + +@pytest.mark.asyncio +async def test_invocation_state_contains_request_context(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that the full RequestContext is passed as a2a_request_context in invocation state.""" + mock_task = MagicMock() + mock_task.id = "task-42" + mock_task.context_id = "ctx-99" + mock_request_context.current_task = mock_task + mock_request_context.metadata = {"caller": "test-client"} + + _setup_streaming_context(mock_strands_agent, mock_request_context) + + executor = StrandsA2AExecutor(mock_strands_agent) + await executor.execute(mock_request_context, mock_event_queue) + + mock_strands_agent.stream_async.assert_called_once() + call_kwargs = mock_strands_agent.stream_async.call_args[1] + invocation_state = call_kwargs["invocation_state"] + + assert invocation_state is not None + assert invocation_state["a2a_request_context"] is mock_request_context + + +@pytest.mark.asyncio +async def test_invocation_state_context_exposes_metadata(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that metadata is accessible through the RequestContext in invocation state.""" + test_metadata = {"caller": "test-client", "session": "abc-123"} + mock_request_context.metadata = test_metadata + mock_task = MagicMock() + mock_task.id = "task-1" + mock_task.context_id = "ctx-1" + mock_request_context.current_task = mock_task + + _setup_streaming_context(mock_strands_agent, mock_request_context) + + executor = StrandsA2AExecutor(mock_strands_agent) + await executor.execute(mock_request_context, mock_event_queue) + + call_kwargs = mock_strands_agent.stream_async.call_args[1] + context = call_kwargs["invocation_state"]["a2a_request_context"] + + assert context.metadata == test_metadata + + +@pytest.mark.asyncio +async def test_invocation_state_context_exposes_task_info(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that task info is accessible through the RequestContext in invocation state.""" + mock_task = MagicMock() + mock_task.id = "task-100" + mock_task.context_id = "ctx-200" + mock_request_context.current_task = mock_task + + _setup_streaming_context(mock_strands_agent, mock_request_context) + + executor = StrandsA2AExecutor(mock_strands_agent) + await executor.execute(mock_request_context, mock_event_queue) + + call_kwargs = mock_strands_agent.stream_async.call_args[1] + context = call_kwargs["invocation_state"]["a2a_request_context"] + + assert context.current_task.id == "task-100" + assert context.current_task.context_id == "ctx-200" + + +@pytest.mark.asyncio +async def test_invocation_state_context_when_no_task(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that RequestContext is passed even when there is no current task.""" + mock_request_context.current_task = None + mock_request_context.metadata = {} + + _setup_streaming_context(mock_strands_agent, mock_request_context) + + executor = StrandsA2AExecutor(mock_strands_agent) + + with patch("strands.multiagent.a2a.executor.new_task") as mock_new_task: + mock_new_task.return_value = MagicMock(id="generated-id", context_id="generated-ctx") + await executor.execute(mock_request_context, mock_event_queue) + + call_kwargs = mock_strands_agent.stream_async.call_args[1] + invocation_state = call_kwargs["invocation_state"] + + assert invocation_state["a2a_request_context"] is mock_request_context + + +@pytest.mark.asyncio +async def test_invocation_state_with_a2a_compliant_streaming( + mock_strands_agent, mock_request_context, mock_event_queue +): + """Test that invocation state is passed correctly in A2A-compliant streaming mode.""" + mock_task = MagicMock() + mock_task.id = "task-compliant" + mock_task.context_id = "ctx-compliant" + mock_request_context.current_task = mock_task + + _setup_streaming_context(mock_strands_agent, mock_request_context) + + executor = StrandsA2AExecutor(mock_strands_agent, enable_a2a_compliant_streaming=True) + await executor.execute(mock_request_context, mock_event_queue) + + call_kwargs = mock_strands_agent.stream_async.call_args[1] + invocation_state = call_kwargs["invocation_state"] + + assert invocation_state is not None + assert invocation_state["a2a_request_context"] is mock_request_context + + +# ========================================================================= +# NEW TESTS: A2A Lifecycle State Support +# ========================================================================= + + +@pytest.mark.asyncio +async def test_execute_transitions_to_failed_on_streaming_error( + mock_strands_agent, mock_request_context, mock_event_queue +): + """Test that errors during streaming transition task to failed state.""" + from a2a.types import TaskState, TaskStatusUpdateEvent, TextPart + + async def mock_stream(content_blocks, **kwargs): + """Mock streaming that raises mid-stream.""" + yield {"data": "partial output"} + raise RuntimeError("Connection lost") + + mock_strands_agent.stream_async = MagicMock(side_effect=mock_stream) + + executor = StrandsA2AExecutor(mock_strands_agent) + + mock_task = MagicMock() + mock_task.id = "task-fail" + mock_task.context_id = "ctx-fail" + mock_request_context.current_task = mock_task + + mock_text_part = MagicMock(spec=TextPart) + mock_text_part.text = "test" + mock_part = MagicMock() + mock_part.root = mock_text_part + mock_message = MagicMock() + mock_message.parts = [mock_part] + mock_request_context.message = mock_message + + # Should not raise + await executor.execute(mock_request_context, mock_event_queue) + + # Verify failed state was enqueued + enqueued_events = [call[0][0] for call in mock_event_queue.enqueue_event.call_args_list] + failed_events = [ + e for e in enqueued_events if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.failed + ] + assert len(failed_events) == 1 + assert "Agent execution failed" in failed_events[0].status.message.parts[0].root.text + + +@pytest.mark.asyncio +async def test_cancel_with_valid_task(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that cancel transitions task to canceled state when task exists.""" + from a2a.types import TaskState, TaskStatusUpdateEvent + + executor = StrandsA2AExecutor(mock_strands_agent) + + mock_task = MagicMock() + mock_task.id = "task-cancel" + mock_task.context_id = "ctx-cancel" + mock_request_context.current_task = mock_task + + await executor.cancel(mock_request_context, mock_event_queue) + + # Verify canceled state was enqueued + enqueued_events = [call[0][0] for call in mock_event_queue.enqueue_event.call_args_list] + canceled_events = [ + e for e in enqueued_events if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.canceled + ] + assert len(canceled_events) == 1 + assert "cancelled" in canceled_events[0].status.message.parts[0].root.text.lower() + + +@pytest.mark.asyncio +async def test_cancel_without_task_raises_unsupported(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that cancel raises UnsupportedOperationError when no task exists.""" + executor = StrandsA2AExecutor(mock_strands_agent) + mock_request_context.current_task = None + + with pytest.raises(ServerError) as excinfo: + await executor.cancel(mock_request_context, mock_event_queue) + + assert isinstance(excinfo.value.error, UnsupportedOperationError) + + +@pytest.mark.asyncio +async def test_execute_with_interrupt_transitions_to_input_required( + mock_strands_agent, mock_request_context, mock_event_queue +): + """Test that agent interrupts map to input_required state.""" + from a2a.types import TaskState, TaskStatusUpdateEvent, TextPart + + from strands.interrupt import Interrupt + + # Create a mock result with interrupts + mock_result = MagicMock(spec=SAAgentResult) + mock_result.stop_reason = "interrupt" + mock_interrupt = Interrupt(id="int-1", name="approval", reason="Need user approval") + mock_result.interrupts = [mock_interrupt] + + async def mock_stream(content_blocks, **kwargs): + yield {"data": "Processing..."} + yield {"result": mock_result} + + mock_strands_agent.stream_async = MagicMock(side_effect=mock_stream) + + executor = StrandsA2AExecutor(mock_strands_agent) + + mock_task = MagicMock() + mock_task.id = "task-interrupt" + mock_task.context_id = "ctx-interrupt" + mock_request_context.current_task = mock_task + + mock_text_part = MagicMock(spec=TextPart) + mock_text_part.text = "delete file X" + mock_part = MagicMock() + mock_part.root = mock_text_part + mock_message = MagicMock() + mock_message.parts = [mock_part] + mock_request_context.message = mock_message + + await executor.execute(mock_request_context, mock_event_queue) + + # Verify input_required state was enqueued + enqueued_events = [call[0][0] for call in mock_event_queue.enqueue_event.call_args_list] + input_required_events = [ + e + for e in enqueued_events + if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.input_required + ] + assert len(input_required_events) == 1 + msg_text = input_required_events[0].status.message.parts[0].root.text + assert "approval" in msg_text + assert "Need user approval" in msg_text + + +@pytest.mark.asyncio +async def test_execute_with_multiple_interrupts(mock_strands_agent, mock_request_context, mock_event_queue): + """Test handling of multiple interrupts in a single result.""" + from a2a.types import TaskState, TaskStatusUpdateEvent, TextPart + + from strands.interrupt import Interrupt + + mock_result = MagicMock(spec=SAAgentResult) + mock_result.stop_reason = "interrupt" + mock_result.interrupts = [ + Interrupt(id="int-1", name="confirm_delete", reason="Confirm deletion of file X"), + Interrupt(id="int-2", name="select_backup", reason="Choose backup location"), + ] + + async def mock_stream(content_blocks, **kwargs): + yield {"result": mock_result} + + mock_strands_agent.stream_async = MagicMock(side_effect=mock_stream) + + executor = StrandsA2AExecutor(mock_strands_agent) + + mock_task = MagicMock() + mock_task.id = "task-multi-int" + mock_task.context_id = "ctx-multi-int" + mock_request_context.current_task = mock_task + + mock_text_part = MagicMock(spec=TextPart) + mock_text_part.text = "delete with backup" + mock_part = MagicMock() + mock_part.root = mock_text_part + mock_message = MagicMock() + mock_message.parts = [mock_part] + mock_request_context.message = mock_message + + await executor.execute(mock_request_context, mock_event_queue) + + enqueued_events = [call[0][0] for call in mock_event_queue.enqueue_event.call_args_list] + input_required_events = [ + e + for e in enqueued_events + if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.input_required + ] + assert len(input_required_events) == 1 + msg_text = input_required_events[0].status.message.parts[0].root.text + assert "confirm_delete" in msg_text + assert "select_backup" in msg_text + assert "Confirm deletion of file X" in msg_text + assert "Choose backup location" in msg_text + + +@pytest.mark.asyncio +async def test_execute_normal_completion_no_interrupts(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that normal completion (no interrupts) still works as before.""" + from a2a.types import TaskState, TaskStatusUpdateEvent, TextPart + + mock_result = MagicMock(spec=SAAgentResult) + mock_result.stop_reason = "end_turn" + mock_result.interrupts = None + mock_result.__str__ = MagicMock(return_value="Task completed successfully") + + async def mock_stream(content_blocks, **kwargs): + yield {"data": "Working..."} + yield {"result": mock_result} + + mock_strands_agent.stream_async = MagicMock(side_effect=mock_stream) + + executor = StrandsA2AExecutor(mock_strands_agent) + + mock_task = MagicMock() + mock_task.id = "task-normal" + mock_task.context_id = "ctx-normal" + mock_request_context.current_task = mock_task + + mock_text_part = MagicMock(spec=TextPart) + mock_text_part.text = "do something" + mock_part = MagicMock() + mock_part.root = mock_text_part + mock_message = MagicMock() + mock_message.parts = [mock_part] + mock_request_context.message = mock_message + + await executor.execute(mock_request_context, mock_event_queue) + + # Verify completed state was enqueued (not input_required) + enqueued_events = [call[0][0] for call in mock_event_queue.enqueue_event.call_args_list] + completed_events = [ + e for e in enqueued_events if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.completed + ] + assert len(completed_events) == 1 + + # Verify no input_required events + input_required_events = [ + e + for e in enqueued_events + if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.input_required + ] + assert len(input_required_events) == 0 + + +@pytest.mark.asyncio +async def test_execute_setup_failure_raises_server_error(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that setup failures (missing message) still raise ServerError.""" + executor = StrandsA2AExecutor(mock_strands_agent) + + mock_task = MagicMock() + mock_task.id = "task-setup-fail" + mock_task.context_id = "ctx-setup-fail" + mock_request_context.current_task = mock_task + + # No message at all + mock_request_context.message = None + + with pytest.raises(ServerError) as excinfo: + await executor.execute(mock_request_context, mock_event_queue) + + assert isinstance(excinfo.value.error, InternalError) + + +@pytest.mark.asyncio +async def test_execute_error_when_task_already_terminal(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that error during execution is handled gracefully when task is already in terminal state.""" + from a2a.types import TextPart + + # Make stream_async raise to trigger the error path + mock_strands_agent.stream_async = MagicMock(side_effect=Exception("Agent error")) + + executor = StrandsA2AExecutor(mock_strands_agent) + + mock_task = MagicMock() + mock_task.id = "task-already-done" + mock_task.context_id = "ctx-already-done" + mock_request_context.current_task = mock_task + + mock_text_part = MagicMock(spec=TextPart) + mock_text_part.text = "test" + mock_part = MagicMock() + mock_part.root = mock_text_part + mock_message = MagicMock() + mock_message.parts = [mock_part] + mock_request_context.message = mock_message + + # Patch TaskUpdater.failed to raise RuntimeError (simulating task already in terminal state) + with patch("strands.multiagent.a2a.executor.TaskUpdater") as MockTaskUpdater: + mock_updater = MagicMock() + mock_updater.failed = AsyncMock(side_effect=RuntimeError("Task is already in a terminal state")) + mock_updater.new_agent_message = MagicMock(return_value=MagicMock()) + MockTaskUpdater.return_value = mock_updater + + # Should NOT raise - handles RuntimeError gracefully + await executor.execute(mock_request_context, mock_event_queue) + + # Verify failed() was attempted + mock_updater.failed.assert_called_once() + + +@pytest.mark.asyncio +async def test_cancel_calls_agent_cancel_method(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that cancel() attempts to call agent.cancel() if available.""" + from a2a.types import TaskState, TaskStatusUpdateEvent + + # Give the agent a cancel method + mock_strands_agent.cancel = MagicMock() + + executor = StrandsA2AExecutor(mock_strands_agent) + + mock_task = MagicMock() + mock_task.id = "task-cancel-agent" + mock_task.context_id = "ctx-cancel-agent" + mock_request_context.current_task = mock_task + + await executor.cancel(mock_request_context, mock_event_queue) + + # Verify agent.cancel() was called + mock_strands_agent.cancel.assert_called_once() + + # Verify task state is canceled + enqueued_events = [call[0][0] for call in mock_event_queue.enqueue_event.call_args_list] + canceled_events = [ + e for e in enqueued_events if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.canceled + ] + assert len(canceled_events) == 1 + + +@pytest.mark.asyncio +async def test_cancel_handles_agent_cancel_exception(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that cancel() gracefully handles agent.cancel() raising an exception.""" + from a2a.types import TaskState, TaskStatusUpdateEvent + + # Give the agent a cancel method that raises + mock_strands_agent.cancel = MagicMock(side_effect=RuntimeError("Cannot cancel")) + + executor = StrandsA2AExecutor(mock_strands_agent) + + mock_task = MagicMock() + mock_task.id = "task-cancel-err" + mock_task.context_id = "ctx-cancel-err" + mock_request_context.current_task = mock_task + + # Should still succeed (agent cancel is best-effort) + await executor.cancel(mock_request_context, mock_event_queue) + + # Task should still be transitioned to canceled + enqueued_events = [call[0][0] for call in mock_event_queue.enqueue_event.call_args_list] + canceled_events = [ + e for e in enqueued_events if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.canceled + ] + assert len(canceled_events) == 1 + + +@pytest.mark.asyncio +async def test_cancel_raises_when_task_already_terminal(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that cancel() raises ServerError when task is already in a terminal state.""" + executor = StrandsA2AExecutor(mock_strands_agent) + + mock_task = MagicMock() + mock_task.id = "task-terminal" + mock_task.context_id = "ctx-terminal" + mock_request_context.current_task = mock_task + + # Patch TaskUpdater.cancel to raise RuntimeError (task already completed/failed) + with patch("strands.multiagent.a2a.executor.TaskUpdater") as MockTaskUpdater: + mock_updater = MagicMock() + mock_updater.cancel = AsyncMock(side_effect=RuntimeError("Task is already in a terminal state")) + mock_updater.new_agent_message = MagicMock(return_value=MagicMock()) + MockTaskUpdater.return_value = mock_updater + + with pytest.raises(ServerError) as excinfo: + await executor.cancel(mock_request_context, mock_event_queue) + + assert isinstance(excinfo.value.error, UnsupportedOperationError) + mock_updater.cancel.assert_called_once() + + +# ========================================================================= +# DEVIL'S ADVOCATE FINDINGS — Tests addressing review gaps +# ========================================================================= + + +@pytest.mark.asyncio +async def test_execute_handles_asyncio_cancelled_error(mock_strands_agent, mock_request_context, mock_event_queue): + """Critical Finding 1: asyncio.CancelledError transitions task to canceled state. + + asyncio.CancelledError is a BaseException (not Exception). It's raised when an asyncio + task is cancelled — e.g., HTTP client disconnect, server shutdown, task group cancellation. + Without explicit handling, the task would remain stuck in 'working' state forever (zombie). + + This test verifies the task transitions to 'canceled' before re-raising CancelledError. + """ + import asyncio + + from a2a.types import TaskState, TaskStatusUpdateEvent, TextPart + + async def mock_stream(content_blocks, **kwargs): + """Mock streaming that gets cancelled mid-stream.""" + yield {"data": "partial output"} + raise asyncio.CancelledError() + + mock_strands_agent.stream_async = MagicMock(side_effect=mock_stream) + + executor = StrandsA2AExecutor(mock_strands_agent) + + mock_task = MagicMock() + mock_task.id = "task-cancelled" + mock_task.context_id = "ctx-cancelled" + mock_request_context.current_task = mock_task + + mock_text_part = MagicMock(spec=TextPart) + mock_text_part.text = "test" + mock_part = MagicMock() + mock_part.root = mock_text_part + mock_message = MagicMock() + mock_message.parts = [mock_part] + mock_request_context.message = mock_message + + # CancelledError should be re-raised (framework needs to know task was cancelled) + with pytest.raises(asyncio.CancelledError): + await executor.execute(mock_request_context, mock_event_queue) + + # But BEFORE re-raising, the task should have been transitioned to canceled + enqueued_events = [call[0][0] for call in mock_event_queue.enqueue_event.call_args_list] + canceled_events = [ + e for e in enqueued_events if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.canceled + ] + assert len(canceled_events) == 1 + assert ( + "cancelled" in canceled_events[0].status.message.parts[0].root.text.lower() + or "connection termination" in canceled_events[0].status.message.parts[0].root.text.lower() + ) + + +@pytest.mark.asyncio +async def test_execute_asyncio_cancelled_when_task_already_terminal( + mock_strands_agent, mock_request_context, mock_event_queue +): + """Test CancelledError handling when task is already in a terminal state. + + If the task completed right before the cancellation arrives, the updater.cancel() + will raise RuntimeError. We should handle this gracefully and still re-raise CancelledError. + """ + import asyncio + + from a2a.types import TextPart + + async def mock_stream(content_blocks, **kwargs): + """Async generator that immediately raises CancelledError.""" + yield {"data": "partial"} # Must yield to be async generator + raise asyncio.CancelledError() + + mock_strands_agent.stream_async = MagicMock(side_effect=mock_stream) + + executor = StrandsA2AExecutor(mock_strands_agent) + + mock_task = MagicMock() + mock_task.id = "task-cancelled-terminal" + mock_task.context_id = "ctx-cancelled-terminal" + mock_request_context.current_task = mock_task + + mock_text_part = MagicMock(spec=TextPart) + mock_text_part.text = "test" + mock_part = MagicMock() + mock_part.root = mock_text_part + mock_message = MagicMock() + mock_message.parts = [mock_part] + mock_request_context.message = mock_message + + # Patch TaskUpdater to simulate task already in terminal state + with patch("strands.multiagent.a2a.executor.TaskUpdater") as MockTaskUpdater: + mock_updater = MagicMock() + mock_updater.cancel = AsyncMock(side_effect=RuntimeError("Task is already in a terminal state")) + mock_updater.update_status = AsyncMock() + mock_updater.add_artifact = AsyncMock() + mock_updater.new_agent_message = MagicMock(return_value=MagicMock()) + mock_updater.context_id = "ctx-cancelled-terminal" + mock_updater.task_id = "task-cancelled-terminal" + MockTaskUpdater.return_value = mock_updater + + # Should still re-raise CancelledError + with pytest.raises(asyncio.CancelledError): + await executor.execute(mock_request_context, mock_event_queue) + + # cancel() was attempted + mock_updater.cancel.assert_called_once() + + +@pytest.mark.asyncio +async def test_execute_with_interrupt_empty_list_transitions_to_input_required( + mock_strands_agent, mock_request_context, mock_event_queue +): + """Critical Finding 2: stop_reason='interrupt' with empty interrupts list. + + The agent explicitly signaled it needs input (stop_reason="interrupt") but provided + no interrupt details. This should STILL transition to input_required — the stop_reason + is the authoritative signal. Previously this would silently complete the task. + """ + from a2a.types import TaskState, TaskStatusUpdateEvent, TextPart + + mock_result = MagicMock(spec=SAAgentResult) + mock_result.stop_reason = "interrupt" + mock_result.interrupts = [] # Empty list — previously this was falsy and caused completion! + + async def mock_stream(content_blocks, **kwargs): + yield {"result": mock_result} + + mock_strands_agent.stream_async = MagicMock(side_effect=mock_stream) + + executor = StrandsA2AExecutor(mock_strands_agent) + + mock_task = MagicMock() + mock_task.id = "task-empty-interrupts" + mock_task.context_id = "ctx-empty-interrupts" + mock_request_context.current_task = mock_task + + mock_text_part = MagicMock(spec=TextPart) + mock_text_part.text = "do something" + mock_part = MagicMock() + mock_part.root = mock_text_part + mock_message = MagicMock() + mock_message.parts = [mock_part] + mock_request_context.message = mock_message + + await executor.execute(mock_request_context, mock_event_queue) + + # Should transition to input_required, NOT completed + enqueued_events = [call[0][0] for call in mock_event_queue.enqueue_event.call_args_list] + input_required_events = [ + e + for e in enqueued_events + if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.input_required + ] + completed_events = [ + e for e in enqueued_events if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.completed + ] + + assert len(input_required_events) == 1, "Empty interrupts list should still trigger input_required" + assert len(completed_events) == 0, "Should NOT complete when stop_reason='interrupt'" + # Verify the fallback message is used + assert "additional input" in input_required_events[0].status.message.parts[0].root.text.lower() + + +@pytest.mark.asyncio +async def test_execute_with_interrupt_none_list_transitions_to_input_required( + mock_strands_agent, mock_request_context, mock_event_queue +): + """Edge case: stop_reason='interrupt' with interrupts=None. + + Same logic — the stop_reason is authoritative. None interrupts should + still result in input_required transition. + """ + from a2a.types import TaskState, TaskStatusUpdateEvent, TextPart + + mock_result = MagicMock(spec=SAAgentResult) + mock_result.stop_reason = "interrupt" + mock_result.interrupts = None # None, not empty list + + async def mock_stream(content_blocks, **kwargs): + yield {"result": mock_result} + + mock_strands_agent.stream_async = MagicMock(side_effect=mock_stream) + + executor = StrandsA2AExecutor(mock_strands_agent) + + mock_task = MagicMock() + mock_task.id = "task-none-interrupts" + mock_task.context_id = "ctx-none-interrupts" + mock_request_context.current_task = mock_task + + mock_text_part = MagicMock(spec=TextPart) + mock_text_part.text = "do something" + mock_part = MagicMock() + mock_part.root = mock_text_part + mock_message = MagicMock() + mock_message.parts = [mock_part] + mock_request_context.message = mock_message + + await executor.execute(mock_request_context, mock_event_queue) + + enqueued_events = [call[0][0] for call in mock_event_queue.enqueue_event.call_args_list] + input_required_events = [ + e + for e in enqueued_events + if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.input_required + ] + assert len(input_required_events) == 1 + + +@pytest.mark.asyncio +async def test_cancel_without_hasattr_cancel(mock_strands_agent, mock_request_context, mock_event_queue): + """Test cancel works when agent doesn't have cancel() method (AttributeError).""" + from a2a.types import TaskState, TaskStatusUpdateEvent + + # Remove cancel method entirely + del mock_strands_agent.cancel + + executor = StrandsA2AExecutor(mock_strands_agent) + + mock_task = MagicMock() + mock_task.id = "task-no-cancel-method" + mock_task.context_id = "ctx-no-cancel-method" + mock_request_context.current_task = mock_task + + # Should succeed — AttributeError from agent.cancel() is caught + await executor.cancel(mock_request_context, mock_event_queue) + + # Task should still be transitioned to canceled + enqueued_events = [call[0][0] for call in mock_event_queue.enqueue_event.call_args_list] + canceled_events = [ + e for e in enqueued_events if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.canceled + ] + assert len(canceled_events) == 1 diff --git a/tests/strands/multiagent/a2a/test_server.py b/tests/strands/multiagent/a2a/test_server.py index 00dd164b5..aeb882b19 100644 --- a/tests/strands/multiagent/a2a/test_server.py +++ b/tests/strands/multiagent/a2a/test_server.py @@ -852,3 +852,174 @@ def test_serve_at_root_edge_cases(mock_strands_agent): ) assert server3.mount_path == "" assert server3.http_url == "http://api.example.com/v1/agents/team1/agent1/" + + +def test_to_starlette_app_with_app_kwargs(mock_strands_agent): + """Test that to_starlette_app passes app_kwargs to the Starlette constructor.""" + mock_strands_agent.tool_registry.get_all_tools_config.return_value = {} + + a2a_agent = A2AServer(mock_strands_agent, skills=[]) + + app = a2a_agent.to_starlette_app(app_kwargs={"debug": True}) + + assert isinstance(app, Starlette) + assert app.debug is True + + +def test_to_fastapi_app_with_app_kwargs(mock_strands_agent): + """Test that to_fastapi_app passes app_kwargs to the FastAPI constructor.""" + mock_strands_agent.tool_registry.get_all_tools_config.return_value = {} + + a2a_agent = A2AServer(mock_strands_agent, skills=[]) + + app = a2a_agent.to_fastapi_app(app_kwargs={"title": "Custom Agent Title"}) + + assert isinstance(app, FastAPI) + assert app.title == "Custom Agent Title" + + +@patch("uvicorn.run") +def test_serve_with_overridden_host_port_updates_agent_card_url(mock_run, mock_strands_agent): + """Test that serve() with host/port overrides updates the agent card URL. + + This test verifies the fix for issue #1258 where specifying host/port in serve() + did not update the agent card URL, causing clients to fail when trying to connect. + """ + mock_strands_agent.tool_registry.get_all_tools_config.return_value = {} + + a2a_agent = A2AServer(mock_strands_agent, skills=[]) + + # Verify initial URL from constructor defaults + assert a2a_agent.http_url == "http://127.0.0.1:9000/" + assert a2a_agent.public_base_url == "http://127.0.0.1:9000" + + # Call serve with different host and port + a2a_agent.serve(host="localhost", port=9210) + + # Verify URL was updated to match the actual serve parameters + assert a2a_agent.http_url == "http://localhost:9210/" + assert a2a_agent.public_base_url == "http://localhost:9210" + assert a2a_agent.host == "localhost" + assert a2a_agent.port == 9210 + + # Verify the agent card reflects the updated URL + card = a2a_agent.public_agent_card + assert card.url == "http://localhost:9210/" + + # Verify uvicorn was called with the overridden parameters + mock_run.assert_called_once() + _, kwargs = mock_run.call_args + assert kwargs["host"] == "localhost" + assert kwargs["port"] == 9210 + + +@patch("uvicorn.run") +def test_serve_with_overridden_port_only_updates_url(mock_run, mock_strands_agent): + """Test that serve() with only port override updates the agent card URL.""" + mock_strands_agent.tool_registry.get_all_tools_config.return_value = {} + + a2a_agent = A2AServer(mock_strands_agent, skills=[]) + + # Call serve with different port only + a2a_agent.serve(port=8080) + + # Verify URL was updated with the new port + assert a2a_agent.http_url == "http://127.0.0.1:8080/" + assert a2a_agent.port == 8080 + + # Verify uvicorn was called with the correct parameters + mock_run.assert_called_once() + _, kwargs = mock_run.call_args + assert kwargs["host"] == "127.0.0.1" + assert kwargs["port"] == 8080 + + +@patch("uvicorn.run") +def test_serve_with_overridden_host_only_updates_url(mock_run, mock_strands_agent): + """Test that serve() with only host override updates the agent card URL.""" + mock_strands_agent.tool_registry.get_all_tools_config.return_value = {} + + a2a_agent = A2AServer(mock_strands_agent, skills=[]) + + # Call serve with different host only + a2a_agent.serve(host="0.0.0.0") + + # Verify URL was updated with the new host + assert a2a_agent.http_url == "http://0.0.0.0:9000/" + assert a2a_agent.host == "0.0.0.0" + + # Verify uvicorn was called with the correct parameters + mock_run.assert_called_once() + _, kwargs = mock_run.call_args + assert kwargs["host"] == "0.0.0.0" + assert kwargs["port"] == 9000 + + +@patch("uvicorn.run") +def test_serve_with_explicit_http_url_does_not_override_url(mock_run, mock_strands_agent): + """Test that serve() with host/port does not override explicitly set http_url. + + When a user explicitly sets http_url in the constructor (e.g., for load balancer scenarios), + the serve() method should NOT override the URL even if host/port are provided. + """ + mock_strands_agent.tool_registry.get_all_tools_config.return_value = {} + + # Create server with explicit http_url (simulating load balancer scenario) + a2a_agent = A2AServer( + mock_strands_agent, + host="0.0.0.0", + port=8080, + http_url="https://my-alb.amazonaws.com/agent1", + skills=[], + ) + + # Verify initial URL is the explicit one + assert a2a_agent.http_url == "https://my-alb.amazonaws.com/agent1/" + assert a2a_agent._http_url_explicit is True + + # Call serve with different host/port (the local binding) + a2a_agent.serve(host="0.0.0.0", port=9000) + + # Verify URL was NOT changed (explicit http_url should be preserved) + assert a2a_agent.http_url == "https://my-alb.amazonaws.com/agent1/" + assert a2a_agent.public_base_url == "https://my-alb.amazonaws.com" + + # But host/port should still be updated for the actual binding + assert a2a_agent.host == "0.0.0.0" + assert a2a_agent.port == 9000 + + # Verify the agent card still shows the public URL + card = a2a_agent.public_agent_card + assert card.url == "https://my-alb.amazonaws.com/agent1/" + + +@patch("uvicorn.run") +def test_serve_without_overrides_does_not_change_url(mock_run, mock_strands_agent): + """Test that serve() without host/port parameters does not modify the URL.""" + mock_strands_agent.tool_registry.get_all_tools_config.return_value = {} + + a2a_agent = A2AServer(mock_strands_agent, host="localhost", port=8000, skills=[]) + + # Verify initial URL + assert a2a_agent.http_url == "http://localhost:8000/" + + # Call serve without overrides + a2a_agent.serve() + + # Verify URL was NOT changed + assert a2a_agent.http_url == "http://localhost:8000/" + assert a2a_agent.host == "localhost" + assert a2a_agent.port == 8000 + + +def test_http_url_explicit_flag_set_correctly(mock_strands_agent): + """Test that _http_url_explicit flag is set correctly during initialization.""" + mock_strands_agent.tool_registry.get_all_tools_config.return_value = {} + + # Without explicit http_url + server1 = A2AServer(mock_strands_agent, skills=[]) + assert server1._http_url_explicit is False + + # With explicit http_url + server2 = A2AServer(mock_strands_agent, http_url="http://example.com/agent", skills=[]) + assert server2._http_url_explicit is True diff --git a/tests/strands/multiagent/conftest.py b/tests/strands/multiagent/conftest.py index 85e0ef7fc..190dc4a91 100644 --- a/tests/strands/multiagent/conftest.py +++ b/tests/strands/multiagent/conftest.py @@ -1,16 +1,22 @@ import pytest -from strands.experimental.hooks.multiagent import BeforeNodeCallEvent -from strands.hooks import HookProvider +from strands.hooks import AfterNodeCallEvent, BeforeNodeCallEvent, HookProvider @pytest.fixture def interrupt_hook(): class Hook(HookProvider): + def __init__(self): + self.after_count = 0 + def register_hooks(self, registry): registry.add_callback(BeforeNodeCallEvent, self.interrupt) + registry.add_callback(AfterNodeCallEvent, self.cleanup) def interrupt(self, event): return event.interrupt("test_name", reason="test_reason") + def cleanup(self, event): + self.after_count += 1 + return Hook() diff --git a/tests/strands/multiagent/test_base.py b/tests/strands/multiagent/test_base.py index 4e8a5dd06..2fb2cc617 100644 --- a/tests/strands/multiagent/test_base.py +++ b/tests/strands/multiagent/test_base.py @@ -156,6 +156,7 @@ def deserialize_state(self, payload: dict) -> None: assert isinstance(agent, MultiAgentBase) +@pytest.mark.filterwarnings("ignore:`\\*\\*kwargs` parameter is deprecating:UserWarning") def test_multi_agent_base_call_method(): """Test that __call__ method properly delegates to invoke_async.""" diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py index 4875d1bec..a6085627c 100644 --- a/tests/strands/multiagent/test_graph.py +++ b/tests/strands/multiagent/test_graph.py @@ -1,14 +1,14 @@ import asyncio import time -from unittest.mock import AsyncMock, MagicMock, Mock, call, patch +from unittest.mock import ANY, AsyncMock, MagicMock, Mock, call, patch import pytest -from strands.agent import Agent, AgentResult +from strands.agent import Agent, AgentBase, AgentResult from strands.agent.state import AgentState -from strands.experimental.hooks.multiagent import BeforeNodeCallEvent -from strands.hooks import AgentInitializedEvent +from strands.hooks import AgentInitializedEvent, BeforeNodeCallEvent from strands.hooks.registry import HookProvider, HookRegistry +from strands.interrupt import Interrupt, _InterruptState from strands.multiagent.base import MultiAgentBase, MultiAgentResult, NodeResult from strands.multiagent.graph import Graph, GraphBuilder, GraphEdge, GraphNode, GraphResult, GraphState, Status from strands.session.file_session_manager import FileSessionManager @@ -23,6 +23,10 @@ def create_mock_agent(name, response_text="Default response", metrics=None, agen agent.id = agent_id or f"{name}_id" agent._session_manager = None agent.hooks = HookRegistry() + agent.state = AgentState() + agent.messages = [] + agent._interrupt_state = _InterruptState() + agent._model_state = {} if metrics is None: metrics = Mock( @@ -1100,9 +1104,6 @@ async def test_state_reset_only_with_cycles_enabled(): # Create GraphNode node = GraphNode("test_node", agent) - # Simulate agent being in completed_nodes (as if revisited) - from strands.multiagent.graph import GraphState - state = GraphState() state.completed_nodes.add(node) @@ -1986,7 +1987,10 @@ async def stream_without_result(*args, **kwargs): @pytest.mark.asyncio async def test_graph_persisted(mock_strands_tracer, mock_use_span): - """Test graph persistence functionality.""" + """Test graph persistence functionality with multimodal input containing binary bytes.""" + import base64 + import json + # Create mock session manager session_manager = Mock(spec=FileSessionManager) session_manager.read_multi_agent().return_value = None @@ -2004,23 +2008,78 @@ async def test_graph_persisted(mock_strands_tracer, mock_use_span): state = graph.serialize_state() assert state["type"] == "graph" assert state["id"] == "default_graph" + assert state["_internal_state"] == { + "interrupt_state": {"activated": False, "context": {}, "interrupts": {}}, + } assert "status" in state assert "completed_nodes" in state assert "node_results" in state - # Test apply_state_from_dict with persisted state + # Build a multimodal prompt with inline binary PDF bytes (the problematic case) + pdf_bytes = b"%PDF-1.4 binary content" + multimodal_task = [ + {"text": "Analyze this PDF"}, + { + "document": { + "format": "pdf", + "name": "document.pdf", + "source": { + "bytes": pdf_bytes, + }, + } + }, + ] + + # Simulate graph having executed with a multimodal task + graph.state.task = multimodal_task + + # serialize_state must not raise TypeError for bytes + serialized = graph.serialize_state() + assert json.dumps(serialized) # must be JSON-serializable + + # The bytes should be encoded in the serialized form + encoded_bytes = serialized["current_task"][1]["document"]["source"]["bytes"] + assert encoded_bytes == {"__bytes_encoded__": True, "data": base64.b64encode(pdf_bytes).decode()} + + # deserialize_state must restore bytes back to original + serialized["next_nodes_to_execute"] = ["test_node"] + serialized["status"] = "executing" + graph.deserialize_state(serialized) + restored_bytes = graph.state.task[1]["document"]["source"]["bytes"] + assert restored_bytes == pdf_bytes + + # Test apply_state_from_dict with plain string persisted state (backward compat) persisted_state = { "status": "executing", "completed_nodes": [], "failed_nodes": [], + "interrupted_nodes": [], "node_results": {}, "current_task": "persisted task", "execution_order": [], "next_nodes_to_execute": ["test_node"], + "_internal_state": { + "interrupt_state": { + "activated": False, + "context": {"a": 1}, + "interrupts": { + "i1": { + "id": "i1", + "name": "test_name", + "reason": "test_reason", + }, + }, + }, + }, } graph.deserialize_state(persisted_state) assert graph.state.task == "persisted task" + assert graph._interrupt_state == _InterruptState( + activated=False, + context={"a": 1}, + interrupts={"i1": Interrupt(id="i1", name="test_name", reason="test_reason")}, + ) # Execute graph to test persistence integration result = await graph.invoke_async("Test persistence") @@ -2068,3 +2127,354 @@ def cancel_callback(event): tru_status = graph.state.status exp_status = Status.FAILED assert tru_status == exp_status + + +def test_graph_interrupt_on_before_node_call_event(interrupt_hook): + agent = create_mock_agent("test_agent", "Task completed") + + builder = GraphBuilder() + builder.add_node(agent, "test_agent") + builder.set_hook_providers([interrupt_hook]) + graph = builder.build() + + multiagent_result = graph("Test task") + + first_execution_time = multiagent_result.execution_time + + tru_result_status = multiagent_result.status + exp_result_status = Status.INTERRUPTED + assert tru_result_status == exp_result_status + + tru_state_status = graph.state.status + exp_state_status = Status.INTERRUPTED + assert tru_state_status == exp_state_status + + tru_node_ids = [node.node_id for node in graph.state.interrupted_nodes] + exp_node_ids = ["test_agent"] + assert tru_node_ids == exp_node_ids + + tru_interrupts = multiagent_result.interrupts + exp_interrupts = [ + Interrupt( + id=ANY, + name="test_name", + reason="test_reason", + ), + ] + assert tru_interrupts == exp_interrupts + + tru_after_count = interrupt_hook.after_count + exp_after_count = 0 + assert tru_after_count == exp_after_count + + interrupt = multiagent_result.interrupts[0] + responses = [ + { + "interruptResponse": { + "interruptId": interrupt.id, + "response": "test_response", + }, + }, + ] + multiagent_result = graph(responses) + + tru_result_status = multiagent_result.status + exp_result_status = Status.COMPLETED + assert tru_result_status == exp_result_status + + tru_state_status = graph.state.status + exp_state_status = Status.COMPLETED + assert tru_state_status == exp_state_status + + assert len(multiagent_result.results) == 1 + agent_result = multiagent_result.results["test_agent"] + + tru_message = agent_result.result.message["content"][0]["text"] + exp_message = "Task completed" + assert tru_message == exp_message + + tru_after_count = interrupt_hook.after_count + exp_after_count = 1 + assert tru_after_count == exp_after_count + + assert multiagent_result.execution_time >= first_execution_time + + +def test_graph_interrupt_on_agent(agenerator): + exp_interrupts = [ + Interrupt( + id="test_id", + name="test_name", + reason="test_reason", + ) + ] + + agent = create_mock_agent("test_agent", "Task completed") + agent.stream_async = Mock() + agent.stream_async.return_value = agenerator( + [ + { + "result": AgentResult( + message={}, + stop_reason="interrupt", + state={}, + metrics=None, + interrupts=exp_interrupts, + ), + }, + ], + ) + + builder = GraphBuilder() + builder.add_node(agent, "test_agent") + graph = builder.build() + + multiagent_result = graph("Test task") + + tru_result_status = multiagent_result.status + exp_result_status = Status.INTERRUPTED + assert tru_result_status == exp_result_status + + tru_state_status = graph.state.status + exp_state_status = Status.INTERRUPTED + assert tru_state_status == exp_state_status + + tru_node_ids = [node.node_id for node in graph.state.interrupted_nodes] + exp_node_ids = ["test_agent"] + assert tru_node_ids == exp_node_ids + + tru_interrupts = multiagent_result.interrupts + assert tru_interrupts == exp_interrupts + + interrupt = multiagent_result.interrupts[0] + + agent.stream_async = Mock() + agent.stream_async.return_value = agenerator( + [ + { + "result": AgentResult( + message={}, + stop_reason="end_turn", + state={}, + metrics=None, + ), + }, + ], + ) + graph._interrupt_state.context["test_agent"] = { + "from_hook": False, + "interrupt_ids": [interrupt.id], + "interrupt_state": { + "activated": True, + "context": {}, + "interrupts": {interrupt.id: interrupt.to_dict()}, + }, + "messages": [], + "state": {}, + "model_state": {}, + } + + responses = [ + { + "interruptResponse": { + "interruptId": interrupt.id, + "response": "test_response", + }, + }, + ] + multiagent_result = graph(responses) + + tru_result_status = multiagent_result.status + exp_result_status = Status.COMPLETED + assert tru_result_status == exp_result_status + + tru_state_status = graph.state.status + exp_state_status = Status.COMPLETED + assert tru_state_status == exp_state_status + + assert len(multiagent_result.results) == 1 + + agent.stream_async.assert_called_once_with(responses, invocation_state={}) + + +def test_graph_interrupt_on_multiagent(agenerator): + exp_interrupts = [ + Interrupt( + id="test_id", + name="test_name", + reason="test_reason", + ) + ] + + multiagent = create_mock_multi_agent("test_multiagent", "Multi-agent completed") + multiagent.stream_async = Mock() + multiagent.stream_async.return_value = agenerator( + [ + { + "result": MultiAgentResult( + results={}, + status=Status.INTERRUPTED, + interrupts=exp_interrupts, + ), + }, + ], + ) + + builder = GraphBuilder() + builder.add_node(multiagent, "test_multiagent") + graph = builder.build() + + multiagent_result = graph("Test task") + + tru_result_status = multiagent_result.status + exp_result_status = Status.INTERRUPTED + assert tru_result_status == exp_result_status + + tru_state_status = graph.state.status + exp_state_status = Status.INTERRUPTED + assert tru_state_status == exp_state_status + + tru_node_ids = [node.node_id for node in graph.state.interrupted_nodes] + exp_node_ids = ["test_multiagent"] + assert tru_node_ids == exp_node_ids + + tru_interrupts = multiagent_result.interrupts + assert tru_interrupts == exp_interrupts + + interrupt = multiagent_result.interrupts[0] + + multiagent.stream_async = Mock() + multiagent.stream_async.return_value = agenerator( + [ + { + "result": MultiAgentResult( + results={ + "inner_node": NodeResult( + result=AgentResult( + message={"role": "assistant", "content": [{"text": "Inner completed"}]}, + stop_reason="end_turn", + state={}, + metrics={}, + ) + ) + }, + status=Status.COMPLETED, + ), + }, + ], + ) + graph._interrupt_state.context["test_multiagent"] = { + "from_hook": False, + "interrupt_ids": [interrupt.id], + } + + responses = [ + { + "interruptResponse": { + "interruptId": interrupt.id, + "response": "test_response", + }, + }, + ] + multiagent_result = graph(responses) + + tru_result_status = multiagent_result.status + exp_result_status = Status.COMPLETED + assert tru_result_status == exp_result_status + + tru_state_status = graph.state.status + exp_state_status = Status.COMPLETED + assert tru_state_status == exp_state_status + + assert len(multiagent_result.results) == 1 + + multiagent.stream_async.assert_called_once_with(responses, {}) + + +@pytest.mark.asyncio +async def test_graph_with_agentbase_implementation(mock_strands_tracer, mock_use_span): + """Test that Graph accepts any AgentBase implementation (not just Agent).""" + + # Create a minimal AgentBase implementation + class CustomAgentBase: + """Custom AgentBase implementation for testing.""" + + def __init__(self, name: str, response_text: str): + self.name = name + self.id = f"{name}_id" + self._response_text = response_text + + def __call__(self, prompt=None, **kwargs): + return AgentResult( + message={"role": "assistant", "content": [{"text": self._response_text}]}, + stop_reason="end_turn", + state={}, + metrics=Mock( + accumulated_usage={"inputTokens": 10, "outputTokens": 20, "totalTokens": 30}, + accumulated_metrics={"latencyMs": 100.0}, + ), + ) + + async def invoke_async(self, prompt=None, **kwargs): + return self(prompt, **kwargs) + + async def stream_async(self, prompt=None, **kwargs): + yield {"start": True} + yield {"result": self(prompt, **kwargs)} + + # Verify it satisfies AgentBase protocol + custom_agent = CustomAgentBase("custom", "Custom response") + assert isinstance(custom_agent, AgentBase) + + # Create a regular mock agent + regular_agent = create_mock_agent("regular", "Regular response") + + # Build graph with both + builder = GraphBuilder() + builder.add_node(custom_agent, "custom_node") + builder.add_node(regular_agent, "regular_node") + builder.add_edge("custom_node", "regular_node") + builder.set_entry_point("custom_node") + graph = builder.build() + + result = await graph.invoke_async("Test task") + + assert result.status == Status.COMPLETED + assert result.completed_nodes == 2 + assert "custom_node" in result.results + assert "regular_node" in result.results + + +def test_find_newly_ready_nodes_only_evaluates_outbound_edges(): + """Verify _find_newly_ready_nodes only checks destinations of outbound edges from completed batch. + + Previously, it iterated over ALL nodes, which could cause nodes to fire + before their actual dependencies completed. + + See: https://github.com/strands-agents/sdk-python/issues/685 + """ + # Build a graph: A -> B -> C, D -> E (independent chain) + node_a = GraphNode(node_id="A", executor=create_mock_agent("A")) + node_b = GraphNode(node_id="B", executor=create_mock_agent("B")) + node_c = GraphNode(node_id="C", executor=create_mock_agent("C")) + node_d = GraphNode(node_id="D", executor=create_mock_agent("D")) + node_e = GraphNode(node_id="E", executor=create_mock_agent("E")) + + graph = Graph.__new__(Graph) + graph.nodes = {"A": node_a, "B": node_b, "C": node_c, "D": node_d, "E": node_e} + graph.edges = [ + GraphEdge(from_node=node_a, to_node=node_b), + GraphEdge(from_node=node_b, to_node=node_c), + GraphEdge(from_node=node_d, to_node=node_e), + ] + graph.state = GraphState() + + # When A completes, only B should be ready (not E) + ready = graph._find_newly_ready_nodes([node_a]) + ready_ids = {n.node_id for n in ready} + assert ready_ids == {"B"}, f"Expected only B, got {ready_ids}" + + # When D completes, only E should be ready (not B or C) + ready = graph._find_newly_ready_nodes([node_d]) + ready_ids = {n.node_id for n in ready} + assert ready_ids == {"E"}, f"Expected only E, got {ready_ids}" diff --git a/tests/strands/multiagent/test_multiagent_plugins.py b/tests/strands/multiagent/test_multiagent_plugins.py new file mode 100644 index 000000000..85cc8d817 --- /dev/null +++ b/tests/strands/multiagent/test_multiagent_plugins.py @@ -0,0 +1,283 @@ +"""Tests for MultiAgentPlugin integration with Swarm and Graph.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from strands.hooks import BeforeNodeCallEvent +from strands.hooks.registry import HookProvider +from strands.multiagent import GraphBuilder, Swarm +from strands.multiagent.graph import Graph, GraphNode +from strands.plugins import MultiAgentPlugin, hook + +# --- Fixtures --- + + +@pytest.fixture +def mock_swarm_agent(): + """Create a mock agent suitable for Swarm construction.""" + agent = MagicMock() + agent.name = "agent1" + agent.description = "Test agent" + agent.messages = [] + agent.state = MagicMock() + agent.state.get.return_value = {} + agent._model_state = {} + agent._session_manager = None + agent.tool_registry = MagicMock() + agent.tool_registry.get_all_tools_config.return_value = {} + return agent + + +@pytest.fixture +def mock_graph_agent(): + """Create a mock agent suitable for Graph construction.""" + agent = MagicMock() + agent.name = "agent1" + agent.messages = [] + agent.state = MagicMock() + agent.state.get.return_value = {} + agent._model_state = {} + agent._session_manager = None + return agent + + +def _make_swarm(agent, **kwargs): + """Helper to construct a Swarm with tracer patched out.""" + with patch("strands.multiagent.swarm.get_tracer"): + return Swarm(nodes=[agent], **kwargs) + + +def _make_graph(agent, **kwargs): + """Helper to construct a Graph with tracer patched out.""" + with patch("strands.multiagent.graph.get_tracer"): + node = GraphNode(node_id="agent1", executor=agent) + return Graph(nodes={"agent1": node}, edges=set(), entry_points={node}, **kwargs) + + +# --- Swarm plugin integration tests --- + + +def test_swarm_accepts_plugins_parameter(mock_swarm_agent): + """Test that Swarm constructor accepts a plugins parameter.""" + + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + swarm = _make_swarm(mock_swarm_agent, plugins=[MyPlugin()]) + assert swarm._plugin_registry is not None + + +def test_swarm_initializes_plugins(mock_swarm_agent): + """Test that Swarm calls init_multi_agent on plugins during construction.""" + init_called = False + + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + def init_multi_agent(self, orchestrator): + nonlocal init_called + init_called = True + + _make_swarm(mock_swarm_agent, plugins=[MyPlugin()]) + assert init_called + + +def test_swarm_registers_plugin_hooks(mock_swarm_agent): + """Test that Swarm registers plugin hooks with its hook registry.""" + + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + @hook + def on_before_node(self, event: BeforeNodeCallEvent): + pass + + swarm = _make_swarm(mock_swarm_agent, plugins=[MyPlugin()]) + assert len(swarm.hooks._registered_callbacks.get(BeforeNodeCallEvent, [])) == 1 + + +def test_swarm_plugins_coexist_with_hooks(mock_swarm_agent): + """Test that plugins and legacy hooks parameter work together.""" + + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + @hook + def on_before_node(self, event: BeforeNodeCallEvent): + pass + + class MyHookProvider(HookProvider): + def register_hooks(self, registry): + registry.add_callback(BeforeNodeCallEvent, self.on_before_node) + + def on_before_node(self, event): + pass + + swarm = _make_swarm(mock_swarm_agent, plugins=[MyPlugin()], hooks=[MyHookProvider()]) + assert len(swarm.hooks._registered_callbacks.get(BeforeNodeCallEvent, [])) == 2 + + +def test_swarm_duplicate_plugin_raises_error(mock_swarm_agent): + """Test that duplicate plugin names raise an error in Swarm.""" + + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + with pytest.raises(ValueError, match="plugin already registered"): + _make_swarm(mock_swarm_agent, plugins=[MyPlugin(), MyPlugin()]) + + +def test_swarm_no_plugins_parameter(mock_swarm_agent): + """Test that Swarm works without plugins parameter (backward compat).""" + swarm = _make_swarm(mock_swarm_agent) + assert swarm._plugin_registry is not None + + +# --- Graph plugin integration tests --- + + +def test_graph_builder_accepts_plugins(): + """Test that GraphBuilder has a set_plugins method.""" + + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + builder = GraphBuilder() + result = builder.set_plugins([MyPlugin()]) + assert result is builder + + +def test_graph_accepts_plugins_parameter(mock_graph_agent): + """Test that Graph constructor accepts a plugins parameter.""" + + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + graph = _make_graph(mock_graph_agent, plugins=[MyPlugin()]) + assert graph._plugin_registry is not None + + +def test_graph_initializes_plugins(mock_graph_agent): + """Test that Graph calls init_multi_agent on plugins during construction.""" + init_called = False + + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + def init_multi_agent(self, orchestrator): + nonlocal init_called + init_called = True + + _make_graph(mock_graph_agent, plugins=[MyPlugin()]) + assert init_called + + +def test_graph_registers_plugin_hooks(mock_graph_agent): + """Test that Graph registers plugin hooks with its hook registry.""" + + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + @hook + def on_before_node(self, event: BeforeNodeCallEvent): + pass + + graph = _make_graph(mock_graph_agent, plugins=[MyPlugin()]) + assert len(graph.hooks._registered_callbacks.get(BeforeNodeCallEvent, [])) == 1 + + +def test_graph_plugins_coexist_with_hooks(mock_graph_agent): + """Test that plugins and legacy hooks parameter work together in Graph.""" + + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + @hook + def on_before_node(self, event: BeforeNodeCallEvent): + pass + + class MyHookProvider(HookProvider): + def register_hooks(self, registry): + registry.add_callback(BeforeNodeCallEvent, self.on_before_node) + + def on_before_node(self, event): + pass + + graph = _make_graph(mock_graph_agent, plugins=[MyPlugin()], hooks=[MyHookProvider()]) + assert len(graph.hooks._registered_callbacks.get(BeforeNodeCallEvent, [])) == 2 + + +def test_graph_builder_passes_plugins_to_graph(mock_graph_agent): + """Test that GraphBuilder.build() passes plugins to the Graph constructor.""" + init_called = False + + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + def init_multi_agent(self, orchestrator): + nonlocal init_called + init_called = True + + with patch("strands.multiagent.graph.get_tracer"): + builder = GraphBuilder() + builder.add_node(mock_graph_agent, node_id="agent1") + builder.set_entry_point("agent1") + builder.set_plugins([MyPlugin()]) + graph = builder.build() + + assert init_called + assert graph._plugin_registry is not None + + +# --- add_hook method tests --- + + +def test_swarm_add_hook_registers_callback(mock_swarm_agent): + """Test that Swarm.add_hook registers a callback directly.""" + events_received = [] + + def on_before_node(event: BeforeNodeCallEvent): + events_received.append(event) + + swarm = _make_swarm(mock_swarm_agent) + swarm.add_hook(on_before_node, BeforeNodeCallEvent) + + assert len(swarm.hooks._registered_callbacks.get(BeforeNodeCallEvent, [])) == 1 + + +def test_graph_add_hook_registers_callback(mock_graph_agent): + """Test that Graph.add_hook registers a callback directly.""" + events_received = [] + + def on_before_node(event: BeforeNodeCallEvent): + events_received.append(event) + + graph = _make_graph(mock_graph_agent) + graph.add_hook(on_before_node, BeforeNodeCallEvent) + + assert len(graph.hooks._registered_callbacks.get(BeforeNodeCallEvent, [])) == 1 + + +def test_swarm_add_hook_infers_event_type(mock_swarm_agent): + """Test that Swarm.add_hook can infer event type from type hint.""" + + def on_before_node(event: BeforeNodeCallEvent): + pass + + swarm = _make_swarm(mock_swarm_agent) + swarm.add_hook(on_before_node) + + assert len(swarm.hooks._registered_callbacks.get(BeforeNodeCallEvent, [])) == 1 + + +def test_graph_add_hook_infers_event_type(mock_graph_agent): + """Test that Graph.add_hook can infer event type from type hint.""" + + def on_before_node(event: BeforeNodeCallEvent): + pass + + graph = _make_graph(mock_graph_agent) + graph.add_hook(on_before_node) + + assert len(graph.hooks._registered_callbacks.get(BeforeNodeCallEvent, [])) == 1 diff --git a/tests/strands/multiagent/test_swarm.py b/tests/strands/multiagent/test_swarm.py index f2abed9f7..cb0414b42 100644 --- a/tests/strands/multiagent/test_swarm.py +++ b/tests/strands/multiagent/test_swarm.py @@ -6,7 +6,7 @@ from strands.agent import Agent, AgentResult from strands.agent.state import AgentState -from strands.experimental.hooks.multiagent import BeforeNodeCallEvent +from strands.hooks import BeforeNodeCallEvent from strands.hooks.registry import HookRegistry from strands.interrupt import Interrupt, _InterruptState from strands.multiagent.base import Status @@ -25,6 +25,7 @@ def create_mock_agent(name, response_text="Default response", metrics=None, agen agent.messages = [] agent.state = AgentState() # Add state attribute agent._interrupt_state = _InterruptState() # Add interrupt state + agent._model_state = {} # Add model state agent.tool_registry = Mock() agent.tool_registry.registry = {} agent.tool_registry.process_tools = Mock() @@ -1106,7 +1107,10 @@ async def failing_execute_swarm(*args, **kwargs): @pytest.mark.asyncio async def test_swarm_persistence(mock_strands_tracer, mock_use_span): - """Test swarm persistence functionality.""" + """Test swarm persistence functionality with multimodal input containing binary bytes.""" + import base64 + import json + # Create mock session manager session_manager = Mock(spec=FileSessionManager) session_manager.read_multi_agent.return_value = None @@ -1127,7 +1131,40 @@ async def test_swarm_persistence(mock_strands_tracer, mock_use_span): assert "node_results" in state assert "context" in state - # Test apply_state_from_dict with persisted state + # Build a multimodal prompt with inline binary PDF bytes (the problematic case) + pdf_bytes = b"%PDF-1.4 binary content" + multimodal_task = [ + {"text": "Analyze this PDF"}, + { + "document": { + "format": "pdf", + "name": "document.pdf", + "source": { + "bytes": pdf_bytes, + }, + } + }, + ] + + # Simulate swarm having executed with a multimodal task + swarm.state.task = multimodal_task + + # serialize_state must not raise TypeError for bytes + serialized = swarm.serialize_state() + assert json.dumps(serialized) # must be JSON-serializable + + # The bytes should be encoded in the serialized form + encoded_bytes = serialized["current_task"][1]["document"]["source"]["bytes"] + assert encoded_bytes == {"__bytes_encoded__": True, "data": base64.b64encode(pdf_bytes).decode()} + + # deserialize_state must restore bytes back to original + serialized["next_nodes_to_execute"] = ["test_agent"] + serialized["status"] = "executing" + swarm.deserialize_state(serialized) + restored_bytes = swarm.state.task[1]["document"]["source"]["bytes"] + assert restored_bytes == pdf_bytes + + # Test apply_state_from_dict with plain string persisted state (backward compat) persisted_state = { "status": "executing", "node_history": [], @@ -1243,6 +1280,8 @@ def test_swarm_interrupt_on_before_node_call_event(interrupt_hook): multiagent_result = swarm("Test task") + first_execution_time = multiagent_result.execution_time + tru_status = multiagent_result.status exp_status = Status.INTERRUPTED assert tru_status == exp_status @@ -1257,6 +1296,10 @@ def test_swarm_interrupt_on_before_node_call_event(interrupt_hook): ] assert tru_interrupts == exp_interrupts + tru_after_count = interrupt_hook.after_count + exp_after_count = 0 + assert tru_after_count == exp_after_count + interrupt = multiagent_result.interrupts[0] responses = [ { @@ -1279,6 +1322,12 @@ def test_swarm_interrupt_on_before_node_call_event(interrupt_hook): exp_message = "Task completed" assert tru_message == exp_message + tru_after_count = interrupt_hook.after_count + exp_after_count = 1 + assert tru_after_count == exp_after_count + + assert multiagent_result.execution_time >= first_execution_time + def test_swarm_interrupt_on_agent(agenerator): exp_interrupts = [ diff --git a/tests/strands/plugins/__init__.py b/tests/strands/plugins/__init__.py new file mode 100644 index 000000000..6b722411e --- /dev/null +++ b/tests/strands/plugins/__init__.py @@ -0,0 +1 @@ +"""Tests for the plugins module.""" diff --git a/tests/strands/plugins/test_hook_decorator.py b/tests/strands/plugins/test_hook_decorator.py new file mode 100644 index 000000000..d05e79edb --- /dev/null +++ b/tests/strands/plugins/test_hook_decorator.py @@ -0,0 +1,243 @@ +"""Tests for the @hook decorator.""" + +import unittest.mock + +import pytest + +from strands.hooks import ( + AfterInvocationEvent, + AfterModelCallEvent, + BeforeInvocationEvent, + BeforeModelCallEvent, +) +from strands.plugins.decorator import hook + + +class TestHookDecoratorBasic: + """Tests for basic @hook decorator functionality.""" + + def test_hook_decorator_marks_method(self): + """Test that @hook marks a method with hook metadata.""" + + @hook + def on_before_model_call(event: BeforeModelCallEvent): + pass + + assert hasattr(on_before_model_call, "_hook_event_types") + assert BeforeModelCallEvent in on_before_model_call._hook_event_types + + def test_hook_decorator_with_parentheses(self): + """Test that @hook() syntax also works.""" + + @hook() + def on_before_model_call(event: BeforeModelCallEvent): + pass + + assert hasattr(on_before_model_call, "_hook_event_types") + assert BeforeModelCallEvent in on_before_model_call._hook_event_types + + def test_hook_decorator_preserves_function_metadata(self): + """Test that @hook preserves the original function's metadata.""" + + @hook + def on_before_model_call(event: BeforeModelCallEvent): + """Docstring for the hook.""" + pass + + assert on_before_model_call.__name__ == "on_before_model_call" + assert on_before_model_call.__doc__ == "Docstring for the hook." + + def test_hook_decorator_function_still_callable(self): + """Test that decorated function can still be called normally.""" + call_count = 0 + + @hook + def on_before_model_call(event: BeforeModelCallEvent): + nonlocal call_count + call_count += 1 + + mock_event = unittest.mock.MagicMock(spec=BeforeModelCallEvent) + on_before_model_call(mock_event) + assert call_count == 1 + + +class TestHookDecoratorEventTypeInference: + """Tests for event type inference from type hints.""" + + def test_hook_infers_event_type_from_type_hint(self): + """Test that @hook infers event type from the first parameter's type hint.""" + + @hook + def handler(event: BeforeInvocationEvent): + pass + + assert BeforeInvocationEvent in handler._hook_event_types + + def test_hook_infers_different_event_types(self): + """Test that different event types are correctly inferred.""" + + @hook + def handler1(event: BeforeModelCallEvent): + pass + + @hook + def handler2(event: AfterModelCallEvent): + pass + + @hook + def handler3(event: AfterInvocationEvent): + pass + + assert BeforeModelCallEvent in handler1._hook_event_types + assert AfterModelCallEvent in handler2._hook_event_types + assert AfterInvocationEvent in handler3._hook_event_types + + def test_hook_skips_cls_parameter(self): + """Test that @hook skips 'cls' parameter for classmethods.""" + + class MyClass: + @classmethod + @hook + def handler(cls, event: BeforeModelCallEvent): + pass + + assert BeforeModelCallEvent in MyClass.handler._hook_event_types + + +class TestHookDecoratorUnionTypes: + """Tests for union type support in @hook decorator.""" + + def test_hook_supports_union_types_with_pipe(self): + """Test that @hook supports union types using | syntax.""" + + @hook + def handler(event: BeforeModelCallEvent | AfterModelCallEvent): + pass + + assert BeforeModelCallEvent in handler._hook_event_types + assert AfterModelCallEvent in handler._hook_event_types + + def test_hook_supports_union_types_with_typing_union(self): + """Test that @hook supports Union[] syntax.""" + + @hook + def handler(event: BeforeModelCallEvent | AfterModelCallEvent): + pass + + assert BeforeModelCallEvent in handler._hook_event_types + assert AfterModelCallEvent in handler._hook_event_types + + def test_hook_supports_multiple_union_types(self): + """Test that @hook supports unions with more than two types.""" + + @hook + def handler(event: BeforeModelCallEvent | AfterModelCallEvent | BeforeInvocationEvent): + pass + + assert BeforeModelCallEvent in handler._hook_event_types + assert AfterModelCallEvent in handler._hook_event_types + assert BeforeInvocationEvent in handler._hook_event_types + + +class TestHookDecoratorErrorHandling: + """Tests for error handling in @hook decorator.""" + + def test_hook_raises_error_without_type_hint(self): + """Test that @hook raises error when no type hint is provided.""" + with pytest.raises(ValueError, match="cannot infer event type"): + + @hook + def handler(event): + pass + + def test_hook_raises_error_with_non_hook_event_type(self): + """Test that @hook raises error when type hint is not a HookEvent subclass.""" + with pytest.raises(ValueError, match="must be a subclass of BaseHookEvent"): + + @hook + def handler(event: str): + pass + + def test_hook_raises_error_with_none_in_union(self): + """Test that @hook raises error when union contains None.""" + with pytest.raises(ValueError, match="None is not a valid event type"): + + @hook + def handler(event: BeforeModelCallEvent | None): + pass + + +class TestHookDecoratorWithMethods: + """Tests for @hook decorator on class methods.""" + + def test_hook_works_on_instance_method(self): + """Test that @hook works correctly on instance methods.""" + + class MyClass: + @hook + def handler(self, event: BeforeModelCallEvent): + pass + + instance = MyClass() + assert hasattr(instance.handler, "_hook_event_types") + assert BeforeModelCallEvent in instance.handler._hook_event_types + + def test_hook_instance_method_is_callable(self): + """Test that decorated instance method can be called.""" + call_count = 0 + + class MyClass: + @hook + def handler(self, event: BeforeModelCallEvent): + nonlocal call_count + call_count += 1 + + instance = MyClass() + mock_event = unittest.mock.MagicMock(spec=BeforeModelCallEvent) + instance.handler(mock_event) + assert call_count == 1 + + def test_hook_method_accesses_self(self): + """Test that decorated method can access self.""" + + class MyClass: + def __init__(self): + self.events_received = [] + + @hook + def handler(self, event: BeforeModelCallEvent): + self.events_received.append(event) + + instance = MyClass() + mock_event = unittest.mock.MagicMock(spec=BeforeModelCallEvent) + instance.handler(mock_event) + assert len(instance.events_received) == 1 + assert instance.events_received[0] is mock_event + + +class TestHookDecoratorAsync: + """Tests for async functions with @hook decorator.""" + + def test_hook_works_on_async_function(self): + """Test that @hook works on async functions.""" + + @hook + async def handler(event: BeforeModelCallEvent): + pass + + assert hasattr(handler, "_hook_event_types") + assert BeforeModelCallEvent in handler._hook_event_types + + @pytest.mark.asyncio + async def test_hook_async_function_is_callable(self): + """Test that decorated async function can be awaited.""" + call_count = 0 + + @hook + async def handler(event: BeforeModelCallEvent): + nonlocal call_count + call_count += 1 + + mock_event = unittest.mock.MagicMock(spec=BeforeModelCallEvent) + await handler(mock_event) + assert call_count == 1 diff --git a/tests/strands/plugins/test_multiagent_plugin.py b/tests/strands/plugins/test_multiagent_plugin.py new file mode 100644 index 000000000..b7e16c9eb --- /dev/null +++ b/tests/strands/plugins/test_multiagent_plugin.py @@ -0,0 +1,563 @@ +"""Tests for the MultiAgentPlugin base class and registry.""" + +import gc +import unittest.mock + +import pytest + +from strands.hooks import AfterNodeCallEvent, BeforeNodeCallEvent, HookRegistry +from strands.plugins import Plugin, hook +from strands.plugins.multiagent_plugin import MultiAgentPlugin +from strands.plugins.multiagent_registry import _MultiAgentPluginRegistry +from strands.plugins.registry import _PluginRegistry + +# --- Fixtures --- + + +@pytest.fixture +def mock_orchestrator(): + """Create a mock orchestrator with a working hook registry.""" + orch = unittest.mock.MagicMock() + orch.hooks = HookRegistry() + orch.add_hook = unittest.mock.Mock( + side_effect=lambda callback, event_type=None: orch.hooks.add_callback(event_type, callback) + ) + return orch + + +@pytest.fixture +def registry(mock_orchestrator): + """Create a _MultiAgentPluginRegistry backed by the mock orchestrator.""" + return _MultiAgentPluginRegistry(mock_orchestrator) + + +@pytest.fixture +def mock_agent(): + """Create a mock agent with a working hook registry for dual-plugin tests.""" + agent = unittest.mock.MagicMock() + agent.hooks = HookRegistry() + agent.add_hook = unittest.mock.Mock( + side_effect=lambda callback, event_type=None: agent.hooks.add_callback(event_type, callback) + ) + agent.tool_registry = unittest.mock.MagicMock() + return agent + + +# --- MultiAgentPlugin base class tests --- + + +def test_multiagent_plugin_is_class_not_protocol(): + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + assert isinstance(MyPlugin(), MultiAgentPlugin) + + +def test_multiagent_plugin_requires_name_attribute(): + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + assert MyPlugin().name == "my-plugin" + + +def test_multiagent_plugin_name_as_property(): + class MyPlugin(MultiAgentPlugin): + @property + def name(self) -> str: + return "property-plugin" + + assert MyPlugin().name == "property-plugin" + + +def test_multiagent_plugin_requires_name(): + with pytest.raises(TypeError, match="Can't instantiate abstract class"): + + class PluginWithoutName(MultiAgentPlugin): + def init_multi_agent(self, orchestrator): + pass + + PluginWithoutName() + + +def test_multiagent_plugin_provides_default_init_multi_agent(): + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + assert MyPlugin().init_multi_agent(unittest.mock.MagicMock()) is None + + +# --- Auto-discovery tests --- + + +def test_discovers_hook_decorated_methods(): + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + @hook + def on_before_node(self, event: BeforeNodeCallEvent): + pass + + plugin = MyPlugin() + assert len(plugin.hooks) == 1 + assert plugin.hooks[0].__name__ == "on_before_node" + + +def test_discovers_multiple_hooks(): + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + @hook + def hook1(self, event: BeforeNodeCallEvent): + pass + + @hook + def hook2(self, event: AfterNodeCallEvent): + pass + + plugin = MyPlugin() + assert len(plugin.hooks) == 2 + assert {h.__name__ for h in plugin.hooks} == {"hook1", "hook2"} + + +def test_hooks_preserve_definition_order(): + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + @hook + def z_last(self, event: BeforeNodeCallEvent): + pass + + @hook + def a_first(self, event: BeforeNodeCallEvent): + pass + + plugin = MyPlugin() + assert [h.__name__ for h in plugin.hooks] == ["z_last", "a_first"] + + +def test_ignores_non_decorated_methods(): + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + def regular_method(self): + pass + + @hook + def decorated_hook(self, event: BeforeNodeCallEvent): + pass + + plugin = MyPlugin() + assert len(plugin.hooks) == 1 + assert plugin.hooks[0].__name__ == "decorated_hook" + + +def test_no_tool_support(): + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + assert not hasattr(MyPlugin(), "tools") + + +# --- Registry tests --- + + +def test_registry_add_and_init_calls_init_multi_agent(registry): + class TestPlugin(MultiAgentPlugin): + name = "test-plugin" + + def __init__(self): + super().__init__() + self.initialized = False + + def init_multi_agent(self, orchestrator): + self.initialized = True + + plugin = TestPlugin() + registry.add_and_init(plugin) + assert plugin.initialized + + +def test_registry_add_duplicate_raises_error(registry): + class TestPlugin(MultiAgentPlugin): + name = "test-plugin" + + registry.add_and_init(TestPlugin()) + with pytest.raises(ValueError, match="plugin_name= | plugin already registered"): + registry.add_and_init(TestPlugin()) + + +def test_registry_registers_discovered_hooks(mock_orchestrator, registry): + class TestPlugin(MultiAgentPlugin): + name = "test-plugin" + + @hook + def on_before_node(self, event: BeforeNodeCallEvent): + pass + + registry.add_and_init(TestPlugin()) + assert len(mock_orchestrator.hooks._registered_callbacks.get(BeforeNodeCallEvent, [])) == 1 + + +def test_registry_registers_multiple_hooks(mock_orchestrator, registry): + class TestPlugin(MultiAgentPlugin): + name = "test-plugin" + + @hook + def on_before_node(self, event: BeforeNodeCallEvent): + pass + + @hook + def on_after_node(self, event: AfterNodeCallEvent): + pass + + registry.add_and_init(TestPlugin()) + assert len(mock_orchestrator.hooks._registered_callbacks.get(BeforeNodeCallEvent, [])) == 1 + assert len(mock_orchestrator.hooks._registered_callbacks.get(AfterNodeCallEvent, [])) == 1 + + +def test_registry_async_init_multi_agent_supported(registry): + async_init_called = False + + class AsyncPlugin(MultiAgentPlugin): + name = "async-plugin" + + async def init_multi_agent(self, orchestrator): + nonlocal async_init_called + async_init_called = True + + registry.add_and_init(AsyncPlugin()) + assert async_init_called + + +def test_registry_hooks_are_bound_to_instance(mock_orchestrator, registry): + class TestPlugin(MultiAgentPlugin): + name = "test-plugin" + + def __init__(self): + super().__init__() + self.events_received = [] + + @hook + def on_before_node(self, event: BeforeNodeCallEvent): + self.events_received.append(event) + + plugin = TestPlugin() + registry.add_and_init(plugin) + + mock_event = unittest.mock.MagicMock(spec=BeforeNodeCallEvent) + mock_orchestrator.hooks._registered_callbacks[BeforeNodeCallEvent][0](mock_event) + + assert plugin.events_received == [mock_event] + + +def test_registry_raises_reference_error_after_orchestrator_collected(): + orch = unittest.mock.MagicMock() + orch.hooks = HookRegistry() + reg = _MultiAgentPluginRegistry(orch) + del orch + gc.collect() + + with pytest.raises(ReferenceError, match="Orchestrator has been garbage collected"): + _ = reg._orchestrator + + +def test_registry_init_multi_agent_called_before_hook_registration(mock_orchestrator): + call_order = [] + + class TestPlugin(MultiAgentPlugin): + name = "test-plugin" + + @hook + def on_before_node(self, event: BeforeNodeCallEvent): + pass + + def init_multi_agent(self, orchestrator): + call_order.append("init") + + original = mock_orchestrator.hooks.add_callback + + def tracking(event_type, callback): + call_order.append("hook") + return original(event_type, callback) + + mock_orchestrator.hooks.add_callback = tracking + + registry = _MultiAgentPluginRegistry(mock_orchestrator) + registry.add_and_init(TestPlugin()) + + assert call_order == ["init", "hook"] + + +# --- Union type tests --- + + +def test_registers_hook_for_union_types(mock_orchestrator, registry): + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + @hook + def on_node_events(self, event: BeforeNodeCallEvent | AfterNodeCallEvent): + pass + + registry.add_and_init(MyPlugin()) + assert len(mock_orchestrator.hooks._registered_callbacks.get(BeforeNodeCallEvent, [])) == 1 + assert len(mock_orchestrator.hooks._registered_callbacks.get(AfterNodeCallEvent, [])) == 1 + + +# --- Subclass override tests --- + + +def test_subclass_can_override_init_multi_agent(mock_orchestrator, registry): + custom_init_called = False + + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + @hook + def on_before_node(self, event: BeforeNodeCallEvent): + pass + + def init_multi_agent(self, orchestrator): + nonlocal custom_init_called + custom_init_called = True + + registry.add_and_init(MyPlugin()) + assert custom_init_called + assert len(mock_orchestrator.hooks._registered_callbacks.get(BeforeNodeCallEvent, [])) == 1 + + +def test_subclass_can_add_manual_hooks_in_init(mock_orchestrator, registry): + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + @hook + def auto_hook(self, event: BeforeNodeCallEvent): + pass + + def manual_hook(self, event: AfterNodeCallEvent): + pass + + def init_multi_agent(self, orchestrator): + orchestrator.hooks.add_callback(AfterNodeCallEvent, self.manual_hook) + + registry.add_and_init(MyPlugin()) + assert len(mock_orchestrator.hooks._registered_callbacks.get(BeforeNodeCallEvent, [])) == 1 + assert len(mock_orchestrator.hooks._registered_callbacks.get(AfterNodeCallEvent, [])) == 1 + + +# --- Inheritance tests --- + + +def test_child_inherits_parent_hooks(): + class ParentPlugin(MultiAgentPlugin): + name = "parent-plugin" + + @hook + def parent_hook(self, event: BeforeNodeCallEvent): + pass + + class ChildPlugin(ParentPlugin): + name = "child-plugin" + + @hook + def child_hook(self, event: AfterNodeCallEvent): + pass + + plugin = ChildPlugin() + assert len(plugin.hooks) == 2 + assert {h.__name__ for h in plugin.hooks} == {"parent_hook", "child_hook"} + + +def test_child_can_override_parent_hook(): + class ParentPlugin(MultiAgentPlugin): + name = "parent-plugin" + + @hook + def on_before_node(self, event: BeforeNodeCallEvent): + pass + + class ChildPlugin(ParentPlugin): + name = "child-plugin" + + @hook + def on_before_node(self, event: BeforeNodeCallEvent): + pass + + assert len(ChildPlugin().hooks) == 1 + + +# --- Dual plugin tests --- + + +def test_dual_plugin_isinstance_checks(): + class DualPlugin(Plugin, MultiAgentPlugin): + name = "dual-plugin" + + plugin = DualPlugin() + assert isinstance(plugin, Plugin) + assert isinstance(plugin, MultiAgentPlugin) + + +def test_dual_plugin_discovers_hooks_once(): + class DualPlugin(Plugin, MultiAgentPlugin): + name = "dual-plugin" + + @hook + def on_before_node(self, event: BeforeNodeCallEvent): + pass + + assert len(DualPlugin().hooks) == 1 + + +def test_dual_plugin_discover_hooks_called_once(monkeypatch): + """Verify the hasattr guard prevents discover_hooks from running twice in dual inheritance.""" + import strands.plugins.plugin as plugin_mod + + call_count = 0 + original = plugin_mod.discover_hooks + + def counting_discover_hooks(instance, plugin_name): + nonlocal call_count + call_count += 1 + return original(instance, plugin_name) + + monkeypatch.setattr(plugin_mod, "discover_hooks", counting_discover_hooks) + + class DualPlugin(Plugin, MultiAgentPlugin): + name = "dual-plugin" + + @hook + def on_before_node(self, event: BeforeNodeCallEvent): + pass + + DualPlugin() + # Plugin.__init__ calls discover_hooks once; MultiAgentPlugin.__init__ skips due to hasattr guard + assert call_count == 1 + + +def test_dual_plugin_has_both_init_methods(mock_agent, mock_orchestrator): + agent_init_called = False + multi_agent_init_called = False + + class DualPlugin(Plugin, MultiAgentPlugin): + name = "dual-plugin" + + def init_agent(self, agent): + nonlocal agent_init_called + agent_init_called = True + + def init_multi_agent(self, orchestrator): + nonlocal multi_agent_init_called + multi_agent_init_called = True + + _PluginRegistry(mock_agent).add_and_init(DualPlugin()) + assert agent_init_called + + _MultiAgentPluginRegistry(mock_orchestrator).add_and_init(DualPlugin()) + assert multi_agent_init_called + + +def test_dual_plugin_registers_hooks_in_both_contexts(mock_agent, mock_orchestrator): + from strands.hooks import BeforeModelCallEvent + + class DualPlugin(Plugin, MultiAgentPlugin): + name = "dual-plugin" + + @hook + def on_model_call(self, event: BeforeModelCallEvent): + pass + + @hook + def on_node_call(self, event: BeforeNodeCallEvent): + pass + + _PluginRegistry(mock_agent).add_and_init(DualPlugin()) + assert len(mock_agent.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) == 1 + assert len(mock_agent.hooks._registered_callbacks.get(BeforeNodeCallEvent, [])) == 1 + + _MultiAgentPluginRegistry(mock_orchestrator).add_and_init(DualPlugin()) + assert len(mock_orchestrator.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) == 1 + assert len(mock_orchestrator.hooks._registered_callbacks.get(BeforeNodeCallEvent, [])) == 1 + + +def test_dual_plugin_shared_state(mock_agent, mock_orchestrator): + class DualPlugin(Plugin, MultiAgentPlugin): + name = "dual-plugin" + + def __init__(self): + super().__init__() + self.call_count = 0 + + @hook + def on_before_node(self, event: BeforeNodeCallEvent): + self.call_count += 1 + + def init_agent(self, agent): + self.call_count += 10 + + def init_multi_agent(self, orchestrator): + self.call_count += 100 + + plugin = DualPlugin() + _PluginRegistry(mock_agent).add_and_init(plugin) + assert plugin.call_count == 10 + + _MultiAgentPluginRegistry(mock_orchestrator).add_and_init(plugin) + assert plugin.call_count == 110 + + +def test_dual_plugin_tools_only_for_agent(mock_agent, mock_orchestrator): + from strands.tools.decorator import tool + + class DualPlugin(Plugin, MultiAgentPlugin): + name = "dual-plugin" + + @tool + def my_tool(self, param: str) -> str: + """A test tool.""" + return param + + _PluginRegistry(mock_agent).add_and_init(DualPlugin()) + mock_agent.tool_registry.process_tools.assert_called_once() + + # Orchestrator has no tool registration + _MultiAgentPluginRegistry(mock_orchestrator).add_and_init(DualPlugin()) + + +# --- Double-discovery guard tests --- + + +def test_dual_plugin_hasattr_guard_prevents_double_discovery(): + """Test that the hasattr guard in __init__ prevents hooks from being discovered twice.""" + + class DualPlugin(Plugin, MultiAgentPlugin): + name = "dual-plugin" + + @hook + def shared_hook(self, event: BeforeNodeCallEvent): + pass + + plugin = DualPlugin() + # If double-discovery occurred, we'd see 2 hooks instead of 1 + assert len(plugin.hooks) == 1 + assert plugin.hooks[0].__name__ == "shared_hook" + + +def test_multiagent_plugin_hasattr_guard_with_pre_set_hooks(): + """Test that MultiAgentPlugin.__init__ skips discovery if _hooks already set.""" + + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + def __init__(self): + # Pre-set _hooks before super().__init__ + self._hooks = [] + super().__init__() + + @hook + def should_not_be_discovered(self, event: BeforeNodeCallEvent): + pass + + plugin = MyPlugin() + # The guard should have skipped discovery since _hooks was already set + assert len(plugin.hooks) == 0 diff --git a/tests/strands/plugins/test_plugin_base_class.py b/tests/strands/plugins/test_plugin_base_class.py new file mode 100644 index 000000000..dab3e7210 --- /dev/null +++ b/tests/strands/plugins/test_plugin_base_class.py @@ -0,0 +1,553 @@ +"""Tests for the Plugin base class with auto-discovery.""" + +import unittest.mock + +import pytest + +from strands.hooks import BeforeInvocationEvent, BeforeModelCallEvent, HookRegistry +from strands.plugins import Plugin, hook +from strands.plugins.registry import _PluginRegistry +from strands.tools.decorator import tool + + +def _configure_mock_agent_with_hooks(): + """Helper to create a mock agent with working add_hook.""" + mock_agent = unittest.mock.MagicMock() + mock_agent.hooks = HookRegistry() + mock_agent.add_hook.side_effect = lambda callback, event_type=None: mock_agent.hooks.add_callback( + event_type, callback + ) + return mock_agent + + +class TestPluginBaseClass: + """Tests for Plugin base class basics.""" + + def test_plugin_is_class_not_protocol(self): + """Test that Plugin is now a class, not a Protocol.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + plugin = MyPlugin() + assert isinstance(plugin, Plugin) + + def test_plugin_requires_name_attribute(self): + """Test that Plugin subclass must have name attribute.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + plugin = MyPlugin() + assert plugin.name == "my-plugin" + + def test_plugin_name_as_property(self): + """Test that Plugin name can be a property.""" + + class MyPlugin(Plugin): + @property + def name(self) -> str: + return "property-plugin" + + plugin = MyPlugin() + assert plugin.name == "property-plugin" + + +class TestPluginAutoDiscovery: + """Tests for automatic discovery of decorated methods.""" + + def test_plugin_discovers_hook_decorated_methods(self): + """Test that Plugin.__init__ discovers @hook decorated methods.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def on_before_model(self, event: BeforeModelCallEvent): + pass + + plugin = MyPlugin() + assert len(plugin.hooks) == 1 + assert plugin.hooks[0].__name__ == "on_before_model" + + def test_plugin_discovers_multiple_hooks(self): + """Test that Plugin discovers multiple @hook decorated methods.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def hook1(self, event: BeforeModelCallEvent): + pass + + @hook + def hook2(self, event: BeforeInvocationEvent): + pass + + plugin = MyPlugin() + assert len(plugin.hooks) == 2 + hook_names = {h.__name__ for h in plugin.hooks} + assert "hook1" in hook_names + assert "hook2" in hook_names + + def test_hooks_preserve_definition_order(self): + """Test that hooks are discovered in definition order, not alphabetical.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def z_last_alphabetically(self, event: BeforeModelCallEvent): + pass + + @hook + def a_first_alphabetically(self, event: BeforeModelCallEvent): + pass + + @hook + def m_middle_alphabetically(self, event: BeforeModelCallEvent): + pass + + plugin = MyPlugin() + assert len(plugin.hooks) == 3 + # Should be in definition order, not alphabetical + assert plugin.hooks[0].__name__ == "z_last_alphabetically" + assert plugin.hooks[1].__name__ == "a_first_alphabetically" + assert plugin.hooks[2].__name__ == "m_middle_alphabetically" + + def test_plugin_discovers_tool_decorated_methods(self): + """Test that Plugin.__init__ discovers @tool decorated methods.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @tool + def my_tool(self, param: str) -> str: + """A test tool.""" + return param + + plugin = MyPlugin() + assert len(plugin.tools) == 1 + assert plugin.tools[0].tool_name == "my_tool" + + def test_plugin_discovers_both_hooks_and_tools(self): + """Test that Plugin discovers both @hook and @tool decorated methods.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def my_hook(self, event: BeforeModelCallEvent): + pass + + @tool + def my_tool(self, param: str) -> str: + """A test tool.""" + return param + + plugin = MyPlugin() + assert len(plugin.hooks) == 1 + assert len(plugin.tools) == 1 + + def test_plugin_ignores_non_decorated_methods(self): + """Test that Plugin doesn't discover non-decorated methods.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + def regular_method(self): + pass + + @hook + def decorated_hook(self, event: BeforeModelCallEvent): + pass + + plugin = MyPlugin() + assert len(plugin.hooks) == 1 + assert plugin.hooks[0].__name__ == "decorated_hook" + + def test_hooks_property_returns_list(self): + """Test that hooks property returns a mutable list.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def my_hook(self, event: BeforeModelCallEvent): + pass + + plugin = MyPlugin() + assert isinstance(plugin.hooks, list) + + def test_tools_property_returns_list(self): + """Test that tools property returns a mutable list.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @tool + def my_tool(self, param: str) -> str: + """A test tool.""" + return param + + plugin = MyPlugin() + assert isinstance(plugin.tools, list) + + def test_hooks_can_be_filtered(self): + """Test that hooks list can be modified before registration.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def hook1(self, event: BeforeModelCallEvent): + pass + + @hook + def hook2(self, event: BeforeInvocationEvent): + pass + + plugin = MyPlugin() + assert len(plugin.hooks) == 2 + + # Filter out hook1 + plugin.hooks[:] = [h for h in plugin.hooks if h.__name__ != "hook1"] + assert len(plugin.hooks) == 1 + assert plugin.hooks[0].__name__ == "hook2" + + def test_tools_can_be_filtered(self): + """Test that tools list can be modified before registration.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @tool + def tool1(self, param: str) -> str: + """Tool 1.""" + return param + + @tool + def tool2(self, param: str) -> str: + """Tool 2.""" + return param + + plugin = MyPlugin() + assert len(plugin.tools) == 2 + + # Filter out tool1 + plugin.tools[:] = [t for t in plugin.tools if t.tool_name != "tool1"] + assert len(plugin.tools) == 1 + assert plugin.tools[0].tool_name == "tool2" + + +class TestPluginRegistryAutoRegistration: + """Tests for auto-registration via _PluginRegistry.""" + + def test_registry_registers_hooks_with_agent(self): + """Test that _PluginRegistry registers discovered hooks with agent.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def on_before_model(self, event: BeforeModelCallEvent): + pass + + plugin = MyPlugin() + mock_agent = _configure_mock_agent_with_hooks() + registry = _PluginRegistry(mock_agent) + + registry.add_and_init(plugin) + + # Verify hook was registered + assert len(mock_agent.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) == 1 + + def test_registry_registers_tools_with_agent(self): + """Test that _PluginRegistry adds discovered tools to agent's tools.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @tool + def my_tool(self, param: str) -> str: + """A test tool.""" + return param + + plugin = MyPlugin() + mock_agent = unittest.mock.MagicMock() + mock_agent.hooks = HookRegistry() + mock_agent.tool_registry = unittest.mock.MagicMock() + registry = _PluginRegistry(mock_agent) + + registry.add_and_init(plugin) + + # Verify tool was added to agent + mock_agent.tool_registry.process_tools.assert_called_once() + + def test_registry_registers_both_hooks_and_tools(self): + """Test that _PluginRegistry registers both hooks and tools.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def my_hook(self, event: BeforeModelCallEvent): + pass + + @tool + def my_tool(self, param: str) -> str: + """A test tool.""" + return param + + plugin = MyPlugin() + mock_agent = _configure_mock_agent_with_hooks() + mock_agent.tool_registry = unittest.mock.MagicMock() + registry = _PluginRegistry(mock_agent) + + registry.add_and_init(plugin) + + # Verify both registered + assert len(mock_agent.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) == 1 + mock_agent.tool_registry.process_tools.assert_called_once() + + def test_registry_calls_init_agent_before_registration(self): + """Test that _PluginRegistry calls init_agent for custom logic.""" + init_called = False + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def my_hook(self, event: BeforeModelCallEvent): + pass + + def init_agent(self, agent): + nonlocal init_called + init_called = True + # Custom logic - no super() needed + + plugin = MyPlugin() + mock_agent = _configure_mock_agent_with_hooks() + registry = _PluginRegistry(mock_agent) + + registry.add_and_init(plugin) + + assert init_called + # Verify auto-registration still happened + assert len(mock_agent.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) == 1 + + +class TestPluginHookWithUnionTypes: + """Tests for Plugin hooks with union types.""" + + def test_registry_registers_hook_for_union_types(self): + """Test that hooks with union types are registered for all event types.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def on_model_events(self, event: BeforeModelCallEvent | BeforeInvocationEvent): + pass + + plugin = MyPlugin() + mock_agent = _configure_mock_agent_with_hooks() + registry = _PluginRegistry(mock_agent) + + registry.add_and_init(plugin) + + # Verify hook was registered for both event types + assert len(mock_agent.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) == 1 + assert len(mock_agent.hooks._registered_callbacks.get(BeforeInvocationEvent, [])) == 1 + + +class TestPluginMultipleAgents: + """Tests for plugin reuse with multiple agents.""" + + def test_plugin_can_be_attached_to_multiple_agents(self): + """Test that the same plugin instance can be used with multiple agents.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def on_before_model(self, event: BeforeModelCallEvent): + pass + + plugin = MyPlugin() + + mock_agent1 = _configure_mock_agent_with_hooks() + mock_agent2 = _configure_mock_agent_with_hooks() + + # Note: In practice, different registries would be used for each agent + # Here we simulate attaching to multiple agents directly + registry1 = _PluginRegistry(mock_agent1) + registry1.add_and_init(plugin) + + # Create new plugin instance for second agent (same class) + plugin2 = MyPlugin() + registry2 = _PluginRegistry(mock_agent2) + registry2.add_and_init(plugin2) + + # Verify both agents have the hook registered + assert len(mock_agent1.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) == 1 + assert len(mock_agent2.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) == 1 + + +class TestPluginSubclassOverride: + """Tests for subclass overriding init_agent.""" + + def test_subclass_can_override_init_agent_without_super(self): + """Test that subclass can override init_agent without calling super().""" + custom_init_called = False + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def on_before_model(self, event: BeforeModelCallEvent): + pass + + def init_agent(self, agent): + nonlocal custom_init_called + custom_init_called = True + # No super() needed - registry handles auto-registration + + plugin = MyPlugin() + mock_agent = _configure_mock_agent_with_hooks() + registry = _PluginRegistry(mock_agent) + + registry.add_and_init(plugin) + + assert custom_init_called + # Verify auto-registration still happened via registry + assert len(mock_agent.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) == 1 + + def test_subclass_can_add_manual_hooks(self): + """Test that subclass can manually add hooks in addition to decorated ones.""" + manual_hook_added = False + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def auto_hook(self, event: BeforeModelCallEvent): + pass + + def manual_hook(self, event: BeforeInvocationEvent): + pass + + def init_agent(self, agent): + nonlocal manual_hook_added + # Add manual hook - no super() needed + agent.hooks.add_callback(BeforeInvocationEvent, self.manual_hook) + manual_hook_added = True + + plugin = MyPlugin() + mock_agent = _configure_mock_agent_with_hooks() + registry = _PluginRegistry(mock_agent) + + registry.add_and_init(plugin) + + assert manual_hook_added + # Verify both hooks registered (1 manual + 1 auto) + assert len(mock_agent.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) == 1 + assert len(mock_agent.hooks._registered_callbacks.get(BeforeInvocationEvent, [])) == 1 + + +class TestPluginAsyncInitPlugin: + """Tests for async init_agent support.""" + + @pytest.mark.asyncio + async def test_async_init_agent_supported(self): + """Test that async init_agent is supported.""" + async_init_called = False + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def on_before_model(self, event: BeforeModelCallEvent): + pass + + async def init_agent(self, agent): + nonlocal async_init_called + async_init_called = True + # No super() needed - registry handles auto-registration + + plugin = MyPlugin() + mock_agent = _configure_mock_agent_with_hooks() + registry = _PluginRegistry(mock_agent) + + registry.add_and_init(plugin) + + # Verify async init was called (run_async handles it) + assert async_init_called + # Verify hook was registered + assert len(mock_agent.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) == 1 + + +class TestPluginBoundMethods: + """Tests for bound method registration.""" + + def test_hooks_are_bound_to_instance(self): + """Test that registered hooks are bound to the plugin instance.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + def __init__(self): + super().__init__() + self.events_received = [] + + @hook + def on_before_model(self, event: BeforeModelCallEvent): + self.events_received.append(event) + + plugin = MyPlugin() + mock_agent = _configure_mock_agent_with_hooks() + registry = _PluginRegistry(mock_agent) + + registry.add_and_init(plugin) + + # Call the registered hook and verify it accesses the correct instance + mock_event = unittest.mock.MagicMock(spec=BeforeModelCallEvent) + callbacks = list(mock_agent.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) + callbacks[0](mock_event) + + assert len(plugin.events_received) == 1 + assert plugin.events_received[0] is mock_event + + def test_tools_are_bound_to_instance(self): + """Test that registered tools are bound to the plugin instance.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + def __init__(self): + super().__init__() + self.tool_called = False + + @tool + def my_tool(self, param: str) -> str: + """A test tool.""" + self.tool_called = True + return param + + plugin = MyPlugin() + mock_agent = unittest.mock.MagicMock() + mock_agent.hooks = HookRegistry() + mock_agent.tool_registry = unittest.mock.MagicMock() + registry = _PluginRegistry(mock_agent) + + registry.add_and_init(plugin) + + # Get the tool that was registered and call it + call_args = mock_agent.tool_registry.process_tools.call_args + registered_tools = call_args[0][0] + assert len(registered_tools) == 1 + + # Call the tool - it should be bound to the instance + result = registered_tools[0]("test") + assert plugin.tool_called + assert result == "test" diff --git a/tests/strands/plugins/test_plugins.py b/tests/strands/plugins/test_plugins.py new file mode 100644 index 000000000..88ed41f8d --- /dev/null +++ b/tests/strands/plugins/test_plugins.py @@ -0,0 +1,209 @@ +"""Tests for the plugin system.""" + +import gc +import unittest.mock + +import pytest + +from strands import Agent +from strands.hooks import HookRegistry +from strands.plugins import Plugin +from strands.plugins.registry import _PluginRegistry + +# Plugin Base Class Tests + + +def test_plugin_base_class_isinstance_check(): + """Test that Plugin subclass passes isinstance check.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + plugin = MyPlugin() + assert isinstance(plugin, Plugin) + + +def test_plugin_base_class_sync_implementation(): + """Test Plugin base class works with synchronous init_agent.""" + + class SyncPlugin(Plugin): + name = "sync-plugin" + + def init_agent(self, agent): + # No super() needed - registry handles auto-registration + agent.custom_attribute = "initialized by plugin" + + plugin = SyncPlugin() + mock_agent = unittest.mock.Mock() + mock_agent.hooks = HookRegistry() + mock_agent.tool_registry = unittest.mock.MagicMock() + + # Verify the plugin is an instance + assert isinstance(plugin, Plugin) + assert plugin.name == "sync-plugin" + + # Execute init_agent synchronously + plugin.init_agent(mock_agent) + assert mock_agent.custom_attribute == "initialized by plugin" + + +@pytest.mark.asyncio +async def test_plugin_base_class_async_implementation(): + """Test Plugin base class works with asynchronous init_agent.""" + + class AsyncPlugin(Plugin): + name = "async-plugin" + + async def init_agent(self, agent): + # No super() needed - registry handles auto-registration + agent.custom_attribute = "initialized by async plugin" + + plugin = AsyncPlugin() + mock_agent = unittest.mock.Mock() + mock_agent.hooks = HookRegistry() + mock_agent.tool_registry = unittest.mock.MagicMock() + + # Verify the plugin is an instance + assert isinstance(plugin, Plugin) + assert plugin.name == "async-plugin" + + # Execute init_agent asynchronously + await plugin.init_agent(mock_agent) + assert mock_agent.custom_attribute == "initialized by async plugin" + + +def test_plugin_class_requires_name(): + """Test that Plugin class requires a name property.""" + + with pytest.raises(TypeError, match="Can't instantiate abstract class"): + + class PluginWithoutName(Plugin): + def init_agent(self, agent): + pass + + PluginWithoutName() + + +def test_plugin_base_class_requires_init_agent_method(): + """Test that Plugin base class provides default init_agent.""" + + class PluginWithoutOverride(Plugin): + name = "no-override-plugin" + + plugin = PluginWithoutOverride() + # Plugin base class provides default init_agent + assert hasattr(plugin, "init_agent") + assert callable(plugin.init_agent) + + +def test_plugin_base_class_with_class_attribute_name(): + """Test Plugin base class works when name is a class attribute.""" + + class PluginWithClassAttribute(Plugin): + name: str = "class-attr-plugin" + + plugin = PluginWithClassAttribute() + assert isinstance(plugin, Plugin) + assert plugin.name == "class-attr-plugin" + + +def test_plugin_base_class_with_property_name(): + """Test Plugin base class works when name is a property.""" + + class PluginWithProperty(Plugin): + @property + def name(self) -> str: + return "property-plugin" + + plugin = PluginWithProperty() + assert isinstance(plugin, Plugin) + assert plugin.name == "property-plugin" + + +# _PluginRegistry Tests + + +@pytest.fixture +def mock_agent(): + """Create a mock agent for testing.""" + agent = unittest.mock.Mock() + agent.hooks = HookRegistry() + agent.tool_registry = unittest.mock.MagicMock() + agent.add_hook = unittest.mock.Mock() + return agent + + +@pytest.fixture +def registry(mock_agent): + """Create a fresh _PluginRegistry for each test.""" + return _PluginRegistry(mock_agent) + + +def test_plugin_registry_add_and_init_calls_init_agent(registry, mock_agent): + """Test adding a plugin calls its init_agent method.""" + + class TestPlugin(Plugin): + name = "test-plugin" + + def __init__(self): + super().__init__() + self.initialized = False + + def init_agent(self, agent): + # No super() needed - registry handles auto-registration + self.initialized = True + agent.plugin_initialized = True + + plugin = TestPlugin() + registry.add_and_init(plugin) + + assert plugin.initialized + assert mock_agent.plugin_initialized + + +def test_plugin_registry_add_duplicate_raises_error(registry, mock_agent): + """Test that adding a duplicate plugin raises an error.""" + + class TestPlugin(Plugin): + name = "test-plugin" + + plugin1 = TestPlugin() + plugin2 = TestPlugin() + + registry.add_and_init(plugin1) + + with pytest.raises(ValueError, match="plugin_name= | plugin already registered"): + registry.add_and_init(plugin2) + + +def test_plugin_registry_add_and_init_with_async_plugin(registry, mock_agent): + """Test that add_and_init handles async plugins using run_async.""" + + class AsyncPlugin(Plugin): + name = "async-plugin" + + def __init__(self): + super().__init__() + self.initialized = False + + async def init_agent(self, agent): + # No super() needed - registry handles auto-registration + self.initialized = True + agent.async_plugin_initialized = True + + plugin = AsyncPlugin() + registry.add_and_init(plugin) + + assert plugin.initialized + assert mock_agent.async_plugin_initialized + + +def test_plugin_registry_raises_reference_error_after_agent_collected(): + """Verify _PluginRegistry raises ReferenceError when the Agent has been garbage collected.""" + agent = Agent() + registry = agent._plugin_registry + del agent + gc.collect() + + with pytest.raises(ReferenceError, match="Agent has been garbage collected"): + _ = registry._agent diff --git a/tests/strands/session/test_file_session_manager.py b/tests/strands/session/test_file_session_manager.py index 7e28be998..8e14c9adc 100644 --- a/tests/strands/session/test_file_session_manager.py +++ b/tests/strands/session/test_file_session_manager.py @@ -82,7 +82,7 @@ def test_create_session(file_manager, sample_session): assert os.path.exists(session_file) # Verify content - with open(session_file, "r") as f: + with open(session_file) as f: data = json.load(f) assert data["session_id"] == sample_session.session_id assert data["session_type"] == sample_session.session_type @@ -144,7 +144,7 @@ def test_create_agent(file_manager, sample_session, sample_agent): assert os.path.exists(agent_file) # Verify content - with open(agent_file, "r") as f: + with open(agent_file) as f: data = json.load(f) assert data["agent_id"] == sample_agent.agent_id assert data["state"] == sample_agent.state @@ -210,7 +210,7 @@ def test_create_message(file_manager, sample_session, sample_agent, sample_messa assert os.path.exists(message_path) # Verify content - with open(message_path, "r") as f: + with open(message_path) as f: data = json.load(f) assert data["message_id"] == sample_message.message_id @@ -439,7 +439,7 @@ def test_create_multi_agent(multi_agent_manager, sample_session, mock_multi_agen assert os.path.exists(multi_agent_file) # Verify content - with open(multi_agent_file, "r") as f: + with open(multi_agent_file) as f: data = json.load(f) assert data["id"] == mock_multi_agent.id assert data["state"] == mock_multi_agent.state diff --git a/tests/strands/session/test_repository_session_manager.py b/tests/strands/session/test_repository_session_manager.py index 22de9f964..1d5048113 100644 --- a/tests/strands/session/test_repository_session_manager.py +++ b/tests/strands/session/test_repository_session_manager.py @@ -5,6 +5,7 @@ import pytest from strands.agent.agent import Agent +from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager from strands.agent.conversation_manager.summarizing_conversation_manager import SummarizingConversationManager from strands.agent.state import AgentState @@ -28,6 +29,15 @@ def session_manager(mock_repository): return RepositorySessionManager(session_id="test-session", session_repository=mock_repository) +@pytest.fixture +def existing_session_manager(mock_repository): + """Create a session manager with a pre-existing session in the repository.""" + # Create session first so the manager sees it as existing + session = Session(session_id="test-session", session_type=SessionType.AGENT) + mock_repository.create_session(session) + return RepositorySessionManager(session_id="test-session", session_repository=mock_repository) + + @pytest.fixture def agent(): """Create a mock agent.""" @@ -100,7 +110,7 @@ def test_initialize_multiple_agents_without_id(session_manager, agent): session_manager.initialize(agent2) -def test_initialize_restores_existing_agent(session_manager, agent): +def test_initialize_restores_existing_agent(existing_session_manager, agent): """Test that initializing an existing agent restores its state.""" # Set agent ID agent.agent_id = "existing-agent" @@ -112,7 +122,7 @@ def test_initialize_restores_existing_agent(session_manager, agent): conversation_manager_state=SlidingWindowConversationManager().get_state(), _internal_state={"interrupt_state": {"interrupts": {}, "context": {"test": "init"}, "activated": False}}, ) - session_manager.session_repository.create_agent("test-session", session_agent) + existing_session_manager.session_repository.create_agent("test-session", session_agent) # Create some messages message = SessionMessage( @@ -122,10 +132,10 @@ def test_initialize_restores_existing_agent(session_manager, agent): }, message_id=0, ) - session_manager.session_repository.create_message("test-session", "existing-agent", message) + existing_session_manager.session_repository.create_message("test-session", "existing-agent", message) # Initialize agent - session_manager.initialize(agent) + existing_session_manager.initialize(agent) # Verify agent state restored assert agent.state.get("key") == "value" @@ -135,7 +145,7 @@ def test_initialize_restores_existing_agent(session_manager, agent): assert agent._interrupt_state == _InterruptState(interrupts={}, context={"test": "init"}, activated=False) -def test_initialize_restores_existing_agent_with_summarizing_conversation_manager(session_manager): +def test_initialize_restores_existing_agent_with_summarizing_conversation_manager(existing_session_manager): """Test that initializing an existing agent restores its state.""" conversation_manager = SummarizingConversationManager() conversation_manager.removed_message_count = 1 @@ -147,7 +157,7 @@ def test_initialize_restores_existing_agent_with_summarizing_conversation_manage state={"key": "value"}, conversation_manager_state=conversation_manager.get_state(), ) - session_manager.session_repository.create_agent("test-session", session_agent) + existing_session_manager.session_repository.create_agent("test-session", session_agent) # Create some messages message = SessionMessage( @@ -158,13 +168,13 @@ def test_initialize_restores_existing_agent_with_summarizing_conversation_manage message_id=0, ) # Create two messages as one will be removed by the conversation manager - session_manager.session_repository.create_message("test-session", "existing-agent", message) + existing_session_manager.session_repository.create_message("test-session", "existing-agent", message) message.message_id = 1 - session_manager.session_repository.create_message("test-session", "existing-agent", message) + existing_session_manager.session_repository.create_message("test-session", "existing-agent", message) # Initialize agent agent = Agent(agent_id="existing-agent", conversation_manager=SummarizingConversationManager()) - session_manager.initialize(agent) + existing_session_manager.initialize(agent) # Verify agent state restored assert agent.state.get("key") == "value" @@ -217,26 +227,52 @@ def test_initialize_multi_agent_new(session_manager, mock_multi_agent): assert state["state"] == {"key": "value"} -def test_initialize_multi_agent_existing(session_manager, mock_multi_agent): +def test_initialize_multi_agent_existing(existing_session_manager, mock_multi_agent): """Test initializing existing multi-agent state.""" # Create existing state first - session_manager.session_repository.create_multi_agent("test-session", mock_multi_agent) + existing_session_manager.session_repository.create_multi_agent("test-session", mock_multi_agent) # Create a mock with updated state for the update call updated_mock = Mock() updated_mock.id = "test-multi-agent" existing_state = {"id": "test-multi-agent", "state": {"restored": "data"}} updated_mock.serialize_state.return_value = existing_state - session_manager.session_repository.update_multi_agent("test-session", updated_mock) + existing_session_manager.session_repository.update_multi_agent("test-session", updated_mock) # Initialize multi-agent - session_manager.initialize_multi_agent(mock_multi_agent) + existing_session_manager.initialize_multi_agent(mock_multi_agent) # Verify deserialize_state was called with existing state mock_multi_agent.deserialize_state.assert_called_once_with(existing_state) -def test_fix_broken_tool_use_adds_missing_tool_results(session_manager): +def test_initialize_skips_message_restore_for_server_managed_conversation(existing_session_manager): + """Test that messages are not restored when model manages conversation server-side.""" + session_agent = SessionAgent( + agent_id="existing-agent", + state={}, + conversation_manager_state=NullConversationManager().get_state(), + _internal_state={ + "interrupt_state": {"interrupts": {}, "context": {}, "activated": False}, + "model_state": {"response_id": "resp_abc123"}, + }, + ) + existing_session_manager.session_repository.create_agent("test-session", session_agent) + + message = SessionMessage.from_message({"role": "user", "content": [{"text": "Hello"}]}, 0) + existing_session_manager.session_repository.create_message("test-session", "existing-agent", message) + + mock_model = Mock() + mock_model.stateful = True + agent = Agent(agent_id="existing-agent", model=mock_model) + existing_session_manager.initialize(agent) + + assert agent.messages == [] + assert agent._model_state == {"response_id": "resp_abc123"} + assert existing_session_manager.session_repository.list_messages("test-session", "existing-agent") == [message] + + +def test_fix_broken_tool_use_adds_missing_tool_results(existing_session_manager): """Test that _fix_broken_tool_use adds missing toolResult messages.""" conversation_manager = SlidingWindowConversationManager() @@ -246,7 +282,7 @@ def test_fix_broken_tool_use_adds_missing_tool_results(session_manager): state={"key": "value"}, conversation_manager_state=conversation_manager.get_state(), ) - session_manager.session_repository.create_agent("test-session", session_agent) + existing_session_manager.session_repository.create_agent("test-session", session_agent) broken_messages = [ { @@ -261,11 +297,13 @@ def test_fix_broken_tool_use_adds_missing_tool_results(session_manager): message=broken_message, message_id=index, ) - session_manager.session_repository.create_message("test-session", "existing-agent", broken_session_message) + existing_session_manager.session_repository.create_message( + "test-session", "existing-agent", broken_session_message + ) # Initialize agent agent = Agent(agent_id="existing-agent") - session_manager.initialize(agent) + existing_session_manager.initialize(agent) fixed_messages = agent.messages @@ -277,7 +315,7 @@ def test_fix_broken_tool_use_adds_missing_tool_results(session_manager): assert fixed_messages[1]["content"][0]["toolResult"]["content"][0]["text"] == "Tool was interrupted." -def test_fix_broken_tool_use_extends_partial_tool_results(session_manager): +def test_fix_broken_tool_use_extends_partial_tool_results(existing_session_manager): """Test fixing messages where some toolResults are missing.""" conversation_manager = SlidingWindowConversationManager() # Create agent in repository first @@ -286,7 +324,7 @@ def test_fix_broken_tool_use_extends_partial_tool_results(session_manager): state={"key": "value"}, conversation_manager_state=conversation_manager.get_state(), ) - session_manager.session_repository.create_agent("test-session", session_agent) + existing_session_manager.session_repository.create_agent("test-session", session_agent) broken_messages = [ { @@ -309,11 +347,13 @@ def test_fix_broken_tool_use_extends_partial_tool_results(session_manager): message=broken_message, message_id=index, ) - session_manager.session_repository.create_message("test-session", "existing-agent", broken_session_message) + existing_session_manager.session_repository.create_message( + "test-session", "existing-agent", broken_session_message + ) # Initialize agent agent = Agent(agent_id="existing-agent") - session_manager.initialize(agent) + existing_session_manager.initialize(agent) fixed_messages = agent.messages @@ -330,7 +370,7 @@ def test_fix_broken_tool_use_extends_partial_tool_results(session_manager): assert missing_result["toolResult"]["content"][0]["text"] == "Tool was interrupted." -def test_fix_broken_tool_use_handles_multiple_orphaned_tools(session_manager): +def test_fix_broken_tool_use_handles_multiple_orphaned_tools(existing_session_manager): """Test fixing multiple orphaned toolUse messages.""" conversation_manager = SlidingWindowConversationManager() @@ -340,7 +380,7 @@ def test_fix_broken_tool_use_handles_multiple_orphaned_tools(session_manager): state={"key": "value"}, conversation_manager_state=conversation_manager.get_state(), ) - session_manager.session_repository.create_agent("test-session", session_agent) + existing_session_manager.session_repository.create_agent("test-session", session_agent) broken_messages = [ { @@ -358,11 +398,13 @@ def test_fix_broken_tool_use_handles_multiple_orphaned_tools(session_manager): message=broken_message, message_id=index, ) - session_manager.session_repository.create_message("test-session", "existing-agent", broken_session_message) + existing_session_manager.session_repository.create_message( + "test-session", "existing-agent", broken_session_message + ) # Initialize agent agent = Agent(agent_id="existing-agent") - session_manager.initialize(agent) + existing_session_manager.initialize(agent) fixed_messages = agent.messages @@ -449,7 +491,7 @@ def test_initialize_bidi_agent_creates_new(session_manager, mock_bidi_agent): assert messages[0].message["role"] == "user" -def test_initialize_bidi_agent_restores_existing(session_manager, mock_bidi_agent): +def test_initialize_bidi_agent_restores_existing(existing_session_manager, mock_bidi_agent): """Test initializing BidiAgent restores from existing session.""" # Create existing session data session_agent = SessionAgent( @@ -457,16 +499,16 @@ def test_initialize_bidi_agent_restores_existing(session_manager, mock_bidi_agen state={"restored": "state"}, conversation_manager_state={}, # Empty for BidiAgent ) - session_manager.session_repository.create_agent("test-session", session_agent) + existing_session_manager.session_repository.create_agent("test-session", session_agent) # Add messages msg1 = SessionMessage.from_message({"role": "user", "content": [{"text": "Message 1"}]}, 0) msg2 = SessionMessage.from_message({"role": "assistant", "content": [{"text": "Response 1"}]}, 1) - session_manager.session_repository.create_message("test-session", "bidi-agent-1", msg1) - session_manager.session_repository.create_message("test-session", "bidi-agent-1", msg2) + existing_session_manager.session_repository.create_message("test-session", "bidi-agent-1", msg1) + existing_session_manager.session_repository.create_message("test-session", "bidi-agent-1", msg2) # Initialize agent - session_manager.initialize_bidi_agent(mock_bidi_agent) + existing_session_manager.initialize_bidi_agent(mock_bidi_agent) # Verify state restored assert mock_bidi_agent.state.get() == {"restored": "state"} @@ -532,7 +574,7 @@ def test_bidi_agent_unique_id_constraint(session_manager, mock_bidi_agent): session_manager.initialize_bidi_agent(agent2) -def test_bidi_agent_messages_with_offset_zero(session_manager, mock_bidi_agent): +def test_bidi_agent_messages_with_offset_zero(existing_session_manager, mock_bidi_agent): """Test that BidiAgent uses offset=0 for message restoration (no conversation_manager).""" # Create session with messages session_agent = SessionAgent( @@ -540,15 +582,15 @@ def test_bidi_agent_messages_with_offset_zero(session_manager, mock_bidi_agent): state={}, conversation_manager_state={}, ) - session_manager.session_repository.create_agent("test-session", session_agent) + existing_session_manager.session_repository.create_agent("test-session", session_agent) # Add 5 messages for i in range(5): msg = SessionMessage.from_message({"role": "user", "content": [{"text": f"Message {i}"}]}, i) - session_manager.session_repository.create_message("test-session", "bidi-agent-1", msg) + existing_session_manager.session_repository.create_message("test-session", "bidi-agent-1", msg) # Initialize agent - session_manager.initialize_bidi_agent(mock_bidi_agent) + existing_session_manager.initialize_bidi_agent(mock_bidi_agent) # Verify all messages restored (offset=0, no removed_message_count) assert len(mock_bidi_agent.messages) == 5 @@ -595,3 +637,453 @@ def test_fix_broken_tool_use_does_not_affect_normal_conversations(session_manage # Should remain unchanged assert fixed_messages == messages + + +# ============================================================================ +# Conditional Sync Tests +# ============================================================================ + + +def test_sync_agent_skips_update_when_state_not_dirty_and_internal_state_unchanged(mock_repository): + """Test that sync_agent() skips update_agent() when state is not dirty and internal state unchanged.""" + session_manager = RepositorySessionManager(session_id="test-session", session_repository=mock_repository) + + # Create and initialize agent + agent = Agent(agent_id="test-agent", session_manager=session_manager) + + # Track update_agent calls + update_agent_calls = [] + original_update_agent = mock_repository.update_agent + + def tracking_update_agent(session_id, session_agent): + update_agent_calls.append((session_id, session_agent)) + return original_update_agent(session_id, session_agent) + + mock_repository.update_agent = tracking_update_agent + + # First sync should update (to establish baseline) + session_manager.sync_agent(agent) + assert len(update_agent_calls) == 1 + + # Clear tracking + update_agent_calls.clear() + + # Second sync without changes should skip update + session_manager.sync_agent(agent) + assert len(update_agent_calls) == 0 + + +def test_sync_agent_calls_update_when_state_is_dirty(mock_repository): + """Test that sync_agent() calls update_agent() when agent.state is dirty.""" + session_manager = RepositorySessionManager(session_id="test-session", session_repository=mock_repository) + + # Create and initialize agent + agent = Agent(agent_id="test-agent", session_manager=session_manager) + + # Track update_agent calls + update_agent_calls = [] + original_update_agent = mock_repository.update_agent + + def tracking_update_agent(session_id, session_agent): + update_agent_calls.append((session_id, session_agent)) + return original_update_agent(session_id, session_agent) + + mock_repository.update_agent = tracking_update_agent + + # First sync to establish baseline + session_manager.sync_agent(agent) + update_agent_calls.clear() + + # Modify state (makes it dirty) + agent.state.set("key", "value") + + # Sync should call update_agent because state is dirty + session_manager.sync_agent(agent) + assert len(update_agent_calls) == 1 + + +def test_sync_agent_calls_update_when_internal_state_changed(mock_repository): + """Test that sync_agent() calls update_agent() when internal state (interrupt_state) is dirty.""" + session_manager = RepositorySessionManager(session_id="test-session", session_repository=mock_repository) + + # Create and initialize agent + agent = Agent(agent_id="test-agent", session_manager=session_manager) + + # Track update_agent calls + update_agent_calls = [] + original_update_agent = mock_repository.update_agent + + def tracking_update_agent(session_id, session_agent): + update_agent_calls.append((session_id, session_agent)) + return original_update_agent(session_id, session_agent) + + mock_repository.update_agent = tracking_update_agent + + # First sync to establish baseline + session_manager.sync_agent(agent) + update_agent_calls.clear() + + # Modify internal state (activate interrupt state which sets dirty flag) + agent._interrupt_state.activate() + + # Sync should call update_agent because internal state is dirty + session_manager.sync_agent(agent) + assert len(update_agent_calls) == 1 + + +def test_sync_agent_calls_update_when_conversation_manager_state_changed(mock_repository): + """Test that sync_agent() calls update_agent() when conversation manager state changed.""" + session_manager = RepositorySessionManager(session_id="test-session", session_repository=mock_repository) + + # Create and initialize agent + agent = Agent(agent_id="test-agent", session_manager=session_manager) + + # Track update_agent calls + update_agent_calls = [] + original_update_agent = mock_repository.update_agent + + def tracking_update_agent(session_id, session_agent): + update_agent_calls.append((session_id, session_agent)) + return original_update_agent(session_id, session_agent) + + mock_repository.update_agent = tracking_update_agent + + # First sync to establish baseline + session_manager.sync_agent(agent) + update_agent_calls.clear() + + # Modify conversation manager state + agent.conversation_manager.removed_message_count = 5 + + # Sync should call update_agent because conversation manager state changed + session_manager.sync_agent(agent) + assert len(update_agent_calls) == 1 + + +def test_sync_agent_calls_update_when_model_state_changed(mock_repository): + """Test that sync_agent() calls update_agent() when model state changed.""" + session_manager = RepositorySessionManager(session_id="test-session", session_repository=mock_repository) + + # Create and initialize agent + agent = Agent(agent_id="test-agent", session_manager=session_manager) + + # Track update_agent calls + update_agent_calls = [] + original_update_agent = mock_repository.update_agent + + def tracking_update_agent(session_id, session_agent): + update_agent_calls.append((session_id, session_agent)) + return original_update_agent(session_id, session_agent) + + mock_repository.update_agent = tracking_update_agent + + # First sync to establish baseline + session_manager.sync_agent(agent) + update_agent_calls.clear() + + # Modify model state + agent._model_state["response_id"] = "resp_abc123" + + # Sync should call update_agent because model state changed + session_manager.sync_agent(agent) + assert len(update_agent_calls) == 1 + + +def test_sync_agent_tracks_version_after_successful_sync(mock_repository): + """Test that sync_agent() tracks version after successful sync.""" + session_manager = RepositorySessionManager(session_id="test-session", session_repository=mock_repository) + + # Create and initialize agent + agent = Agent(agent_id="test-agent", session_manager=session_manager) + + # First sync to establish baseline + session_manager.sync_agent(agent) + initial_version = agent.state._get_version() + + # Modify state (increments version) + agent.state.set("key", "value") + assert agent.state._get_version() == initial_version + 1 + + # Track update_agent calls + update_agent_calls = [] + original_update_agent = mock_repository.update_agent + + def tracking_update_agent(session_id, session_agent): + update_agent_calls.append((session_id, session_agent)) + return original_update_agent(session_id, session_agent) + + mock_repository.update_agent = tracking_update_agent + + # Sync should update because version changed + session_manager.sync_agent(agent) + assert len(update_agent_calls) == 1 + + # Second sync without changes should skip + update_agent_calls.clear() + session_manager.sync_agent(agent) + assert len(update_agent_calls) == 0 + + +def test_sync_agent_retries_on_failure(mock_repository): + """Test that sync_agent() retries on next call if update_agent() fails.""" + session_manager = RepositorySessionManager(session_id="test-session", session_repository=mock_repository) + + # Create and initialize agent + agent = Agent(agent_id="test-agent", session_manager=session_manager) + + # First sync to establish baseline + session_manager.sync_agent(agent) + + # Modify state (increments version) + agent.state.set("key", "value") + + # Make update_agent fail + def failing_update_agent(session_id, session_agent): + raise SessionException("Update failed") + + mock_repository.update_agent = failing_update_agent + + # Sync should fail + with pytest.raises(SessionException, match="Update failed"): + session_manager.sync_agent(agent) + + # Restore working update_agent + update_agent_calls = [] + original_update_agent = MockedSessionRepository.update_agent + + def tracking_update_agent(self, session_id, session_agent): + update_agent_calls.append((session_id, session_agent)) + return original_update_agent(self, session_id, session_agent) + + mock_repository.update_agent = lambda sid, sa: tracking_update_agent(mock_repository, sid, sa) + + # Retry should work because version wasn't updated on failure + session_manager.sync_agent(agent) + assert len(update_agent_calls) == 1 + + +def test_sync_agent_first_sync_always_updates(mock_repository): + """Test that the first sync_agent() call always updates (no previous state to compare).""" + session_manager = RepositorySessionManager(session_id="test-session", session_repository=mock_repository) + + # Create and initialize agent + agent = Agent(agent_id="test-agent", session_manager=session_manager) + + # Track update_agent calls + update_agent_calls = [] + original_update_agent = mock_repository.update_agent + + def tracking_update_agent(session_id, session_agent): + update_agent_calls.append((session_id, session_agent)) + return original_update_agent(session_id, session_agent) + + mock_repository.update_agent = tracking_update_agent + + # First sync should always update (no previous state) + session_manager.sync_agent(agent) + assert len(update_agent_calls) == 1 + + +# ============================================================================ +# New Session Optimization Tests (Issue #1828) +# ============================================================================ + + +def test_is_new_session_true_when_session_created(mock_repository): + """Test that _is_new_session is True when creating a new session.""" + # Session doesn't exist yet + assert mock_repository.read_session("new-session") is None + + # Creating manager should set _is_new_session to True + manager = RepositorySessionManager(session_id="new-session", session_repository=mock_repository) + + assert manager._is_new_session is True + + +def test_is_new_session_false_when_session_exists(mock_repository): + """Test that _is_new_session is False when using an existing session.""" + # Create session first + session = Session(session_id="existing-session", session_type=SessionType.AGENT) + mock_repository.create_session(session) + + # Creating manager should set _is_new_session to False + manager = RepositorySessionManager(session_id="existing-session", session_repository=mock_repository) + + assert manager._is_new_session is False + + +def test_initialize_skips_read_agent_for_new_session(mock_repository): + """Test that initialize() skips read_agent() call when _is_new_session is True.""" + # Create manager (new session) + manager = RepositorySessionManager(session_id="new-session", session_repository=mock_repository) + assert manager._is_new_session is True + + # Track read_agent calls + read_agent_calls = [] + original_read_agent = mock_repository.read_agent + + def tracking_read_agent(session_id, agent_id): + read_agent_calls.append((session_id, agent_id)) + return original_read_agent(session_id, agent_id) + + mock_repository.read_agent = tracking_read_agent + + # Initialize agent + agent = Agent(agent_id="test-agent") + manager.initialize(agent) + + # read_agent should NOT be called for new session + assert len(read_agent_calls) == 0 + + +def test_initialize_calls_read_agent_for_existing_session(mock_repository): + """Test that initialize() calls read_agent() when _is_new_session is False.""" + # Create session first + session = Session(session_id="existing-session", session_type=SessionType.AGENT) + mock_repository.create_session(session) + + # Create manager (existing session) + manager = RepositorySessionManager(session_id="existing-session", session_repository=mock_repository) + assert manager._is_new_session is False + + # Track read_agent calls + read_agent_calls = [] + original_read_agent = mock_repository.read_agent + + def tracking_read_agent(session_id, agent_id): + read_agent_calls.append((session_id, agent_id)) + return original_read_agent(session_id, agent_id) + + mock_repository.read_agent = tracking_read_agent + + # Initialize agent + agent = Agent(agent_id="test-agent") + manager.initialize(agent) + + # read_agent should be called for existing session + assert len(read_agent_calls) == 1 + assert read_agent_calls[0] == ("existing-session", "test-agent") + + +def test_initialize_bidi_agent_skips_read_agent_for_new_session(mock_repository): + """Test that initialize_bidi_agent() skips read_agent() call when _is_new_session is True.""" + # Create manager (new session) + manager = RepositorySessionManager(session_id="new-session", session_repository=mock_repository) + assert manager._is_new_session is True + + # Track read_agent calls + read_agent_calls = [] + original_read_agent = mock_repository.read_agent + + def tracking_read_agent(session_id, agent_id): + read_agent_calls.append((session_id, agent_id)) + return original_read_agent(session_id, agent_id) + + mock_repository.read_agent = tracking_read_agent + + # Create mock BidiAgent + bidi_agent = Mock() + bidi_agent.agent_id = "bidi-agent-1" + bidi_agent.messages = [{"role": "user", "content": [{"text": "Hello!"}]}] + bidi_agent.state = AgentState({}) + + # Initialize bidi agent + manager.initialize_bidi_agent(bidi_agent) + + # read_agent should NOT be called for new session + assert len(read_agent_calls) == 0 + + +def test_initialize_bidi_agent_calls_read_agent_for_existing_session(mock_repository): + """Test that initialize_bidi_agent() calls read_agent() when _is_new_session is False.""" + # Create session first + session = Session(session_id="existing-session", session_type=SessionType.AGENT) + mock_repository.create_session(session) + + # Create manager (existing session) + manager = RepositorySessionManager(session_id="existing-session", session_repository=mock_repository) + assert manager._is_new_session is False + + # Track read_agent calls + read_agent_calls = [] + original_read_agent = mock_repository.read_agent + + def tracking_read_agent(session_id, agent_id): + read_agent_calls.append((session_id, agent_id)) + return original_read_agent(session_id, agent_id) + + mock_repository.read_agent = tracking_read_agent + + # Create mock BidiAgent + bidi_agent = Mock() + bidi_agent.agent_id = "bidi-agent-1" + bidi_agent.messages = [{"role": "user", "content": [{"text": "Hello!"}]}] + bidi_agent.state = AgentState({}) + + # Initialize bidi agent + manager.initialize_bidi_agent(bidi_agent) + + # read_agent should be called for existing session + assert len(read_agent_calls) == 1 + assert read_agent_calls[0] == ("existing-session", "bidi-agent-1") + + +def test_initialize_multi_agent_skips_read_for_new_session(mock_repository): + """Test that initialize_multi_agent() skips read_multi_agent() call when _is_new_session is True.""" + # Create manager (new session) + manager = RepositorySessionManager(session_id="new-session", session_repository=mock_repository) + assert manager._is_new_session is True + + # Track read_multi_agent calls + read_multi_agent_calls = [] + original_read_multi_agent = mock_repository.read_multi_agent + + def tracking_read_multi_agent(session_id, multi_agent_id, **kwargs): + read_multi_agent_calls.append((session_id, multi_agent_id)) + return original_read_multi_agent(session_id, multi_agent_id, **kwargs) + + mock_repository.read_multi_agent = tracking_read_multi_agent + + # Create mock multi-agent + multi_agent = Mock() + multi_agent.id = "test-multi-agent" + multi_agent.serialize_state.return_value = {"id": "test-multi-agent", "state": {}} + + # Initialize multi-agent + manager.initialize_multi_agent(multi_agent) + + # read_multi_agent should NOT be called for new session + assert len(read_multi_agent_calls) == 0 + + +def test_initialize_multi_agent_calls_read_for_existing_session(mock_repository): + """Test that initialize_multi_agent() calls read_multi_agent() when _is_new_session is False.""" + # Create session first + session = Session(session_id="existing-session", session_type=SessionType.AGENT) + mock_repository.create_session(session) + + # Create manager (existing session) + manager = RepositorySessionManager(session_id="existing-session", session_repository=mock_repository) + assert manager._is_new_session is False + + # Track read_multi_agent calls + read_multi_agent_calls = [] + original_read_multi_agent = mock_repository.read_multi_agent + + def tracking_read_multi_agent(session_id, multi_agent_id, **kwargs): + read_multi_agent_calls.append((session_id, multi_agent_id)) + return original_read_multi_agent(session_id, multi_agent_id, **kwargs) + + mock_repository.read_multi_agent = tracking_read_multi_agent + + # Create mock multi-agent + multi_agent = Mock() + multi_agent.id = "test-multi-agent" + multi_agent.serialize_state.return_value = {"id": "test-multi-agent", "state": {}} + + # Initialize multi-agent + manager.initialize_multi_agent(multi_agent) + + # read_multi_agent should be called for existing session + assert len(read_multi_agent_calls) == 1 + assert read_multi_agent_calls[0] == ("existing-session", "test-multi-agent") diff --git a/tests/strands/session/test_s3_session_manager.py b/tests/strands/session/test_s3_session_manager.py index 719fbc2c9..29bc97ab5 100644 --- a/tests/strands/session/test_s3_session_manager.py +++ b/tests/strands/session/test_s3_session_manager.py @@ -89,6 +89,17 @@ def test_init_s3_session_manager_with_existing_user_agent(mocked_aws, s3_bucket) assert "strands-agents" in session_manager.client.meta.config.user_agent_extra +def test_empty_prefix_session_roundtrip(mocked_aws, s3_bucket, sample_session, sample_agent): + """Test that session data can be written and read back with default empty prefix.""" + manager = S3SessionManager(session_id="test", bucket=s3_bucket, prefix="", region_name="us-west-2") + manager.create_session(sample_session) + manager.create_agent(sample_session.session_id, sample_agent) + + result = manager.read_agent(sample_session.session_id, sample_agent.agent_id) + assert result is not None + assert result.agent_id == sample_agent.agent_id + + def test_create_session(s3_manager, sample_session): """Test creating a session in S3.""" result = s3_manager.create_session(sample_session) @@ -282,6 +293,40 @@ def test_list_messages_all(s3_manager, sample_session, sample_agent): assert len(result) == 5 +def test_list_messages_single_message(s3_manager, sample_session, sample_agent): + """Test listing all messages from S3.""" + # Create session and agent + s3_manager.create_session(sample_session) + s3_manager.create_agent(sample_session.session_id, sample_agent) + + # Create single message + message = SessionMessage( + { + "role": "user", + "content": [ContentBlock(text="Single Message")], + }, + 0, + ) + s3_manager.create_message(sample_session.session_id, sample_agent.agent_id, message) + + # List all messages + result = s3_manager.list_messages(sample_session.session_id, sample_agent.agent_id) + + assert len(result) == 1 + + +def test_list_no_messages(s3_manager, sample_session, sample_agent): + """Test listing all messages from S3.""" + # Create session and agent + s3_manager.create_session(sample_session) + s3_manager.create_agent(sample_session.session_id, sample_agent) + + # List all messages + result = s3_manager.list_messages(sample_session.session_id, sample_agent.agent_id) + + assert len(result) == 0 + + def test_list_messages_with_pagination(s3_manager, sample_session, sample_agent): """Test listing messages with pagination in S3.""" # Create session and agent @@ -335,6 +380,24 @@ def test_update_nonexistent_message(s3_manager, sample_session, sample_agent, sa s3_manager.update_message(sample_session.session_id, sample_agent.agent_id, sample_message) +@pytest.mark.parametrize( + "prefix, expected_path", + [ + ("", "session_test-id/"), + ("sessions", "sessions/session_test-id/"), + ("sessions/", "sessions/session_test-id/"), + ("/sessions", "sessions/session_test-id/"), + ("/sessions/", "sessions/session_test-id/"), + ("a/b/c", "a/b/c/session_test-id/"), + ("a/b/c/", "a/b/c/session_test-id/"), + ], +) +def test__get_session_path_prefix_normalization(mocked_aws, s3_bucket, prefix, expected_path): + """Test that _get_session_path normalizes prefix to avoid leading or double slashes.""" + manager = S3SessionManager(session_id="test", bucket=s3_bucket, prefix=prefix, region_name="us-west-2") + assert manager._get_session_path("test-id") == expected_path + + @pytest.mark.parametrize( "session_id", [ diff --git a/tests/strands/telemetry/test_config.py b/tests/strands/telemetry/test_config.py index 658d4d08a..cc08c295c 100644 --- a/tests/strands/telemetry/test_config.py +++ b/tests/strands/telemetry/test_config.py @@ -2,6 +2,7 @@ import pytest +import strands.telemetry.config as telemetry_config from strands.telemetry import StrandsTelemetry @@ -212,3 +213,21 @@ def test_setup_otlp_exporter_exception(mock_resource, mock_tracer_provider, mock telemetry.setup_otlp_exporter() mock_otlp_exporter.assert_called_once() + + +def test_get_otel_resource_uses_default_service_name(monkeypatch): + monkeypatch.delenv("OTEL_SERVICE_NAME", raising=False) + monkeypatch.setattr(telemetry_config, "version", lambda _: "0.0.0") + + resource = telemetry_config.get_otel_resource() + + assert resource.attributes.get("service.name") == "strands-agents" + + +def test_get_otel_resource_respects_otel_service_name(monkeypatch): + monkeypatch.setenv("OTEL_SERVICE_NAME", "my-service") + monkeypatch.setattr(telemetry_config, "version", lambda _: "0.0.0") + + resource = telemetry_config.get_otel_resource() + + assert resource.attributes.get("service.name") == "my-service" diff --git a/tests/strands/telemetry/test_metrics.py b/tests/strands/telemetry/test_metrics.py index e87277eed..7d54c0cc6 100644 --- a/tests/strands/telemetry/test_metrics.py +++ b/tests/strands/telemetry/test_metrics.py @@ -240,9 +240,15 @@ def test_tool_metrics_add_call(success, tool, tool_metrics, mock_get_meter_provi @unittest.mock.patch.object(strands.telemetry.metrics.uuid, "uuid4") def test_event_loop_metrics_start_cycle(mock_uuid4, mock_time, event_loop_metrics, mock_get_meter_provider): mock_time.return_value = 1 - mock_uuid4.return_value = "i1" + mock_event_loop_cycle_id = "i1" + mock_uuid4.return_value = mock_event_loop_cycle_id - tru_start_time, tru_cycle_trace = event_loop_metrics.start_cycle() + # Reset must be called first + event_loop_metrics.reset_usage_metrics() + + tru_start_time, tru_cycle_trace = event_loop_metrics.start_cycle( + attributes={"event_loop_cycle_id": mock_event_loop_cycle_id} + ) exp_start_time, exp_cycle_trace = 1, strands.telemetry.metrics.Trace("Cycle 1") tru_attrs = {"cycle_count": event_loop_metrics.cycle_count, "traces": event_loop_metrics.traces} @@ -256,6 +262,13 @@ def test_event_loop_metrics_start_cycle(mock_uuid4, mock_time, event_loop_metric and tru_attrs == exp_attrs ) + assert len(event_loop_metrics.agent_invocations) == 1 + assert len(event_loop_metrics.agent_invocations[0].cycles) == 1 + assert event_loop_metrics.agent_invocations[0].cycles[0].event_loop_cycle_id == "i1" + assert event_loop_metrics.agent_invocations[0].cycles[0].usage["inputTokens"] == 0 + assert event_loop_metrics.agent_invocations[0].cycles[0].usage["outputTokens"] == 0 + assert event_loop_metrics.agent_invocations[0].cycles[0].usage["totalTokens"] == 0 + @unittest.mock.patch.object(strands.telemetry.metrics.time, "time") def test_event_loop_metrics_end_cycle(mock_time, trace, event_loop_metrics, mock_get_meter_provider): @@ -324,6 +337,9 @@ def test_event_loop_metrics_add_tool_usage(mock_time, trace, tool, event_loop_me def test_event_loop_metrics_update_usage(usage, event_loop_metrics, mock_get_meter_provider): + event_loop_metrics.reset_usage_metrics() + event_loop_metrics.start_cycle(attributes={"event_loop_cycle_id": "test-cycle"}) + for _ in range(3): event_loop_metrics.update_usage(usage) @@ -331,6 +347,14 @@ def test_event_loop_metrics_update_usage(usage, event_loop_metrics, mock_get_met exp_usage = Usage(inputTokens=3, outputTokens=6, totalTokens=9, cacheWriteInputTokens=6) assert tru_usage == exp_usage + + assert event_loop_metrics.latest_agent_invocation.usage == exp_usage + + assert len(event_loop_metrics.agent_invocations) == 1 + assert len(event_loop_metrics.agent_invocations[0].cycles) == 1 + assert event_loop_metrics.agent_invocations[0].cycles[0].event_loop_cycle_id == "test-cycle" + assert event_loop_metrics.agent_invocations[0].cycles[0].usage == exp_usage + mock_get_meter_provider.return_value.get_meter.assert_called() metrics_client = event_loop_metrics._metrics_client metrics_client.event_loop_input_tokens.record.assert_called() @@ -370,6 +394,7 @@ def test_event_loop_metrics_get_summary(trace, tool, event_loop_metrics, mock_ge "outputTokens": 0, "totalTokens": 0, }, + "agent_invocations": [], "average_cycle_time": 0, "tool_usage": { "tool1": { @@ -476,3 +501,156 @@ def test_use_ProxyMeter_if_no_global_meter_provider(): # Verify it's using a _ProxyMeter assert isinstance(metrics_client.meter, _ProxyMeter) + + +def test_latest_agent_invocation_property(usage, event_loop_metrics, mock_get_meter_provider): + """Test the latest_agent_invocation property getter""" + # Initially, no invocations exist + assert event_loop_metrics.latest_agent_invocation is None + + event_loop_metrics.reset_usage_metrics() + event_loop_metrics.start_cycle(attributes={"event_loop_cycle_id": "cycle-1"}) + event_loop_metrics.update_usage(usage) + + # latest_agent_invocation should return the first invocation + current = event_loop_metrics.latest_agent_invocation + assert current is not None + assert current.usage["inputTokens"] == 1 + assert len(current.cycles) == 1 + + event_loop_metrics.reset_usage_metrics() + event_loop_metrics.start_cycle(attributes={"event_loop_cycle_id": "cycle-2"}) + usage2 = Usage(inputTokens=10, outputTokens=20, totalTokens=30) + event_loop_metrics.update_usage(usage2) + + # Should return the second invocation + current = event_loop_metrics.latest_agent_invocation + assert current is not None + assert current.usage["inputTokens"] == 10 + assert len(current.cycles) == 1 + + assert len(event_loop_metrics.agent_invocations) == 2 + + assert current is event_loop_metrics.agent_invocations[-1] + + +def test_reset_usage_metrics(usage, event_loop_metrics, mock_get_meter_provider): + """Test that reset_usage_metrics creates a new agent invocation but preserves accumulated_usage""" + # Add some usage across multiple cycles in first invocation + event_loop_metrics.reset_usage_metrics() + event_loop_metrics.start_cycle(attributes={"event_loop_cycle_id": "cycle-1"}) + event_loop_metrics.update_usage(usage) + + event_loop_metrics.start_cycle(attributes={"event_loop_cycle_id": "cycle-2"}) + usage2 = Usage(inputTokens=10, outputTokens=20, totalTokens=30) + event_loop_metrics.update_usage(usage2) + + assert len(event_loop_metrics.agent_invocations) == 1 + assert event_loop_metrics.latest_agent_invocation.usage["inputTokens"] == 11 + assert len(event_loop_metrics.latest_agent_invocation.cycles) == 2 + assert event_loop_metrics.accumulated_usage["inputTokens"] == 11 + + # Reset - creates a new invocation + event_loop_metrics.reset_usage_metrics() + + assert len(event_loop_metrics.agent_invocations) == 2 + + assert event_loop_metrics.latest_agent_invocation.usage["inputTokens"] == 0 + assert event_loop_metrics.latest_agent_invocation.usage["outputTokens"] == 0 + assert event_loop_metrics.latest_agent_invocation.usage["totalTokens"] == 0 + assert len(event_loop_metrics.latest_agent_invocation.cycles) == 0 + + # Verify first invocation data is preserved + assert event_loop_metrics.agent_invocations[0].usage["inputTokens"] == 11 + assert len(event_loop_metrics.agent_invocations[0].cycles) == 2 + + # Verify accumulated_usage is NOT cleared + assert event_loop_metrics.accumulated_usage["inputTokens"] == 11 + + +def test_latest_context_size_no_invocations(event_loop_metrics): + assert event_loop_metrics.latest_context_size is None + + +def test_latest_context_size_invocation_with_no_cycles(event_loop_metrics): + event_loop_metrics.reset_usage_metrics() + assert event_loop_metrics.latest_context_size is None + + +def test_latest_context_size_returns_last_cycle(event_loop_metrics, mock_get_meter_provider): + event_loop_metrics.reset_usage_metrics() + event_loop_metrics.start_cycle(attributes={"event_loop_cycle_id": "c1"}) + event_loop_metrics.update_usage(Usage(inputTokens=100, outputTokens=50, totalTokens=150)) + + event_loop_metrics.start_cycle(attributes={"event_loop_cycle_id": "c2"}) + event_loop_metrics.update_usage(Usage(inputTokens=250, outputTokens=80, totalTokens=330)) + + assert event_loop_metrics.latest_context_size == 250 + + +def test_latest_context_size_returns_from_latest_invocation(event_loop_metrics, mock_get_meter_provider): + # First invocation + event_loop_metrics.reset_usage_metrics() + event_loop_metrics.start_cycle(attributes={"event_loop_cycle_id": "c1"}) + event_loop_metrics.update_usage(Usage(inputTokens=100, outputTokens=50, totalTokens=150)) + + # Second invocation + event_loop_metrics.reset_usage_metrics() + event_loop_metrics.start_cycle(attributes={"event_loop_cycle_id": "c2"}) + event_loop_metrics.update_usage(Usage(inputTokens=500, outputTokens=80, totalTokens=580)) + + assert event_loop_metrics.latest_context_size == 500 + + +def test_latest_context_size_missing_input_tokens_key(event_loop_metrics): + """Returns None when usage dict is missing inputTokens (e.g. provider bug).""" + event_loop_metrics.reset_usage_metrics() + invocation = event_loop_metrics.agent_invocations[-1] + invocation.cycles.append( + strands.telemetry.metrics.EventLoopCycleMetric( + event_loop_cycle_id="c1", + usage={"outputTokens": 50, "totalTokens": 50}, + ) + ) + assert event_loop_metrics.latest_context_size is None + + +def test_projected_context_size_no_invocations(event_loop_metrics): + assert event_loop_metrics.projected_context_size is None + + +def test_projected_context_size_invocation_with_no_cycles(event_loop_metrics): + event_loop_metrics.reset_usage_metrics() + assert event_loop_metrics.projected_context_size is None + + +def test_projected_context_size_returns_input_plus_output(event_loop_metrics, mock_get_meter_provider): + event_loop_metrics.reset_usage_metrics() + event_loop_metrics.start_cycle(attributes={"event_loop_cycle_id": "c1"}) + event_loop_metrics.update_usage(Usage(inputTokens=100, outputTokens=50, totalTokens=150)) + + assert event_loop_metrics.projected_context_size == 150 + + +def test_projected_context_size_updates_across_cycles(event_loop_metrics, mock_get_meter_provider): + event_loop_metrics.reset_usage_metrics() + event_loop_metrics.start_cycle(attributes={"event_loop_cycle_id": "c1"}) + event_loop_metrics.update_usage(Usage(inputTokens=100, outputTokens=50, totalTokens=150)) + + event_loop_metrics.start_cycle(attributes={"event_loop_cycle_id": "c2"}) + event_loop_metrics.update_usage(Usage(inputTokens=200, outputTokens=80, totalTokens=280)) + + assert event_loop_metrics.projected_context_size == 280 + + +def test_projected_context_size_missing_tokens_key(event_loop_metrics): + """Returns None when usage dict is missing inputTokens or outputTokens.""" + event_loop_metrics.reset_usage_metrics() + invocation = event_loop_metrics.agent_invocations[-1] + invocation.cycles.append( + strands.telemetry.metrics.EventLoopCycleMetric( + event_loop_cycle_id="c1", + usage={"outputTokens": 50, "totalTokens": 50}, + ) + ) + assert event_loop_metrics.projected_context_size is None diff --git a/tests/strands/telemetry/test_tracer.py b/tests/strands/telemetry/test_tracer.py index 205748956..c7b096a5a 100644 --- a/tests/strands/telemetry/test_tracer.py +++ b/tests/strands/telemetry/test_tracer.py @@ -1,4 +1,5 @@ import json +import logging import os from datetime import date, datetime, timezone from unittest import mock @@ -79,22 +80,11 @@ def test_start_span(mock_tracer): span = tracer._start_span("test_span", attributes={"key": "value"}) mock_tracer.start_span.assert_called_once_with(name="test_span", context=None, kind=SpanKind.INTERNAL) - mock_span.set_attribute.assert_any_call("key", "value") + # Check that set_attributes was called with the provided attributes + mock_span.set_attributes.assert_called_once_with({"key": "value"}) assert span is not None -def test_set_attributes(mock_span): - """Test setting attributes on a span.""" - tracer = Tracer() - attributes = {"str_attr": "value", "int_attr": 123, "bool_attr": True} - - tracer._set_attributes(mock_span, attributes) - - # Check that set_attribute was called for each attribute - calls = [mock.call(k, v) for k, v in attributes.items()] - mock_span.set_attribute.assert_has_calls(calls, any_order=True) - - def test_end_span_no_span(): """Test ending a span when span is None.""" tracer = Tracer() @@ -109,7 +99,8 @@ def test_end_span(mock_span): tracer._end_span(mock_span, attributes) - mock_span.set_attribute.assert_any_call("key", "value") + # Check that set_attributes was called with the provided attributes + mock_span.set_attributes.assert_called_once_with({"key": "value"}) mock_span.set_status.assert_called_once_with(StatusCode.OK) mock_span.end.assert_called_once() @@ -138,6 +129,30 @@ def test_end_span_with_error_message(mock_span): mock_span.end.assert_called_once() +def test_end_span_with_empty_exception_message_uses_exception_name(mock_span): + """Test that empty exception messages fall back to the exception type name.""" + tracer = Tracer() + error = Exception() + + tracer.end_span_with_error(mock_span, "", error) + + mock_span.set_status.assert_called_once_with(StatusCode.ERROR, "Exception") + mock_span.record_exception.assert_called_once_with(error) + mock_span.end.assert_called_once() + + +def test_end_span_with_error_prefers_explicit_message(mock_span): + """Test that an explicit error message takes precedence over the exception text.""" + tracer = Tracer() + error = Exception() + + tracer.end_span_with_error(mock_span, "Explicit error message", error) + + mock_span.set_status.assert_called_once_with(StatusCode.ERROR, "Explicit error message") + mock_span.record_exception.assert_called_once_with(error) + mock_span.end.assert_called_once() + + def test_start_model_invoke_span(mock_tracer): """Test starting a model invoke span.""" with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): @@ -150,22 +165,37 @@ def test_start_model_invoke_span(mock_tracer): messages = [{"role": "user", "content": [{"text": "Hello"}]}] model_id = "test-model" custom_attrs = {"custom_key": "custom_value", "user_id": "12345"} + system_prompt = "You are a helpful assistant" span = tracer.start_model_invoke_span( - messages=messages, agent_name="TestAgent", model_id=model_id, custom_trace_attributes=custom_attrs + messages=messages, + agent_name="TestAgent", + model_id=model_id, + custom_trace_attributes=custom_attrs, + system_prompt=system_prompt, ) mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "chat" assert mock_tracer.start_span.call_args[1]["kind"] == SpanKind.INTERNAL - mock_span.set_attribute.assert_any_call("gen_ai.system", "strands-agents") - mock_span.set_attribute.assert_any_call("gen_ai.operation.name", "chat") - mock_span.set_attribute.assert_any_call("gen_ai.request.model", model_id) - mock_span.set_attribute.assert_any_call("custom_key", "custom_value") - mock_span.set_attribute.assert_any_call("user_id", "12345") - mock_span.add_event.assert_called_with( - "gen_ai.user.message", attributes={"content": json.dumps(messages[0]["content"])} + mock_span.set_attributes.assert_called_once_with( + { + "gen_ai.operation.name": "chat", + "gen_ai.system": "strands-agents", + "custom_key": "custom_value", + "user_id": "12345", + "gen_ai.request.model": model_id, + "agent_name": "TestAgent", + } ) + + calls = mock_span.add_event.call_args_list + assert len(calls) == 2 + assert calls[0] == mock.call( + "gen_ai.system.message", + attributes={"content": serialize([{"text": system_prompt}])}, + ) + assert calls[1] == mock.call("gen_ai.user.message", attributes={"content": json.dumps(messages[0]["content"])}) assert span is not None @@ -189,16 +219,34 @@ def test_start_model_invoke_span_latest_conventions(mock_tracer, monkeypatch): }, ] model_id = "test-model" + system_prompt = "You are a calculator assistant" - span = tracer.start_model_invoke_span(messages=messages, agent_name="TestAgent", model_id=model_id) + span = tracer.start_model_invoke_span( + messages=messages, agent_name="TestAgent", model_id=model_id, system_prompt=system_prompt + ) mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "chat" assert mock_tracer.start_span.call_args[1]["kind"] == SpanKind.INTERNAL - mock_span.set_attribute.assert_any_call("gen_ai.provider.name", "strands-agents") - mock_span.set_attribute.assert_any_call("gen_ai.operation.name", "chat") - mock_span.set_attribute.assert_any_call("gen_ai.request.model", model_id) - mock_span.add_event.assert_called_with( + + mock_span.set_attributes.assert_called_once_with( + { + "gen_ai.operation.name": "chat", + "gen_ai.provider.name": "strands-agents", + "gen_ai.request.model": model_id, + "agent_name": "TestAgent", + } + ) + + calls = mock_span.add_event.call_args_list + assert len(calls) == 2 + assert calls[0] == mock.call( + "gen_ai.client.inference.operation.details", + attributes={ + "gen_ai.system_instructions": serialize([{"type": "text", "content": system_prompt}]), + }, + ) + assert calls[1] == mock.call( "gen_ai.client.inference.operation.details", attributes={ "gen_ai.input.messages": serialize( @@ -225,6 +273,54 @@ def test_start_model_invoke_span_latest_conventions(mock_tracer, monkeypatch): assert span is not None +def test_start_model_invoke_span_without_system_prompt(mock_tracer): + """Test that no system prompt event is emitted when system_prompt is None.""" + with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): + tracer = Tracer() + tracer.tracer = mock_tracer + + mock_span = mock.MagicMock() + mock_tracer.start_span.return_value = mock_span + + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + + span = tracer.start_model_invoke_span(messages=messages, model_id="test-model") + + assert mock_span.add_event.call_count == 1 + mock_span.add_event.assert_called_once_with( + "gen_ai.user.message", attributes={"content": json.dumps(messages[0]["content"])} + ) + assert span is not None + + +def test_start_model_invoke_span_with_system_prompt_content(mock_tracer): + """Test that system_prompt_content takes priority over system_prompt string.""" + with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): + tracer = Tracer() + tracer.tracer = mock_tracer + + mock_span = mock.MagicMock() + mock_tracer.start_span.return_value = mock_span + + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + system_prompt_content = [{"text": "You are helpful"}, {"text": "Be concise"}] + + span = tracer.start_model_invoke_span( + messages=messages, + model_id="test-model", + system_prompt="ignored string", + system_prompt_content=system_prompt_content, + ) + + calls = mock_span.add_event.call_args_list + assert len(calls) == 2 + assert calls[0] == mock.call( + "gen_ai.system.message", + attributes={"content": serialize(system_prompt_content)}, + ) + assert span is not None + + def test_end_model_invoke_span(mock_span): """Test ending a model invoke span.""" tracer = Tracer() @@ -235,13 +331,17 @@ def test_end_model_invoke_span(mock_span): tracer.end_model_invoke_span(mock_span, message, usage, metrics, stop_reason) - mock_span.set_attribute.assert_any_call("gen_ai.usage.prompt_tokens", 10) - mock_span.set_attribute.assert_any_call("gen_ai.usage.input_tokens", 10) - mock_span.set_attribute.assert_any_call("gen_ai.usage.completion_tokens", 20) - mock_span.set_attribute.assert_any_call("gen_ai.usage.output_tokens", 20) - mock_span.set_attribute.assert_any_call("gen_ai.usage.total_tokens", 30) - mock_span.set_attribute.assert_any_call("gen_ai.server.request.duration", 20) - mock_span.set_attribute.assert_any_call("gen_ai.server.time_to_first_token", 10) + mock_span.set_attributes.assert_called_once_with( + { + "gen_ai.usage.prompt_tokens": 10, + "gen_ai.usage.input_tokens": 10, + "gen_ai.usage.completion_tokens": 20, + "gen_ai.usage.output_tokens": 20, + "gen_ai.usage.total_tokens": 30, + "gen_ai.server.time_to_first_token": 10, + "gen_ai.server.request.duration": 20, + } + ) mock_span.add_event.assert_called_with( "gen_ai.choice", attributes={"message": json.dumps(message["content"]), "finish_reason": "end_turn"}, @@ -262,13 +362,17 @@ def test_end_model_invoke_span_latest_conventions(mock_span, monkeypatch): tracer.end_model_invoke_span(mock_span, message, usage, metrics, stop_reason) - mock_span.set_attribute.assert_any_call("gen_ai.usage.prompt_tokens", 10) - mock_span.set_attribute.assert_any_call("gen_ai.usage.input_tokens", 10) - mock_span.set_attribute.assert_any_call("gen_ai.usage.completion_tokens", 20) - mock_span.set_attribute.assert_any_call("gen_ai.usage.output_tokens", 20) - mock_span.set_attribute.assert_any_call("gen_ai.usage.total_tokens", 30) - mock_span.set_attribute.assert_any_call("gen_ai.server.time_to_first_token", 10) - mock_span.set_attribute.assert_any_call("gen_ai.server.request.duration", 20) + mock_span.set_attributes.assert_called_once_with( + { + "gen_ai.usage.prompt_tokens": 10, + "gen_ai.usage.input_tokens": 10, + "gen_ai.usage.completion_tokens": 20, + "gen_ai.usage.output_tokens": 20, + "gen_ai.usage.total_tokens": 30, + "gen_ai.server.time_to_first_token": 10, + "gen_ai.server.request.duration": 20, + } + ) mock_span.add_event.assert_called_with( "gen_ai.client.inference.operation.details", attributes={ @@ -283,7 +387,6 @@ def test_end_model_invoke_span_latest_conventions(mock_span, monkeypatch): ), }, ) - mock_span.set_status.assert_called_once_with(StatusCode.OK) mock_span.end.assert_called_once() @@ -304,12 +407,17 @@ def test_start_tool_call_span(mock_tracer): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "execute_tool test-tool" - mock_span.set_attribute.assert_any_call("gen_ai.tool.name", "test-tool") - mock_span.set_attribute.assert_any_call("gen_ai.system", "strands-agents") - mock_span.set_attribute.assert_any_call("gen_ai.operation.name", "execute_tool") - mock_span.set_attribute.assert_any_call("gen_ai.tool.call.id", "123") - mock_span.set_attribute.assert_any_call("session_id", "abc123") - mock_span.set_attribute.assert_any_call("environment", "production") + + mock_span.set_attributes.assert_called_once_with( + { + "gen_ai.tool.name": "test-tool", + "gen_ai.system": "strands-agents", + "gen_ai.operation.name": "execute_tool", + "gen_ai.tool.call.id": "123", + "session_id": "abc123", + "environment": "production", + } + ) mock_span.add_event.assert_any_call( "gen_ai.tool.message", attributes={"role": "tool", "content": json.dumps({"param": "value"}), "id": "123"} ) @@ -332,10 +440,15 @@ def test_start_tool_call_span_latest_conventions(mock_tracer, monkeypatch): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "execute_tool test-tool" - mock_span.set_attribute.assert_any_call("gen_ai.tool.name", "test-tool") - mock_span.set_attribute.assert_any_call("gen_ai.provider.name", "strands-agents") - mock_span.set_attribute.assert_any_call("gen_ai.operation.name", "execute_tool") - mock_span.set_attribute.assert_any_call("gen_ai.tool.call.id", "123") + + mock_span.set_attributes.assert_called_once_with( + { + "gen_ai.tool.name": "test-tool", + "gen_ai.provider.name": "strands-agents", + "gen_ai.operation.name": "execute_tool", + "gen_ai.tool.call.id": "123", + } + ) mock_span.add_event.assert_called_with( "gen_ai.client.inference.operation.details", attributes={ @@ -375,11 +488,16 @@ def test_start_swarm_call_span_with_string_task(mock_tracer): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "invoke_swarm" - mock_span.set_attribute.assert_any_call("gen_ai.system", "strands-agents") - mock_span.set_attribute.assert_any_call("gen_ai.agent.name", "swarm") - mock_span.set_attribute.assert_any_call("gen_ai.operation.name", "invoke_swarm") - mock_span.set_attribute.assert_any_call("workflow_id", "wf-789") - mock_span.set_attribute.assert_any_call("priority", "high") + + mock_span.set_attributes.assert_called_once_with( + { + "gen_ai.operation.name": "invoke_swarm", + "gen_ai.system": "strands-agents", + "gen_ai.agent.name": "swarm", + "workflow_id": "wf-789", + "priority": "high", + } + ) mock_span.add_event.assert_any_call("gen_ai.user.message", attributes={"content": "Design foo bar"}) assert span is not None @@ -399,9 +517,14 @@ def test_start_swarm_span_with_contentblock_task(mock_tracer): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "invoke_swarm" - mock_span.set_attribute.assert_any_call("gen_ai.system", "strands-agents") - mock_span.set_attribute.assert_any_call("gen_ai.agent.name", "swarm") - mock_span.set_attribute.assert_any_call("gen_ai.operation.name", "invoke_swarm") + + mock_span.set_attributes.assert_called_once_with( + { + "gen_ai.operation.name": "invoke_swarm", + "gen_ai.system": "strands-agents", + "gen_ai.agent.name": "swarm", + } + ) mock_span.add_event.assert_any_call( "gen_ai.user.message", attributes={"content": '[{"text": "Original Task: foo bar"}]'} ) @@ -452,9 +575,14 @@ def test_start_swarm_span_with_contentblock_task_latest_conventions(mock_tracer, mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "invoke_swarm" - mock_span.set_attribute.assert_any_call("gen_ai.provider.name", "strands-agents") - mock_span.set_attribute.assert_any_call("gen_ai.agent.name", "swarm") - mock_span.set_attribute.assert_any_call("gen_ai.operation.name", "invoke_swarm") + + mock_span.set_attributes.assert_called_once_with( + { + "gen_ai.operation.name": "invoke_swarm", + "gen_ai.provider.name": "strands-agents", + "gen_ai.agent.name": "swarm", + } + ) mock_span.add_event.assert_any_call( "gen_ai.client.inference.operation.details", attributes={ @@ -517,10 +645,15 @@ def test_start_graph_call_span(mock_tracer): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "execute_tool test-tool" - mock_span.set_attribute.assert_any_call("gen_ai.tool.name", "test-tool") - mock_span.set_attribute.assert_any_call("gen_ai.system", "strands-agents") - mock_span.set_attribute.assert_any_call("gen_ai.operation.name", "execute_tool") - mock_span.set_attribute.assert_any_call("gen_ai.tool.call.id", "123") + + mock_span.set_attributes.assert_called_once_with( + { + "gen_ai.operation.name": "execute_tool", + "gen_ai.system": "strands-agents", + "gen_ai.tool.name": "test-tool", + "gen_ai.tool.call.id": "123", + } + ) mock_span.add_event.assert_any_call( "gen_ai.tool.message", attributes={"role": "tool", "content": json.dumps({"param": "value"}), "id": "123"} ) @@ -534,7 +667,7 @@ def test_end_tool_call_span(mock_span): tracer.end_tool_call_span(mock_span, tool_result) - mock_span.set_attribute.assert_any_call("gen_ai.tool.status", "success") + mock_span.set_attributes.assert_called_once_with({"gen_ai.tool.status": "success"}) mock_span.add_event.assert_called_with( "gen_ai.choice", attributes={"message": json.dumps(tool_result.get("content")), "id": ""}, @@ -551,7 +684,7 @@ def test_end_tool_call_span_latest_conventions(mock_span, monkeypatch): tracer.end_tool_call_span(mock_span, tool_result) - mock_span.set_attribute.assert_any_call("gen_ai.tool.status", "success") + mock_span.set_attributes.assert_called_once_with({"gen_ai.tool.status": "success"}) mock_span.add_event.assert_called_with( "gen_ai.client.inference.operation.details", attributes={ @@ -575,6 +708,33 @@ def test_end_tool_call_span_latest_conventions(mock_span, monkeypatch): mock_span.end.assert_called_once() +def test_end_tool_call_span_with_error(mock_span): + """Test ending a tool call span with an explicit error sets StatusCode.ERROR.""" + tracer = Tracer() + error = ValueError("tool exploded") + tool_result = {"status": "error", "content": [{"text": "Error: tool exploded"}]} + + tracer.end_tool_call_span(mock_span, tool_result, error=error) + + mock_span.set_attributes.assert_called_once_with({"gen_ai.tool.status": "error"}) + mock_span.set_status.assert_called_once_with(StatusCode.ERROR, "tool exploded") + mock_span.record_exception.assert_called_once_with(error) + mock_span.end.assert_called_once() + + +def test_end_tool_call_span_error_result_no_exception(mock_span): + """Test that an error result without an exception still sets StatusCode.ERROR.""" + tracer = Tracer() + tool_result = {"status": "error", "content": [{"text": "tool cancelled by user"}]} + + tracer.end_tool_call_span(mock_span, tool_result) + + mock_span.set_attributes.assert_called_once_with({"gen_ai.tool.status": "error"}) + mock_span.set_status.assert_called_once_with(StatusCode.ERROR, "tool cancelled by user") + mock_span.record_exception.assert_not_called() + mock_span.end.assert_called_once() + + def test_start_event_loop_cycle_span(mock_tracer): """Test starting an event loop cycle span.""" with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): @@ -594,9 +754,16 @@ def test_start_event_loop_cycle_span(mock_tracer): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "execute_event_loop_cycle" - mock_span.set_attribute.assert_any_call("event_loop.cycle_id", "cycle-123") - mock_span.set_attribute.assert_any_call("request_id", "req-456") - mock_span.set_attribute.assert_any_call("trace_level", "debug") + + mock_span.set_attributes.assert_called_once_with( + { + "gen_ai.operation.name": "execute_event_loop_cycle", + "gen_ai.system": "strands-agents", + "event_loop.cycle_id": "cycle-123", + "request_id": "req-456", + "trace_level": "debug", + } + ) mock_span.add_event.assert_any_call( "gen_ai.user.message", attributes={"content": json.dumps([{"text": "Hello"}])} ) @@ -620,7 +787,14 @@ def test_start_event_loop_cycle_span_latest_conventions(mock_tracer, monkeypatch mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "execute_event_loop_cycle" - mock_span.set_attribute.assert_any_call("event_loop.cycle_id", "cycle-123") + + mock_span.set_attributes.assert_called_once_with( + { + "gen_ai.operation.name": "execute_event_loop_cycle", + "gen_ai.provider.name": "strands-agents", + "event_loop.cycle_id": "cycle-123", + } + ) mock_span.add_event.assert_any_call( "gen_ai.client.inference.operation.details", attributes={ @@ -716,10 +890,17 @@ def test_start_agent_span(mock_tracer): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "invoke_agent WeatherAgent" assert mock_tracer.start_span.call_args[1]["kind"] == SpanKind.INTERNAL - mock_span.set_attribute.assert_any_call("gen_ai.system", "strands-agents") - mock_span.set_attribute.assert_any_call("gen_ai.agent.name", "WeatherAgent") - mock_span.set_attribute.assert_any_call("gen_ai.request.model", model_id) - mock_span.set_attribute.assert_any_call("custom_attr", "value") + + mock_span.set_attributes.assert_called_once_with( + { + "gen_ai.operation.name": "invoke_agent", + "gen_ai.system": "strands-agents", + "gen_ai.agent.name": "WeatherAgent", + "gen_ai.request.model": model_id, + "gen_ai.agent.tools": json.dumps(tools), + "custom_attr": "value", + } + ) mock_span.add_event.assert_any_call("gen_ai.user.message", attributes={"content": json.dumps(content)}) assert span is not None @@ -749,10 +930,17 @@ def test_start_agent_span_latest_conventions(mock_tracer, monkeypatch): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "invoke_agent WeatherAgent" - mock_span.set_attribute.assert_any_call("gen_ai.provider.name", "strands-agents") - mock_span.set_attribute.assert_any_call("gen_ai.agent.name", "WeatherAgent") - mock_span.set_attribute.assert_any_call("gen_ai.request.model", model_id) - mock_span.set_attribute.assert_any_call("custom_attr", "value") + + mock_span.set_attributes.assert_called_once_with( + { + "gen_ai.operation.name": "invoke_agent", + "gen_ai.provider.name": "strands-agents", + "gen_ai.agent.name": "WeatherAgent", + "gen_ai.request.model": model_id, + "gen_ai.agent.tools": json.dumps(tools), + "custom_attr": "value", + } + ) mock_span.add_event.assert_any_call( "gen_ai.client.inference.operation.details", attributes={ @@ -779,13 +967,53 @@ def test_end_agent_span(mock_span): tracer.end_agent_span(mock_span, mock_response) - mock_span.set_attribute.assert_any_call("gen_ai.usage.prompt_tokens", 50) - mock_span.set_attribute.assert_any_call("gen_ai.usage.input_tokens", 50) - mock_span.set_attribute.assert_any_call("gen_ai.usage.completion_tokens", 100) - mock_span.set_attribute.assert_any_call("gen_ai.usage.output_tokens", 100) - mock_span.set_attribute.assert_any_call("gen_ai.usage.total_tokens", 150) - mock_span.set_attribute.assert_any_call("gen_ai.usage.cache_read_input_tokens", 0) - mock_span.set_attribute.assert_any_call("gen_ai.usage.cache_write_input_tokens", 0) + mock_span.set_attributes.assert_called_once_with( + { + "gen_ai.usage.prompt_tokens": 50, + "gen_ai.usage.input_tokens": 50, + "gen_ai.usage.completion_tokens": 100, + "gen_ai.usage.output_tokens": 100, + "gen_ai.usage.total_tokens": 150, + "gen_ai.usage.cache_read_input_tokens": 0, + "gen_ai.usage.cache_write_input_tokens": 0, + } + ) + mock_span.add_event.assert_any_call( + "gen_ai.choice", + attributes={"message": "Agent response", "finish_reason": "end_turn"}, + ) + mock_span.set_status.assert_called_once_with(StatusCode.OK) + mock_span.end.assert_called_once() + + +def test_end_agent_span_with_langfuse_observation_type(mock_span, monkeypatch): + """Test ending an agent span with Langfuse observation type to prevent double counting the tokens.""" + monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "https://us.cloud.langfuse.com") + tracer = Tracer() + + # Mock AgentResult with metrics + mock_metrics = mock.MagicMock() + mock_metrics.accumulated_usage = {"inputTokens": 50, "outputTokens": 100, "totalTokens": 150} + + mock_response = mock.MagicMock() + mock_response.metrics = mock_metrics + mock_response.stop_reason = "end_turn" + mock_response.__str__ = mock.MagicMock(return_value="Agent response") + + tracer.end_agent_span(mock_span, mock_response) + + mock_span.set_attributes.assert_called_once_with( + { + "langfuse.observation.type": "span", + "gen_ai.usage.prompt_tokens": 50, + "gen_ai.usage.input_tokens": 50, + "gen_ai.usage.completion_tokens": 100, + "gen_ai.usage.output_tokens": 100, + "gen_ai.usage.total_tokens": 150, + "gen_ai.usage.cache_read_input_tokens": 0, + "gen_ai.usage.cache_write_input_tokens": 0, + } + ) mock_span.add_event.assert_any_call( "gen_ai.choice", attributes={"message": "Agent response", "finish_reason": "end_turn"}, @@ -810,13 +1038,17 @@ def test_end_agent_span_latest_conventions(mock_span, monkeypatch): tracer.end_agent_span(mock_span, mock_response) - mock_span.set_attribute.assert_any_call("gen_ai.usage.prompt_tokens", 50) - mock_span.set_attribute.assert_any_call("gen_ai.usage.input_tokens", 50) - mock_span.set_attribute.assert_any_call("gen_ai.usage.completion_tokens", 100) - mock_span.set_attribute.assert_any_call("gen_ai.usage.output_tokens", 100) - mock_span.set_attribute.assert_any_call("gen_ai.usage.total_tokens", 150) - mock_span.set_attribute.assert_any_call("gen_ai.usage.cache_read_input_tokens", 0) - mock_span.set_attribute.assert_any_call("gen_ai.usage.cache_write_input_tokens", 0) + mock_span.set_attributes.assert_called_once_with( + { + "gen_ai.usage.prompt_tokens": 50, + "gen_ai.usage.input_tokens": 50, + "gen_ai.usage.completion_tokens": 100, + "gen_ai.usage.output_tokens": 100, + "gen_ai.usage.total_tokens": 150, + "gen_ai.usage.cache_read_input_tokens": 0, + "gen_ai.usage.cache_write_input_tokens": 0, + } + ) mock_span.add_event.assert_called_with( "gen_ai.client.inference.operation.details", attributes={ @@ -835,6 +1067,57 @@ def test_end_agent_span_latest_conventions(mock_span, monkeypatch): mock_span.end.assert_called_once() +def test_end_agent_span_uses_per_invocation_usage_when_opted_in(mock_span, monkeypatch): + """Test that agent span reports per-invocation usage when gen_ai_use_latest_invocation_tokens is set.""" + monkeypatch.setenv("OTEL_SEMCONV_STABILITY_OPT_IN", "gen_ai_use_latest_invocation_tokens") + tracer = Tracer() + + mock_invocation = mock.MagicMock() + mock_invocation.usage = {"inputTokens": 100, "outputTokens": 50, "totalTokens": 150} + + mock_metrics = mock.MagicMock() + mock_metrics.accumulated_usage = {"inputTokens": 1000, "outputTokens": 500, "totalTokens": 1500} + mock_metrics.latest_agent_invocation = mock_invocation + + mock_response = mock.MagicMock() + mock_response.metrics = mock_metrics + mock_response.stop_reason = "end_turn" + mock_response.__str__ = mock.MagicMock(return_value="Agent response") + + tracer.end_agent_span(mock_span, mock_response) + + call_args = mock_span.set_attributes.call_args[0][0] + assert call_args["gen_ai.usage.input_tokens"] == 100 + assert call_args["gen_ai.usage.output_tokens"] == 50 + assert call_args["gen_ai.usage.total_tokens"] == 150 + assert call_args["gen_ai.usage.prompt_tokens"] == 100 + assert call_args["gen_ai.usage.completion_tokens"] == 50 + + +def test_end_agent_span_warns_when_opted_in_but_no_invocations(mock_span, monkeypatch, caplog): + """Test warning and zero usage when opted in but no agent invocations exist.""" + monkeypatch.setenv("OTEL_SEMCONV_STABILITY_OPT_IN", "gen_ai_use_latest_invocation_tokens") + tracer = Tracer() + + mock_metrics = mock.MagicMock() + mock_metrics.accumulated_usage = {"inputTokens": 200, "outputTokens": 100, "totalTokens": 300} + mock_metrics.latest_agent_invocation = None + + mock_response = mock.MagicMock() + mock_response.metrics = mock_metrics + mock_response.stop_reason = "end_turn" + mock_response.__str__ = mock.MagicMock(return_value="Agent response") + + with caplog.at_level(logging.WARNING): + tracer.end_agent_span(mock_span, mock_response) + + assert "latest_agent_invocation is None" in caplog.text + call_args = mock_span.set_attributes.call_args[0][0] + assert call_args["gen_ai.usage.input_tokens"] == 0 + assert call_args["gen_ai.usage.output_tokens"] == 0 + assert call_args["gen_ai.usage.total_tokens"] == 0 + + def test_end_model_invoke_span_with_cache_metrics(mock_span): """Test ending a model invoke span with cache metrics.""" tracer = Tracer() @@ -851,15 +1134,19 @@ def test_end_model_invoke_span_with_cache_metrics(mock_span): tracer.end_model_invoke_span(mock_span, message, usage, metrics, stop_reason) - mock_span.set_attribute.assert_any_call("gen_ai.usage.prompt_tokens", 10) - mock_span.set_attribute.assert_any_call("gen_ai.usage.input_tokens", 10) - mock_span.set_attribute.assert_any_call("gen_ai.usage.completion_tokens", 20) - mock_span.set_attribute.assert_any_call("gen_ai.usage.output_tokens", 20) - mock_span.set_attribute.assert_any_call("gen_ai.usage.total_tokens", 30) - mock_span.set_attribute.assert_any_call("gen_ai.usage.cache_read_input_tokens", 5) - mock_span.set_attribute.assert_any_call("gen_ai.usage.cache_write_input_tokens", 3) - mock_span.set_attribute.assert_any_call("gen_ai.server.request.duration", 10) - mock_span.set_attribute.assert_any_call("gen_ai.server.time_to_first_token", 5) + mock_span.set_attributes.assert_called_once_with( + { + "gen_ai.usage.prompt_tokens": 10, + "gen_ai.usage.input_tokens": 10, + "gen_ai.usage.completion_tokens": 20, + "gen_ai.usage.output_tokens": 20, + "gen_ai.usage.total_tokens": 30, + "gen_ai.usage.cache_read_input_tokens": 5, + "gen_ai.usage.cache_write_input_tokens": 3, + "gen_ai.server.request.duration": 10, + "gen_ai.server.time_to_first_token": 5, + } + ) mock_span.set_status.assert_called_once_with(StatusCode.OK) mock_span.end.assert_called_once() @@ -885,13 +1172,17 @@ def test_end_agent_span_with_cache_metrics(mock_span): tracer.end_agent_span(mock_span, mock_response) - mock_span.set_attribute.assert_any_call("gen_ai.usage.prompt_tokens", 50) - mock_span.set_attribute.assert_any_call("gen_ai.usage.input_tokens", 50) - mock_span.set_attribute.assert_any_call("gen_ai.usage.completion_tokens", 100) - mock_span.set_attribute.assert_any_call("gen_ai.usage.output_tokens", 100) - mock_span.set_attribute.assert_any_call("gen_ai.usage.total_tokens", 150) - mock_span.set_attribute.assert_any_call("gen_ai.usage.cache_read_input_tokens", 25) - mock_span.set_attribute.assert_any_call("gen_ai.usage.cache_write_input_tokens", 10) + mock_span.set_attributes.assert_called_once_with( + { + "gen_ai.usage.prompt_tokens": 50, + "gen_ai.usage.input_tokens": 50, + "gen_ai.usage.completion_tokens": 100, + "gen_ai.usage.output_tokens": 100, + "gen_ai.usage.total_tokens": 150, + "gen_ai.usage.cache_read_input_tokens": 25, + "gen_ai.usage.cache_write_input_tokens": 10, + } + ) mock_span.set_status.assert_called_once_with(StatusCode.OK) mock_span.end.assert_called_once() @@ -943,21 +1234,6 @@ def test_end_span_with_exception_handling(mock_span): pytest.fail("_end_span should not raise exceptions") -def test_force_flush_with_error(mock_span, mock_get_tracer_provider): - """Test force flush with error handling.""" - # Setup the tracer with a provider that raises an exception on force_flush - tracer = Tracer() - - mock_tracer_provider = mock_get_tracer_provider.return_value - mock_tracer_provider.force_flush.side_effect = Exception("Force flush error") - - # Should not raise an exception - tracer._end_span(mock_span) - - # Verify force_flush was called - mock_tracer_provider.force_flush.assert_called_once() - - def test_end_tool_call_span_with_none(mock_span): """Test ending a tool call span with None result.""" tracer = Tracer() @@ -1425,3 +1701,129 @@ def test_start_agent_span_includes_tool_definitions_when_enabled(monkeypatch): ] expected_json = serialize(expected_tool_details) assert attributes["gen_ai.tool.definitions"] == expected_json + + +def test_end_model_invoke_span_langfuse_adds_attributes(mock_span, monkeypatch): + """Test that end_model_invoke_span adds attributes via set_attributes for Langfuse.""" + monkeypatch.setenv("OTEL_SEMCONV_STABILITY_OPT_IN", "gen_ai_latest_experimental") + monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "https://us.cloud.langfuse.com") + + tracer = Tracer() + message = {"role": "assistant", "content": [{"text": "Response"}]} + usage = Usage(inputTokens=10, outputTokens=20, totalTokens=30) + metrics = Metrics(latencyMs=20, timeToFirstByteMs=10) + stop_reason: StopReason = "end_turn" + + tracer.end_model_invoke_span(mock_span, message, usage, metrics, stop_reason) + + expected_output = serialize( + [ + { + "role": "assistant", + "parts": [{"type": "text", "content": "Response"}], + "finish_reason": "end_turn", + } + ] + ) + + assert mock_span.set_attributes.call_count == 2 + mock_span.set_attributes.assert_any_call({"gen_ai.output.messages": expected_output}) + mock_span.set_attributes.assert_any_call( + { + "gen_ai.usage.prompt_tokens": 10, + "gen_ai.usage.input_tokens": 10, + "gen_ai.usage.completion_tokens": 20, + "gen_ai.usage.output_tokens": 20, + "gen_ai.usage.total_tokens": 30, + "gen_ai.server.time_to_first_token": 10, + "gen_ai.server.request.duration": 20, + } + ) + + mock_span.add_event.assert_called_with( + "gen_ai.client.inference.operation.details", + attributes={"gen_ai.output.messages": expected_output}, + ) + + +def test_end_model_invoke_span_non_langfuse_no_extra_attributes(mock_span, monkeypatch): + """Test that end_model_invoke_span doesn't add extra attributes for non-Langfuse endpoints.""" + monkeypatch.setenv("OTEL_SEMCONV_STABILITY_OPT_IN", "gen_ai_latest_experimental") + monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "https://api.honeycomb.io") + + tracer = Tracer() + message = {"role": "assistant", "content": [{"text": "Response"}]} + usage = Usage(inputTokens=10, outputTokens=20, totalTokens=30) + metrics = Metrics(latencyMs=20, timeToFirstByteMs=10) + stop_reason: StopReason = "end_turn" + + tracer.end_model_invoke_span(mock_span, message, usage, metrics, stop_reason) + + # Verify that set_attribute was NOT called with gen_ai.output.messages + # (it should only be in the event, not as an attribute) + expected_output = serialize( + [ + { + "role": "assistant", + "parts": [{"type": "text", "content": "Response"}], + "finish_reason": "end_turn", + } + ] + ) + + # Check that gen_ai.output.messages was not set as an attribute + set_attribute_calls = [call[0][0] for call in mock_span.set_attribute.call_args_list] + assert "gen_ai.output.messages" not in set_attribute_calls + + # But verify that add_event was still called + mock_span.add_event.assert_called_with( + "gen_ai.client.inference.operation.details", + attributes={"gen_ai.output.messages": expected_output}, + ) + + +class TestIsLangfuse: + """Tests for the is_langfuse property.""" + + def test_is_langfuse_with_otel_exporter_otlp_endpoint(self, monkeypatch): + """Test is_langfuse returns True when OTEL_EXPORTER_OTLP_ENDPOINT contains langfuse.""" + monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "https://us.cloud.langfuse.com") + tracer = Tracer() + assert tracer.is_langfuse is True + + def test_is_langfuse_with_otel_exporter_otlp_traces_endpoint(self, monkeypatch): + """Test is_langfuse returns True when OTEL_EXPORTER_OTLP_TRACES_ENDPOINT contains langfuse.""" + monkeypatch.setenv( + "OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", "https://us.cloud.langfuse.com/api/public/otel/v1/traces" + ) + tracer = Tracer() + assert tracer.is_langfuse is True + + def test_is_langfuse_with_langfuse_base_url(self, monkeypatch): + """Test is_langfuse returns True when LANGFUSE_BASE_URL contains langfuse.""" + monkeypatch.setenv("LANGFUSE_BASE_URL", "https://us.cloud.langfuse.com") + tracer = Tracer() + assert tracer.is_langfuse is True + + def test_is_langfuse_false_when_no_langfuse_env_vars(self, monkeypatch): + """Test is_langfuse returns False when no Langfuse-related env vars are set.""" + monkeypatch.delenv("OTEL_EXPORTER_OTLP_ENDPOINT", raising=False) + monkeypatch.delenv("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", raising=False) + monkeypatch.delenv("LANGFUSE_BASE_URL", raising=False) + tracer = Tracer() + assert tracer.is_langfuse is False + + def test_is_langfuse_false_with_non_langfuse_endpoint(self, monkeypatch): + """Test is_langfuse returns False when endpoint is not Langfuse.""" + monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "https://api.honeycomb.io") + monkeypatch.delenv("LANGFUSE_BASE_URL", raising=False) + tracer = Tracer() + assert tracer.is_langfuse is False + + def test_is_langfuse_false_with_non_langfuse_base_url(self, monkeypatch): + """Test is_langfuse returns False when LANGFUSE_BASE_URL doesn't contain langfuse.""" + monkeypatch.setenv("LANGFUSE_BASE_URL", "https://some-other-service.com") + monkeypatch.delenv("OTEL_EXPORTER_OTLP_ENDPOINT", raising=False) + monkeypatch.delenv("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", raising=False) + tracer = Tracer() + assert tracer.is_langfuse is False diff --git a/tests/strands/test_interrupt.py b/tests/strands/test_interrupt.py index 9c14cc63b..5c928cc81 100644 --- a/tests/strands/test_interrupt.py +++ b/tests/strands/test_interrupt.py @@ -127,3 +127,67 @@ def test_interrupt_resume_invalid_id(): exp_message = r"interrupt_id= \| no interrupt found" with pytest.raises(KeyError, match=exp_message): interrupt_state.resume([{"interruptResponse": {"interruptId": "invalid", "response": None}}]) + + +# ============================================================================ +# Version Tracking Tests +# ============================================================================ + + +def test_interrupt_state_version_is_zero_after_initialization(): + """Test that _get_version() returns 0 after initialization.""" + interrupt_state = _InterruptState() + assert interrupt_state._get_version() == 0 + + +def test_interrupt_state_version_increments_after_activate(): + """Test that _get_version() increments after activate() is called.""" + interrupt_state = _InterruptState() + assert interrupt_state._get_version() == 0 + + interrupt_state.activate() + assert interrupt_state._get_version() == 1 + + +def test_interrupt_state_version_increments_after_deactivate(): + """Test that _get_version() increments after deactivate() is called.""" + interrupt_state = _InterruptState(activated=True) + initial_version = interrupt_state._get_version() + + interrupt_state.deactivate() + assert interrupt_state._get_version() == initial_version + 1 + + +def test_interrupt_state_version_increments_after_resume(): + """Test that _get_version() increments after resume() is called.""" + interrupt_state = _InterruptState( + interrupts={"test_id": Interrupt(id="test_id", name="test_name", reason="test reason")}, + activated=True, + ) + initial_version = interrupt_state._get_version() + + prompt = [{"interruptResponse": {"interruptId": "test_id", "response": "test response"}}] + interrupt_state.resume(prompt) + assert interrupt_state._get_version() == initial_version + 1 + + +def test_interrupt_state_version_increments_independently(): + """Test that version increments independently for each operation.""" + interrupt_state = _InterruptState() + assert interrupt_state._get_version() == 0 + + interrupt_state.activate() + assert interrupt_state._get_version() == 1 + + interrupt_state.deactivate() + assert interrupt_state._get_version() == 2 + + +def test_interrupt_state_version_not_in_to_dict(): + """Test that _version is not included in to_dict() output.""" + interrupt_state = _InterruptState() + interrupt_state.activate() + + data = interrupt_state.to_dict() + assert "_version" not in data + assert "version" not in data diff --git a/tests/strands/tools/executors/conftest.py b/tests/strands/tools/executors/conftest.py index ad92ba603..8ecbe2f88 100644 --- a/tests/strands/tools/executors/conftest.py +++ b/tests/strands/tools/executors/conftest.py @@ -1,3 +1,4 @@ +import asyncio import threading import unittest.mock @@ -90,13 +91,24 @@ def func(tool_context: ToolContext) -> str: @pytest.fixture -def tool_registry(weather_tool, temperature_tool, exception_tool, thread_tool, interrupt_tool): +def slow_tool(): + @strands.tool(name="slow_tool") + async def func(): + """A tool that blocks until cancelled.""" + await asyncio.sleep(3) + + return func + + +@pytest.fixture +def tool_registry(weather_tool, temperature_tool, exception_tool, thread_tool, interrupt_tool, slow_tool): registry = ToolRegistry() registry.register_tool(weather_tool) registry.register_tool(temperature_tool) registry.register_tool(exception_tool) registry.register_tool(thread_tool) registry.register_tool(interrupt_tool) + registry.register_tool(slow_tool) return registry diff --git a/tests/strands/tools/executors/test_concurrent.py b/tests/strands/tools/executors/test_concurrent.py index ce07ee4ce..a8ac05830 100644 --- a/tests/strands/tools/executors/test_concurrent.py +++ b/tests/strands/tools/executors/test_concurrent.py @@ -1,6 +1,6 @@ import pytest -from strands.hooks import BeforeToolCallEvent +from strands.hooks import AfterToolCallEvent, BeforeToolCallEvent from strands.interrupt import Interrupt from strands.tools.executors import ConcurrentToolExecutor from strands.tools.structured_output._structured_output_context import StructuredOutputContext @@ -76,3 +76,30 @@ def interrupt_callback(event): tru_results = tool_results exp_results = [exp_events[1].tool_result] assert tru_results == exp_results + + +@pytest.mark.asyncio +async def test_concurrent_executor_reraises_exceptions( + executor, agent, tool_results, cycle_trace, cycle_span, invocation_state, structured_output_context, alist +): + """Test that hook re-raised exceptions propagate and cancel remaining tasks.""" + + def reraise_callback(event): + if event.exception is not None: + raise event.exception + + agent.hooks.add_callback(AfterToolCallEvent, reraise_callback) + + tool_uses = [ + {"name": "exception_tool", "toolUseId": "1", "input": {}}, + {"name": "slow_tool", "toolUseId": "2", "input": {}}, + ] + + stream = executor._execute( + agent, tool_uses, tool_results, cycle_trace, cycle_span, invocation_state, structured_output_context + ) + + with pytest.raises(RuntimeError, match="Tool error"): + await alist(stream) + + assert tool_results == [] diff --git a/tests/strands/tools/executors/test_executor.py b/tests/strands/tools/executors/test_executor.py index 8139fbf66..9c38340b9 100644 --- a/tests/strands/tools/executors/test_executor.py +++ b/tests/strands/tools/executors/test_executor.py @@ -4,6 +4,7 @@ import pytest import strands +from strands.experimental.hooks.events import BidiAfterToolCallEvent from strands.hooks import AfterToolCallEvent, BeforeToolCallEvent from strands.interrupt import Interrupt from strands.telemetry.metrics import Trace @@ -188,6 +189,7 @@ async def test_executor_stream_yields_unknown_tool(executor, agent, tool_results tool_use=tool_use, invocation_state=invocation_state, result=exp_results[0], + exception=unittest.mock.ANY, ) assert tru_hook_after_event == exp_hook_after_event @@ -215,6 +217,7 @@ async def test_executor_stream_with_trace( tracer.end_tool_call_span.assert_called_once_with( tracer.start_tool_call_span.return_value, {"content": [{"text": "sunny"}], "status": "success", "toolUseId": "1"}, + error=None, ) cycle_trace.add_child.assert_called_once() @@ -463,6 +466,57 @@ async def test_executor_stream_tool_interrupt_resume(executor, agent, tool_resul assert tru_results == exp_results +@pytest.mark.asyncio +async def test_executor_stream_tool_interrupt_registers_on_agent( + executor, agent, tool_results, invocation_state, alist +): + """ToolInterruptEvent from a tool should register interrupts in the agent's _interrupt_state.""" + # Create a tool that yields a ToolInterruptEvent with an interrupt NOT pre-registered on the agent + # (simulates _AgentAsTool propagating sub-agent interrupts). + foreign_interrupt = Interrupt(id="sub-agent-interrupt-1", name="approval", reason="need approval") + + @strands.tool(name="agent_tool") + def agent_tool_func(): + return "unused" + + async def mock_stream(_tool_use, _invocation_state, **_kwargs): + yield ToolInterruptEvent(_tool_use, [foreign_interrupt]) + + agent_tool_func.stream = mock_stream + agent.tool_registry.register_tool(agent_tool_func) + + tool_use: ToolUse = {"name": "agent_tool", "toolUseId": "test_tool_id", "input": {}} + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + events = await alist(stream) + + # Should yield the interrupt event + assert len(events) == 1 + assert isinstance(events[0], ToolInterruptEvent) + + # The interrupt should now be registered on the agent's _interrupt_state + assert "sub-agent-interrupt-1" in agent._interrupt_state.interrupts + assert agent._interrupt_state.interrupts["sub-agent-interrupt-1"] is foreign_interrupt + + +@pytest.mark.asyncio +async def test_executor_stream_tool_interrupt_does_not_overwrite_existing( + executor, agent, tool_results, invocation_state, alist +): + """setdefault should not overwrite interrupts already in the agent's state (normal hook case).""" + tool_use = {"name": "interrupt_tool", "toolUseId": "test_tool_id", "input": {}} + + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + await alist(stream) + + # The interrupt_tool hook registered the interrupt via _Interruptible.interrupt(). + # The executor's setdefault should have been a no-op for this pre-registered interrupt. + registered = agent._interrupt_state.interrupts + assert len(registered) == 1 + interrupt = next(iter(registered.values())) + assert interrupt.name == "test_name" + assert interrupt.reason == "test reason" + + @pytest.mark.asyncio async def test_executor_stream_updates_invocation_state_with_agent( executor, agent, tool_results, invocation_state, weather_tool, alist @@ -479,3 +533,461 @@ async def test_executor_stream_updates_invocation_state_with_agent( # Verify that the invocation_state was updated with the agent assert "agent" in empty_invocation_state assert empty_invocation_state["agent"] is agent + + +@pytest.mark.asyncio +async def test_executor_stream_decorated_tool_exception_in_hook( + executor, agent, tool_results, invocation_state, hook_events, alist +): + """Test that exceptions from @tool-decorated functions reach AfterToolCallEvent.""" + exception = ValueError("decorated tool error") + + @strands.tool(name="decorated_error_tool") + def failing_tool(): + """A tool that raises an exception.""" + raise exception + + agent.tool_registry.register_tool(failing_tool) + tool_use = {"name": "decorated_error_tool", "toolUseId": "1", "input": {}} + + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + await alist(stream) + + after_event = hook_events[-1] + assert isinstance(after_event, AfterToolCallEvent) + assert after_event.exception is exception + + +@pytest.mark.asyncio +async def test_executor_stream_decorated_tool_runtime_error_in_hook( + executor, agent, tool_results, invocation_state, hook_events, alist +): + """Test that RuntimeError from @tool-decorated functions reach AfterToolCallEvent.""" + exception = RuntimeError("runtime error from decorated tool") + + @strands.tool(name="runtime_error_tool") + def runtime_error_tool(): + """A tool that raises a RuntimeError.""" + raise exception + + agent.tool_registry.register_tool(runtime_error_tool) + tool_use = {"name": "runtime_error_tool", "toolUseId": "1", "input": {}} + + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + await alist(stream) + + after_event = hook_events[-1] + assert isinstance(after_event, AfterToolCallEvent) + assert after_event.exception is exception + + +@pytest.mark.asyncio +async def test_executor_stream_decorated_tool_no_exception_on_success( + executor, agent, tool_results, invocation_state, hook_events, alist +): + """Test that AfterToolCallEvent.exception is None when decorated tool succeeds.""" + + @strands.tool(name="success_decorated_tool") + def success_tool(): + """A tool that succeeds.""" + return "success" + + agent.tool_registry.register_tool(success_tool) + tool_use = {"name": "success_decorated_tool", "toolUseId": "1", "input": {}} + + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + await alist(stream) + + after_event = hook_events[-1] + assert isinstance(after_event, AfterToolCallEvent) + assert after_event.exception is None + assert after_event.result["status"] == "success" + + +@pytest.mark.asyncio +async def test_executor_stream_decorated_tool_error_result_without_exception( + executor, agent, tool_results, invocation_state, hook_events, alist +): + """Test that exception is None when a tool returns an error result without throwing.""" + + @strands.tool(name="error_result_tool") + def error_result_tool(): + """A tool that returns an error result dict without raising.""" + return {"status": "error", "content": [{"text": "something went wrong"}]} + + agent.tool_registry.register_tool(error_result_tool) + tool_use = {"name": "error_result_tool", "toolUseId": "1", "input": {}} + + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + await alist(stream) + + after_event = hook_events[-1] + assert isinstance(after_event, AfterToolCallEvent) + assert after_event.exception is None + assert after_event.result["status"] == "error" + + +@pytest.mark.asyncio +async def test_executor_stream_no_retry_set(executor, agent, tool_results, invocation_state, alist): + """Test default behavior when retry is not set - tool executes once.""" + call_count = {"count": 0} + + @strands.tool(name="counting_tool") + def counting_tool(): + call_count["count"] += 1 + return f"attempt_{call_count['count']}" + + agent.tool_registry.register_tool(counting_tool) + + tool_use: ToolUse = {"name": "counting_tool", "toolUseId": "1", "input": {}} + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + + tru_events = await alist(stream) + + # Tool should be called exactly once + assert call_count["count"] == 1 + + # Single result event with first attempt's content + assert len(tru_events) == 1 + assert tru_events[0].tool_result == {"toolUseId": "1", "status": "success", "content": [{"text": "attempt_1"}]} + + # tool_results should contain the result + assert len(tool_results) == 1 + assert tool_results[0] == {"toolUseId": "1", "status": "success", "content": [{"text": "attempt_1"}]} + + +@pytest.mark.asyncio +async def test_executor_stream_retry_true(executor, agent, tool_results, invocation_state, alist): + """Test that retry=True causes tool re-execution.""" + call_count = {"count": 0} + + @strands.tool(name="counting_tool") + def counting_tool(): + call_count["count"] += 1 + return f"attempt_{call_count['count']}" + + agent.tool_registry.register_tool(counting_tool) + + # Set retry=True on first call only + def retry_once(event): + if isinstance(event, AfterToolCallEvent) and call_count["count"] == 1: + event.retry = True + return event + + agent.hooks.add_callback(AfterToolCallEvent, retry_once) + + tool_use: ToolUse = {"name": "counting_tool", "toolUseId": "1", "input": {}} + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + + tru_events = await alist(stream) + + # Tool should be called twice due to retry + assert call_count["count"] == 2 + + # Only final result is yielded (first attempt's result was discarded) + assert len(tru_events) == 1 + assert tru_events[0].tool_result == {"toolUseId": "1", "status": "success", "content": [{"text": "attempt_2"}]} + + # tool_results only contains the final result + assert len(tool_results) == 1 + assert tool_results[0] == {"toolUseId": "1", "status": "success", "content": [{"text": "attempt_2"}]} + + +@pytest.mark.asyncio +async def test_executor_stream_retry_true_emits_events_from_both_attempts( + executor, agent, tool_results, invocation_state, alist +): + """Test that ToolStreamEvents from discarded attempt ARE emitted, but ToolResultEvent is NOT. + + This validates the documented behavior: 'Streaming events from the discarded + tool execution will have already been emitted to callers before the retry occurs.' + + Key distinction: + - ToolStreamEvent (intermediate): Yielded immediately, visible from BOTH attempts + - ToolResultEvent (final): Only yielded for the final attempt, discarded on retry + """ + call_count = {"count": 0} + + @strands.tool(name="streaming_tool") + def streaming_tool(): + return "unused" + + # Provide streaming implementation (same pattern as exception_tool fixture) + async def tool_stream(_tool_use, _invocation_state, **kwargs): + call_count["count"] += 1 + yield f"streaming_from_attempt_{call_count['count']}" + yield ToolResultEvent( + {"toolUseId": "1", "status": "success", "content": [{"text": f"result_{call_count['count']}"}]} + ) + + streaming_tool.stream = tool_stream + agent.tool_registry.register_tool(streaming_tool) + + # Set retry=True on first call + def retry_once(event): + if isinstance(event, AfterToolCallEvent) and call_count["count"] == 1: + event.retry = True + return event + + agent.hooks.add_callback(AfterToolCallEvent, retry_once) + + tool_use: ToolUse = {"name": "streaming_tool", "toolUseId": "1", "input": {}} + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + + tru_events = await alist(stream) + + # Tool called twice + assert call_count["count"] == 2 + + # Streaming events from BOTH attempts are emitted (documented behavior) + stream_events = [e for e in tru_events if isinstance(e, ToolStreamEvent)] + assert len(stream_events) == 2 + assert stream_events[0] == ToolStreamEvent(tool_use, "streaming_from_attempt_1") + assert stream_events[1] == ToolStreamEvent(tool_use, "streaming_from_attempt_2") + + # Only final ToolResultEvent is emitted + result_events = [e for e in tru_events if isinstance(e, ToolResultEvent)] + assert len(result_events) == 1 + assert result_events[0].tool_result["content"][0]["text"] == "result_2" + + +@pytest.mark.asyncio +async def test_executor_stream_retry_false(executor, agent, tool_results, invocation_state, alist): + """Test that explicitly setting retry=False does not retry.""" + call_count = {"count": 0} + + @strands.tool(name="counting_tool") + def counting_tool(): + call_count["count"] += 1 + return f"attempt_{call_count['count']}" + + agent.tool_registry.register_tool(counting_tool) + + # Explicitly set retry=False + def no_retry(event): + if isinstance(event, AfterToolCallEvent): + event.retry = False + return event + + agent.hooks.add_callback(AfterToolCallEvent, no_retry) + + tool_use: ToolUse = {"name": "counting_tool", "toolUseId": "1", "input": {}} + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + + tru_events = await alist(stream) + + # Tool should be called exactly once + assert call_count["count"] == 1 + + # Single result event + assert len(tru_events) == 1 + assert tru_events[0].tool_result == {"toolUseId": "1", "status": "success", "content": [{"text": "attempt_1"}]} + + # tool_results should contain the result + assert len(tool_results) == 1 + assert tool_results[0] == {"toolUseId": "1", "status": "success", "content": [{"text": "attempt_1"}]} + + +@pytest.mark.asyncio +async def test_executor_stream_bidi_event_no_retry_attribute(executor, agent, tool_results, invocation_state, alist): + """Test that BidiAfterToolCallEvent (which lacks retry attribute) doesn't cause retry. + + This tests the getattr(after_event, "retry", False) fallback for events without retry. + """ + call_count = {"count": 0} + + @strands.tool(name="counting_tool") + def counting_tool(): + call_count["count"] += 1 + return f"attempt_{call_count['count']}" + + agent.tool_registry.register_tool(counting_tool) + + tool_use: ToolUse = {"name": "counting_tool", "toolUseId": "1", "input": {}} + result: strands.types.tools.ToolResult = { + "toolUseId": "1", + "status": "success", + "content": [{"text": "attempt_1"}], + } + + # Create a BidiAfterToolCallEvent (which has no retry attribute) + bidi_event = BidiAfterToolCallEvent( + agent=agent, + selected_tool=counting_tool, + tool_use=tool_use, + invocation_state=invocation_state, + result=result, + ) + + # Patch _invoke_after_tool_call_hook to return BidiAfterToolCallEvent + async def mock_after_hook(*args, **kwargs): + return bidi_event, [] + + with unittest.mock.patch.object(ToolExecutor, "_invoke_after_tool_call_hook", mock_after_hook): + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + tru_events = await alist(stream) + + # Tool should be called once - no retry since BidiAfterToolCallEvent has no retry attr + assert call_count["count"] == 1 + + # Result should be returned + assert len(tru_events) == 1 + + +@pytest.mark.asyncio +async def test_executor_stream_retry_after_exception(executor, agent, tool_results, invocation_state, alist): + """Test that retry=True works when tool raises an exception. + + Covers the exception path retry check. + """ + call_count = {"count": 0} + + @strands.tool(name="flaky_tool") + def flaky_tool(): + call_count["count"] += 1 + if call_count["count"] == 1: + raise RuntimeError("First call fails") + return "success" + + agent.tool_registry.register_tool(flaky_tool) + + # Retry once on error (check result status, not exception attribute) + def retry_on_error(event): + if isinstance(event, AfterToolCallEvent) and event.result.get("status") == "error" and call_count["count"] == 1: + event.retry = True + return event + + agent.hooks.add_callback(AfterToolCallEvent, retry_on_error) + + tool_use: ToolUse = {"name": "flaky_tool", "toolUseId": "1", "input": {}} + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + tru_events = await alist(stream) + + # Tool called twice (1 exception + 1 success) + assert call_count["count"] == 2 + + # Final result is success + assert len(tru_events) == 1 + assert tru_events[0].tool_result["status"] == "success" + + +@pytest.mark.asyncio +async def test_executor_stream_retry_after_unknown_tool(executor, agent, tool_results, invocation_state, alist): + """Test that retry=True triggers retry loop for unknown tool. + + Covers the unknown tool path retry check. Tool lookup happens before retry loop, + so even after retry the tool remains unknown - this test verifies the retry + mechanism is triggered, not that it resolves the unknown tool. + """ + hook_call_count = {"count": 0} + + # Retry once on first unknown tool error + def retry_once_on_unknown(event): + if isinstance(event, AfterToolCallEvent): + hook_call_count["count"] += 1 + # Retry only on first call + if hook_call_count["count"] == 1: + event.retry = True + return event + + agent.hooks.add_callback(AfterToolCallEvent, retry_once_on_unknown) + + tool_use: ToolUse = {"name": "nonexistent_tool", "toolUseId": "1", "input": {}} + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + tru_events = await alist(stream) + + # Hook called twice (retry was triggered) + assert hook_call_count["count"] == 2 + + # Final result is still error (tool remains unknown after retry) + assert len(tru_events) == 1 + assert tru_events[0].tool_result["status"] == "error" + assert "Unknown tool" in tru_events[0].tool_result["content"][0]["text"] + + +@pytest.mark.asyncio +async def test_executor_stream_with_trace_error( + executor, tracer, agent, tool_results, cycle_trace, cycle_span, invocation_state, alist +): + """Test that _stream_with_trace passes the exception to end_tool_call_span when a tool fails.""" + tool_use: ToolUse = {"name": "exception_tool", "toolUseId": "1", "input": {}} + stream = executor._stream_with_trace(agent, tool_use, tool_results, cycle_trace, cycle_span, invocation_state) + + await alist(stream) + + tracer.end_tool_call_span.assert_called_once() + call_args = tracer.end_tool_call_span.call_args + assert call_args[0][1]["status"] == "error" + error_arg = call_args[1].get("error") + assert error_arg is not None + assert isinstance(error_arg, RuntimeError) + assert "Tool error" in str(error_arg) + + +@pytest.mark.asyncio +async def test_executor_stream_error_preserves_exception(executor, agent, tool_results, invocation_state, alist): + """Test that _stream yields a ToolResultEvent with the exception preserved.""" + tool_use: ToolUse = {"name": "exception_tool", "toolUseId": "1", "input": {}} + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + + events = await alist(stream) + result_event = events[-1] + assert isinstance(result_event, ToolResultEvent) + assert result_event.tool_result["status"] == "error" + assert result_event.exception is not None + assert isinstance(result_event.exception, RuntimeError) + assert "Tool error" in str(result_event.exception) + + +@pytest.mark.asyncio +async def test_executor_stream_unknown_tool_has_exception(executor, agent, tool_results, invocation_state, alist): + """Test that _stream yields a ToolResultEvent with exception for unknown tools.""" + tool_use: ToolUse = {"name": "nonexistent_tool", "toolUseId": "1", "input": {}} + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + + events = await alist(stream) + result_event = events[-1] + assert isinstance(result_event, ToolResultEvent) + assert result_event.tool_result["status"] == "error" + assert result_event.exception is not None + assert "Unknown tool" in str(result_event.exception) + + +@pytest.mark.asyncio +async def test_executor_stream_cancel_no_exception(executor, agent, tool_results, invocation_state, alist): + """Test that _stream yields a ToolResultEvent with no exception for cancelled tools.""" + + def cancel_callback(event): + event.cancel_tool = True + return event + + agent.hooks.add_callback(BeforeToolCallEvent, cancel_callback) + tool_use: ToolUse = {"name": "weather_tool", "toolUseId": "1", "input": {}} + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + + events = await alist(stream) + result_event = events[-1] + assert isinstance(result_event, ToolResultEvent) + assert result_event.tool_result["status"] == "error" + assert result_event.exception is None + + +@pytest.mark.asyncio +async def test_executor_stream_cancel_after_hook_sees_no_exception( + executor, agent, tool_results, invocation_state, hook_events, alist +): + """Test that AfterToolCallEvent.exception is None when a tool is cancelled.""" + + def cancel_callback(event): + event.cancel_tool = "user denied permission" + return event + + agent.hooks.add_callback(BeforeToolCallEvent, cancel_callback) + tool_use: ToolUse = {"name": "weather_tool", "toolUseId": "1", "input": {}} + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + await alist(stream) + + after_event = hook_events[-1] + assert isinstance(after_event, AfterToolCallEvent) + assert after_event.exception is None + assert after_event.cancel_message == "user denied permission" diff --git a/tests/strands/tools/mcp/conftest.py b/tests/strands/tools/mcp/conftest.py new file mode 100644 index 000000000..d0ac46bdc --- /dev/null +++ b/tests/strands/tools/mcp/conftest.py @@ -0,0 +1,61 @@ +"""Shared fixtures and helpers for MCP client tests.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + + +@pytest.fixture +def mock_transport(): + """Create a mock MCP transport.""" + mock_read_stream = AsyncMock() + mock_write_stream = AsyncMock() + mock_transport_cm = AsyncMock() + mock_transport_cm.__aenter__.return_value = (mock_read_stream, mock_write_stream) + mock_transport_callable = MagicMock(return_value=mock_transport_cm) + + return { + "read_stream": mock_read_stream, + "write_stream": mock_write_stream, + "transport_cm": mock_transport_cm, + "transport_callable": mock_transport_callable, + } + + +@pytest.fixture +def mock_session(): + """Create a mock MCP session.""" + mock_session = AsyncMock() + mock_init_result = MagicMock() + mock_init_result.instructions = None + mock_session.initialize = AsyncMock(return_value=mock_init_result) + # Default: no task support (get_server_capabilities is sync, not async!) + mock_session.get_server_capabilities = MagicMock(return_value=None) + + # Create a mock context manager for ClientSession + mock_session_cm = AsyncMock() + mock_session_cm.__aenter__.return_value = mock_session + + # Patch ClientSession to return our mock session + with patch("strands.tools.mcp.mcp_client.ClientSession", return_value=mock_session_cm): + yield mock_session + + +def create_server_capabilities(has_task_support: bool) -> MagicMock: + """Create mock server capabilities. + + Args: + has_task_support: Whether the server should advertise task support. + + Returns: + MagicMock representing server capabilities. + """ + caps = MagicMock() + if has_task_support: + caps.tasks = MagicMock() + caps.tasks.requests = MagicMock() + caps.tasks.requests.tools = MagicMock() + caps.tasks.requests.tools.call = MagicMock() + else: + caps.tasks = None + return caps diff --git a/tests/strands/tools/mcp/test_mcp_client.py b/tests/strands/tools/mcp/test_mcp_client.py index ec77b48a2..f270fa6fc 100644 --- a/tests/strands/tools/mcp/test_mcp_client.py +++ b/tests/strands/tools/mcp/test_mcp_client.py @@ -1,47 +1,31 @@ import base64 import time -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import MagicMock, patch import pytest from mcp import ListToolsResult from mcp.types import CallToolResult as MCPCallToolResult -from mcp.types import GetPromptResult, ListPromptsResult, Prompt, PromptMessage +from mcp.types import ( + GetPromptResult, + ListPromptsResult, + ListResourcesResult, + ListResourceTemplatesResult, + Prompt, + PromptMessage, + ReadResourceResult, + Resource, + ResourceTemplate, + TextResourceContents, +) from mcp.types import TextContent as MCPTextContent from mcp.types import Tool as MCPTool +from pydantic import AnyUrl from strands.tools.mcp import MCPClient from strands.tools.mcp.mcp_types import MCPToolResult from strands.types.exceptions import MCPClientInitializationError - -@pytest.fixture -def mock_transport(): - mock_read_stream = AsyncMock() - mock_write_stream = AsyncMock() - mock_transport_cm = AsyncMock() - mock_transport_cm.__aenter__.return_value = (mock_read_stream, mock_write_stream) - mock_transport_callable = MagicMock(return_value=mock_transport_cm) - - return { - "read_stream": mock_read_stream, - "write_stream": mock_write_stream, - "transport_cm": mock_transport_cm, - "transport_callable": mock_transport_callable, - } - - -@pytest.fixture -def mock_session(): - mock_session = AsyncMock() - mock_session.initialize = AsyncMock() - - # Create a mock context manager for ClientSession - mock_session_cm = AsyncMock() - mock_session_cm.__aenter__.return_value = mock_session - - # Patch ClientSession to return our mock session - with patch("strands.tools.mcp.mcp_client.ClientSession", return_value=mock_session_cm): - yield mock_session +# Fixtures mock_transport and mock_session are imported from conftest.py @pytest.fixture @@ -66,6 +50,20 @@ def test_mcp_client_context_manager(mock_transport, mock_session): assert client._background_thread is None +def test_server_instructions_default(mock_transport, mock_session): + """Test that server_instructions defaults to None when server returns None.""" + mock_session.initialize.return_value.instructions = None + with MCPClient(mock_transport["transport_callable"]) as client: + assert client.server_instructions is None + + +def test_server_instructions_from_server(mock_transport, mock_session): + """Test that server_instructions is populated from InitializeResult.""" + mock_session.initialize.return_value.instructions = "Use tool A before tool B." + with MCPClient(mock_transport["transport_callable"]) as client: + assert client.server_instructions == "Use tool A before tool B." + + def test_list_tools_sync(mock_transport, mock_session): """Test that list_tools_sync correctly retrieves and adapts tools.""" mock_tool = MCPTool(name="test_tool", description="A test tool", inputSchema={"type": "object", "properties": {}}) @@ -126,7 +124,7 @@ def test_call_tool_sync_status(mock_transport, mock_session, is_error, expected_ with MCPClient(mock_transport["transport_callable"]) as client: result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={"param": "value"}) - mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None) + mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None, meta=None) assert result["status"] == expected_status assert result["toolUseId"] == "test-123" @@ -134,6 +132,8 @@ def test_call_tool_sync_status(mock_transport, mock_session, is_error, expected_ assert result["content"][0]["text"] == "Test message" # No structured content should be present when not provided by MCP assert result.get("structuredContent") is None + # isError mirrors the MCP server's explicit value; absent only for protocol/client exceptions + assert result.get("isError") is is_error def test_call_tool_sync_session_not_active(): @@ -155,7 +155,7 @@ def test_call_tool_sync_with_structured_content(mock_transport, mock_session): with MCPClient(mock_transport["transport_callable"]) as client: result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={"param": "value"}) - mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None) + mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None, meta=None) assert result["status"] == "success" assert result["toolUseId"] == "test-123" @@ -182,6 +182,51 @@ def test_call_tool_sync_exception(mock_transport, mock_session): assert "Test exception" in result["content"][0]["text"] +def test_call_tool_sync_forwards_meta(mock_transport, mock_session): + """Test that call_tool_sync forwards meta to ClientSession.call_tool.""" + mock_content = MCPTextContent(type="text", text="Test message") + mock_session.call_tool.return_value = MCPCallToolResult(isError=False, content=[mock_content]) + meta = {"com.example/request_id": "abc-123"} + + with MCPClient(mock_transport["transport_callable"]) as client: + result = client.call_tool_sync( + tool_use_id="test-123", name="test_tool", arguments={"param": "value"}, meta=meta + ) + + mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None, meta=meta) + assert result["status"] == "success" + + +@pytest.mark.asyncio +async def test_call_tool_async_forwards_meta(mock_transport, mock_session): + """Test that call_tool_async forwards meta to ClientSession.call_tool.""" + mock_content = MCPTextContent(type="text", text="Test message") + mock_result = MCPCallToolResult(isError=False, content=[mock_content]) + mock_session.call_tool.return_value = mock_result + meta = {"com.example/request_id": "abc-123"} + + with MCPClient(mock_transport["transport_callable"]) as client: + with ( + patch("asyncio.run_coroutine_threadsafe") as mock_run_coroutine_threadsafe, + patch("asyncio.wrap_future") as mock_wrap_future, + ): + mock_future = MagicMock() + mock_run_coroutine_threadsafe.return_value = mock_future + + async def mock_awaitable(): + return mock_result + + mock_wrap_future.return_value = mock_awaitable() + + result = await client.call_tool_async( + tool_use_id="test-123", name="test_tool", arguments={"param": "value"}, meta=meta + ) + + mock_run_coroutine_threadsafe.assert_called_once() + + assert result["status"] == "success" + + @pytest.mark.asyncio @pytest.mark.parametrize("is_error,expected_status", [(False, "success"), (True, "error")]) async def test_call_tool_async_status(mock_transport, mock_session, is_error, expected_status): @@ -218,6 +263,8 @@ async def mock_awaitable(): assert result["toolUseId"] == "test-123" assert len(result["content"]) == 1 assert result["content"][0]["text"] == "Test message" + # isError mirrors the MCP server's explicit value; absent only for protocol/client exceptions + assert result.get("isError") is is_error @pytest.mark.asyncio @@ -339,7 +386,10 @@ def test_enter_with_initialization_exception(mock_transport): client = MCPClient(mock_transport["transport_callable"]) with patch.object(client, "stop") as mock_stop: - with pytest.raises(MCPClientInitializationError, match="the client initialization failed"): + with pytest.raises( + MCPClientInitializationError, + match="the client initialization failed: Transport initialization failed", + ): client.start() # Verify stop() was called for cleanup @@ -365,6 +415,15 @@ def test_mcp_tool_result_type(): assert result_with_structured["structuredContent"] == {"key": "value"} + # isError is optional — absent by default + assert "isError" not in result + + # isError can be set to flag tool-reported application errors + result_with_is_error = MCPToolResult( + status="error", toolUseId="test-789", content=[{"text": "Tool failed"}], isError=True + ) + assert result_with_is_error["isError"] is True + def test_call_tool_sync_without_structured_content(mock_transport, mock_session): """Test that call_tool_sync works correctly when no structured content is provided.""" @@ -524,6 +583,60 @@ def test_stop_with_background_thread_but_no_event_loop(): assert client._background_thread is None +def test_stop_closes_event_loop(): + """Test that stop() properly closes the event loop when it exists.""" + client = MCPClient(MagicMock()) + + # Mock a background thread with event loop + mock_thread = MagicMock() + mock_thread.join = MagicMock() + mock_event_loop = MagicMock() + mock_event_loop.close = MagicMock() + + client._background_thread = mock_thread + client._background_thread_event_loop = mock_event_loop + + # Should close the event loop and join the thread + client.stop(None, None, None) + + # Verify thread was joined + mock_thread.join.assert_called_once() + + # Verify event loop was closed + mock_event_loop.close.assert_called_once() + + # Verify cleanup occurred + assert client._background_thread is None + assert client._background_thread_event_loop is None + + +def test_stop_skips_cleanup_during_interpreter_finalization(): + """Test that stop() is a no-op when the interpreter is finalizing. + + On Python 3.14+, threading.Thread.join() raises PythonFinalizationError at + shutdown. The background thread is a daemon and is reclaimed automatically, + so stop() should skip join() and event loop cleanup to avoid noisy + tracebacks surfaced via Agent.__del__ during GC. See issue #2143. + """ + client = MCPClient(MagicMock()) + + mock_thread = MagicMock() + mock_event_loop = MagicMock() + client._background_thread = mock_thread + client._background_thread_event_loop = mock_event_loop + + with patch("strands.tools.mcp.mcp_client.sys.is_finalizing", return_value=True): + # Must not raise, and must not touch the thread or event loop. + client.stop(None, None, None) + + mock_thread.join.assert_not_called() + mock_event_loop.close.assert_not_called() + # State is intentionally left alone during finalization — the interpreter + # is going away and cleanup is unnecessary. + assert client._background_thread is mock_thread + assert client._background_thread_event_loop is mock_event_loop + + def test_mcp_client_state_reset_after_timeout(): """Test that all client state is properly reset after timeout.""" @@ -559,7 +672,7 @@ def test_call_tool_sync_embedded_nested_text(mock_transport, mock_session): with MCPClient(mock_transport["transport_callable"]) as client: result = client.call_tool_sync(tool_use_id="er-text", name="get_file_contents", arguments={}) - mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None) + mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, meta=None) assert result["status"] == "success" assert len(result["content"]) == 1 assert result["content"][0]["text"] == "inner text" @@ -584,7 +697,7 @@ def test_call_tool_sync_embedded_nested_base64_textual_mime(mock_transport, mock with MCPClient(mock_transport["transport_callable"]) as client: result = client.call_tool_sync(tool_use_id="er-blob", name="get_file_contents", arguments={}) - mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None) + mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, meta=None) assert result["status"] == "success" assert len(result["content"]) == 1 assert result["content"][0]["text"] == '{"k":"v"}' @@ -593,7 +706,7 @@ def test_call_tool_sync_embedded_nested_base64_textual_mime(mock_transport, mock def test_call_tool_sync_embedded_image_blob(mock_transport, mock_session): """EmbeddedResource.resource (blob with image MIME) should map to image content.""" # Read yellow.png file - with open("tests_integ/yellow.png", "rb") as image_file: + with open("tests_integ/resources/yellow.png", "rb") as image_file: png_data = image_file.read() payload = base64.b64encode(png_data).decode() @@ -610,7 +723,7 @@ def test_call_tool_sync_embedded_image_blob(mock_transport, mock_session): with MCPClient(mock_transport["transport_callable"]) as client: result = client.call_tool_sync(tool_use_id="er-image", name="get_file_contents", arguments={}) - mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None) + mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, meta=None) assert result["status"] == "success" assert len(result["content"]) == 1 assert "image" in result["content"][0] @@ -635,7 +748,7 @@ def test_call_tool_sync_embedded_non_textual_blob_dropped(mock_transport, mock_s with MCPClient(mock_transport["transport_callable"]) as client: result = client.call_tool_sync(tool_use_id="er-binary", name="get_file_contents", arguments={}) - mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None) + mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, meta=None) assert result["status"] == "success" assert len(result["content"]) == 0 # Content should be dropped @@ -658,7 +771,7 @@ def test_call_tool_sync_embedded_multiple_textual_mimes(mock_transport, mock_ses with MCPClient(mock_transport["transport_callable"]) as client: result = client.call_tool_sync(tool_use_id="er-yaml", name="get_file_contents", arguments={}) - mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None) + mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, meta=None) assert result["status"] == "success" assert len(result["content"]) == 1 assert "key: value" in result["content"][0]["text"] @@ -685,7 +798,7 @@ def __init__(self): with MCPClient(mock_transport["transport_callable"]) as client: result = client.call_tool_sync(tool_use_id="er-unknown", name="get_file_contents", arguments={}) - mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None) + mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, meta=None) assert result["status"] == "success" assert len(result["content"]) == 0 # Unknown resource type should be dropped @@ -723,3 +836,332 @@ async def test_handle_error_message_non_exception(): # This should not raise an exception await client._handle_error_message("normal message") + + +def test_call_tool_sync_with_meta_and_structured_content(mock_transport, mock_session): + """Test that call_tool_sync correctly handles both meta and structuredContent fields.""" + mock_content = MCPTextContent(type="text", text="Test message") + metadata = {"tokenUsage": {"inputTokens": 100, "outputTokens": 50}} + structured_content = {"result": 42, "status": "completed"} + mock_session.call_tool.return_value = MCPCallToolResult( + isError=False, content=[mock_content], _meta=metadata, structuredContent=structured_content + ) + + with MCPClient(mock_transport["transport_callable"]) as client: + result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={"param": "value"}) + + mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None, meta=None) + + assert result["status"] == "success" + assert result["toolUseId"] == "test-123" + assert "metadata" in result + assert result["metadata"] == metadata + assert "structuredContent" in result + assert result["structuredContent"] == structured_content + + +# Resource Tests - Sync Methods + + +def test_list_resources_sync(mock_transport, mock_session): + """Test that list_resources_sync correctly retrieves resources.""" + mock_resource = Resource( + uri=AnyUrl("file://documents/test.txt"), name="test.txt", description="A test document", mimeType="text/plain" + ) + mock_session.list_resources.return_value = ListResourcesResult(resources=[mock_resource]) + + with MCPClient(mock_transport["transport_callable"]) as client: + result = client.list_resources_sync() + + mock_session.list_resources.assert_called_once_with(cursor=None) + assert len(result.resources) == 1 + assert result.resources[0].name == "test.txt" + assert str(result.resources[0].uri) == "file://documents/test.txt" + assert result.nextCursor is None + + +def test_list_resources_sync_with_pagination_token(mock_transport, mock_session): + """Test that list_resources_sync correctly passes pagination token and returns next cursor.""" + mock_resource = Resource( + uri=AnyUrl("file://documents/test.txt"), name="test.txt", description="A test document", mimeType="text/plain" + ) + mock_session.list_resources.return_value = ListResourcesResult(resources=[mock_resource], nextCursor="next_page") + + with MCPClient(mock_transport["transport_callable"]) as client: + result = client.list_resources_sync(pagination_token="current_page") + + mock_session.list_resources.assert_called_once_with(cursor="current_page") + assert len(result.resources) == 1 + assert result.resources[0].name == "test.txt" + assert result.nextCursor == "next_page" + + +def test_list_resources_sync_session_not_active(): + """Test that list_resources_sync raises an error when session is not active.""" + client = MCPClient(MagicMock()) + + with pytest.raises(MCPClientInitializationError, match="client session is not running"): + client.list_resources_sync() + + +def test_read_resource_sync(mock_transport, mock_session): + """Test that read_resource_sync correctly reads a resource.""" + mock_content = TextResourceContents( + uri=AnyUrl("file://documents/test.txt"), text="Resource content", mimeType="text/plain" + ) + mock_session.read_resource.return_value = ReadResourceResult(contents=[mock_content]) + + with MCPClient(mock_transport["transport_callable"]) as client: + result = client.read_resource_sync("file://documents/test.txt") + + # Verify the session method was called + mock_session.read_resource.assert_called_once() + # Check the URI argument (it will be wrapped as AnyUrl) + call_args = mock_session.read_resource.call_args[0] + assert str(call_args[0]) == "file://documents/test.txt" + + assert len(result.contents) == 1 + assert result.contents[0].text == "Resource content" + + +def test_read_resource_sync_with_anyurl(mock_transport, mock_session): + """Test that read_resource_sync correctly handles AnyUrl input.""" + mock_content = TextResourceContents( + uri=AnyUrl("file://documents/test.txt"), text="Resource content", mimeType="text/plain" + ) + mock_session.read_resource.return_value = ReadResourceResult(contents=[mock_content]) + + with MCPClient(mock_transport["transport_callable"]) as client: + uri = AnyUrl("file://documents/test.txt") + result = client.read_resource_sync(uri) + + mock_session.read_resource.assert_called_once() + call_args = mock_session.read_resource.call_args[0] + assert str(call_args[0]) == "file://documents/test.txt" + + assert len(result.contents) == 1 + assert result.contents[0].text == "Resource content" + + +def test_read_resource_sync_session_not_active(): + """Test that read_resource_sync raises an error when session is not active.""" + client = MCPClient(MagicMock()) + + with pytest.raises(MCPClientInitializationError, match="client session is not running"): + client.read_resource_sync("file://documents/test.txt") + + +def test_list_resource_templates_sync(mock_transport, mock_session): + """Test that list_resource_templates_sync correctly retrieves resource templates.""" + mock_template = ResourceTemplate( + uriTemplate="file://documents/{name}", + name="document_template", + description="Template for documents", + mimeType="text/plain", + ) + mock_session.list_resource_templates.return_value = ListResourceTemplatesResult(resourceTemplates=[mock_template]) + + with MCPClient(mock_transport["transport_callable"]) as client: + result = client.list_resource_templates_sync() + + mock_session.list_resource_templates.assert_called_once_with(cursor=None) + assert len(result.resourceTemplates) == 1 + assert result.resourceTemplates[0].name == "document_template" + assert result.resourceTemplates[0].uriTemplate == "file://documents/{name}" + assert result.nextCursor is None + + +def test_list_resource_templates_sync_with_pagination_token(mock_transport, mock_session): + """Test that list_resource_templates_sync correctly passes pagination token and returns next cursor.""" + mock_template = ResourceTemplate( + uriTemplate="file://documents/{name}", + name="document_template", + description="Template for documents", + mimeType="text/plain", + ) + mock_session.list_resource_templates.return_value = ListResourceTemplatesResult( + resourceTemplates=[mock_template], nextCursor="next_page" + ) + + with MCPClient(mock_transport["transport_callable"]) as client: + result = client.list_resource_templates_sync(pagination_token="current_page") + + mock_session.list_resource_templates.assert_called_once_with(cursor="current_page") + assert len(result.resourceTemplates) == 1 + assert result.resourceTemplates[0].name == "document_template" + assert result.nextCursor == "next_page" + + +def test_list_resource_templates_sync_session_not_active(): + """Test that list_resource_templates_sync raises an error when session is not active.""" + client = MCPClient(MagicMock()) + + with pytest.raises(MCPClientInitializationError, match="client session is not running"): + client.list_resource_templates_sync() + + +@pytest.mark.asyncio +async def test_handle_error_message_with_percent_in_message(): + """Test that _handle_error_message handles messages containing % characters without string formatting errors. + + This is a regression test for issue #1244 where MCP error messages containing '%' characters + (e.g., from URLs like "https://example.com/path?param=value%20encoded") would cause a + TypeError: not all arguments converted during string formatting. + """ + client = MCPClient(MagicMock()) + + # Test with a message that contains % characters (like URL-encoded strings) + # This simulates the error that occurs when MCP servers return messages with % in them + error_with_percent = Exception("unknown request id: abc%20123%30def") + + # This should not raise TypeError and should not raise the exception (since it's non-fatal) + await client._handle_error_message(error_with_percent) + + +def test_call_tool_sync_elicitation_error(mock_transport, mock_session): + """Test that call_tool_sync correctly handles elicitation required errors.""" + from mcp.shared.exceptions import McpError + from mcp.types import ElicitationRequiredErrorData, ElicitRequestURLParams + + elicitation_data = ElicitationRequiredErrorData( + elicitations=[ + ElicitRequestURLParams( + url="https://example.com/auth", message="Please authorize the application", elicitationId="elicit-123" + ) + ] + ) + + error = McpError(error=MagicMock(code=-32042, data=elicitation_data.model_dump())) + mock_session.call_tool.side_effect = error + + with MCPClient(mock_transport["transport_callable"]) as client: + result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={"param": "value"}) + + assert result["status"] == "error" + assert result["toolUseId"] == "test-123" + assert len(result["content"]) == 1 + assert "MCP Elicitation required" in result["content"][0]["text"] + assert "https://example.com/auth" in result["content"][0]["text"] + assert "Please authorize the application" in result["content"][0]["text"] + assert "elicit-123" in result["content"][0]["text"] + + +def test_call_tool_sync_elicitation_error_multiple_urls(mock_transport, mock_session): + """Test that call_tool_sync correctly handles elicitation errors with multiple elicitations.""" + from mcp.shared.exceptions import McpError + from mcp.types import ElicitationRequiredErrorData, ElicitRequestURLParams + + elicitation_data = ElicitationRequiredErrorData( + elicitations=[ + ElicitRequestURLParams( + url="https://example.com/auth1", message="First authorization", elicitationId="elicit-1" + ), + ElicitRequestURLParams( + url="https://example.com/auth2", message="Second authorization", elicitationId="elicit-2" + ), + ] + ) + + error = McpError(error=MagicMock(code=-32042, data=elicitation_data.model_dump())) + mock_session.call_tool.side_effect = error + + with MCPClient(mock_transport["transport_callable"]) as client: + result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={"param": "value"}) + + assert result["status"] == "error" + assert result["toolUseId"] == "test-123" + assert len(result["content"]) == 1 + assert "MCP Elicitation required" in result["content"][0]["text"] + assert "https://example.com/auth1" in result["content"][0]["text"] + assert "https://example.com/auth2" in result["content"][0]["text"] + assert "First authorization" in result["content"][0]["text"] + assert "Second authorization" in result["content"][0]["text"] + assert "elicit-1" in result["content"][0]["text"] + assert "elicit-2" in result["content"][0]["text"] + + +def test_call_tool_sync_elicitation_error_no_urls(mock_transport, mock_session): + """Test that -32042 error with empty URL still returns generic elicitation result.""" + from mcp.shared.exceptions import McpError + from mcp.types import ElicitationRequiredErrorData, ElicitRequestURLParams + + elicitation_data = ElicitationRequiredErrorData( + elicitations=[ElicitRequestURLParams(url="", message="No URL provided", elicitationId="elicit-1")] + ) + error = McpError(error=MagicMock(code=-32042, data=elicitation_data.model_dump())) + mock_session.call_tool.side_effect = error + + with MCPClient(mock_transport["transport_callable"]) as client: + result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={}) + assert result["status"] == "error" + assert "MCP Elicitation required" in result["content"][0]["text"] + assert "elicit-1" in result["content"][0]["text"] + assert "No URL provided" in result["content"][0]["text"] + + +def test_call_tool_sync_other_mcp_error_code(mock_transport, mock_session): + """Test that non-32042 McpError falls through to generic error.""" + from mcp.shared.exceptions import McpError + + error = McpError(error=MagicMock(code=-32600, message="Invalid request")) + mock_session.call_tool.side_effect = error + + with MCPClient(mock_transport["transport_callable"]) as client: + result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={}) + assert result["status"] == "error" + assert "Tool execution failed" in result["content"][0]["text"] + + +def test_call_tool_sync_elicitation_error_malformed_data(mock_transport, mock_session): + """Test that -32042 with unparseable data falls through to generic error.""" + from mcp.shared.exceptions import McpError + + error = McpError(error=MagicMock(code=-32042, data={"garbage": True})) + mock_session.call_tool.side_effect = error + + with MCPClient(mock_transport["transport_callable"]) as client: + result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={}) + assert result["status"] == "error" + assert "Tool execution failed" in result["content"][0]["text"] + + +@pytest.mark.asyncio +async def test_call_tool_async_elicitation_error(mock_transport, mock_session): + """Test that call_tool_async correctly handles elicitation required errors.""" + from mcp.shared.exceptions import McpError + from mcp.types import ElicitationRequiredErrorData, ElicitRequestURLParams + + elicitation_data = ElicitationRequiredErrorData( + elicitations=[ + ElicitRequestURLParams( + url="https://example.com/auth", message="Please authorize the application", elicitationId="elicit-123" + ) + ] + ) + + error = McpError(error=MagicMock(code=-32042, data=elicitation_data.model_dump())) + + with MCPClient(mock_transport["transport_callable"]) as client: + with ( + patch("asyncio.run_coroutine_threadsafe") as mock_run_coroutine_threadsafe, + patch("asyncio.wrap_future") as mock_wrap_future, + ): + mock_future = MagicMock() + mock_run_coroutine_threadsafe.return_value = mock_future + + async def mock_awaitable(): + raise error + + mock_wrap_future.return_value = mock_awaitable() + + result = await client.call_tool_async( + tool_use_id="test-123", name="test_tool", arguments={"param": "value"} + ) + + assert result["status"] == "error" + assert result["toolUseId"] == "test-123" + assert len(result["content"]) == 1 + assert "MCP Elicitation required" in result["content"][0]["text"] + assert "https://example.com/auth" in result["content"][0]["text"] + assert "Please authorize the application" in result["content"][0]["text"] + assert "elicit-123" in result["content"][0]["text"] diff --git a/tests/strands/tools/mcp/test_mcp_client_contextvar.py b/tests/strands/tools/mcp/test_mcp_client_contextvar.py new file mode 100644 index 000000000..1770a050a --- /dev/null +++ b/tests/strands/tools/mcp/test_mcp_client_contextvar.py @@ -0,0 +1,126 @@ +"""Test for MCP client context variable propagation. + +This test verifies that context variables set in the main thread are +properly propagated to the MCP client's background thread. + +Related: https://github.com/strands-agents/sdk-python/issues/1440 +""" + +import contextvars +import threading +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from strands.tools.mcp import MCPClient + + +@pytest.fixture +def mock_transport(): + """Create mock MCP transport.""" + mock_read_stream = AsyncMock() + mock_write_stream = AsyncMock() + mock_transport_cm = AsyncMock() + mock_transport_cm.__aenter__.return_value = (mock_read_stream, mock_write_stream) + mock_transport_callable = MagicMock(return_value=mock_transport_cm) + + return { + "read_stream": mock_read_stream, + "write_stream": mock_write_stream, + "transport_cm": mock_transport_cm, + "transport_callable": mock_transport_callable, + } + + +@pytest.fixture +def mock_session(): + """Create mock MCP session.""" + mock_session = AsyncMock() + mock_session.initialize = AsyncMock() + # get_server_capabilities is sync, not async + mock_session.get_server_capabilities = MagicMock(return_value=None) + + mock_session_cm = AsyncMock() + mock_session_cm.__aenter__.return_value = mock_session + + with patch("strands.tools.mcp.mcp_client.ClientSession", return_value=mock_session_cm): + yield mock_session + + +# Context variable for testing +test_contextvar: contextvars.ContextVar[str] = contextvars.ContextVar("test_contextvar", default="default_value") + + +def test_mcp_client_propagates_contextvars_to_background_thread(mock_transport, mock_session): + """Test that context variables are propagated to the MCP client background thread. + + This verifies the fix for https://github.com/strands-agents/sdk-python/issues/1440 + where context variables set in the main thread were not accessible in the + MCP client's background thread. + """ + # Store the value seen in the background thread + background_thread_value = {} + + # Patch _background_task to capture the contextvar value + original_background_task = MCPClient._background_task + + def capturing_background_task(self): + # Capture the contextvar value in the background thread + background_thread_value["contextvar"] = test_contextvar.get() + background_thread_value["thread_id"] = threading.current_thread().ident + # Call the original background task + return original_background_task(self) + + # Set a specific value in the main thread + test_contextvar.set("main_thread_value") + main_thread_id = threading.current_thread().ident + + with patch.object(MCPClient, "_background_task", capturing_background_task): + with MCPClient(mock_transport["transport_callable"]) as client: + # Verify the client started successfully + assert client._background_thread is not None + + # Verify context was propagated to background thread + assert "contextvar" in background_thread_value, "Background task should have run and captured contextvar" + assert background_thread_value["contextvar"] == "main_thread_value", ( + f"Context variable should be propagated to background thread. " + f"Expected 'main_thread_value', got '{background_thread_value['contextvar']}'" + ) + # Verify it was indeed a different thread + assert background_thread_value["thread_id"] != main_thread_id, "Background task should run in a different thread" + + +def test_mcp_client_clears_running_loop_in_background_thread(mock_transport, mock_session): + """Test that _background_task clears any leaked running event loop state. + + When OpenTelemetry's ThreadingInstrumentor is active, Thread.run() is wrapped to propagate + trace context, which can leak the parent thread's running event loop reference into child + threads. This causes "RuntimeError: Cannot run the event loop while another loop is running" + when the background thread calls run_until_complete(). + + This test simulates that scenario by setting a running loop before _background_task runs + and verifying it gets cleared. + """ + import asyncio + + cleared_running_loop = {} + + original_background_task = MCPClient._background_task + + def simulating_otel_leak_background_task(self): + # Simulate OTEL ThreadingInstrumentor leaking the parent's running loop + fake_loop = asyncio.new_event_loop() + asyncio._set_running_loop(fake_loop) # type: ignore[attr-defined] + + # Call the real _background_task — it should clear the leaked loop and succeed + try: + return original_background_task(self) + finally: + cleared_running_loop["success"] = True + fake_loop.close() + + with patch.object(MCPClient, "_background_task", simulating_otel_leak_background_task): + with MCPClient(mock_transport["transport_callable"]) as client: + assert client._background_thread is not None + + assert cleared_running_loop.get("success"), "_background_task should have run successfully despite leaked loop" diff --git a/tests/strands/tools/mcp/test_mcp_client_tasks.py b/tests/strands/tools/mcp/test_mcp_client_tasks.py new file mode 100644 index 000000000..d566ac6f5 --- /dev/null +++ b/tests/strands/tools/mcp/test_mcp_client_tasks.py @@ -0,0 +1,286 @@ +"""Tests for MCP task-augmented execution support in MCPClient.""" + +import asyncio +from datetime import timedelta +from unittest.mock import AsyncMock, MagicMock + +import pytest +from mcp import ListToolsResult +from mcp.types import CallToolResult as MCPCallToolResult +from mcp.types import TextContent as MCPTextContent +from mcp.types import Tool as MCPTool +from mcp.types import ToolExecution + +from strands.tools.mcp import MCPClient, TasksConfig +from strands.tools.mcp.mcp_tasks import DEFAULT_TASK_POLL_TIMEOUT, DEFAULT_TASK_TTL + +from .conftest import create_server_capabilities + + +class TestTasksOptIn: + """Tests for task opt-in behavior via tasks config.""" + + @pytest.mark.parametrize( + "tasks_config,expected_enabled", + [ + (None, False), + ({}, True), + ], + ) + def test_tasks_enabled_state(self, mock_transport, mock_session, tasks_config, expected_enabled): + """Test _is_tasks_enabled based on tasks config.""" + with MCPClient(mock_transport["transport_callable"], tasks_config=tasks_config) as client: + assert client._is_tasks_enabled() is expected_enabled + + def test_should_use_task_requires_opt_in(self, mock_transport, mock_session): + """Test that _should_use_task returns False without opt-in even with server/tool support.""" + with MCPClient(mock_transport["transport_callable"]) as client: + client._server_task_capable = True + assert client._should_use_task("test_tool") is False + + with MCPClient(mock_transport["transport_callable"], tasks_config={}) as client: + client._server_task_capable = True + client._tool_task_support_cache["test_tool"] = "required" + assert client._should_use_task("test_tool") is True + + +class TestTaskConfiguration: + """Tests for task-related configuration options.""" + + @pytest.mark.parametrize( + "config,expected_ttl,expected_timeout", + [ + ({}, DEFAULT_TASK_TTL, DEFAULT_TASK_POLL_TIMEOUT), + ({"ttl": timedelta(seconds=120)}, timedelta(seconds=120), DEFAULT_TASK_POLL_TIMEOUT), + ({"poll_timeout": timedelta(seconds=60)}, DEFAULT_TASK_TTL, timedelta(seconds=60)), + ( + {"ttl": timedelta(seconds=120), "poll_timeout": timedelta(seconds=60)}, + timedelta(seconds=120), + timedelta(seconds=60), + ), + ], + ) + def test_task_config_values(self, mock_transport, mock_session, config, expected_ttl, expected_timeout): + """Test task configuration values with various configs.""" + with MCPClient(mock_transport["transport_callable"], tasks_config=config) as client: + config_actual = client._get_task_config() + assert config_actual.get("ttl") == expected_ttl + assert config_actual.get("poll_timeout") == expected_timeout + + def test_stop_resets_task_caches(self, mock_transport, mock_session): + """Test that stop() resets the task support caches.""" + with MCPClient(mock_transport["transport_callable"], tasks_config={}) as client: + client._server_task_capable = True + client._tool_task_support_cache["tool1"] = "required" + assert client._server_task_capable is None + assert client._tool_task_support_cache == {} + + +class TestTaskExecution: + """Tests for task execution and error handling.""" + + def _setup_task_tool(self, mock_session, tool_name: str) -> None: + """Helper to set up a mock task-enabled tool.""" + mock_session.get_server_capabilities = MagicMock(return_value=create_server_capabilities(True)) + mock_tool = MCPTool( + name=tool_name, + description="A test tool", + inputSchema={"type": "object"}, + execution=ToolExecution(taskSupport="optional"), + ) + mock_session.list_tools = AsyncMock(return_value=ListToolsResult(tools=[mock_tool], nextCursor=None)) + mock_create_result = MagicMock() + mock_create_result.task.taskId = "test-task-id" + mock_session.experimental = MagicMock() + mock_session.experimental.call_tool_as_task = AsyncMock(return_value=mock_create_result) + + @pytest.mark.parametrize( + "status,status_message,expected_text", + [ + ("failed", "Something went wrong", "Something went wrong"), + ("cancelled", None, "cancelled"), + ("unknown_status", None, "unexpected task status"), + ], + ) + def test_terminal_status_handling(self, mock_transport, mock_session, status, status_message, expected_text): + """Test handling of terminal task statuses.""" + mock_create_result = MagicMock() + mock_create_result.task.taskId = f"task-{status}" + mock_session.experimental.call_tool_as_task = AsyncMock(return_value=mock_create_result) + + async def mock_poll_task(task_id): + yield MagicMock(status=status, statusMessage=status_message) + + mock_session.experimental.poll_task = mock_poll_task + + with MCPClient(mock_transport["transport_callable"], tasks_config=TasksConfig()) as client: + client._server_task_capable = True + client._tool_task_support_cache["test_tool"] = "required" + result = client.call_tool_sync(tool_use_id="test-id", name="test_tool", arguments={}) + assert result["status"] == "error" + assert expected_text.lower() in result["content"][0].get("text", "").lower() + + @pytest.mark.asyncio + async def test_polling_timeout(self, mock_transport, mock_session): + """Test that task polling times out properly.""" + self._setup_task_tool(mock_session, "slow_tool") + + async def infinite_poll(task_id): + while True: + await asyncio.sleep(1) + yield MagicMock(status="running") + + mock_session.experimental.poll_task = infinite_poll + + with MCPClient( + mock_transport["transport_callable"], tasks_config=TasksConfig(poll_timeout=timedelta(seconds=0.1)) + ) as client: + client.list_tools_sync() + result = await client.call_tool_async(tool_use_id="t", name="slow_tool", arguments={}) + assert result["status"] == "error" + assert "timed out" in result["content"][0].get("text", "").lower() + + @pytest.mark.asyncio + async def test_explicit_timeout_overrides_default(self, mock_transport, mock_session): + """Test that read_timeout_seconds overrides the default poll timeout.""" + self._setup_task_tool(mock_session, "timeout_tool") + + async def infinite_poll(task_id): + while True: + await asyncio.sleep(1) + yield MagicMock(status="running") + + mock_session.experimental.poll_task = infinite_poll + + with MCPClient( + mock_transport["transport_callable"], tasks_config=TasksConfig(poll_timeout=timedelta(minutes=5)) + ) as client: + client.list_tools_sync() + result = await client.call_tool_async( + tool_use_id="t", name="timeout_tool", arguments={}, read_timeout_seconds=timedelta(seconds=0.1) + ) + assert result["status"] == "error" + assert "timed out" in result["content"][0].get("text", "").lower() + + @pytest.mark.asyncio + async def test_result_retrieval_failure(self, mock_transport, mock_session): + """Test that get_task_result failures are handled gracefully.""" + self._setup_task_tool(mock_session, "failing_tool") + + async def successful_poll(task_id): + yield MagicMock(status="completed", statusMessage=None) + + mock_session.experimental.poll_task = successful_poll + mock_session.experimental.get_task_result = AsyncMock(side_effect=Exception("Network error")) + + with MCPClient(mock_transport["transport_callable"], tasks_config=TasksConfig()) as client: + client.list_tools_sync() + result = await client.call_tool_async(tool_use_id="t", name="failing_tool", arguments={}) + assert result["status"] == "error" + assert "result retrieval failed" in result["content"][0].get("text", "").lower() + + @pytest.mark.asyncio + async def test_empty_poll_result(self, mock_transport, mock_session): + """Test handling when poll_task yields nothing.""" + self._setup_task_tool(mock_session, "empty_poll_tool") + + async def empty_poll(task_id): + return + yield # noqa: B901 + + mock_session.experimental.poll_task = empty_poll + + with MCPClient(mock_transport["transport_callable"], tasks_config=TasksConfig()) as client: + client.list_tools_sync() + result = await client.call_tool_async(tool_use_id="t", name="empty_poll_tool", arguments={}) + assert result["status"] == "error" + assert "without status" in result["content"][0].get("text", "").lower() + + @pytest.mark.asyncio + async def test_successful_completion(self, mock_transport, mock_session): + """Test successful task completion.""" + self._setup_task_tool(mock_session, "success_tool") + + async def poll(task_id): + yield MagicMock(status="completed", statusMessage=None) + + mock_session.experimental.poll_task = poll + mock_session.experimental.get_task_result = AsyncMock( + return_value=MCPCallToolResult(content=[MCPTextContent(type="text", text="Done")], isError=False) + ) + + with MCPClient(mock_transport["transport_callable"], tasks_config=TasksConfig()) as client: + client.list_tools_sync() + result = await client.call_tool_async(tool_use_id="t", name="success_tool", arguments={}) + assert result["status"] == "success" + assert "Done" in result["content"][0].get("text", "") + + +class TestTaskMetaForwarding: + """Tests for meta parameter forwarding in task-augmented execution.""" + + def _setup_task_tool_with_meta(self, mock_session, tool_name: str) -> MagicMock: + """Helper to set up a mock task-enabled tool and return the experimental mock.""" + mock_session.get_server_capabilities = MagicMock(return_value=create_server_capabilities(True)) + mock_tool = MCPTool( + name=tool_name, + description="A test tool", + inputSchema={"type": "object"}, + execution=ToolExecution(taskSupport="optional"), + ) + mock_session.list_tools = AsyncMock(return_value=ListToolsResult(tools=[mock_tool], nextCursor=None)) + mock_create_result = MagicMock() + mock_create_result.task.taskId = "test-task-id" + mock_session.experimental = MagicMock() + mock_session.experimental.call_tool_as_task = AsyncMock(return_value=mock_create_result) + + async def successful_poll(task_id): + yield MagicMock(status="completed", statusMessage=None) + + mock_session.experimental.poll_task = successful_poll + mock_session.experimental.get_task_result = AsyncMock( + return_value=MCPCallToolResult(content=[MCPTextContent(type="text", text="Done")], isError=False) + ) + + return mock_session.experimental + + def test_call_tool_sync_forwards_meta_to_task(self, mock_transport, mock_session): + """Test that call_tool_sync forwards meta to call_tool_as_task.""" + experimental = self._setup_task_tool_with_meta(mock_session, "meta_tool") + meta = {"com.example/request_id": "abc-123"} + + with MCPClient(mock_transport["transport_callable"], tasks_config=TasksConfig()) as client: + client.list_tools_sync() + client.call_tool_sync(tool_use_id="test-id", name="meta_tool", arguments={"param": "value"}, meta=meta) + + experimental.call_tool_as_task.assert_called_once() + call_kwargs = experimental.call_tool_as_task.call_args + assert call_kwargs.kwargs.get("meta") == meta + + @pytest.mark.asyncio + async def test_call_tool_async_forwards_meta_to_task(self, mock_transport, mock_session): + """Test that call_tool_async forwards meta to call_tool_as_task.""" + experimental = self._setup_task_tool_with_meta(mock_session, "meta_tool") + meta = {"com.example/trace_id": "xyz-456"} + + with MCPClient(mock_transport["transport_callable"], tasks_config=TasksConfig()) as client: + client.list_tools_sync() + await client.call_tool_async( + tool_use_id="test-id", name="meta_tool", arguments={"param": "value"}, meta=meta + ) + + experimental.call_tool_as_task.assert_called_once() + call_kwargs = experimental.call_tool_as_task.call_args + assert call_kwargs.kwargs.get("meta") == meta + + def test_call_tool_sync_forwards_none_meta_to_task(self, mock_transport, mock_session): + """Test that call_tool_sync forwards None meta to call_tool_as_task when not provided.""" + experimental = self._setup_task_tool_with_meta(mock_session, "no_meta_tool") + + with MCPClient(mock_transport["transport_callable"], tasks_config=TasksConfig()) as client: + client.list_tools_sync() + client.call_tool_sync(tool_use_id="test-id", name="no_meta_tool", arguments={"param": "value"}) + + experimental.call_tool_as_task.assert_called_once() + call_kwargs = experimental.call_tool_as_task.call_args + assert call_kwargs.kwargs.get("meta") is None diff --git a/tests/strands/tools/mcp/test_mcp_instrumentation.py b/tests/strands/tools/mcp/test_mcp_instrumentation.py index 85d533403..9d44bba0c 100644 --- a/tests/strands/tools/mcp/test_mcp_instrumentation.py +++ b/tests/strands/tools/mcp/test_mcp_instrumentation.py @@ -328,7 +328,7 @@ class MockPydanticParams: def __init__(self, **data): self._data = data - def model_dump(self): + def model_dump(self, by_alias=False): return self._data.copy() @classmethod @@ -431,6 +431,32 @@ def test_patch_mcp_client_injects_context_pydantic_model(self): # Verify the params object is still a MockPydanticParams (or dict if fallback occurred) assert hasattr(mock_request.root.params, "model_dump") or isinstance(mock_request.root.params, dict) + def test_patch_mcp_client_preserves_existing_meta_pydantic(self): + """Test that instrumentation preserves existing _meta values in Pydantic models.""" + mock_request = MagicMock() + mock_request.root.method = "tools/call" + + # Pydantic model with existing _meta (returned via by_alias=True) + mock_params = MockPydanticParams(_meta={"com.example/request_id": "abc-123"}, name="echo") + mock_request.root.params = mock_params + + with patch("strands.tools.mcp.mcp_instrumentation.wrap_function_wrapper") as mock_wrap: + mcp_instrumentation() + patch_function = mock_wrap.call_args_list[0][0][2] + + mock_wrapped = MagicMock() + + with patch.object(propagate, "get_global_textmap") as mock_textmap: + mock_textmap_instance = MagicMock() + mock_textmap.return_value = mock_textmap_instance + + patch_function(mock_wrapped, None, [mock_request], {}) + + # Verify the reconstructed params use the key "_meta" (alias) not "meta" (Python name) + validated_params = mock_request.root.params.model_dump(by_alias=True) + assert "_meta" in validated_params + assert validated_params["_meta"]["com.example/request_id"] == "abc-123" + def test_patch_mcp_client_injects_context_dict_params(self): """Test that the client patch injects OpenTelemetry context into dict params.""" # Create a mock request with tools/call method and dict params @@ -507,7 +533,7 @@ class FailingMockPydanticParams: def __init__(self, **data): self._data = data - def model_dump(self): + def model_dump(self, by_alias=False): return self._data.copy() def model_validate(self, data): diff --git a/tests/strands/tools/structured_output/test_structured_output_context.py b/tests/strands/tools/structured_output/test_structured_output_context.py index a7eb27ca5..6d75852d1 100644 --- a/tests/strands/tools/structured_output/test_structured_output_context.py +++ b/tests/strands/tools/structured_output/test_structured_output_context.py @@ -1,10 +1,11 @@ """Tests for StructuredOutputContext class.""" -from typing import Optional - from pydantic import BaseModel, Field -from strands.tools.structured_output._structured_output_context import StructuredOutputContext +from strands.tools.structured_output._structured_output_context import ( + DEFAULT_STRUCTURED_OUTPUT_PROMPT, + StructuredOutputContext, +) from strands.tools.structured_output.structured_output_tool import StructuredOutputTool @@ -13,7 +14,7 @@ class SampleModel(BaseModel): name: str = Field(..., description="Name field") age: int = Field(..., description="Age field", ge=0) - email: Optional[str] = Field(None, description="Optional email field") + email: str | None = Field(None, description="Optional email field") class AnotherSampleModel(BaseModel): @@ -37,6 +38,7 @@ def test_initialization_with_structured_output_model(self): assert context.forced_mode is False assert context.tool_choice is None assert context.stop_loop is False + assert context.structured_output_prompt == DEFAULT_STRUCTURED_OUTPUT_PROMPT def test_initialization_without_structured_output_model(self): """Test initialization without a structured output model.""" @@ -49,6 +51,31 @@ def test_initialization_without_structured_output_model(self): assert context.forced_mode is False assert context.tool_choice is None assert context.stop_loop is False + assert context.structured_output_prompt == DEFAULT_STRUCTURED_OUTPUT_PROMPT + + def test_initialization_with_custom_prompt(self): + """Test initialization with a custom structured output prompt.""" + custom_prompt = "Please format your response using the output schema." + context = StructuredOutputContext( + structured_output_model=SampleModel, + structured_output_prompt=custom_prompt, + ) + + assert context.structured_output_model == SampleModel + assert context.structured_output_prompt == custom_prompt + + def test_initialization_with_none_prompt_uses_default(self): + """Test that None prompt falls back to default.""" + context = StructuredOutputContext( + structured_output_model=SampleModel, + structured_output_prompt=None, + ) + + assert context.structured_output_prompt == DEFAULT_STRUCTURED_OUTPUT_PROMPT + + def test_default_prompt_constant_value(self): + """Test the default prompt constant has expected value.""" + assert DEFAULT_STRUCTURED_OUTPUT_PROMPT == "You must format the previous response as structured output." def test_is_enabled_property(self): """Test the is_enabled property.""" diff --git a/tests/strands/tools/structured_output/test_structured_output_tool.py b/tests/strands/tools/structured_output/test_structured_output_tool.py index 66f1d465d..784a508bd 100644 --- a/tests/strands/tools/structured_output/test_structured_output_tool.py +++ b/tests/strands/tools/structured_output/test_structured_output_tool.py @@ -1,6 +1,5 @@ """Tests for StructuredOutputTool class.""" -from typing import List, Optional from unittest.mock import MagicMock import pytest @@ -23,8 +22,8 @@ class ComplexModel(BaseModel): title: str = Field(..., description="Title field") count: int = Field(..., ge=0, le=100, description="Count between 0 and 100") - tags: List[str] = Field(default_factory=list, description="List of tags") - metadata: Optional[dict] = Field(None, description="Optional metadata") + tags: list[str] = Field(default_factory=list, description="List of tags") + metadata: dict | None = Field(None, description="Optional metadata") class ValidationTestModel(BaseModel): diff --git a/tests/strands/tools/test_caller.py b/tests/strands/tools/test_caller.py index 18de6d3f0..2658af6b4 100644 --- a/tests/strands/tools/test_caller.py +++ b/tests/strands/tools/test_caller.py @@ -1,8 +1,11 @@ +import gc import unittest.mock +import weakref import pytest from strands import Agent, tool +from strands.tools.tool_provider import ToolProvider @pytest.fixture @@ -312,3 +315,122 @@ def test_agent_tool_caller_interrupt_activated(): exp_message = r"cannot directly call tool during interrupt" with pytest.raises(RuntimeError, match=exp_message): agent.tool.test_tool() + + +def test_agent_collected_without_cyclic_gc(): + """Verify that Agent is promptly collectable (no persistent reference cycle). + + This ensures that the weakref-based back-references in _ToolCaller and _PluginRegistry + do not create reference cycles that would delay cleanup until interpreter shutdown. + When cleanup is deferred to interpreter shutdown, MCPClient.stop() hangs because its + background thread cannot complete async cleanup at that point. + + Note: On some platforms/versions (e.g. Python 3.14 with deferred refcounting), del may + not immediately trigger collection. A single gc.collect() is allowed as a fallback since + it still proves no persistent cycle exists — the agent is collected promptly, not deferred + to interpreter shutdown. + """ + gc.disable() + try: + agent = Agent() + ref = weakref.ref(agent) + del agent + + if ref() is not None: + # Deferred refcounting (Python 3.14+) may not collect immediately on del; + # a single gc.collect() should still reclaim it since there are no cycles. + gc.collect() + + assert ref() is None, "Agent was not collected; a reference cycle likely exists" + finally: + gc.enable() + + +class _MockToolProvider(ToolProvider): + """Minimal ToolProvider that tracks cleanup calls, mimicking MCPClient lifecycle.""" + + def __init__(self): + self.consumers: set = set() + self.cleanup_called = False + + async def load_tools(self, **kwargs): + return [] + + def add_consumer(self, consumer_id, **kwargs): + self.consumers.add(consumer_id) + + def remove_consumer(self, consumer_id, **kwargs): + self.consumers.discard(consumer_id) + if not self.consumers: + self.cleanup_called = True + + +def test_agent_with_tool_provider_cleaned_up_when_function_returns(): + """Replicate the hang from issue #1732: Agent with MCPClient created inside a function. + + When an Agent using a managed MCPClient (as ToolProvider) is created inside a function, + the script used to hang on exit. The Agent went out of scope when the function returned, + but circular references (Agent → _ToolCaller._agent → Agent) prevented refcount-based + destruction. Cleanup was deferred to the cyclic GC during interpreter shutdown, where + MCPClient.stop() → thread.join() would hang. + + This test verifies that with the weakref fix, the Agent is destroyed immediately when + the function returns, and the tool provider's cleanup runs promptly. + """ + provider = _MockToolProvider() + + def get_agent(): + return Agent(tools=[provider]) + + def main(): + agent = get_agent() # noqa: F841 + + gc.disable() + try: + main() + + if not provider.cleanup_called: + # Deferred refcounting (Python 3.14+) may not collect immediately on scope exit; + # a single gc.collect() should still reclaim it since there are no cycles. + gc.collect() + + assert provider.cleanup_called, ( + "Tool provider was not cleaned up when the function returned; Agent likely leaked due to a reference cycle" + ) + finally: + gc.enable() + + +def test_agent_with_tool_provider_cleaned_up_on_del(): + """Replicate the working case from issue #1732: Agent at module scope, explicitly deleted. + + In the issue, an Agent created at module level did not hang because module-level variables + are cleared early during interpreter shutdown (while the runtime is still functional). + This test verifies the equivalent: explicitly deleting the agent triggers immediate cleanup. + """ + provider = _MockToolProvider() + + agent = Agent(tools=[provider]) + assert not provider.cleanup_called + + del agent + + if not provider.cleanup_called: + # Deferred refcounting (Python 3.14+) may not collect immediately on del; + # a single gc.collect() should still reclaim it since there are no cycles. + gc.collect() + + assert provider.cleanup_called, "Tool provider was not cleaned up after del agent" + + +def test_tool_caller_raises_reference_error_after_agent_collected(): + """Verify _ToolCaller raises ReferenceError when the Agent has been garbage collected.""" + agent = Agent() + caller = agent.tool_caller + # Clear the weak reference by replacing it directly + caller._agent_ref = weakref.ref(agent) + del agent + gc.collect() + + with pytest.raises(ReferenceError, match="Agent has been garbage collected"): + _ = caller._agent diff --git a/tests/strands/tools/test_decorator.py b/tests/strands/tools/test_decorator.py index a2a4c6213..cc1158983 100644 --- a/tests/strands/tools/test_decorator.py +++ b/tests/strands/tools/test_decorator.py @@ -3,7 +3,8 @@ """ from asyncio import Queue -from typing import Annotated, Any, AsyncGenerator, Dict, List, Optional, Union +from collections.abc import AsyncGenerator +from typing import Annotated, Any from unittest.mock import MagicMock import pytest @@ -135,7 +136,7 @@ def identity(a: int, agent: dict = None): tru_events = await alist(stream) exp_events = [ - ToolResultEvent({"toolUseId": "unknown", "status": "success", "content": [{"text": "(2, {'state': 1})"}]}) + ToolResultEvent({"toolUseId": "unknown", "status": "success", "content": [{"text": '[2, {"state": 1}]'}]}) ] assert tru_events == exp_events @@ -267,7 +268,7 @@ async def test_tool_with_optional_params(alist): """Test tool decorator with optional parameters.""" @strands.tool - def test_tool(required: str, optional: Optional[int] = None) -> str: + def test_tool(required: str, optional: int | None = None) -> str: """Test with optional param. Args: @@ -594,12 +595,12 @@ def none_return_tool(param: str) -> None: assert result["tool_result"]["status"] == "success" assert result["tool_result"]["content"][0]["text"] == "Result: test" - # Test None return - should still create valid ToolResult with "None" text + # Test None return - should still create valid ToolResult with "null" stream = none_return_tool.stream(tool_use, {}) result = (await alist(stream))[-1] assert result["tool_result"]["status"] == "success" - assert result["tool_result"]["content"][0]["text"] == "None" + assert result["tool_result"]["content"][0]["text"] == "null" @pytest.mark.asyncio @@ -860,11 +861,11 @@ def int_return_tool(param: str) -> int: result = (await alist(stream))[-1] assert result["tool_result"]["status"] == "success" - assert result["tool_result"]["content"][0]["text"] == "None" + assert result["tool_result"]["content"][0]["text"] == "null" # Define tool with Union return type @strands.tool - def union_return_tool(param: str) -> Union[Dict[str, Any], str, None]: + def union_return_tool(param: str) -> dict[str, Any] | str | None: """Tool with Union return type. Args: @@ -883,10 +884,7 @@ def union_return_tool(param: str) -> Union[Dict[str, Any], str, None]: result = (await alist(stream))[-1] assert result["tool_result"]["status"] == "success" - assert ( - "{'key': 'value'}" in result["tool_result"]["content"][0]["text"] - or '{"key": "value"}' in result["tool_result"]["content"][0]["text"] - ) + assert result["tool_result"]["content"][0]["text"] == '{"key": "value"}' tool_use = {"toolUseId": "test-id", "input": {"param": "str"}} stream = union_return_tool.stream(tool_use, {}) @@ -900,7 +898,7 @@ def union_return_tool(param: str) -> Union[Dict[str, Any], str, None]: result = (await alist(stream))[-1] assert result["tool_result"]["status"] == "success" - assert result["tool_result"]["content"][0]["text"] == "None" + assert result["tool_result"]["content"][0]["text"] == "null" @pytest.mark.asyncio @@ -936,7 +934,7 @@ async def test_complex_parameter_types(alist): """Test handling of complex parameter types like nested dictionaries.""" @strands.tool - def complex_type_tool(config: Dict[str, Any]) -> str: + def complex_type_tool(config: dict[str, Any]) -> str: """Tool with complex parameter type. Args: @@ -965,7 +963,7 @@ async def test_custom_tool_result_handling(alist): """Test that a function returning a properly formatted tool result dictionary is handled correctly.""" @strands.tool - def custom_result_tool(param: str) -> Dict[str, Any]: + def custom_result_tool(param: str) -> dict[str, Any]: """Tool that returns a custom tool result dictionary. Args: @@ -991,6 +989,132 @@ def custom_result_tool(param: str) -> Dict[str, Any]: assert result["tool_result"]["content"][1]["type"] == "markdown" +@pytest.mark.asyncio +async def test_tool_result_json_serialization_dict(alist): + """Test that dict results are serialized as JSON.""" + + @strands.tool + def dict_tool() -> dict: + """Returns a dict.""" + return {"key": "value", "number": 42} + + tool_use = {"toolUseId": "test-id", "input": {}} + stream = dict_tool.stream(tool_use, {}) + + result = (await alist(stream))[-1] + text = result["tool_result"]["content"][0]["text"] + + assert text == '{"key": "value", "number": 42}' + + +@pytest.mark.asyncio +async def test_tool_result_json_serialization_list(alist): + """Test that list results are serialized as JSON.""" + + @strands.tool + def list_tool() -> list: + """Returns a list.""" + return [1, "two", {"three": 3}] + + tool_use = {"toolUseId": "test-id", "input": {}} + stream = list_tool.stream(tool_use, {}) + + result = (await alist(stream))[-1] + text = result["tool_result"]["content"][0]["text"] + + assert text == '[1, "two", {"three": 3}]' + + +@pytest.mark.asyncio +async def test_tool_result_json_serialization_pydantic(alist): + """Test that Pydantic model results are serialized as JSON.""" + from pydantic import BaseModel + + class MyModel(BaseModel): + name: str + count: int + + @strands.tool + def pydantic_tool() -> MyModel: + """Returns a Pydantic model.""" + return MyModel(name="test", count=5) + + tool_use = {"toolUseId": "test-id", "input": {}} + stream = pydantic_tool.stream(tool_use, {}) + + result = (await alist(stream))[-1] + text = result["tool_result"]["content"][0]["text"] + + assert text == '{"name":"test","count":5}' + + +@pytest.mark.asyncio +async def test_tool_result_json_serialization_pydantic_non_serializable(alist): + """Test that Pydantic models with non-serializable fields fall back to str().""" + from pydantic import BaseModel + + class NonSerializable: + def __repr__(self): + return "NonSerializable()" + + class MyModel(BaseModel): + model_config = {"arbitrary_types_allowed": True} + data: NonSerializable + + @strands.tool + def pydantic_tool() -> MyModel: + """Returns a Pydantic model with non-serializable field.""" + return MyModel(data=NonSerializable()) + + tool_use = {"toolUseId": "test-id", "input": {}} + stream = pydantic_tool.stream(tool_use, {}) + + result = (await alist(stream))[-1] + text = result["tool_result"]["content"][0]["text"] + + assert text == "data=NonSerializable()" + + +@pytest.mark.asyncio +async def test_tool_result_json_serialization_non_serializable(alist): + """Test that non-JSON-serializable results fall back to str().""" + + class CustomClass: + def __str__(self): + return "custom_str_repr" + + @strands.tool + def custom_tool() -> Any: + """Returns a non-serializable object.""" + return CustomClass() + + tool_use = {"toolUseId": "test-id", "input": {}} + stream = custom_tool.stream(tool_use, {}) + + result = (await alist(stream))[-1] + text = result["tool_result"]["content"][0]["text"] + + assert text == "custom_str_repr" + + +@pytest.mark.asyncio +async def test_tool_result_string_not_json_encoded(alist): + """Test that string results are NOT JSON-encoded (no extra quotes).""" + + @strands.tool + def string_tool() -> str: + """Returns a string.""" + return "hello world" + + tool_use = {"toolUseId": "test-id", "input": {}} + stream = string_tool.stream(tool_use, {}) + + result = (await alist(stream))[-1] + text = result["tool_result"]["content"][0]["text"] + + assert text == "hello world" + + def test_docstring_parsing(): """Test that function docstring is correctly parsed into tool spec.""" @@ -1079,11 +1203,11 @@ def validation_tool(str_param: str, int_param: int, bool_param: bool) -> str: @pytest.mark.asyncio async def test_tool_complex_validation_edge_cases(alist): """Test validation of complex schema edge cases.""" - from typing import Any, Dict, Union + from typing import Any # Define a tool with a complex anyOf type that could trigger edge case handling @strands.tool - def edge_case_tool(param: Union[Dict[str, Any], None]) -> str: + def edge_case_tool(param: dict[str, Any] | None) -> str: """Tool with complex anyOf structure. Args: @@ -1236,10 +1360,10 @@ def failing_tool(param: str) -> str: @pytest.mark.asyncio async def test_tool_with_complex_anyof_schema(alist): """Test handling of complex anyOf structures in the schema.""" - from typing import Any, Dict, List, Union + from typing import Any @strands.tool - def complex_schema_tool(union_param: Union[List[int], Dict[str, Any], str, None]) -> str: + def complex_schema_tool(union_param: list[int] | dict[str, Any] | str | None) -> str: """Tool with a complex Union type that creates anyOf in schema. Args: @@ -1680,7 +1804,7 @@ def test_tool_decorator_annotated_optional_type(): @strands.tool def optional_annotated_tool( - required: Annotated[str, "Required parameter"], optional: Annotated[Optional[str], "Optional parameter"] = None + required: Annotated[str, "Required parameter"], optional: Annotated[str | None, "Optional parameter"] = None ) -> str: """Tool with optional annotated parameter.""" return f"{required}, {optional}" @@ -1702,7 +1826,7 @@ def test_tool_decorator_annotated_complex_types(): @strands.tool def complex_annotated_tool( - tags: Annotated[List[str], "List of tag strings"], config: Annotated[Dict[str, Any], "Configuration dictionary"] + tags: Annotated[list[str], "List of tag strings"], config: Annotated[dict[str, Any], "Configuration dictionary"] ) -> str: """Tool with complex annotated types.""" return f"Tags: {len(tags)}, Config: {len(config)}" @@ -1822,3 +1946,158 @@ def test_tool_decorator_annotated_field_with_inner_default(): @strands.tool def inner_default_tool(name: str, level: Annotated[int, Field(description="A level value", default=10)]) -> str: return f"{name} is at level {level}" + + +@pytest.mark.asyncio +async def test_tool_result_event_carries_exception_runtime_error(alist): + """Test that ToolResultEvent carries exception when tool raises RuntimeError.""" + + @strands.tool + def error_tool(): + """Tool that raises a RuntimeError.""" + raise RuntimeError("test runtime error") + + tool_use = {"toolUseId": "test-id", "input": {}} + events = await alist(error_tool.stream(tool_use, {})) + + result_event = events[-1] + assert isinstance(result_event, ToolResultEvent) + assert hasattr(result_event, "exception") + assert isinstance(result_event.exception, RuntimeError) + assert str(result_event.exception) == "test runtime error" + assert result_event.tool_result["status"] == "error" + + +@pytest.mark.asyncio +async def test_tool_result_event_carries_exception_value_error(alist): + """Test that ToolResultEvent carries exception when tool raises ValueError.""" + + @strands.tool + def validation_error_tool(): + """Tool that raises a ValueError.""" + raise ValueError("validation failed") + + tool_use = {"toolUseId": "test-id", "input": {}} + events = await alist(validation_error_tool.stream(tool_use, {})) + + result_event = events[-1] + assert isinstance(result_event, ToolResultEvent) + assert hasattr(result_event, "exception") + assert isinstance(result_event.exception, ValueError) + assert str(result_event.exception) == "validation failed" + assert result_event.tool_result["status"] == "error" + + +@pytest.mark.asyncio +async def test_tool_result_event_no_exception_on_success(alist): + """Test that ToolResultEvent.exception is None when tool succeeds.""" + + @strands.tool + def success_tool(): + """Tool that succeeds.""" + return "success" + + tool_use = {"toolUseId": "test-id", "input": {}} + events = await alist(success_tool.stream(tool_use, {})) + + result_event = events[-1] + assert isinstance(result_event, ToolResultEvent) + assert result_event.exception is None + assert result_event.tool_result["status"] == "success" + + +@pytest.mark.asyncio +async def test_tool_result_event_carries_exception_assertion_error(alist): + """Test that ToolResultEvent carries AssertionError for unexpected failures.""" + + @strands.tool + def assertion_error_tool(): + """Tool that raises an AssertionError.""" + raise AssertionError("unexpected assertion failure") + + tool_use = {"toolUseId": "test-id", "input": {}} + events = await alist(assertion_error_tool.stream(tool_use, {})) + + result_event = events[-1] + assert isinstance(result_event, ToolResultEvent) + assert isinstance(result_event.exception, AssertionError) + assert "unexpected assertion failure" in str(result_event.exception) + assert result_event.tool_result["status"] == "error" + + +def test_tool_nullable_required_field_preserves_anyof(): + """Test that a required nullable field preserves anyOf so the model can pass null. + + Regression test for https://github.com/strands-agents/sdk-python/issues/1525 + """ + from enum import Enum + + class Priority(str, Enum): + HIGH = "high" + MEDIUM = "medium" + LOW = "low" + + @strands.tool + def prioritized_task(description: str, priority: Priority | None) -> str: + """Create a task with optional priority. + + Args: + description: Task description + priority: Optional priority level + """ + return f"{description}: {priority}" + + spec = prioritized_task.tool_spec + schema = spec["inputSchema"]["json"] + + expected_schema = { + "$defs": { + "Priority": { + "enum": ["high", "medium", "low"], + "title": "Priority", + "type": "string", + }, + }, + "type": "object", + "properties": { + "description": { + "type": "string", + "description": "Task description", + }, + "priority": { + "anyOf": [ + {"$ref": "#/$defs/Priority"}, + {"type": "null"}, + ], + "description": "Optional priority level", + }, + }, + "required": ["description", "priority"], + } + + assert schema == expected_schema + + +def test_tool_nullable_optional_field_simplifies_anyof(): + """Test that a non-required nullable field still gets anyOf simplified.""" + + @strands.tool + def my_tool(name: str, tag: str | None = None) -> str: + """A tool. + + Args: + name: The name + tag: An optional tag + """ + return f"{name}: {tag}" + + spec = my_tool.tool_spec + schema = spec["inputSchema"]["json"] + + # tag has a default, so it should NOT be required + assert "name" in schema["required"] + assert "tag" not in schema["required"] + + # Since tag is not required, anyOf should be simplified away + assert "anyOf" not in schema["properties"]["tag"] + assert schema["properties"]["tag"]["type"] == "string" diff --git a/tests/strands/tools/test_decorator_pep563.py b/tests/strands/tools/test_decorator_pep563.py new file mode 100644 index 000000000..44d9a626a --- /dev/null +++ b/tests/strands/tools/test_decorator_pep563.py @@ -0,0 +1,142 @@ +"""Tests for PEP 563 (from __future__ import annotations) compatibility. + +This module tests that the @tool decorator works correctly when modules use +`from __future__ import annotations` (PEP 563), which causes all annotations +to be stored as string literals rather than evaluated types. + +This is a regression test for issue #1208: +https://github.com/strands-agents/sdk-python/issues/1208 +""" + +from __future__ import annotations + +from typing import Any, Literal + +import pytest +from typing_extensions import TypedDict + +from strands import tool + +# Define types at module level (simulating nova-act pattern) +CLICK_TYPE = Literal["left", "right", "middle", "double"] +EXTRA_TYPE = Literal["extra"] + + +class ClickOptions(TypedDict): + """Options for click operation.""" + + blur_field: bool | None + + +@tool +def simple_literal_tool(click_type: CLICK_TYPE) -> dict[str, Any]: + return {"status": "success", "content": [{"text": f"Clicked: {click_type}"}]} + + +@tool +def complex_literal_tool( + box: str, + extra: EXTRA_TYPE, + click_type: CLICK_TYPE | None = None, + click_options: ClickOptions | None = None, +) -> Any: + return "Done" + + +@tool +def union_literal_tool(mode: Literal["fast", "slow"] | None = None) -> str: + return f"Mode: {mode}" + + +def test_simple_literal_type_tool_spec(): + """Test that simple Literal type parameters work with __future__ annotations.""" + spec = simple_literal_tool.tool_spec + assert spec["name"] == "simple_literal_tool" + + schema = spec["inputSchema"]["json"] + assert "click_type" in schema["properties"] + # Verify Literal values are present in schema + click_type_schema = schema["properties"]["click_type"] + assert "enum" in click_type_schema or "anyOf" in click_type_schema + + +def test_complex_literal_type_tool_spec(): + """Test that complex type hints with Literal work with __future__ annotations.""" + spec = complex_literal_tool.tool_spec + assert spec["name"] == "complex_literal_tool" + + schema = spec["inputSchema"]["json"] + # Ensure schema is correct and contains the expected shape + assert schema == { + "$defs": { + "ClickOptions": { + "description": "Options for click operation.", + "properties": {"blur_field": {"anyOf": [{"type": "boolean"}, {"type": "null"}], "title": "Blur Field"}}, + "required": ["blur_field"], + "title": "ClickOptions", + "type": "object", + } + }, + "properties": { + "box": {"description": "Parameter box", "type": "string"}, + "click_options": { + "$ref": "#/$defs/ClickOptions", + "default": None, + "description": "Parameter click_options", + }, + "click_type": { + "default": None, + "description": "Parameter click_type", + "enum": ["left", "right", "middle", "double"], + "type": "string", + }, + "extra": {"const": "extra", "description": "Parameter extra", "type": "string"}, + }, + "required": ["box", "extra"], + "type": "object", + } + + +def test_union_literal_tool_spec(): + """Test that inline Literal in Union works with __future__ annotations.""" + spec = union_literal_tool.tool_spec + assert spec["name"] == "union_literal_tool" + + schema = spec["inputSchema"]["json"] + assert "mode" in schema["properties"] + + +def test_simple_literal_tool_invocation(): + """Test that tools with Literal types can be invoked.""" + result = simple_literal_tool(click_type="left") + assert result["status"] == "success" + assert "left" in result["content"][0]["text"] + + +def test_complex_literal_tool_invocation(): + """Test that tools with complex types can be invoked.""" + result = complex_literal_tool( + box="box1", + extra="extra", + click_type="double", + click_options={"blur_field": True}, + ) + assert result == "Done" + + +def test_tool_spec_no_pydantic_error(): + """Verify no PydanticUserError is raised when accessing tool_spec. + + This is the specific error from issue #1208: + PydanticUserError: `Agent_clickTool` is not fully defined; + you should define `EXTRA_TYPE`, then call `Agent_clickTool.model_rebuild()`. + """ + # This should not raise PydanticUserError + try: + _ = simple_literal_tool.tool_spec + _ = complex_literal_tool.tool_spec + _ = union_literal_tool.tool_spec + except Exception as e: + if "not fully defined" in str(e): + pytest.fail(f"PydanticUserError raised - PEP 563 compatibility broken: {e}") + raise diff --git a/tests/strands/tools/test_loader.py b/tests/strands/tools/test_loader.py index 1c665b42a..121ebed2d 100644 --- a/tests/strands/tools/test_loader.py +++ b/tests/strands/tools/test_loader.py @@ -10,6 +10,14 @@ from strands.tools.loader import _TOOL_MODULE_PREFIX, ToolLoader, load_tools_from_file_path from strands.tools.tools import PythonAgentTool +# Suppress deprecation warnings for deprecated ToolLoader methods being tested +pytestmark = pytest.mark.filterwarnings( + "ignore:ToolLoader.load_python_tool is deprecated:DeprecationWarning", + "ignore:ToolLoader.load_python_tools is deprecated:DeprecationWarning", + "ignore:ToolLoader.load_tool is deprecated:DeprecationWarning", + "ignore:ToolLoader.load_tools is deprecated:DeprecationWarning", +) + @pytest.fixture def tool_path(request, tmp_path, monkeypatch): diff --git a/tests/strands/tools/test_registry.py b/tests/strands/tools/test_registry.py index c700016f6..3723f381b 100644 --- a/tests/strands/tools/test_registry.py +++ b/tests/strands/tools/test_registry.py @@ -7,13 +7,13 @@ import pytest import strands -from strands.experimental.tools import ToolProvider -from strands.tools import PythonAgentTool +from strands.tools import PythonAgentTool, ToolProvider from strands.tools.decorator import DecoratedFunctionTool, tool from strands.tools.mcp import MCPClient from strands.tools.registry import ToolRegistry +@pytest.mark.filterwarnings("ignore:load_tool_from_filepath is deprecated:DeprecationWarning") def test_load_tool_from_filepath_failure(): """Test error handling when load_tool fails.""" tool_registry = ToolRegistry() @@ -389,3 +389,306 @@ async def track_load_tools(*args, **kwargs): # Verify add_consumer was called with the registry ID mock_provider.add_consumer.assert_called_once_with(registry._registry_id) + + +def test_validate_tool_spec_with_anyof_property(): + """Test that validate_tool_spec does not add type: 'string' to anyOf properties. + + This is important for MCP tools that use anyOf for optional/union types like + Optional[List[str]]. Adding type: 'string' causes models to return string-encoded + JSON instead of proper arrays/objects. + """ + tool_spec = { + "name": "test_tool", + "description": "A test tool", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "regular_field": {}, # Should get type: "string" + "anyof_field": { + "anyOf": [ + {"type": "array", "items": {"type": "string"}}, + {"type": "null"}, + ] + }, + }, + } + }, + } + + registry = ToolRegistry() + registry.validate_tool_spec(tool_spec) + + props = tool_spec["inputSchema"]["json"]["properties"] + + # Regular field should get default type: "string" + assert props["regular_field"]["type"] == "string" + assert props["regular_field"]["description"] == "Property regular_field" + + # anyOf field should NOT get type: "string" added + assert "type" not in props["anyof_field"], "anyOf property should not have type added" + assert "anyOf" in props["anyof_field"], "anyOf should be preserved" + assert props["anyof_field"]["description"] == "Property anyof_field" + + +def test_validate_tool_spec_with_composition_keywords(): + """Test that validate_tool_spec does not add type: 'string' to composition keyword properties. + + JSON Schema composition keywords (anyOf, oneOf, allOf, not) define type constraints. + Properties using these should not get a default type added. + """ + tool_spec = { + "name": "test_tool", + "description": "A test tool", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "regular_field": {}, # Should get type: "string" + "oneof_field": { + "oneOf": [ + {"type": "string"}, + {"type": "integer"}, + ] + }, + "allof_field": { + "allOf": [ + {"minimum": 0}, + {"maximum": 100}, + ] + }, + "not_field": {"not": {"type": "null"}}, + }, + } + }, + } + + registry = ToolRegistry() + registry.validate_tool_spec(tool_spec) + + props = tool_spec["inputSchema"]["json"]["properties"] + + # Regular field should get default type: "string" + assert props["regular_field"]["type"] == "string" + + # Composition keyword fields should NOT get type: "string" added + assert "type" not in props["oneof_field"], "oneOf property should not have type added" + assert "oneOf" in props["oneof_field"], "oneOf should be preserved" + + assert "type" not in props["allof_field"], "allOf property should not have type added" + assert "allOf" in props["allof_field"], "allOf should be preserved" + + assert "type" not in props["not_field"], "not property should not have type added" + assert "not" in props["not_field"], "not should be preserved" + + # All should have descriptions + for field in ["oneof_field", "allof_field", "not_field"]: + assert props[field]["description"] == f"Property {field}" + + +def test_validate_tool_spec_with_ref_property(): + """Test that validate_tool_spec does not modify $ref properties.""" + tool_spec = { + "name": "test_tool", + "description": "A test tool", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "ref_field": {"$ref": "#/$defs/SomeType"}, + }, + } + }, + } + + registry = ToolRegistry() + registry.validate_tool_spec(tool_spec) + + props = tool_spec["inputSchema"]["json"]["properties"] + + # $ref field should not be modified + assert props["ref_field"] == {"$ref": "#/$defs/SomeType"} + assert "type" not in props["ref_field"] + assert "description" not in props["ref_field"] + + +def test_tool_registry_replace_existing_tool(): + """Test replacing an existing tool.""" + old_tool = MagicMock() + old_tool.tool_name = "my_tool" + old_tool.is_dynamic = False + old_tool.supports_hot_reload = False + + new_tool = MagicMock() + new_tool.tool_name = "my_tool" + new_tool.is_dynamic = False + + registry = ToolRegistry() + registry.register_tool(old_tool) + registry.replace(new_tool) + + assert registry.registry["my_tool"] == new_tool + + +def test_tool_registry_replace_nonexistent_tool(): + """Test replacing a tool that doesn't exist raises ValueError.""" + new_tool = MagicMock() + new_tool.tool_name = "my_tool" + + registry = ToolRegistry() + + with pytest.raises(ValueError, match="Cannot replace tool 'my_tool' - tool does not exist"): + registry.replace(new_tool) + + +def test_tool_registry_replace_dynamic_tool(): + """Test replacing a dynamic tool updates both registries.""" + old_tool = MagicMock() + old_tool.tool_name = "dynamic_tool" + old_tool.is_dynamic = True + old_tool.supports_hot_reload = True + + new_tool = MagicMock() + new_tool.tool_name = "dynamic_tool" + new_tool.is_dynamic = True + + registry = ToolRegistry() + registry.register_tool(old_tool) + registry.replace(new_tool) + + assert registry.registry["dynamic_tool"] == new_tool + assert registry.dynamic_tools["dynamic_tool"] == new_tool + + +def test_tool_registry_replace_dynamic_with_non_dynamic(): + """Test replacing a dynamic tool with non-dynamic tool removes from dynamic_tools.""" + old_tool = MagicMock() + old_tool.tool_name = "my_tool" + old_tool.is_dynamic = True + old_tool.supports_hot_reload = True + + new_tool = MagicMock() + new_tool.tool_name = "my_tool" + new_tool.is_dynamic = False + + registry = ToolRegistry() + registry.register_tool(old_tool) + + assert "my_tool" in registry.dynamic_tools + + registry.replace(new_tool) + + assert registry.registry["my_tool"] == new_tool + assert "my_tool" not in registry.dynamic_tools + + +def test_tool_registry_replace_non_dynamic_with_dynamic(): + """Test replacing a non-dynamic tool with dynamic tool adds to dynamic_tools.""" + old_tool = MagicMock() + old_tool.tool_name = "my_tool" + old_tool.is_dynamic = False + old_tool.supports_hot_reload = False + + new_tool = MagicMock() + new_tool.tool_name = "my_tool" + new_tool.is_dynamic = True + + registry = ToolRegistry() + registry.register_tool(old_tool) + + assert "my_tool" not in registry.dynamic_tools + + registry.replace(new_tool) + + assert registry.registry["my_tool"] == new_tool + assert registry.dynamic_tools["my_tool"] == new_tool + + +# --- Agent-as-tool sugar --- + + +def test_process_tools_with_agent_instance(): + """Test that passing an Agent instance in tools list auto-wraps it with as_tool().""" + from strands.agent.agent import Agent + + sub_agent = Agent(name="research_agent", description="Finds information", callback_handler=None) + + registry = ToolRegistry() + tool_names = registry.process_tools([sub_agent]) + + assert "research_agent" in tool_names + assert "research_agent" in registry.registry + assert registry.registry["research_agent"].tool_type == "agent" + + +def test_process_tools_with_agent_instance_uses_agent_name(): + """Test that the auto-wrapped tool uses the agent's name.""" + from strands.agent.agent import Agent + + sub_agent = Agent(name="my_custom_agent", callback_handler=None) + + registry = ToolRegistry() + registry.process_tools([sub_agent]) + + assert "my_custom_agent" in registry.registry + spec = registry.registry["my_custom_agent"].tool_spec + assert spec["name"] == "my_custom_agent" + + +def test_process_tools_with_agent_instance_uses_agent_description(): + """Test that the auto-wrapped tool uses the agent's description.""" + from strands.agent.agent import Agent + + sub_agent = Agent(name="helper", description="A helpful assistant", callback_handler=None) + + registry = ToolRegistry() + registry.process_tools([sub_agent]) + + spec = registry.registry["helper"].tool_spec + assert spec["description"] == "A helpful assistant" + + +def test_process_tools_with_agent_in_nested_list(): + """Test that Agent instances in nested iterables are auto-wrapped.""" + from strands.agent.agent import Agent + + agent_a = Agent(name="agent_a", callback_handler=None) + agent_b = Agent(name="agent_b", callback_handler=None) + + registry = ToolRegistry() + tool_names = sorted(registry.process_tools([[agent_a, agent_b]])) + + assert tool_names == ["agent_a", "agent_b"] + + +def test_process_tools_with_mixed_agents_and_tools(): + """Test that Agent instances can be mixed with regular tools.""" + from strands.agent.agent import Agent + + def function() -> str: + return "done" + + regular_tool = tool(name="regular_tool")(function) + sub_agent = Agent(name="sub_agent", callback_handler=None) + + registry = ToolRegistry() + tool_names = sorted(registry.process_tools([regular_tool, sub_agent])) + + assert tool_names == ["regular_tool", "sub_agent"] + assert registry.registry["sub_agent"].tool_type == "agent" + + +def test_process_tools_with_multiple_agents(): + """Test that multiple Agent instances can be passed.""" + from strands.agent.agent import Agent + + agent_1 = Agent(name="researcher", description="Does research", callback_handler=None) + agent_2 = Agent(name="writer", description="Writes content", callback_handler=None) + agent_3 = Agent(name="reviewer", description="Reviews work", callback_handler=None) + + registry = ToolRegistry() + tool_names = sorted(registry.process_tools([agent_1, agent_2, agent_3])) + + assert tool_names == ["researcher", "reviewer", "writer"] + assert all(registry.registry[name].tool_type == "agent" for name in tool_names) diff --git a/tests/strands/tools/test_registry_tool_provider.py b/tests/strands/tools/test_registry_tool_provider.py index fdf4abb0a..25a4edacb 100644 --- a/tests/strands/tools/test_registry_tool_provider.py +++ b/tests/strands/tools/test_registry_tool_provider.py @@ -4,7 +4,7 @@ import pytest -from strands.experimental.tools.tool_provider import ToolProvider +from strands.tools import ToolProvider from strands.tools.registry import ToolRegistry from tests.fixtures.mock_agent_tool import MockAgentTool diff --git a/tests/strands/tools/test_structured_output.py b/tests/strands/tools/test_structured_output.py index fe9b55334..72a53bfe6 100644 --- a/tests/strands/tools/test_structured_output.py +++ b/tests/strands/tools/test_structured_output.py @@ -1,4 +1,4 @@ -from typing import List, Literal, Optional +from typing import Literal, Optional import pytest from pydantic import BaseModel, Field @@ -27,7 +27,7 @@ class TwoUsersWithPlanet(BaseModel): """Two users model with planet.""" user1: UserWithPlanet = Field(description="The first user") - user2: Optional[UserWithPlanet] = Field(description="The second user", default=None) + user2: UserWithPlanet | None = Field(description="The second user", default=None) # Test model with list of same type fields @@ -250,8 +250,8 @@ class NodeWithCircularRef(BaseModel): def test_conversion_works_with_fields_that_are_not_marked_as_optional_but_have_a_default_value_which_makes_them_optional(): # noqa E501 class Family(BaseModel): - ages: List[str] = Field(default_factory=list) - names: List[str] = Field(default_factory=list) + ages: list[str] = Field(default_factory=list) + names: list[str] = Field(default_factory=list) converted_output = convert_pydantic_to_tool_spec(Family) expected_output = { @@ -281,8 +281,8 @@ class Family(BaseModel): def test_marks_fields_as_optional_for_model_w_fields_that_are_not_marked_as_optional_but_have_a_default_value_which_makes_them_optional(): # noqa E501 class Family(BaseModel): - ages: List[str] = Field(default_factory=list) - names: List[str] = Field(default_factory=list) + ages: list[str] = Field(default_factory=list) + names: list[str] = Field(default_factory=list) converted_output = convert_pydantic_to_tool_spec(Family) assert "null" in converted_output["inputSchema"]["json"]["properties"]["ages"]["type"] @@ -312,14 +312,14 @@ def test_convert_pydantic_with_items_refs(): """Test that no $refs exist after lists of different components.""" class Address(BaseModel): - postal_code: Optional[str] = None + postal_code: str | None = None class Person(BaseModel): """Complete person information.""" list_of_items: list[Address] - list_of_items_nullable: Optional[list[Address]] - list_of_item_or_nullable: list[Optional[Address]] + list_of_items_nullable: list[Address] | None + list_of_item_or_nullable: list[Address | None] tool_spec = convert_pydantic_to_tool_spec(Person) @@ -378,7 +378,7 @@ class Address(BaseModel): street: str city: str country: str - postal_code: Optional[str] = None + postal_code: str | None = None class Contact(BaseModel): address: Address diff --git a/tests/strands/tools/test_tool_spec_setter.py b/tests/strands/tools/test_tool_spec_setter.py new file mode 100644 index 000000000..842146c72 --- /dev/null +++ b/tests/strands/tools/test_tool_spec_setter.py @@ -0,0 +1,253 @@ +"""Tests for tool_spec setter on DecoratedFunctionTool and PythonAgentTool.""" + +import pytest + +from strands.tools.decorator import tool +from strands.tools.tools import PythonAgentTool +from strands.types.tools import ToolSpec + + +class TestDecoratedFunctionToolSpecSetter: + """Tests for DecoratedFunctionTool.tool_spec setter.""" + + def test_set_tool_spec_replaces_spec(self): + @tool + def my_tool(query: str) -> str: + """A test tool.""" + return query + + new_spec: ToolSpec = { + "name": "my_tool", + "description": "Updated tool", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "query": {"type": "string", "description": "The query"}, + "limit": {"type": "integer", "description": "Max results"}, + }, + "required": ["query"], + } + }, + } + my_tool.tool_spec = new_spec + assert my_tool.tool_spec is new_spec + assert "limit" in my_tool.tool_spec["inputSchema"]["json"]["properties"] + + def test_set_tool_spec_persists_across_reads(self): + @tool + def another_tool(x: int) -> int: + """Another test tool.""" + return x + + new_spec: ToolSpec = { + "name": "another_tool", + "description": "Modified", + "inputSchema": { + "json": { + "type": "object", + "properties": {"x": {"type": "integer"}, "y": {"type": "integer"}}, + "required": ["x"], + } + }, + } + another_tool.tool_spec = new_spec + assert another_tool.tool_spec["description"] == "Modified" + assert another_tool.tool_spec["description"] == "Modified" + + def test_add_property_via_setter(self): + @tool + def dynamic_tool(base: str) -> str: + """A dynamic tool.""" + return base + + spec = dynamic_tool.tool_spec.copy() + spec["inputSchema"] = dynamic_tool.tool_spec["inputSchema"].copy() + spec["inputSchema"]["json"] = dynamic_tool.tool_spec["inputSchema"]["json"].copy() + spec["inputSchema"]["json"]["properties"] = dynamic_tool.tool_spec["inputSchema"]["json"]["properties"].copy() + spec["inputSchema"]["json"]["properties"]["extra"] = { + "type": "string", + "description": "Extra param", + } + dynamic_tool.tool_spec = spec + assert "extra" in dynamic_tool.tool_spec["inputSchema"]["json"]["properties"] + + def test_set_tool_spec_rejects_name_change(self): + @tool + def my_tool(query: str) -> str: + """A test tool.""" + return query + + bad_spec: ToolSpec = { + "name": "wrong_name", + "description": "Updated tool", + "inputSchema": {"json": {"type": "object", "properties": {}, "required": []}}, + } + with pytest.raises(ValueError, match="cannot change tool name via tool_spec"): + my_tool.tool_spec = bad_spec + + def test_set_tool_spec_rejects_missing_description(self): + @tool + def my_tool(query: str) -> str: + """A test tool.""" + return query + + bad_spec: ToolSpec = { + "name": "my_tool", + "inputSchema": {"json": {"type": "object", "properties": {}, "required": []}}, + } + with pytest.raises(ValueError, match="tool_spec must contain 'description'"): + my_tool.tool_spec = bad_spec + + def test_set_tool_spec_rejects_missing_input_schema(self): + @tool + def my_tool(query: str) -> str: + """A test tool.""" + return query + + bad_spec: ToolSpec = { + "name": "my_tool", + "description": "Updated tool", + } + with pytest.raises(ValueError, match="tool_spec must contain 'inputSchema'"): + my_tool.tool_spec = bad_spec + + def test_set_tool_spec_accepts_bare_input_schema(self): + @tool + def my_tool(query: str) -> str: + """A test tool.""" + return query + + bare_spec: ToolSpec = { + "name": "my_tool", + "description": "Bare schema", + "inputSchema": {"type": "object", "properties": {"query": {"type": "string"}}, "required": ["query"]}, + } + my_tool.tool_spec = bare_spec + assert my_tool.tool_spec is bare_spec + + def test_set_tool_spec_accepts_valid_spec(self): + @tool + def my_tool(query: str) -> str: + """A test tool.""" + return query + + valid_spec: ToolSpec = { + "name": "my_tool", + "description": "A valid updated spec", + "inputSchema": { + "json": { + "type": "object", + "properties": {"query": {"type": "string"}}, + "required": ["query"], + } + }, + } + my_tool.tool_spec = valid_spec + assert my_tool.tool_spec is valid_spec + + +class TestPythonAgentToolSpecSetter: + """Tests for PythonAgentTool.tool_spec setter.""" + + def _make_tool(self) -> PythonAgentTool: + def func(tool_use, **kwargs): + return {"status": "success", "content": [{"text": "ok"}], "toolUseId": tool_use["toolUseId"]} + + spec: ToolSpec = { + "name": "test_tool", + "description": "A test tool", + "inputSchema": { + "json": { + "type": "object", + "properties": {"input": {"type": "string"}}, + "required": ["input"], + } + }, + } + return PythonAgentTool("test_tool", spec, func) + + def test_set_tool_spec(self): + t = self._make_tool() + new_spec: ToolSpec = { + "name": "test_tool", + "description": "Updated", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "input": {"type": "string"}, + "extra": {"type": "integer"}, + }, + "required": ["input"], + } + }, + } + t.tool_spec = new_spec + assert t.tool_spec is new_spec + assert "extra" in t.tool_spec["inputSchema"]["json"]["properties"] + + def test_set_tool_spec_persists(self): + t = self._make_tool() + new_spec: ToolSpec = { + "name": "test_tool", + "description": "Persisted", + "inputSchema": {"json": {"type": "object", "properties": {}, "required": []}}, + } + t.tool_spec = new_spec + assert t.tool_spec["description"] == "Persisted" + assert t.tool_spec["description"] == "Persisted" + + def test_set_tool_spec_rejects_name_change(self): + t = self._make_tool() + bad_spec: ToolSpec = { + "name": "wrong_name", + "description": "Updated", + "inputSchema": {"json": {"type": "object", "properties": {}, "required": []}}, + } + with pytest.raises(ValueError, match="cannot change tool name via tool_spec"): + t.tool_spec = bad_spec + + def test_set_tool_spec_rejects_missing_description(self): + t = self._make_tool() + bad_spec: ToolSpec = { + "name": "test_tool", + "inputSchema": {"json": {"type": "object", "properties": {}, "required": []}}, + } + with pytest.raises(ValueError, match="tool_spec must contain 'description'"): + t.tool_spec = bad_spec + + def test_set_tool_spec_rejects_missing_input_schema(self): + t = self._make_tool() + bad_spec: ToolSpec = { + "name": "test_tool", + "description": "Updated", + } + with pytest.raises(ValueError, match="tool_spec must contain 'inputSchema'"): + t.tool_spec = bad_spec + + def test_set_tool_spec_accepts_bare_input_schema(self): + t = self._make_tool() + bare_spec: ToolSpec = { + "name": "test_tool", + "description": "Bare schema", + "inputSchema": {"type": "object", "properties": {"input": {"type": "string"}}, "required": ["input"]}, + } + t.tool_spec = bare_spec + assert t.tool_spec is bare_spec + + def test_set_tool_spec_accepts_valid_spec(self): + t = self._make_tool() + valid_spec: ToolSpec = { + "name": "test_tool", + "description": "A valid updated spec", + "inputSchema": { + "json": { + "type": "object", + "properties": {"input": {"type": "string"}}, + "required": ["input"], + } + }, + } + t.tool_spec = valid_spec + assert t.tool_spec is valid_spec diff --git a/tests/strands/tools/test_tools.py b/tests/strands/tools/test_tools.py index 60460f464..e20274523 100644 --- a/tests/strands/tools/test_tools.py +++ b/tests/strands/tools/test_tools.py @@ -509,3 +509,18 @@ async def test_stream(identity_tool, alist): tru_events = await alist(stream) exp_events = [ToolResultEvent(({"tool_use": 1}, 2))] assert tru_events == exp_events + + +def test_normalize_schema_with_anyof(): + """Test that anyOf properties don't get default type.""" + schema = { + "type": "object", + "properties": { + "optional_field": {"anyOf": [{"items": {"type": "string"}, "type": "array"}, {"type": "null"}]}, + "regular_field": {}, + }, + } + normalized = normalize_schema(schema) + + assert "type" not in normalized["properties"]["optional_field"] + assert normalized["properties"]["regular_field"]["type"] == "string" diff --git a/tests/strands/types/test__events.py b/tests/strands/types/test__events.py index d64cabb83..48465e1f6 100644 --- a/tests/strands/types/test__events.py +++ b/tests/strands/types/test__events.py @@ -6,6 +6,7 @@ from strands.telemetry import EventLoopMetrics from strands.types._events import ( + AgentAsToolStreamEvent, AgentResultEvent, CitationStreamEvent, EventLoopStopEvent, @@ -195,8 +196,8 @@ def test_initialization(self): delta = Mock(spec=ContentBlockDelta) citation = Mock(spec=Citation) event = CitationStreamEvent(delta, citation) - assert event["callback"]["citation"] == citation - assert event["callback"]["delta"] == delta + assert event["citation"] == citation + assert event["delta"] == delta class TestReasoningTextStreamEvent: @@ -465,3 +466,39 @@ def test_event_inheritance(self): assert hasattr(event, "is_callback_event") assert hasattr(event, "as_dict") assert hasattr(event, "prepare") + + +class TestAgentAsToolStreamEvent: + """Tests for AgentAsToolStreamEvent.""" + + def test_initialization(self): + """Test AgentAsToolStreamEvent initialization with agent-tool reference.""" + tool_use: ToolUse = { + "toolUseId": "agent_tool_123", + "name": "researcher", + "input": {"input": "hello"}, + } + agent_event = {"data": "partial response"} + mock_agent_as_tool = MagicMock() + mock_agent_as_tool.tool_name = "researcher" + + event = AgentAsToolStreamEvent(tool_use, agent_event, mock_agent_as_tool) + + assert event["tool_stream_event"]["tool_use"] == tool_use + assert event["tool_stream_event"]["data"] == agent_event + assert event.agent_as_tool is mock_agent_as_tool + assert event.tool_use_id == "agent_tool_123" + + def test_is_tool_stream_event_subclass(self): + """Test that AgentAsToolStreamEvent is a ToolStreamEvent subclass.""" + tool_use: ToolUse = { + "toolUseId": "id_123", + "name": "tool", + "input": {}, + } + mock_agent_as_tool = MagicMock() + event = AgentAsToolStreamEvent(tool_use, {}, mock_agent_as_tool) + + assert isinstance(event, ToolStreamEvent) + assert isinstance(event, TypedEvent) + assert type(event) is AgentAsToolStreamEvent diff --git a/tests/strands/types/test_json_dict.py b/tests/strands/types/test_json_dict.py index caa010bac..ad4f4660d 100644 --- a/tests/strands/types/test_json_dict.py +++ b/tests/strands/types/test_json_dict.py @@ -109,3 +109,68 @@ def test_initial_state(): assert state.get("key1") == "value1" assert state.get("key2") == "value2" assert state.get() == initial + + +# ============================================================================ +# Version Tracking Tests +# ============================================================================ + + +def test_version_is_zero_after_initialization(): + """Test that _get_version() returns 0 after initialization.""" + state = JSONSerializableDict() + assert state._get_version() == 0 + + +def test_version_is_zero_after_initialization_with_initial_state(): + """Test that _get_version() returns 0 when initialized with initial_state.""" + state = JSONSerializableDict(initial_state={"key": "value"}) + assert state._get_version() == 0 + + +def test_version_increments_after_set(): + """Test that _get_version() increments after set() is called.""" + state = JSONSerializableDict() + assert state._get_version() == 0 + + state.set("key", "value") + assert state._get_version() == 1 + + state.set("key2", "value2") + assert state._get_version() == 2 + + +def test_version_increments_after_delete(): + """Test that _get_version() increments after delete() is called.""" + state = JSONSerializableDict(initial_state={"key": "value"}) + assert state._get_version() == 0 + + state.delete("key") + assert state._get_version() == 1 + + +def test_version_increments_after_delete_nonexistent_key(): + """Test that _get_version() increments after delete() on nonexistent key.""" + state = JSONSerializableDict() + assert state._get_version() == 0 + + state.delete("nonexistent") + assert state._get_version() == 1 + + +def test_version_increments_independently(): + """Test that version increments independently for each operation.""" + state = JSONSerializableDict() + initial_version = state._get_version() + + state.set("key1", "value1") + version_after_first_set = state._get_version() + assert version_after_first_set == initial_version + 1 + + state.set("key2", "value2") + version_after_second_set = state._get_version() + assert version_after_second_set == version_after_first_set + 1 + + state.delete("key1") + version_after_delete = state._get_version() + assert version_after_delete == version_after_second_set + 1 diff --git a/tests/strands/types/test_media.py b/tests/strands/types/test_media.py new file mode 100644 index 000000000..2fa8c3621 --- /dev/null +++ b/tests/strands/types/test_media.py @@ -0,0 +1,99 @@ +"""Tests for media type definitions.""" + +from strands.types.media import ( + DocumentSource, + ImageSource, + S3Location, + VideoSource, +) + + +class TestS3Location: + """Tests for S3Location TypedDict.""" + + def test_s3_location_with_uri_only(self): + """Test S3Location with only uri field.""" + s3_loc: S3Location = {"uri": "s3://my-bucket/path/to/file.pdf"} + + assert s3_loc["uri"] == "s3://my-bucket/path/to/file.pdf" + assert "bucketOwner" not in s3_loc + + def test_s3_location_with_bucket_owner(self): + """Test S3Location with both uri and bucketOwner fields.""" + s3_loc: S3Location = { + "uri": "s3://my-bucket/path/to/file.pdf", + "bucketOwner": "123456789012", + } + + assert s3_loc["uri"] == "s3://my-bucket/path/to/file.pdf" + assert s3_loc["bucketOwner"] == "123456789012" + + +class TestDocumentSource: + """Tests for DocumentSource TypedDict.""" + + def test_document_source_with_bytes(self): + """Test DocumentSource with bytes content.""" + doc_source: DocumentSource = {"bytes": b"document content"} + + assert doc_source["bytes"] == b"document content" + assert "s3Location" not in doc_source + + def test_document_source_with_s3_location(self): + """Test DocumentSource with s3Location.""" + doc_source: DocumentSource = { + "s3Location": { + "uri": "s3://my-bucket/docs/report.pdf", + "bucketOwner": "123456789012", + } + } + + assert "bytes" not in doc_source + assert doc_source["s3Location"]["uri"] == "s3://my-bucket/docs/report.pdf" + assert doc_source["s3Location"]["bucketOwner"] == "123456789012" + + +class TestImageSource: + """Tests for ImageSource TypedDict.""" + + def test_image_source_with_bytes(self): + """Test ImageSource with bytes content.""" + img_source: ImageSource = {"bytes": b"image content"} + + assert img_source["bytes"] == b"image content" + assert "s3Location" not in img_source + + def test_image_source_with_s3_location(self): + """Test ImageSource with s3Location.""" + img_source: ImageSource = { + "s3Location": { + "uri": "s3://my-bucket/images/photo.png", + } + } + + assert "bytes" not in img_source + assert img_source["s3Location"]["uri"] == "s3://my-bucket/images/photo.png" + + +class TestVideoSource: + """Tests for VideoSource TypedDict.""" + + def test_video_source_with_bytes(self): + """Test VideoSource with bytes content.""" + vid_source: VideoSource = {"bytes": b"video content"} + + assert vid_source["bytes"] == b"video content" + assert "s3Location" not in vid_source + + def test_video_source_with_s3_location(self): + """Test VideoSource with s3Location.""" + vid_source: VideoSource = { + "s3Location": { + "uri": "s3://my-bucket/videos/clip.mp4", + "bucketOwner": "987654321098", + } + } + + assert "bytes" not in vid_source + assert vid_source["s3Location"]["uri"] == "s3://my-bucket/videos/clip.mp4" + assert vid_source["s3Location"]["bucketOwner"] == "987654321098" diff --git a/tests/strands/types/test_message_metadata.py b/tests/strands/types/test_message_metadata.py new file mode 100644 index 000000000..a7f93f690 --- /dev/null +++ b/tests/strands/types/test_message_metadata.py @@ -0,0 +1,37 @@ +"""Tests for MessageMetadata and get_message_metadata.""" + +from strands.types.content import Message, MessageMetadata, get_message_metadata + + +def test_message_without_metadata(): + msg: Message = {"role": "assistant", "content": [{"text": "hello"}]} + assert get_message_metadata(msg) == {} + + +def test_message_with_metadata(): + meta: MessageMetadata = { + "usage": {"inputTokens": 10, "outputTokens": 5, "totalTokens": 15}, + "metrics": {"latencyMs": 100}, + } + msg: Message = {"role": "assistant", "content": [{"text": "hello"}], "metadata": meta} + assert get_message_metadata(msg) == meta + assert get_message_metadata(msg)["usage"]["inputTokens"] == 10 + + +def test_message_with_custom_metadata(): + meta: MessageMetadata = { + "custom": {"source": "summarization", "original_turns": [5, 6, 7]}, + } + msg: Message = {"role": "assistant", "content": [{"text": "summary"}], "metadata": meta} + result = get_message_metadata(msg) + assert result["custom"]["source"] == "summarization" + + +def test_metadata_does_not_affect_role_and_content(): + msg: Message = { + "role": "assistant", + "content": [{"text": "hello"}], + "metadata": {"usage": {"inputTokens": 1, "outputTokens": 1, "totalTokens": 2}}, + } + assert msg["role"] == "assistant" + assert msg["content"] == [{"text": "hello"}] diff --git a/tests/strands/types/test_session.py b/tests/strands/types/test_session.py index 3e5360742..b456f2404 100644 --- a/tests/strands/types/test_session.py +++ b/tests/strands/types/test_session.py @@ -102,13 +102,17 @@ def test_session_agent_from_agent(): agent.conversation_manager = unittest.mock.Mock(get_state=lambda: {"test": "conversation"}) agent.state = AgentState({"test": "state"}) agent._interrupt_state = _InterruptState(interrupts={}, context={}, activated=False) + agent._model_state = {} tru_session_agent = SessionAgent.from_agent(agent) exp_session_agent = SessionAgent( agent_id="a1", conversation_manager_state={"test": "conversation"}, state={"test": "state"}, - _internal_state={"interrupt_state": {"interrupts": {}, "context": {}, "activated": False}}, + _internal_state={ + "interrupt_state": {"interrupts": {}, "context": {}, "activated": False}, + "model_state": {}, + }, created_at=unittest.mock.ANY, updated_at=unittest.mock.ANY, ) @@ -121,7 +125,10 @@ def test_session_agent_initialize_internal_state(): agent_id="a1", conversation_manager_state={}, state={}, - _internal_state={"interrupt_state": {"interrupts": {}, "context": {"test": "init"}, "activated": False}}, + _internal_state={ + "interrupt_state": {"interrupts": {}, "context": {"test": "init"}, "activated": False}, + "model_state": {"response_id": "resp_abc"}, + }, ) session_agent.initialize_internal_state(agent) @@ -129,3 +136,7 @@ def test_session_agent_initialize_internal_state(): tru_interrupt_state = agent._interrupt_state exp_interrupt_state = _InterruptState(interrupts={}, context={"test": "init"}, activated=False) assert tru_interrupt_state == exp_interrupt_state + + tru_model_state = agent._model_state + exp_model_state = {"response_id": "resp_abc"} + assert tru_model_state == exp_model_state diff --git a/tests/strands/experimental/steering/context_providers/__init__.py b/tests/strands/vended_plugins/__init__.py similarity index 100% rename from tests/strands/experimental/steering/context_providers/__init__.py rename to tests/strands/vended_plugins/__init__.py diff --git a/tests/strands/experimental/steering/core/__init__.py b/tests/strands/vended_plugins/context_offloader/__init__.py similarity index 100% rename from tests/strands/experimental/steering/core/__init__.py rename to tests/strands/vended_plugins/context_offloader/__init__.py diff --git a/tests/strands/vended_plugins/context_offloader/test_plugin.py b/tests/strands/vended_plugins/context_offloader/test_plugin.py new file mode 100644 index 000000000..fb9471dbf --- /dev/null +++ b/tests/strands/vended_plugins/context_offloader/test_plugin.py @@ -0,0 +1,591 @@ +"""Tests for the ContextOffloader plugin.""" + +import json +import logging +import math +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from strands.hooks.events import AfterToolCallEvent +from strands.types.tools import ToolContext, ToolUse +from strands.vended_plugins.context_offloader import ( + ContextOffloader, + FileStorage, + InMemoryStorage, +) + + +@pytest.fixture +def storage(): + return InMemoryStorage() + + +@pytest.fixture +def plugin(storage): + return ContextOffloader( + storage=storage, + max_result_tokens=25, + preview_tokens=10, + include_retrieval_tool=False, + ) + + +@pytest.fixture +def mock_agent(): + agent = MagicMock() + agent.model = MagicMock() + agent.model.count_tokens = AsyncMock(side_effect=_heuristic_count_tokens) + return agent + + +async def _heuristic_count_tokens(messages, **kwargs): + """Heuristic token counter for tests: chars / 4.""" + total = 0 + for msg in messages: + for block in msg.get("content", []): + if "toolResult" in block: + for content in block["toolResult"].get("content", []): + if "text" in content: + total += math.ceil(len(content["text"]) / 4) + elif "json" in content: + total += math.ceil(len(json.dumps(content["json"])) / 4) + elif "text" in block: + total += math.ceil(len(block["text"]) / 4) + return total + + +def _make_event(agent, text_content, status="success", tool_use_id="tool_123", cancel_message=None): + """Helper to create an AfterToolCallEvent with content.""" + if isinstance(text_content, str): + content = [{"text": text_content}] + else: + content = text_content + + result = { + "toolUseId": tool_use_id, + "status": status, + "content": content, + } + tool_use = {"toolUseId": tool_use_id, "name": "test_tool", "input": {}} + + return AfterToolCallEvent( + agent=agent, + selected_tool=None, + tool_use=tool_use, + invocation_state={}, + result=result, + cancel_message=cancel_message, + ) + + +class TestContextOffloader: + def test_plugin_name(self, plugin): + assert plugin.name == "context_offloader" + + def test_hooks_auto_discovered(self, plugin): + assert len(plugin.hooks) == 1 + assert plugin.hooks[0].__name__ == "_handle_tool_result" + + def test_raises_on_non_positive_max_result_tokens(self): + with pytest.raises(ValueError, match="max_result_tokens must be positive"): + ContextOffloader(storage=InMemoryStorage(), max_result_tokens=0) + with pytest.raises(ValueError, match="max_result_tokens must be positive"): + ContextOffloader(storage=InMemoryStorage(), max_result_tokens=-1) + + def test_raises_on_negative_preview_tokens(self): + with pytest.raises(ValueError, match="preview_tokens must be non-negative"): + ContextOffloader(storage=InMemoryStorage(), preview_tokens=-1) + + def test_raises_on_preview_tokens_gte_max_result_tokens(self): + with pytest.raises(ValueError, match="preview_tokens must be less than max_result_tokens"): + ContextOffloader(storage=InMemoryStorage(), max_result_tokens=100, preview_tokens=100) + with pytest.raises(ValueError, match="preview_tokens must be less than max_result_tokens"): + ContextOffloader(storage=InMemoryStorage(), max_result_tokens=100, preview_tokens=200) + + @pytest.mark.asyncio + async def test_offloads_oversized_text(self, plugin, storage, mock_agent): + large_text = "a" * 200 + event = _make_event(mock_agent, large_text) + + await plugin._handle_tool_result(event) + + result_text = event.result["content"][0]["text"] + assert "[Offloaded:" in result_text + # Preview should be shorter than the full text + assert len(result_text) < len(large_text) + 500 # preview + metadata < original + overhead + + # Verify stored content + assert len(storage._store) == 1 + ref = list(storage._store.keys())[0] + content, content_type = storage.retrieve(ref) + assert content == large_text.encode("utf-8") + assert content_type == "text/plain" + + @pytest.mark.asyncio + async def test_preserves_status_and_tool_use_id(self, plugin, mock_agent): + event = _make_event(mock_agent, "x" * 200, status="error", tool_use_id="my_tool_456") + + await plugin._handle_tool_result(event) + + assert event.result["status"] == "error" + assert event.result["toolUseId"] == "my_tool_456" + + @pytest.mark.asyncio + async def test_under_threshold_passes_through(self, plugin, mock_agent): + small_text = "x" * 50 # 12.5 tokens, under 25 + event = _make_event(mock_agent, small_text) + original_content = event.result["content"] + + await plugin._handle_tool_result(event) + + assert event.result["content"] is original_content + + @pytest.mark.asyncio + async def test_at_threshold_passes_through(self, plugin, mock_agent): + exact_text = "x" * 100 # exactly 25 tokens + event = _make_event(mock_agent, exact_text) + original_content = event.result["content"] + + await plugin._handle_tool_result(event) + + assert event.result["content"] is original_content + + @pytest.mark.asyncio + async def test_skips_cancelled_tool_calls(self, plugin, mock_agent): + large_text = "x" * 200 + event = _make_event(mock_agent, large_text, cancel_message="tool cancelled by user") + original_content = event.result["content"] + + await plugin._handle_tool_result(event) + + assert event.result["content"] is original_content + + @pytest.mark.asyncio + async def test_skips_retrieve_tool_results_when_enabled(self, storage, mock_agent): + plugin = ContextOffloader(storage=storage, max_result_tokens=25, preview_tokens=10, include_retrieval_tool=True) + large_text = "x" * 200 + result = {"toolUseId": "tool_123", "status": "success", "content": [{"text": large_text}]} + tool_use = {"toolUseId": "tool_123", "name": plugin.retrieve_offloaded_content.tool_name, "input": {}} + event = AfterToolCallEvent( + agent=mock_agent, + selected_tool=None, + tool_use=tool_use, + invocation_state={}, + result=result, + ) + await plugin._handle_tool_result(event) + + assert event.result["content"][0]["text"] == large_text + + @pytest.mark.asyncio + async def test_does_not_skip_retrieve_tool_when_disabled(self, plugin, storage, mock_agent): + large_text = "x" * 200 + result = {"toolUseId": "tool_123", "status": "success", "content": [{"text": large_text}]} + tool_use = {"toolUseId": "tool_123", "name": "retrieve_offloaded_content", "input": {}} + event = AfterToolCallEvent( + agent=mock_agent, + selected_tool=None, + tool_use=tool_use, + invocation_state={}, + result=result, + ) + await plugin._handle_tool_result(event) + + # Tool is disabled, so the result should be offloaded normally + assert "[Offloaded:" in event.result["content"][0]["text"] + + @pytest.mark.asyncio + async def test_image_only_content_passes_through(self, plugin, mock_agent): + content = [{"image": {"format": "png", "source": {"bytes": b"fake"}}}] + event = _make_event(mock_agent, content) + original_content = event.result["content"] + + await plugin._handle_tool_result(event) + + assert event.result["content"] is original_content + + @pytest.mark.asyncio + async def test_image_stored_and_placeholder_has_ref(self, plugin, storage, mock_agent): + img_bytes = b"\x89PNG" + b"\x00" * 100 + content = [ + {"text": "x" * 200}, + {"image": {"format": "png", "source": {"bytes": img_bytes}}}, + ] + event = _make_event(mock_agent, content) + + await plugin._handle_tool_result(event) + + # Should have preview + image placeholder + assert len(event.result["content"]) == 2 + placeholder = event.result["content"][1]["text"] + assert "[image: png, 104 bytes" in placeholder + assert "ref:" in placeholder + + # Verify image was stored + assert len(storage._store) == 2 # text + image + img_ref = placeholder.split("ref: ")[1].rstrip("]") + img_content, img_type = storage.retrieve(img_ref) + assert img_content == img_bytes + assert img_type == "image/png" + + @pytest.mark.asyncio + async def test_document_stored_and_placeholder_has_ref(self, plugin, storage, mock_agent): + doc_bytes = b"%PDF-1.4" + b"\x00" * 100 + content = [ + {"text": "x" * 200}, + {"document": {"format": "pdf", "name": "report.pdf", "source": {"bytes": doc_bytes}}}, + ] + event = _make_event(mock_agent, content) + + await plugin._handle_tool_result(event) + + assert len(event.result["content"]) == 2 + placeholder = event.result["content"][1]["text"] + assert "[document: pdf, report.pdf, 108 bytes" in placeholder + assert "ref:" in placeholder + + # Verify document was stored + doc_ref = placeholder.split("ref: ")[1].rstrip("]") + doc_content, doc_type = storage.retrieve(doc_ref) + assert doc_content == doc_bytes + assert doc_type == "application/pdf" + + @pytest.mark.asyncio + async def test_multiple_text_blocks_stored_separately(self, plugin, storage, mock_agent): + content = [ + {"text": "a" * 60}, + {"text": "b" * 60}, + ] + event = _make_event(mock_agent, content) + + await plugin._handle_tool_result(event) + + # Two text blocks stored separately + assert len(storage._store) == 2 + refs = list(storage._store.keys()) + assert storage.retrieve(refs[0]) == (b"a" * 60, "text/plain") + assert storage.retrieve(refs[1]) == (b"b" * 60, "text/plain") + + @pytest.mark.asyncio + async def test_json_content_stored_as_json(self, plugin, storage, mock_agent): + large_json = {"data": [{"id": i, "value": "x" * 20} for i in range(10)]} + content = [{"json": large_json}] + event = _make_event(mock_agent, content) + + await plugin._handle_tool_result(event) + + assert len(storage._store) == 1 + ref = list(storage._store.keys())[0] + stored_content, content_type = storage.retrieve(ref) + assert content_type == "application/json" + assert json.loads(stored_content) == large_json + + @pytest.mark.asyncio + async def test_mixed_text_and_json(self, plugin, storage, mock_agent): + content = [ + {"text": "a" * 60}, + {"json": {"key": "b" * 60}}, + ] + event = _make_event(mock_agent, content) + + await plugin._handle_tool_result(event) + + # Both stored separately with correct types + assert len(storage._store) == 2 + refs = list(storage._store.keys()) + assert storage.retrieve(refs[0])[1] == "text/plain" + assert storage.retrieve(refs[1])[1] == "application/json" + + @pytest.mark.asyncio + async def test_small_json_passes_through(self, plugin, mock_agent): + content = [{"json": {"key": "value"}}] + event = _make_event(mock_agent, content) + original_content = event.result["content"] + + await plugin._handle_tool_result(event) + + assert event.result["content"] is original_content + + @pytest.mark.asyncio + async def test_error_status_still_offloaded(self, plugin, mock_agent): + large_text = "x" * 200 + event = _make_event(mock_agent, large_text, status="error") + + await plugin._handle_tool_result(event) + + assert "[Offloaded:" in event.result["content"][0]["text"] + assert event.result["status"] == "error" + + @pytest.mark.asyncio + async def test_storage_failure_keeps_original(self, mock_agent, caplog): + failing_storage = MagicMock() + failing_storage.store.side_effect = RuntimeError("disk full") + + plugin = ContextOffloader( + storage=failing_storage, + max_result_tokens=25, + preview_tokens=10, + ) + + large_text = "x" * 200 + event = _make_event(mock_agent, large_text) + + with caplog.at_level(logging.WARNING): + await plugin._handle_tool_result(event) + + assert event.result["content"][0]["text"] == large_text + assert "failed to offload" in caplog.text + + @pytest.mark.asyncio + async def test_partial_storage_failure_keeps_original(self, mock_agent, caplog): + storage = MagicMock() + call_count = 0 + + def store_then_fail(key, content, content_type="text/plain"): + nonlocal call_count + call_count += 1 + if call_count > 1: + raise RuntimeError("disk full on second block") + return f"ref_{call_count}" + + storage.store.side_effect = store_then_fail + + plugin = ContextOffloader(storage=storage, max_result_tokens=25, preview_tokens=10) + + content = [ + {"text": "a" * 60}, + {"text": "b" * 60}, + ] + event = _make_event(mock_agent, content) + + with caplog.at_level(logging.WARNING): + await plugin._handle_tool_result(event) + + assert event.result["content"][0]["text"] == "a" * 60 + assert event.result["content"][1]["text"] == "b" * 60 + assert "failed to offload" in caplog.text + + @pytest.mark.asyncio + async def test_empty_text_blocks_not_stored(self, plugin, storage, mock_agent): + content = [ + {"text": ""}, + {"text": "x" * 200}, + ] + event = _make_event(mock_agent, content) + + await plugin._handle_tool_result(event) + + # Empty text block is not in text_preview_parts but still iterated for storage + # The non-empty block triggers offloading + assert "[Offloaded:" in event.result["content"][0]["text"] + + @pytest.mark.asyncio + async def test_document_only_content_passes_through(self, plugin, mock_agent): + content = [{"document": {"format": "pdf", "name": "report.pdf", "source": {"bytes": b"pdf"}}}] + event = _make_event(mock_agent, content) + original_content = event.result["content"] + + await plugin._handle_tool_result(event) + + assert event.result["content"] is original_content + + @pytest.mark.asyncio + async def test_unknown_content_type_passed_through(self, plugin, mock_agent): + unknown_block = {"custom_type": {"data": "something"}} + content = [ + {"text": "x" * 200}, + unknown_block, + ] + event = _make_event(mock_agent, content) + + await plugin._handle_tool_result(event) + + # Unknown block should be passed through + assert event.result["content"][-1] is unknown_block + + @pytest.mark.asyncio + async def test_all_content_types_mixed(self, plugin, storage, mock_agent): + large_json = {"rows": [{"id": i} for i in range(20)]} + img_bytes = b"\x89PNG" + b"\x00" * 100 + doc_bytes = b"%PDF" + b"\x00" * 200 + content = [ + {"text": "a" * 60}, + {"json": large_json}, + {"image": {"format": "png", "source": {"bytes": img_bytes}}}, + {"document": {"format": "pdf", "name": "report.pdf", "source": {"bytes": doc_bytes}}}, + ] + event = _make_event(mock_agent, content) + + await plugin._handle_tool_result(event) + + result_content = event.result["content"] + # Preview + image placeholder + document placeholder = 3 blocks + assert len(result_content) == 3 + assert "[Offloaded:" in result_content[0]["text"] + assert "[image: png" in result_content[1]["text"] + assert "[document: pdf, report.pdf" in result_content[2]["text"] + + # All 4 blocks stored + assert len(storage._store) == 4 + + @pytest.mark.asyncio + async def test_image_without_bytes_not_stored(self, plugin, storage, mock_agent): + content = [ + {"text": "x" * 200}, + {"image": {"format": "png", "source": {}}}, + ] + event = _make_event(mock_agent, content) + + await plugin._handle_tool_result(event) + + # Only text stored, not the empty image + assert len(storage._store) == 1 + placeholder = event.result["content"][1]["text"] + assert "0 bytes" in placeholder + assert "ref:" not in placeholder + + +class TestRetrievalTool: + @pytest.fixture + def storage(self): + return InMemoryStorage() + + @pytest.fixture + def plugin(self, storage): + return ContextOffloader(storage=storage, max_result_tokens=25, preview_tokens=10, include_retrieval_tool=True) + + @pytest.fixture + def mock_agent(self): + return MagicMock() + + @pytest.fixture + def tool_context(self, mock_agent): + tool_use = ToolUse(toolUseId="retrieve_1", name="retrieve_offloaded_content", input={}) + return ToolContext(tool_use=tool_use, agent=mock_agent, invocation_state={}) + + def test_retrieval_tool_registered_when_enabled(self, plugin): + tool_names = [t.tool_name for t in plugin.tools] + assert "retrieve_offloaded_content" in tool_names + + def test_retrieval_tool_registered_by_default(self): + plugin = ContextOffloader(storage=InMemoryStorage()) + plugin.init_agent(MagicMock()) + tool_names = [t.tool_name for t in plugin.tools] + assert "retrieve_offloaded_content" in tool_names + + def test_retrieval_tool_not_registered_when_disabled(self): + plugin = ContextOffloader(storage=InMemoryStorage(), include_retrieval_tool=False) + plugin.init_agent(MagicMock()) + tool_names = [t.tool_name for t in plugin.tools] + assert "retrieve_offloaded_content" not in tool_names + + def test_retrieve_text_content(self, plugin, storage, tool_context): + ref = storage.store("key_1", b"hello world", "text/plain") + result = plugin.retrieve_offloaded_content(reference=ref, tool_context=tool_context) + assert result == "hello world" + + def test_retrieve_json_content(self, plugin, storage, tool_context): + ref = storage.store("key_1", b'{"key": "value"}', "application/json") + result = plugin.retrieve_offloaded_content(reference=ref, tool_context=tool_context) + assert result["content"][0]["json"] == {"key": "value"} + + def test_retrieve_large_text_returns_full_content(self, plugin, storage, tool_context): + large_text = "a" * 50_000 + ref = storage.store("key_1", large_text.encode("utf-8"), "text/plain") + result = plugin.retrieve_offloaded_content(reference=ref, tool_context=tool_context) + assert result == large_text + + def test_retrieve_missing_reference(self, plugin, tool_context): + result = plugin.retrieve_offloaded_content(reference="nonexistent", tool_context=tool_context) + assert "Error: reference not found" in result + + def test_retrieve_image_content(self, plugin, storage, tool_context): + img_bytes = b"\x89PNG\x00\x00" + ref = storage.store("key_1", img_bytes, "image/png") + result = plugin.retrieve_offloaded_content(reference=ref, tool_context=tool_context) + assert result["status"] == "success" + assert result["content"][0]["image"]["format"] == "png" + assert result["content"][0]["image"]["source"]["bytes"] == img_bytes + + def test_retrieve_document_content(self, plugin, storage, tool_context): + doc_bytes = b"%PDF-1.4 content" + ref = storage.store("key_1", doc_bytes, "application/pdf") + result = plugin.retrieve_offloaded_content(reference=ref, tool_context=tool_context) + assert result["status"] == "success" + assert result["content"][0]["document"]["format"] == "pdf" + assert result["content"][0]["document"]["source"]["bytes"] == doc_bytes + + +class TestInlineGuidance: + @pytest.fixture + def storage(self): + return InMemoryStorage() + + @pytest.fixture + def mock_agent(self): + agent = MagicMock() + agent.model = MagicMock() + agent.model.count_tokens = AsyncMock(side_effect=_heuristic_count_tokens) + return agent + + @pytest.mark.asyncio + async def test_guidance_mentions_retrieval_tool_when_enabled(self, storage, mock_agent): + plugin = ContextOffloader(storage=storage, max_result_tokens=25, preview_tokens=10, include_retrieval_tool=True) + event = _make_event(mock_agent, "x" * 200) + await plugin._handle_tool_result(event) + result_text = event.result["content"][0]["text"] + assert "retrieve_offloaded_content" in result_text + + @pytest.mark.asyncio + async def test_guidance_does_not_mention_retrieval_tool_when_disabled(self, storage, mock_agent): + plugin = ContextOffloader( + storage=storage, max_result_tokens=25, preview_tokens=10, include_retrieval_tool=False + ) + event = _make_event(mock_agent, "x" * 200) + await plugin._handle_tool_result(event) + result_text = event.result["content"][0]["text"] + assert "retrieve_offloaded_content" not in result_text + assert "available tools" in result_text + + +class TestActionableReferences: + """Tests that storage-specific references appear in the offloaded preview.""" + + @pytest.mark.asyncio + async def test_file_storage_path_in_preview(self, tmp_path, mock_agent): + storage = FileStorage(artifact_dir=str(tmp_path / "artifacts")) + plugin = ContextOffloader(storage=storage, max_result_tokens=25, preview_tokens=10) + event = _make_event(mock_agent, "a" * 200) + + await plugin._handle_tool_result(event) + + result_text = event.result["content"][0]["text"] + assert str(tmp_path / "artifacts") in result_text + + @pytest.mark.asyncio + async def test_file_storage_image_placeholder_has_path(self, tmp_path, mock_agent): + storage = FileStorage(artifact_dir=str(tmp_path / "artifacts")) + plugin = ContextOffloader(storage=storage, max_result_tokens=25, preview_tokens=10) + img_bytes = b"\x89PNG" + b"\x00" * 100 + content = [ + {"text": "x" * 200}, + {"image": {"format": "png", "source": {"bytes": img_bytes}}}, + ] + event = _make_event(mock_agent, content) + + await plugin._handle_tool_result(event) + + placeholder = event.result["content"][1]["text"] + assert str(tmp_path / "artifacts") in placeholder + + @pytest.mark.asyncio + async def test_inmemory_storage_opaque_reference_in_preview(self, mock_agent): + storage = InMemoryStorage() + plugin = ContextOffloader(storage=storage, max_result_tokens=25, preview_tokens=10) + event = _make_event(mock_agent, "a" * 200) + + await plugin._handle_tool_result(event) + + result_text = event.result["content"][0]["text"] + assert "mem_" in result_text diff --git a/tests/strands/vended_plugins/context_offloader/test_storage.py b/tests/strands/vended_plugins/context_offloader/test_storage.py new file mode 100644 index 000000000..898dd5f86 --- /dev/null +++ b/tests/strands/vended_plugins/context_offloader/test_storage.py @@ -0,0 +1,301 @@ +"""Tests for offload storage backends.""" + +import threading +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest +from botocore.exceptions import ClientError + +from strands.vended_plugins.context_offloader import ( + FileStorage, + InMemoryStorage, + S3Storage, +) + + +class TestInMemoryStorage: + def test_round_trip(self): + storage = InMemoryStorage() + ref = storage.store("key_1", b"hello world") + content, content_type = storage.retrieve(ref) + assert content == b"hello world" + assert content_type == "text/plain" + + def test_preserves_content_type(self): + storage = InMemoryStorage() + ref = storage.store("key_1", b'{"a": 1}', "application/json") + content, content_type = storage.retrieve(ref) + assert content == b'{"a": 1}' + assert content_type == "application/json" + + def test_stores_binary_content(self): + storage = InMemoryStorage() + img_bytes = b"\x89PNG\r\n\x1a\n" + b"\x00" * 100 + ref = storage.store("key_1", img_bytes, "image/png") + content, content_type = storage.retrieve(ref) + assert content == img_bytes + assert content_type == "image/png" + + def test_retrieve_missing_raises_key_error(self): + storage = InMemoryStorage() + with pytest.raises(KeyError, match="Reference not found"): + storage.retrieve("nonexistent_ref") + + def test_unique_references(self): + storage = InMemoryStorage() + ref1 = storage.store("key_1", b"content a") + ref2 = storage.store("key_1", b"content b") + assert ref1 != ref2 + assert storage.retrieve(ref1)[0] == b"content a" + assert storage.retrieve(ref2)[0] == b"content b" + + def test_reference_format(self): + storage = InMemoryStorage() + ref = storage.store("tool_abc", b"content") + assert ref.startswith("mem_") + assert "tool_abc" in ref + + def test_thread_safety(self): + storage = InMemoryStorage() + refs: list[str] = [] + errors: list[Exception] = [] + + def store_item(i: int): + try: + ref = storage.store(f"key_{i}", f"content_{i}".encode()) + refs.append(ref) + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=store_item, args=(i,)) for i in range(50)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert not errors + assert len(set(refs)) == 50 + + def test_stores_empty_content(self): + storage = InMemoryStorage() + ref = storage.store("key_1", b"") + assert storage.retrieve(ref) == (b"", "text/plain") + + def test_clear(self): + storage = InMemoryStorage() + ref = storage.store("key_1", b"content") + storage.clear() + with pytest.raises(KeyError): + storage.retrieve(ref) + + def test_clear_empty_storage(self): + storage = InMemoryStorage() + storage.clear() + + +class TestFileStorage: + def test_round_trip(self, tmp_path): + storage = FileStorage(artifact_dir=str(tmp_path / "artifacts")) + ref = storage.store("key_1", b"hello world") + content, content_type = storage.retrieve(ref) + assert content == b"hello world" + assert content_type == "text/plain" + + def test_preserves_content_type(self, tmp_path): + storage = FileStorage(artifact_dir=str(tmp_path)) + ref = storage.store("key_1", b'{"a": 1}', "application/json") + content, content_type = storage.retrieve(ref) + assert content == b'{"a": 1}' + assert content_type == "application/json" + + def test_stores_binary_content(self, tmp_path): + storage = FileStorage(artifact_dir=str(tmp_path)) + img_bytes = b"\x89PNG\r\n\x1a\n" + b"\x00" * 100 + ref = storage.store("key_1", img_bytes, "image/png") + content, content_type = storage.retrieve(ref) + assert content == img_bytes + assert content_type == "image/png" + + def test_extension_from_content_type(self, tmp_path): + storage = FileStorage(artifact_dir=str(tmp_path)) + assert storage.store("k", b"text", "text/plain").endswith(".txt") + assert storage.store("k", b"{}", "application/json").endswith(".json") + assert storage.store("k", b"img", "image/png").endswith(".png") + assert storage.store("k", b"pdf", "application/pdf").endswith(".pdf") + + def test_auto_creates_directory(self, tmp_path): + artifact_dir = tmp_path / "nested" / "dir" / "artifacts" + assert not artifact_dir.exists() + storage = FileStorage(artifact_dir=str(artifact_dir)) + storage.store("key_1", b"content") + assert artifact_dir.exists() + + def test_retrieve_missing_raises_key_error(self, tmp_path): + storage = FileStorage(artifact_dir=str(tmp_path)) + with pytest.raises(KeyError, match="Reference not found"): + storage.retrieve("nonexistent.txt") + + def test_unique_references(self, tmp_path): + storage = FileStorage(artifact_dir=str(tmp_path)) + ref1 = storage.store("key_1", b"content a") + ref2 = storage.store("key_1", b"content b") + assert ref1 != ref2 + assert storage.retrieve(ref1)[0] == b"content a" + assert storage.retrieve(ref2)[0] == b"content b" + + def test_sanitizes_path_traversal(self, tmp_path): + storage = FileStorage(artifact_dir=str(tmp_path)) + ref = storage.store("../../etc/passwd", b"content") + assert ".." not in ref + assert "/" not in Path(ref).name + + def test_reference_includes_artifact_dir(self, tmp_path): + artifact_dir = str(tmp_path / "artifacts") + storage = FileStorage(artifact_dir=artifact_dir) + ref = storage.store("key_1", b"content") + assert Path(ref).parent == Path(artifact_dir) + + def test_relative_artifact_dir_gives_relative_reference(self, tmp_path, monkeypatch): + monkeypatch.chdir(tmp_path) + storage = FileStorage(artifact_dir="./artifacts") + ref = storage.store("key_1", b"content") + assert Path(ref).parent == Path("artifacts") + content, content_type = storage.retrieve(ref) + assert content == b"content" + assert content_type == "text/plain" + + def test_retrieve_accepts_bare_filename(self, tmp_path): + storage = FileStorage(artifact_dir=str(tmp_path)) + ref = storage.store("key_1", b"hello world") + filename = Path(ref).name + content, content_type = storage.retrieve(filename) + assert content == b"hello world" + assert content_type == "text/plain" + + def test_metadata_survives_across_instances(self, tmp_path): + artifact_dir = str(tmp_path / "artifacts") + storage1 = FileStorage(artifact_dir=artifact_dir) + ref = storage1.store("key_1", b"hello", "image/png") + + storage2 = FileStorage(artifact_dir=artifact_dir) + content, content_type = storage2.retrieve(ref) + assert content == b"hello" + assert content_type == "image/png" + + def test_corrupt_metadata_fallback(self, tmp_path): + (tmp_path / ".metadata.json").write_text("not valid json", encoding="utf-8") + storage = FileStorage(artifact_dir=str(tmp_path)) + assert storage._content_types == {} + + def test_missing_metadata_fallback(self, tmp_path): + storage = FileStorage(artifact_dir=str(tmp_path)) + ref = storage.store("key_1", b"content", "image/png") + + storage._content_types.clear() + _, content_type = storage.retrieve(ref) + assert content_type == "application/octet-stream" + + def test_retrieve_rejects_path_traversal(self, tmp_path): + storage = FileStorage(artifact_dir=str(tmp_path)) + with pytest.raises(KeyError, match="Reference not found"): + storage.retrieve("../../etc/passwd") + + +class TestS3Storage: + @pytest.fixture + def mock_s3_client(self): + """Create a mock S3 client that stores objects in memory.""" + client = MagicMock() + objects: dict[str, tuple[bytes, str]] = {} + + def put_object(Bucket, Key, Body, ContentType="application/octet-stream", **kwargs): + objects[f"{Bucket}/{Key}"] = (Body, ContentType) + + def get_object(Bucket, Key, **kwargs): + full_key = f"{Bucket}/{Key}" + if full_key not in objects: + error_response = {"Error": {"Code": "NoSuchKey", "Message": "Not found"}} + raise ClientError(error_response, "GetObject") + body_bytes, ct = objects[full_key] + body = MagicMock() + body.read.return_value = body_bytes + return {"Body": body, "ContentType": ct} + + client.put_object.side_effect = put_object + client.get_object.side_effect = get_object + return client + + @pytest.fixture + def storage(self, mock_s3_client): + with patch("boto3.Session") as mock_session_cls: + mock_session = MagicMock() + mock_session.client.return_value = mock_s3_client + mock_session_cls.return_value = mock_session + return S3Storage(bucket="test-bucket", prefix="artifacts") + + def test_round_trip(self, storage): + ref = storage.store("key_1", b"hello world") + content, content_type = storage.retrieve(ref) + assert content == b"hello world" + assert content_type == "text/plain" + + def test_preserves_content_type(self, storage): + ref = storage.store("key_1", b"img", "image/png") + content, content_type = storage.retrieve(ref) + assert content == b"img" + assert content_type == "image/png" + + def test_retrieve_missing_raises_key_error(self, storage): + with pytest.raises(KeyError, match="Reference not found"): + storage.retrieve("nonexistent_key") + + def test_unique_references(self, storage): + ref1 = storage.store("key_1", b"content a") + ref2 = storage.store("key_1", b"content b") + assert ref1 != ref2 + assert storage.retrieve(ref1)[0] == b"content a" + assert storage.retrieve(ref2)[0] == b"content b" + + def test_reference_is_s3_uri(self, storage): + ref = storage.store("tool_abc", b"content") + assert ref.startswith("s3://test-bucket/artifacts/") + + def test_empty_prefix(self, mock_s3_client): + with patch("boto3.Session") as mock_session_cls: + mock_session = MagicMock() + mock_session.client.return_value = mock_s3_client + mock_session_cls.return_value = mock_session + storage = S3Storage(bucket="test-bucket", prefix="") + + ref = storage.store("tool_abc", b"content") + assert ref.startswith("s3://test-bucket/") + assert storage.retrieve(ref)[0] == b"content" + + def test_retrieve_accepts_raw_key(self, storage, mock_s3_client): + ref = storage.store("key_1", b"hello world") + raw_key = ref.removeprefix("s3://test-bucket/") + content, content_type = storage.retrieve(raw_key) + assert content == b"hello world" + assert content_type == "text/plain" + + def test_retrieve_rejects_wrong_bucket_uri(self, storage): + with pytest.raises(KeyError, match="Reference not found"): + storage.retrieve("s3://wrong-bucket/artifacts/some_key") + + def test_put_object_called_with_correct_params(self, storage, mock_s3_client): + storage.store("key_1", b"test content", "application/json") + + mock_s3_client.put_object.assert_called_once() + call_kwargs = mock_s3_client.put_object.call_args[1] + assert call_kwargs["Bucket"] == "test-bucket" + assert call_kwargs["Key"].startswith("artifacts/") + assert call_kwargs["Body"] == b"test content" + assert call_kwargs["ContentType"] == "application/json" + + def test_non_nosuchkey_error_propagates(self, storage, mock_s3_client): + error_response = {"Error": {"Code": "AccessDenied", "Message": "Forbidden"}} + mock_s3_client.get_object.side_effect = ClientError(error_response, "GetObject") + + with pytest.raises(ClientError, match="Forbidden"): + storage.retrieve("some_key") diff --git a/tests/strands/experimental/steering/handlers/__init__.py b/tests/strands/vended_plugins/skills/__init__.py similarity index 100% rename from tests/strands/experimental/steering/handlers/__init__.py rename to tests/strands/vended_plugins/skills/__init__.py diff --git a/tests/strands/vended_plugins/skills/test_agent_skills.py b/tests/strands/vended_plugins/skills/test_agent_skills.py new file mode 100644 index 000000000..03f43ef2c --- /dev/null +++ b/tests/strands/vended_plugins/skills/test_agent_skills.py @@ -0,0 +1,854 @@ +"""Tests for the AgentSkills plugin.""" + +import logging +from pathlib import Path +from unittest.mock import MagicMock + +from strands.hooks.events import BeforeInvocationEvent +from strands.hooks.registry import HookRegistry +from strands.plugins.registry import _PluginRegistry +from strands.types.tools import ToolContext +from strands.vended_plugins.skills.agent_skills import AgentSkills +from strands.vended_plugins.skills.skill import Skill + + +def _make_skill(name: str = "test-skill", description: str = "A test skill", instructions: str = "Do the thing."): + """Helper to create a Skill instance.""" + return Skill(name=name, description=description, instructions=instructions) + + +def _make_skill_dir(parent: Path, name: str, description: str = "A test skill") -> Path: + """Helper to create a skill directory with SKILL.md.""" + skill_dir = parent / name + skill_dir.mkdir(parents=True, exist_ok=True) + content = f"---\nname: {name}\ndescription: {description}\n---\n# Instructions for {name}\n" + (skill_dir / "SKILL.md").write_text(content) + return skill_dir + + +def _mock_agent(): + """Create a mock agent for testing.""" + agent = MagicMock() + agent._system_prompt = "You are an agent." + agent._system_prompt_content = [{"text": "You are an agent."}] + + # Make system_prompt and system_prompt_content properties behave like the real Agent + type(agent).system_prompt = property( + lambda self: self._system_prompt, + lambda self, value: _set_system_prompt(self, value), + ) + type(agent).system_prompt_content = property(lambda self: self._system_prompt_content) + + agent.hooks = HookRegistry() + agent.add_hook = MagicMock( + side_effect=lambda callback, event_type=None: agent.hooks.add_callback(event_type, callback) + ) + agent.tool_registry = MagicMock() + agent.tool_registry.process_tools = MagicMock(return_value=["skills"]) + + # Use a real dict-backed state so get/set work correctly + state_store: dict[str, object] = {} + agent.state = MagicMock() + agent.state.get = MagicMock(side_effect=lambda key: state_store.get(key)) + agent.state.set = MagicMock(side_effect=lambda key, value: state_store.__setitem__(key, value)) + return agent + + +def _mock_tool_context(agent: MagicMock) -> ToolContext: + """Create a mock ToolContext with the given agent.""" + tool_use = {"toolUseId": "test-id", "name": "skills", "input": {}} + return ToolContext(tool_use=tool_use, agent=agent, invocation_state={"agent": agent}) + + +def _set_system_prompt(agent: MagicMock, value: str | list | None) -> None: + """Simulate the Agent.system_prompt setter.""" + if isinstance(value, str): + agent._system_prompt = value + agent._system_prompt_content = [{"text": value}] + elif isinstance(value, list): + text_parts = [block["text"] for block in value if "text" in block] + agent._system_prompt = "\n".join(text_parts) if text_parts else None + agent._system_prompt_content = value + elif value is None: + agent._system_prompt = None + agent._system_prompt_content = None + + +class TestSkillsPluginInit: + """Tests for AgentSkills initialization.""" + + def test_init_with_skill_instances(self): + """Test initialization with Skill instances.""" + skill = _make_skill() + plugin = AgentSkills(skills=[skill]) + + assert len(plugin.get_available_skills()) == 1 + assert plugin.get_available_skills()[0].name == "test-skill" + + def test_init_with_filesystem_paths(self, tmp_path): + """Test initialization with filesystem paths.""" + _make_skill_dir(tmp_path, "fs-skill") + plugin = AgentSkills(skills=[str(tmp_path / "fs-skill")]) + + assert len(plugin.get_available_skills()) == 1 + assert plugin.get_available_skills()[0].name == "fs-skill" + + def test_init_with_parent_directory(self, tmp_path): + """Test initialization with a parent directory containing skills.""" + _make_skill_dir(tmp_path, "skill-a") + _make_skill_dir(tmp_path, "skill-b") + plugin = AgentSkills(skills=[tmp_path]) + + assert len(plugin.get_available_skills()) == 2 + + def test_init_with_mixed_sources(self, tmp_path): + """Test initialization with mixed skill sources.""" + _make_skill_dir(tmp_path, "fs-skill") + direct_skill = _make_skill(name="direct-skill", description="Direct") + plugin = AgentSkills(skills=[str(tmp_path / "fs-skill"), direct_skill]) + + assert len(plugin.get_available_skills()) == 2 + names = {s.name for s in plugin.get_available_skills()} + assert names == {"fs-skill", "direct-skill"} + + def test_init_skips_nonexistent_paths(self, tmp_path): + """Test that nonexistent paths are skipped gracefully.""" + plugin = AgentSkills(skills=[str(tmp_path / "nonexistent")]) + assert len(plugin.get_available_skills()) == 0 + + def test_init_empty_skills(self): + """Test initialization with empty skills list.""" + plugin = AgentSkills(skills=[]) + assert plugin.get_available_skills() == [] + + def test_name_attribute(self): + """Test that the plugin has the correct name.""" + plugin = AgentSkills(skills=[]) + assert plugin.name == "agent_skills" + + def test_custom_state_key(self): + """Test initialization with a custom state key.""" + plugin = AgentSkills(skills=[], state_key="custom_key") + assert plugin._state_key == "custom_key" + + def test_custom_max_resource_files(self): + """Test initialization with a custom max resource files limit.""" + plugin = AgentSkills(skills=[], max_resource_files=50) + assert plugin._max_resource_files == 50 + + +class TestSkillsPluginInitAgent: + """Tests for the init_agent method and plugin registry integration.""" + + def test_registers_tool(self): + """Test that the plugin registry registers the skills tool.""" + plugin = AgentSkills(skills=[_make_skill()]) + agent = _mock_agent() + + registry = _PluginRegistry(agent) + registry.add_and_init(plugin) + + agent.tool_registry.process_tools.assert_called_once() + + def test_registers_hooks(self): + """Test that the plugin registry registers hook callbacks.""" + plugin = AgentSkills(skills=[_make_skill()]) + agent = _mock_agent() + + registry = _PluginRegistry(agent) + registry.add_and_init(plugin) + + assert agent.hooks.has_callbacks() + + def test_does_not_store_agent_reference(self): + """Test that init_agent does not store the agent on the plugin.""" + plugin = AgentSkills(skills=[_make_skill()]) + agent = _mock_agent() + + plugin.init_agent(agent) + + assert not hasattr(plugin, "_agent") + + +class TestSkillsPluginProperties: + """Tests for AgentSkills properties.""" + + def test_available_skills_getter_returns_copy(self): + """Test that get_available_skills returns a copy of the list.""" + skill = _make_skill() + plugin = AgentSkills(skills=[skill]) + + skills_list = plugin.get_available_skills() + skills_list.append(_make_skill(name="another-skill", description="Another")) + + assert len(plugin.get_available_skills()) == 1 + + def test_available_skills_setter(self): + """Test setting skills via set_available_skills.""" + plugin = AgentSkills(skills=[_make_skill()]) + + new_skill = _make_skill(name="new-skill", description="New") + plugin.set_available_skills([new_skill]) + + assert len(plugin.get_available_skills()) == 1 + assert plugin.get_available_skills()[0].name == "new-skill" + + def test_set_available_skills_with_paths(self, tmp_path): + """Test setting skills via set_available_skills with filesystem paths.""" + plugin = AgentSkills(skills=[_make_skill()]) + _make_skill_dir(tmp_path, "fs-skill") + + plugin.set_available_skills([str(tmp_path / "fs-skill")]) + + assert len(plugin.get_available_skills()) == 1 + assert plugin.get_available_skills()[0].name == "fs-skill" + + def test_set_available_skills_with_mixed_sources(self, tmp_path): + """Test setting skills via set_available_skills with mixed sources.""" + plugin = AgentSkills(skills=[]) + _make_skill_dir(tmp_path, "fs-skill") + direct = _make_skill(name="direct", description="Direct") + + plugin.set_available_skills([str(tmp_path / "fs-skill"), direct]) + + assert len(plugin.get_available_skills()) == 2 + names = {s.name for s in plugin.get_available_skills()} + assert names == {"fs-skill", "direct"} + + +class TestSkillsTool: + """Tests for the skills tool method.""" + + def test_activate_skill(self): + """Test activating a skill returns its instructions.""" + skill = _make_skill(instructions="Full instructions here.") + plugin = AgentSkills(skills=[skill]) + agent = _mock_agent() + tool_context = _mock_tool_context(agent) + + result = plugin.skills(skill_name="test-skill", tool_context=tool_context) + + assert "Full instructions here." in result + + def test_activate_nonexistent_skill(self): + """Test activating a nonexistent skill returns error message.""" + skill = _make_skill() + plugin = AgentSkills(skills=[skill]) + agent = _mock_agent() + tool_context = _mock_tool_context(agent) + + result = plugin.skills(skill_name="nonexistent", tool_context=tool_context) + + assert "not found" in result + assert "test-skill" in result + + def test_activate_replaces_previous(self): + """Test that activating a new skill replaces the previous one.""" + skill1 = _make_skill(name="skill-a", description="A", instructions="A instructions") + skill2 = _make_skill(name="skill-b", description="B", instructions="B instructions") + plugin = AgentSkills(skills=[skill1, skill2]) + agent = _mock_agent() + tool_context = _mock_tool_context(agent) + + result_a = plugin.skills(skill_name="skill-a", tool_context=tool_context) + assert "A instructions" in result_a + + result_b = plugin.skills(skill_name="skill-b", tool_context=tool_context) + assert "B instructions" in result_b + + def test_activate_without_name(self): + """Test activating without a skill name returns error.""" + plugin = AgentSkills(skills=[_make_skill()]) + agent = _mock_agent() + tool_context = _mock_tool_context(agent) + + result = plugin.skills(skill_name="", tool_context=tool_context) + + assert "required" in result.lower() + + def test_activate_tracks_in_agent_state(self): + """Test that activating a skill records it in agent state.""" + plugin = AgentSkills(skills=[_make_skill()]) + agent = _mock_agent() + tool_context = _mock_tool_context(agent) + + plugin.skills(skill_name="test-skill", tool_context=tool_context) + + assert plugin.get_activated_skills(agent) == ["test-skill"] + + def test_activate_multiple_tracks_order(self): + """Test that multiple activations are tracked in order.""" + skill_a = _make_skill(name="skill-a", description="A", instructions="A") + skill_b = _make_skill(name="skill-b", description="B", instructions="B") + plugin = AgentSkills(skills=[skill_a, skill_b]) + agent = _mock_agent() + tool_context = _mock_tool_context(agent) + + plugin.skills(skill_name="skill-a", tool_context=tool_context) + plugin.skills(skill_name="skill-b", tool_context=tool_context) + + assert plugin.get_activated_skills(agent) == ["skill-a", "skill-b"] + + def test_activate_same_skill_twice_deduplicates(self): + """Test that re-activating a skill moves it to the end without duplicates.""" + skill_a = _make_skill(name="skill-a", description="A", instructions="A") + skill_b = _make_skill(name="skill-b", description="B", instructions="B") + plugin = AgentSkills(skills=[skill_a, skill_b]) + agent = _mock_agent() + tool_context = _mock_tool_context(agent) + + plugin.skills(skill_name="skill-a", tool_context=tool_context) + plugin.skills(skill_name="skill-b", tool_context=tool_context) + plugin.skills(skill_name="skill-a", tool_context=tool_context) + + assert plugin.get_activated_skills(agent) == ["skill-b", "skill-a"] + + def test_get_activated_skills_empty_by_default(self): + """Test that get_activated_skills returns empty list when nothing activated.""" + plugin = AgentSkills(skills=[_make_skill()]) + agent = _mock_agent() + + assert plugin.get_activated_skills(agent) == [] + + def test_get_activated_skills_returns_copy(self): + """Test that get_activated_skills returns a copy, not a reference.""" + plugin = AgentSkills(skills=[_make_skill()]) + agent = _mock_agent() + tool_context = _mock_tool_context(agent) + + plugin.skills(skill_name="test-skill", tool_context=tool_context) + result = plugin.get_activated_skills(agent) + result.append("injected") + + assert plugin.get_activated_skills(agent) == ["test-skill"] + + +class TestSystemPromptInjection: + """Tests for system prompt injection via hooks.""" + + def test_before_invocation_appends_skills_xml(self): + """Test that before_invocation appends skills XML to system prompt.""" + skill = _make_skill() + plugin = AgentSkills(skills=[skill]) + agent = _mock_agent() + + event = BeforeInvocationEvent(agent=agent) + plugin._on_before_invocation(event) + + assert "" in agent.system_prompt + assert "test-skill" in agent.system_prompt + assert "A test skill" in agent.system_prompt + + def test_before_invocation_preserves_existing_prompt(self): + """Test that existing system prompt content is preserved.""" + plugin = AgentSkills(skills=[_make_skill()]) + agent = _mock_agent() + agent._system_prompt = "Original prompt." + agent._system_prompt_content = [{"text": "Original prompt."}] + + event = BeforeInvocationEvent(agent=agent) + plugin._on_before_invocation(event) + + assert agent.system_prompt.startswith("Original prompt.") + assert "" in agent.system_prompt + + def test_repeated_invocations_do_not_accumulate(self): + """Test that repeated invocations rebuild from current prompt without accumulation.""" + plugin = AgentSkills(skills=[_make_skill()]) + agent = _mock_agent() + agent._system_prompt = "Original prompt." + agent._system_prompt_content = [{"text": "Original prompt."}] + + event = BeforeInvocationEvent(agent=agent) + plugin._on_before_invocation(event) + first_prompt = agent.system_prompt + + plugin._on_before_invocation(event) + second_prompt = agent.system_prompt + + assert first_prompt == second_prompt + + def test_no_skills_injects_empty_message(self): + """Test that a 'no skills available' message is injected when no skills are loaded.""" + plugin = AgentSkills(skills=[]) + agent = _mock_agent() + original_prompt = "Original prompt." + agent._system_prompt = original_prompt + agent._system_prompt_content = [{"text": original_prompt}] + + event = BeforeInvocationEvent(agent=agent) + plugin._on_before_invocation(event) + + assert "No skills are currently available" in agent.system_prompt + assert agent.system_prompt.startswith("Original prompt.") + + def test_none_system_prompt_handled(self): + """Test handling when system prompt is None.""" + plugin = AgentSkills(skills=[_make_skill()]) + agent = _mock_agent() + agent._system_prompt = None + agent._system_prompt_content = None + + event = BeforeInvocationEvent(agent=agent) + plugin._on_before_invocation(event) + + assert "" in agent.system_prompt + + def test_preserves_other_plugin_modifications(self): + """Test that modifications by other plugins/hooks are preserved.""" + plugin = AgentSkills(skills=[_make_skill()]) + agent = _mock_agent() + agent._system_prompt = "Original prompt." + agent._system_prompt_content = [{"text": "Original prompt."}] + + event = BeforeInvocationEvent(agent=agent) + plugin._on_before_invocation(event) + + # Simulate another plugin modifying the prompt + agent.system_prompt = agent.system_prompt + "\n\nExtra context from another plugin." + + plugin._on_before_invocation(event) + + assert "Extra context from another plugin." in agent.system_prompt + assert "" in agent.system_prompt + + def test_uses_public_system_prompt_setter(self): + """Test that the hook uses the public system_prompt setter.""" + plugin = AgentSkills(skills=[_make_skill()]) + agent = _mock_agent() + agent._system_prompt = "Original." + agent._system_prompt_content = [{"text": "Original."}] + + event = BeforeInvocationEvent(agent=agent) + plugin._on_before_invocation(event) + + # The public setter should have been used via the content-block path: + # original block is preserved and the skills XML is appended as a new block. + assert len(agent.system_prompt_content) == 2 + assert agent.system_prompt_content[0] == {"text": "Original."} + assert "" in agent.system_prompt_content[1]["text"] + + def test_preserves_cache_points_in_system_prompt(self): + """Test that cachePoint blocks in the system prompt are preserved after injection.""" + plugin = AgentSkills(skills=[_make_skill()]) + agent = _mock_agent() + agent._system_prompt = "Base instructions." + agent._system_prompt_content = [ + {"text": "Base instructions."}, + {"cachePoint": {"type": "default"}}, + ] + + expected_skills_xml = plugin._generate_skills_xml() + + event = BeforeInvocationEvent(agent=agent) + plugin._on_before_invocation(event) + + # Exact block structure: original text, cachePoint, skills XML + assert agent.system_prompt_content == [ + {"text": "Base instructions."}, + {"cachePoint": {"type": "default"}}, + {"text": expected_skills_xml}, + ] + + # Repeated invocation: identical result, no accumulation + plugin._on_before_invocation(event) + assert agent.system_prompt_content == [ + {"text": "Base instructions."}, + {"cachePoint": {"type": "default"}}, + {"text": expected_skills_xml}, + ] + + def test_warns_when_previous_xml_not_found(self, caplog): + """Test that a warning is logged when the previously injected XML is missing from the prompt.""" + plugin = AgentSkills(skills=[_make_skill()]) + agent = _mock_agent() + agent._system_prompt = "Original prompt." + agent._system_prompt_content = [{"text": "Original prompt."}] + + event = BeforeInvocationEvent(agent=agent) + plugin._on_before_invocation(event) + + # Completely replace the system prompt, removing the injected XML + agent.system_prompt = "Totally new prompt." + + with caplog.at_level(logging.WARNING): + plugin._on_before_invocation(event) + + assert "unable to find previously injected skills XML in system prompt" in caplog.text + assert "" in agent.system_prompt + + +class TestStringPathInjection: + """Tests for the string-path branch of _on_before_invocation (system_prompt_content is None).""" + + def test_string_path_replaces_previous_xml(self): + """Test that old injected XML is replaced when found in the string prompt.""" + plugin = AgentSkills(skills=[_make_skill()]) + agent = _mock_agent() + + old_xml = "\n\nxml" + agent._system_prompt = f"Base prompt.{old_xml}" + agent._system_prompt_content = None + agent.state.set(plugin._state_key, {"last_injected_xml": old_xml}) + + event = BeforeInvocationEvent(agent=agent) + plugin._on_before_invocation(event) + + assert "xml" not in agent.system_prompt + assert "" in agent.system_prompt + assert agent.system_prompt.startswith("Base prompt.") + + def test_string_path_warns_when_previous_xml_not_found(self, caplog): + """Test that a warning is logged when old XML is missing from the string prompt.""" + plugin = AgentSkills(skills=[_make_skill()]) + agent = _mock_agent() + + agent._system_prompt = "Totally new prompt." + agent._system_prompt_content = None + agent.state.set(plugin._state_key, {"last_injected_xml": "\n\nxml"}) + + event = BeforeInvocationEvent(agent=agent) + with caplog.at_level(logging.WARNING): + plugin._on_before_invocation(event) + + assert "unable to find previously injected skills XML in system prompt" in caplog.text + assert "" in agent.system_prompt + + +class TestSkillsXmlGeneration: + """Tests for _generate_skills_xml.""" + + def test_single_skill(self): + """Test XML generation with a single skill.""" + plugin = AgentSkills(skills=[_make_skill()]) + xml = plugin._generate_skills_xml() + + assert "" in xml + assert "" in xml + assert "test-skill" in xml + assert "A test skill" in xml + + def test_multiple_skills(self): + """Test XML generation with multiple skills.""" + skills = [ + _make_skill(name="skill-a", description="Skill A"), + _make_skill(name="skill-b", description="Skill B"), + ] + plugin = AgentSkills(skills=skills) + xml = plugin._generate_skills_xml() + + assert "skill-a" in xml + assert "skill-b" in xml + + def test_empty_skills(self): + """Test XML generation with no skills includes 'no skills available' message.""" + plugin = AgentSkills(skills=[]) + xml = plugin._generate_skills_xml() + + assert "" in xml + assert "No skills are currently available" in xml + assert "" in xml + + def test_location_included_when_path_set(self, tmp_path): + """Test that location element is included when skill has a path.""" + skill = _make_skill() + skill.path = tmp_path / "test-skill" + plugin = AgentSkills(skills=[skill]) + xml = plugin._generate_skills_xml() + + assert f"{tmp_path / 'test-skill' / 'SKILL.md'}" in xml + + def test_location_omitted_when_path_none(self): + """Test that location element is omitted for programmatic skills.""" + skill = _make_skill() + assert skill.path is None + plugin = AgentSkills(skills=[skill]) + xml = plugin._generate_skills_xml() + + assert "" not in xml + + def test_escapes_xml_special_characters(self): + """Test that XML special characters in names and descriptions are escaped.""" + skill = _make_skill(name="a&c", description="Use & more") + plugin = AgentSkills(skills=[skill]) + xml = plugin._generate_skills_xml() + + assert "a<b>&c" in xml + assert "Use <tools> & more" in xml + + +class TestSkillResponseFormat: + """Tests for _format_skill_response.""" + + def test_instructions_only(self): + """Test response with just instructions.""" + skill = _make_skill(instructions="Do the thing.") + plugin = AgentSkills(skills=[skill]) + result = plugin._format_skill_response(skill) + + assert result == "Do the thing." + + def test_no_instructions(self): + """Test response when skill has no instructions.""" + skill = _make_skill(instructions="") + plugin = AgentSkills(skills=[skill]) + result = plugin._format_skill_response(skill) + + assert "no instructions available" in result.lower() + + def test_includes_allowed_tools(self): + """Test response includes allowed tools when set.""" + skill = _make_skill(instructions="Do the thing.") + skill.allowed_tools = ["Bash", "Read"] + plugin = AgentSkills(skills=[skill]) + result = plugin._format_skill_response(skill) + + assert "Do the thing." in result + assert "Allowed tools: Bash, Read" in result + + def test_includes_compatibility(self): + """Test response includes compatibility when set.""" + skill = _make_skill(instructions="Do the thing.") + skill.compatibility = "Requires docker" + plugin = AgentSkills(skills=[skill]) + result = plugin._format_skill_response(skill) + + assert "Compatibility: Requires docker" in result + + def test_includes_location(self, tmp_path): + """Test response includes location when path is set.""" + skill = _make_skill(instructions="Do the thing.") + skill.path = tmp_path / "test-skill" + plugin = AgentSkills(skills=[skill]) + result = plugin._format_skill_response(skill) + + assert f"Location: {tmp_path / 'test-skill' / 'SKILL.md'}" in result + + def test_all_metadata(self, tmp_path): + """Test response with all metadata fields.""" + skill = _make_skill(instructions="Do the thing.") + skill.allowed_tools = ["Bash"] + skill.compatibility = "Requires git" + skill.path = tmp_path / "test-skill" + plugin = AgentSkills(skills=[skill]) + result = plugin._format_skill_response(skill) + + assert "Do the thing." in result + assert "---" in result + assert "Allowed tools: Bash" in result + assert "Compatibility: Requires git" in result + assert "Location:" in result + + def test_includes_resource_listing(self, tmp_path): + """Test response includes resource files from optional directories.""" + skill_dir = tmp_path / "test-skill" + skill_dir.mkdir() + (skill_dir / "scripts").mkdir() + (skill_dir / "scripts" / "extract.py").write_text("# extract") + (skill_dir / "references").mkdir() + (skill_dir / "references" / "REFERENCE.md").write_text("# ref") + + skill = _make_skill(instructions="Do the thing.") + skill.path = skill_dir + plugin = AgentSkills(skills=[skill]) + result = plugin._format_skill_response(skill) + + assert "Available resources:" in result + assert "scripts/extract.py" in result + assert "references/REFERENCE.md" in result + + def test_no_resources_when_no_path(self): + """Test that resources section is omitted for programmatic skills.""" + skill = _make_skill(instructions="Do the thing.") + plugin = AgentSkills(skills=[skill]) + result = plugin._format_skill_response(skill) + + assert "Available resources:" not in result + + def test_no_resources_when_dirs_empty(self, tmp_path): + """Test that resources section is omitted when optional dirs don't exist.""" + skill_dir = tmp_path / "test-skill" + skill_dir.mkdir() + + skill = _make_skill(instructions="Do the thing.") + skill.path = skill_dir + plugin = AgentSkills(skills=[skill]) + result = plugin._format_skill_response(skill) + + assert "Available resources:" not in result + + def test_resource_listing_truncated(self, tmp_path): + """Test that resource listing is truncated at the max file limit.""" + skill_dir = tmp_path / "test-skill" + scripts_dir = skill_dir / "scripts" + scripts_dir.mkdir(parents=True) + for i in range(55): + (scripts_dir / f"script_{i:03d}.py").write_text(f"# script {i}") + + skill = _make_skill(instructions="Do the thing.") + skill.path = skill_dir + plugin = AgentSkills(skills=[skill]) + result = plugin._format_skill_response(skill) + + assert "Available resources:" in result + assert "truncated at 20 files" in result + + +class TestResolveSkills: + """Tests for _resolve_skills.""" + + def test_resolve_skill_instances(self): + """Test resolving Skill instances (pass-through).""" + skill = _make_skill() + plugin = AgentSkills(skills=[skill]) + + assert len(plugin._skills) == 1 + assert plugin._skills["test-skill"] is skill + + def test_resolve_skill_directory_path(self, tmp_path): + """Test resolving a path to a skill directory.""" + _make_skill_dir(tmp_path, "path-skill") + plugin = AgentSkills(skills=[tmp_path / "path-skill"]) + + assert len(plugin._skills) == 1 + assert "path-skill" in plugin._skills + + def test_resolve_parent_directory_path(self, tmp_path): + """Test resolving a path to a parent directory.""" + _make_skill_dir(tmp_path, "child-a") + _make_skill_dir(tmp_path, "child-b") + plugin = AgentSkills(skills=[tmp_path]) + + assert len(plugin._skills) == 2 + + def test_resolve_skill_md_file_path(self, tmp_path): + """Test resolving a path to a SKILL.md file.""" + skill_dir = _make_skill_dir(tmp_path, "file-skill") + plugin = AgentSkills(skills=[skill_dir / "SKILL.md"]) + + assert len(plugin._skills) == 1 + assert "file-skill" in plugin._skills + + def test_resolve_nonexistent_path(self, tmp_path): + """Test that nonexistent paths are skipped.""" + plugin = AgentSkills(skills=[str(tmp_path / "ghost")]) + assert len(plugin._skills) == 0 + + +class TestResolveUrlSkills: + """Tests for _resolve_skills with URL sources.""" + + _SKILL_MODULE = "strands.vended_plugins.skills.skill" + _SAMPLE_CONTENT = "---\nname: url-skill\ndescription: A URL skill\n---\n# Instructions\n" + + def _mock_urlopen(self, content): + """Create a mock urlopen context manager returning the given content.""" + mock_response = MagicMock() + mock_response.read.return_value = content.encode("utf-8") + mock_response.__enter__ = MagicMock(return_value=mock_response) + mock_response.__exit__ = MagicMock(return_value=False) + return mock_response + + def test_resolve_url_source(self): + """Test resolving a URL string as a skill source.""" + from unittest.mock import patch + + with patch( + f"{self._SKILL_MODULE}.urllib.request.urlopen", return_value=self._mock_urlopen(self._SAMPLE_CONTENT) + ): + plugin = AgentSkills(skills=["https://example.com/SKILL.md"]) + + assert len(plugin.get_available_skills()) == 1 + assert plugin.get_available_skills()[0].name == "url-skill" + + def test_resolve_mixed_url_and_local(self, tmp_path): + """Test resolving a mix of URL and local filesystem sources.""" + from unittest.mock import patch + + _make_skill_dir(tmp_path, "local-skill") + + with patch( + f"{self._SKILL_MODULE}.urllib.request.urlopen", return_value=self._mock_urlopen(self._SAMPLE_CONTENT) + ): + plugin = AgentSkills( + skills=[ + "https://example.com/SKILL.md", + str(tmp_path / "local-skill"), + ] + ) + + assert len(plugin.get_available_skills()) == 2 + names = {s.name for s in plugin.get_available_skills()} + assert names == {"url-skill", "local-skill"} + + def test_resolve_url_failure_skips_gracefully(self, caplog): + """Test that a failed URL fetch is skipped with a warning.""" + import logging + import urllib.error + from unittest.mock import patch + + with ( + patch( + f"{self._SKILL_MODULE}.urllib.request.urlopen", + side_effect=urllib.error.HTTPError( + url="https://example.com", code=404, msg="Not Found", hdrs=None, fp=None + ), + ), + caplog.at_level(logging.WARNING), + ): + plugin = AgentSkills(skills=["https://example.com/broken/SKILL.md"]) + + assert len(plugin.get_available_skills()) == 0 + assert "failed to load skill from URL" in caplog.text + + def test_resolve_duplicate_url_skills_warns(self, caplog): + """Test that duplicate skill names from URLs log a warning.""" + import logging + from unittest.mock import patch + + with ( + patch( + f"{self._SKILL_MODULE}.urllib.request.urlopen", + return_value=self._mock_urlopen(self._SAMPLE_CONTENT), + ), + caplog.at_level(logging.WARNING), + ): + plugin = AgentSkills( + skills=[ + "https://example.com/a/SKILL.md", + "https://example.com/b/SKILL.md", + ] + ) + + assert len(plugin.get_available_skills()) == 1 + assert "duplicate skill name" in caplog.text + + +class TestImports: + """Tests for module imports.""" + + def test_import_skill_from_strands(self): + """Test importing Skill from top-level strands package.""" + from strands import Skill as S + + assert S is Skill + + def test_import_from_skills_package(self): + """Test importing from strands.vended_plugins.skills package.""" + from strands.vended_plugins.skills import AgentSkills, Skill + + assert Skill is not None + assert AgentSkills is not None + + def test_skills_plugin_is_plugin_subclass(self): + """Test that AgentSkills is a subclass of the Plugin ABC.""" + from strands.plugins import Plugin + + assert issubclass(AgentSkills, Plugin) + + def test_skills_plugin_isinstance_check(self): + """Test that AgentSkills instances pass isinstance check against Plugin.""" + from strands.plugins import Plugin + + plugin = AgentSkills(skills=[]) + assert isinstance(plugin, Plugin) diff --git a/tests/strands/vended_plugins/skills/test_skill.py b/tests/strands/vended_plugins/skills/test_skill.py new file mode 100644 index 000000000..cb67d71a2 --- /dev/null +++ b/tests/strands/vended_plugins/skills/test_skill.py @@ -0,0 +1,649 @@ +"""Tests for the Skill dataclass and loading utilities.""" + +import logging +from pathlib import Path + +import pytest + +from strands.vended_plugins.skills.skill import ( + Skill, + _find_skill_md, + _fix_yaml_colons, + _parse_frontmatter, + _validate_skill_name, +) + + +class TestSkillDataclass: + """Tests for the Skill dataclass creation and properties.""" + + def test_skill_minimal(self): + """Test creating a Skill with only required fields.""" + skill = Skill(name="test-skill", description="A test skill") + + assert skill.name == "test-skill" + assert skill.description == "A test skill" + assert skill.instructions == "" + assert skill.path is None + assert skill.allowed_tools is None + assert skill.metadata == {} + assert skill.license is None + assert skill.compatibility is None + + def test_skill_full(self): + """Test creating a Skill with all fields.""" + skill = Skill( + name="full-skill", + description="A fully specified skill", + instructions="# Full Instructions\nDo the thing.", + path=Path("/tmp/skills/full-skill"), + allowed_tools=["tool1", "tool2"], + metadata={"author": "test-org"}, + license="Apache-2.0", + compatibility="strands>=1.0", + ) + + assert skill.name == "full-skill" + assert skill.description == "A fully specified skill" + assert skill.instructions == "# Full Instructions\nDo the thing." + assert skill.path == Path("/tmp/skills/full-skill") + assert skill.allowed_tools == ["tool1", "tool2"] + assert skill.metadata == {"author": "test-org"} + assert skill.license == "Apache-2.0" + assert skill.compatibility == "strands>=1.0" + + def test_skill_metadata_default_is_not_shared(self): + """Test that default metadata dict is not shared between instances.""" + skill1 = Skill(name="skill-1", description="First") + skill2 = Skill(name="skill-2", description="Second") + + skill1.metadata["key"] = "value" + assert "key" not in skill2.metadata + + +class TestFindSkillMd: + """Tests for _find_skill_md.""" + + def test_finds_uppercase_skill_md(self, tmp_path): + """Test finding SKILL.md (uppercase).""" + (tmp_path / "SKILL.md").write_text("test") + result = _find_skill_md(tmp_path) + assert result.name == "SKILL.md" + + def test_finds_lowercase_skill_md(self, tmp_path): + """Test finding skill.md (lowercase).""" + (tmp_path / "skill.md").write_text("test") + result = _find_skill_md(tmp_path) + assert result.name.lower() == "skill.md" + + def test_prefers_uppercase(self, tmp_path): + """Test that SKILL.md is preferred over skill.md.""" + (tmp_path / "SKILL.md").write_text("uppercase") + (tmp_path / "skill.md").write_text("lowercase") + result = _find_skill_md(tmp_path) + assert result.name == "SKILL.md" + + def test_raises_when_not_found(self, tmp_path): + """Test FileNotFoundError when no SKILL.md exists.""" + with pytest.raises(FileNotFoundError, match="no SKILL.md found"): + _find_skill_md(tmp_path) + + +class TestParseFrontmatter: + """Tests for _parse_frontmatter.""" + + def test_valid_frontmatter(self): + """Test parsing valid frontmatter.""" + content = "---\nname: test-skill\ndescription: A test\n---\n# Instructions\nDo things." + frontmatter, body = _parse_frontmatter(content) + assert frontmatter["name"] == "test-skill" + assert frontmatter["description"] == "A test" + assert "# Instructions" in body + assert "Do things." in body + + def test_missing_opening_delimiter(self): + """Test error when opening --- is missing.""" + with pytest.raises(ValueError, match="must start with ---"): + _parse_frontmatter("name: test\n---\n") + + def test_missing_closing_delimiter(self): + """Test error when closing --- is missing.""" + with pytest.raises(ValueError, match="missing closing ---"): + _parse_frontmatter("---\nname: test\n") + + def test_empty_body(self): + """Test frontmatter with empty body.""" + content = "---\nname: test-skill\ndescription: test\n---\n" + frontmatter, body = _parse_frontmatter(content) + assert frontmatter["name"] == "test-skill" + assert body == "" + + def test_frontmatter_with_metadata(self): + """Test frontmatter with nested metadata.""" + content = "---\nname: test-skill\ndescription: test\nmetadata:\n author: acme\n---\nBody here." + frontmatter, body = _parse_frontmatter(content) + assert frontmatter["name"] == "test-skill" + assert isinstance(frontmatter["metadata"], dict) + assert frontmatter["metadata"]["author"] == "acme" + assert body == "Body here." + + def test_frontmatter_with_dashes_in_yaml_value(self): + """Test that --- inside a YAML value does not break parsing.""" + content = "---\nname: test-skill\ndescription: has --- inside\n---\nBody here." + frontmatter, body = _parse_frontmatter(content) + assert frontmatter["name"] == "test-skill" + assert frontmatter["description"] == "has --- inside" + assert body == "Body here." + + +class TestValidateSkillName: + """Tests for _validate_skill_name (lenient validation).""" + + def test_valid_names(self): + """Test that valid names pass validation without warnings.""" + valid_names = ["a", "test", "my-skill", "skill-123", "a1b2c3"] + for name in valid_names: + _validate_skill_name(name) # Should not raise + + def test_empty_name(self): + """Test that empty name raises ValueError.""" + with pytest.raises(ValueError, match="cannot be empty"): + _validate_skill_name("") + + def test_too_long_name_warns(self, caplog): + """Test that names exceeding 64 chars warn but do not raise.""" + with caplog.at_level(logging.WARNING): + _validate_skill_name("a" * 65) + assert "exceeds" in caplog.text + + def test_uppercase_warns(self, caplog): + """Test that uppercase characters warn but do not raise.""" + with caplog.at_level(logging.WARNING): + _validate_skill_name("MySkill") + assert "lowercase alphanumeric" in caplog.text + + def test_starts_with_hyphen_warns(self, caplog): + """Test that names starting with hyphen warn but do not raise.""" + with caplog.at_level(logging.WARNING): + _validate_skill_name("-skill") + assert "lowercase alphanumeric" in caplog.text + + def test_ends_with_hyphen_warns(self, caplog): + """Test that names ending with hyphen warn but do not raise.""" + with caplog.at_level(logging.WARNING): + _validate_skill_name("skill-") + assert "lowercase alphanumeric" in caplog.text + + def test_consecutive_hyphens_warns(self, caplog): + """Test that consecutive hyphens warn but do not raise.""" + with caplog.at_level(logging.WARNING): + _validate_skill_name("my--skill") + assert "consecutive hyphens" in caplog.text + + def test_special_characters_warns(self, caplog): + """Test that special characters warn but do not raise.""" + with caplog.at_level(logging.WARNING): + _validate_skill_name("my_skill") + assert "lowercase alphanumeric" in caplog.text + + def test_directory_name_mismatch_warns(self, tmp_path, caplog): + """Test that skill name not matching directory name warns but does not raise.""" + skill_dir = tmp_path / "wrong-name" + skill_dir.mkdir() + with caplog.at_level(logging.WARNING): + _validate_skill_name("my-skill", skill_dir) + assert "does not match parent directory name" in caplog.text + + def test_directory_name_match(self, tmp_path): + """Test that matching directory name passes.""" + skill_dir = tmp_path / "my-skill" + skill_dir.mkdir() + _validate_skill_name("my-skill", skill_dir) # Should not raise or warn + + +class TestValidateSkillNameStrict: + """Tests for _validate_skill_name with strict=True.""" + + def test_strict_valid_name(self): + """Test that valid names pass strict validation.""" + _validate_skill_name("my-skill", strict=True) # Should not raise + + def test_strict_empty_name(self): + """Test that empty name raises in strict mode.""" + with pytest.raises(ValueError, match="cannot be empty"): + _validate_skill_name("", strict=True) + + def test_strict_too_long_name(self): + """Test that names exceeding 64 chars raise in strict mode.""" + with pytest.raises(ValueError, match="exceeds 64 character limit"): + _validate_skill_name("a" * 65, strict=True) + + def test_strict_uppercase_rejected(self): + """Test that uppercase characters raise in strict mode.""" + with pytest.raises(ValueError, match="lowercase alphanumeric"): + _validate_skill_name("MySkill", strict=True) + + def test_strict_starts_with_hyphen(self): + """Test that names starting with hyphen raise in strict mode.""" + with pytest.raises(ValueError, match="lowercase alphanumeric"): + _validate_skill_name("-skill", strict=True) + + def test_strict_consecutive_hyphens(self): + """Test that consecutive hyphens raise in strict mode.""" + with pytest.raises(ValueError, match="consecutive hyphens"): + _validate_skill_name("my--skill", strict=True) + + def test_strict_directory_mismatch(self, tmp_path): + """Test that directory name mismatch raises in strict mode.""" + skill_dir = tmp_path / "wrong-name" + skill_dir.mkdir() + with pytest.raises(ValueError, match="does not match parent directory name"): + _validate_skill_name("my-skill", skill_dir, strict=True) + + +class TestFixYamlColons: + """Tests for _fix_yaml_colons.""" + + def test_fixes_unquoted_colon_in_value(self): + """Test that an unquoted colon in a value gets quoted.""" + raw = "description: Use this skill when: the user asks about PDFs" + fixed = _fix_yaml_colons(raw) + assert fixed == 'description: "Use this skill when: the user asks about PDFs"' + + def test_leaves_already_double_quoted_value(self): + """Test that already double-quoted values are not re-quoted.""" + raw = 'description: "already: quoted"' + assert _fix_yaml_colons(raw) == raw + + def test_leaves_already_single_quoted_value(self): + """Test that already single-quoted values are not re-quoted.""" + raw = "description: 'already: quoted'" + assert _fix_yaml_colons(raw) == raw + + def test_leaves_value_without_colon(self): + """Test that values without colons are unchanged.""" + raw = "name: my-skill" + assert _fix_yaml_colons(raw) == raw + + def test_multiline_mixed(self): + """Test fixing only the lines that need it in a multi-line string.""" + raw = "name: my-skill\ndescription: Use when: needed\nversion: 1.0" + fixed = _fix_yaml_colons(raw) + assert fixed == 'name: my-skill\ndescription: "Use when: needed"\nversion: 1.0' + + def test_empty_string(self): + """Test that an empty string is returned unchanged.""" + assert _fix_yaml_colons("") == "" + + def test_preserves_indented_lines_without_colons(self): + """Test that indented lines without key-value patterns are preserved.""" + raw = " - item one\n - item two" + assert _fix_yaml_colons(raw) == raw + + +class TestParseFrontmatterYamlFallback: + """Tests for YAML colon-quoting fallback in _parse_frontmatter.""" + + def test_fallback_on_unquoted_colon(self): + """Test that frontmatter with unquoted colons in values is parsed via fallback.""" + content = "---\nname: my-skill\ndescription: Use when: the user asks\n---\nBody." + frontmatter, body = _parse_frontmatter(content) + assert frontmatter["name"] == "my-skill" + assert "Use when" in frontmatter["description"] + assert body == "Body." + + def test_fallback_preserves_valid_yaml(self): + """Test that valid YAML is parsed normally without triggering fallback.""" + content = "---\nname: my-skill\ndescription: A simple description\n---\nBody." + frontmatter, body = _parse_frontmatter(content) + assert frontmatter["name"] == "my-skill" + assert frontmatter["description"] == "A simple description" + + +def _make_skill_dir(parent: Path, name: str, description: str = "A test skill", body: str = "Instructions.") -> Path: + """Helper to create a skill directory with SKILL.md.""" + skill_dir = parent / name + skill_dir.mkdir(parents=True, exist_ok=True) + content = f"---\nname: {name}\ndescription: {description}\n---\n{body}\n" + (skill_dir / "SKILL.md").write_text(content) + return skill_dir + + +class TestSkillFromFile: + """Tests for Skill.from_file.""" + + def test_load_from_directory(self, tmp_path): + """Test loading a skill from a directory path.""" + skill_dir = _make_skill_dir(tmp_path, "my-skill", "My description", "# Hello\nWorld.") + skill = Skill.from_file(skill_dir) + + assert skill.name == "my-skill" + assert skill.description == "My description" + assert "# Hello" in skill.instructions + assert "World." in skill.instructions + assert skill.path == skill_dir.resolve() + + def test_load_from_skill_md_file(self, tmp_path): + """Test loading a skill by pointing directly to SKILL.md.""" + skill_dir = _make_skill_dir(tmp_path, "direct-skill") + skill = Skill.from_file(skill_dir / "SKILL.md") + + assert skill.name == "direct-skill" + + def test_load_with_allowed_tools(self, tmp_path): + """Test loading a skill with allowed-tools field as space-delimited string.""" + skill_dir = tmp_path / "tool-skill" + skill_dir.mkdir() + content = "---\nname: tool-skill\ndescription: test\nallowed-tools: read write execute\n---\nBody." + (skill_dir / "SKILL.md").write_text(content) + + skill = Skill.from_file(skill_dir) + assert skill.allowed_tools == ["read", "write", "execute"] + + def test_load_with_allowed_tools_yaml_list(self, tmp_path): + """Test loading a skill with allowed-tools as a YAML list.""" + skill_dir = tmp_path / "list-skill" + skill_dir.mkdir() + content = "---\nname: list-skill\ndescription: test\nallowed-tools:\n - read\n - write\n---\nBody." + (skill_dir / "SKILL.md").write_text(content) + + skill = Skill.from_file(skill_dir) + assert skill.allowed_tools == ["read", "write"] + + def test_load_with_metadata(self, tmp_path): + """Test loading a skill with nested metadata.""" + skill_dir = tmp_path / "meta-skill" + skill_dir.mkdir() + content = "---\nname: meta-skill\ndescription: test\nmetadata:\n author: acme\n---\nBody." + (skill_dir / "SKILL.md").write_text(content) + + skill = Skill.from_file(skill_dir) + assert skill.metadata == {"author": "acme"} + + def test_load_with_license_and_compatibility(self, tmp_path): + """Test loading a skill with license and compatibility fields.""" + skill_dir = tmp_path / "licensed-skill" + skill_dir.mkdir() + content = "---\nname: licensed-skill\ndescription: test\nlicense: MIT\ncompatibility: v1\n---\nBody." + (skill_dir / "SKILL.md").write_text(content) + + skill = Skill.from_file(skill_dir) + assert skill.license == "MIT" + assert skill.compatibility == "v1" + + def test_load_missing_name(self, tmp_path): + """Test error when SKILL.md is missing name field.""" + skill_dir = tmp_path / "no-name" + skill_dir.mkdir() + (skill_dir / "SKILL.md").write_text("---\ndescription: test\n---\nBody.") + + with pytest.raises(ValueError, match="must have a 'name' field"): + Skill.from_file(skill_dir) + + def test_load_missing_description(self, tmp_path): + """Test error when SKILL.md is missing description field.""" + skill_dir = tmp_path / "no-desc" + skill_dir.mkdir() + (skill_dir / "SKILL.md").write_text("---\nname: no-desc\n---\nBody.") + + with pytest.raises(ValueError, match="must have a 'description' field"): + Skill.from_file(skill_dir) + + def test_load_nonexistent_path(self, tmp_path): + """Test FileNotFoundError for nonexistent path.""" + with pytest.raises(FileNotFoundError): + Skill.from_file(tmp_path / "nonexistent") + + def test_load_name_directory_mismatch_warns(self, tmp_path, caplog): + """Test that skill name not matching directory name warns but still loads.""" + skill_dir = tmp_path / "wrong-dir" + skill_dir.mkdir() + (skill_dir / "SKILL.md").write_text("---\nname: right-name\ndescription: test\n---\nBody.") + + with caplog.at_level(logging.WARNING): + skill = Skill.from_file(skill_dir) + + assert skill.name == "right-name" + assert "does not match parent directory name" in caplog.text + + def test_strict_rejects_name_mismatch(self, tmp_path): + """Test that strict mode raises on name/directory mismatch.""" + skill_dir = tmp_path / "wrong-dir" + skill_dir.mkdir() + (skill_dir / "SKILL.md").write_text("---\nname: right-name\ndescription: test\n---\nBody.") + + with pytest.raises(ValueError, match="does not match parent directory name"): + Skill.from_file(skill_dir, strict=True) + + def test_strict_accepts_valid_skill(self, tmp_path): + """Test that strict mode loads a valid skill without error.""" + _make_skill_dir(tmp_path, "valid-skill") + skill = Skill.from_file(tmp_path / "valid-skill", strict=True) + assert skill.name == "valid-skill" + + +class TestSkillFromDirectory: + """Tests for Skill.from_directory.""" + + def test_load_multiple_skills(self, tmp_path): + """Test loading multiple skills from a parent directory.""" + _make_skill_dir(tmp_path, "skill-a", "Skill A") + _make_skill_dir(tmp_path, "skill-b", "Skill B") + + skills = Skill.from_directory(tmp_path) + + assert len(skills) == 2 + names = {s.name for s in skills} + assert names == {"skill-a", "skill-b"} + + def test_skips_directories_without_skill_md(self, tmp_path): + """Test that directories without SKILL.md are silently skipped.""" + _make_skill_dir(tmp_path, "valid-skill") + (tmp_path / "no-skill-here").mkdir() + + skills = Skill.from_directory(tmp_path) + + assert len(skills) == 1 + assert skills[0].name == "valid-skill" + + def test_skips_files_in_parent(self, tmp_path): + """Test that files in the parent directory are ignored.""" + _make_skill_dir(tmp_path, "real-skill") + (tmp_path / "readme.txt").write_text("not a skill") + + skills = Skill.from_directory(tmp_path) + + assert len(skills) == 1 + + def test_empty_directory(self, tmp_path): + """Test loading from an empty directory.""" + skills = Skill.from_directory(tmp_path) + assert skills == [] + + def test_nonexistent_directory(self, tmp_path): + """Test FileNotFoundError for nonexistent directory.""" + with pytest.raises(FileNotFoundError): + Skill.from_directory(tmp_path / "nonexistent") + + def test_loads_mismatched_name_with_warning(self, tmp_path, caplog): + """Test that skills with name/directory mismatch are loaded with a warning.""" + _make_skill_dir(tmp_path, "good-skill") + + bad_dir = tmp_path / "bad-dir" + bad_dir.mkdir() + (bad_dir / "SKILL.md").write_text("---\nname: wrong-name\ndescription: test\n---\nBody.") + + with caplog.at_level(logging.WARNING): + skills = Skill.from_directory(tmp_path) + + assert len(skills) == 2 + names = {s.name for s in skills} + assert names == {"good-skill", "wrong-name"} + assert "does not match parent directory name" in caplog.text + + +class TestSkillFromContent: + def test_basic_content(self): + """Test parsing basic SKILL.md content.""" + content = "---\nname: my-skill\ndescription: A useful skill\n---\n# Instructions\nDo the thing." + skill = Skill.from_content(content) + + assert skill.name == "my-skill" + assert skill.description == "A useful skill" + assert "Do the thing." in skill.instructions + assert skill.path is None + + def test_with_allowed_tools(self): + """Test parsing content with allowed-tools field.""" + content = "---\nname: my-skill\ndescription: A skill\nallowed-tools: Bash Read\n---\nInstructions." + skill = Skill.from_content(content) + + assert skill.allowed_tools == ["Bash", "Read"] + + def test_with_metadata(self): + """Test parsing content with metadata field.""" + content = "---\nname: my-skill\ndescription: A skill\nmetadata:\n key: value\n---\nInstructions." + skill = Skill.from_content(content) + + assert skill.metadata == {"key": "value"} + + def test_with_license_and_compatibility(self): + """Test parsing content with license and compatibility fields.""" + content = ( + "---\nname: my-skill\ndescription: A skill\n" + "license: Apache-2.0\ncompatibility: Requires docker\n---\nInstructions." + ) + skill = Skill.from_content(content) + + assert skill.license == "Apache-2.0" + assert skill.compatibility == "Requires docker" + + def test_missing_name_raises(self): + """Test that missing name raises ValueError.""" + content = "---\ndescription: A skill\n---\nInstructions." + with pytest.raises(ValueError, match="name"): + Skill.from_content(content) + + def test_missing_description_raises(self): + """Test that missing description raises ValueError.""" + content = "---\nname: my-skill\n---\nInstructions." + with pytest.raises(ValueError, match="description"): + Skill.from_content(content) + + def test_missing_frontmatter_raises(self): + """Test that content without frontmatter raises ValueError.""" + content = "# Just markdown\nNo frontmatter here." + with pytest.raises(ValueError, match="frontmatter"): + Skill.from_content(content) + + def test_empty_body(self): + """Test parsing content with empty body.""" + content = "---\nname: my-skill\ndescription: A skill\n---\n" + skill = Skill.from_content(content) + + assert skill.name == "my-skill" + assert skill.instructions == "" + + def test_strict_mode(self): + """Test Skill.from_content with strict=True raises on validation issues.""" + content = "---\nname: BAD_NAME\ndescription: Bad\n---\nBody." + with pytest.raises(ValueError): + Skill.from_content(content, strict=True) + + +class TestSkillFromUrl: + """Tests for Skill.from_url.""" + + _SKILL_MODULE = "strands.vended_plugins.skills.skill" + _SAMPLE_CONTENT = "---\nname: my-skill\ndescription: A remote skill\n---\nRemote instructions.\n" + + def _mock_urlopen(self, content): + """Create a mock urlopen context manager returning the given content.""" + from unittest.mock import MagicMock + + mock_response = MagicMock() + mock_response.read.return_value = content.encode("utf-8") + mock_response.__enter__ = MagicMock(return_value=mock_response) + mock_response.__exit__ = MagicMock(return_value=False) + return mock_response + + def test_from_url_returns_skill(self): + """Test loading a skill from a URL returns a single Skill.""" + from unittest.mock import patch + + mock_response = self._mock_urlopen(self._SAMPLE_CONTENT) + with patch(f"{self._SKILL_MODULE}.urllib.request.urlopen", return_value=mock_response): + skill = Skill.from_url("https://raw.githubusercontent.com/org/repo/main/SKILL.md") + + assert isinstance(skill, Skill) + assert skill.name == "my-skill" + assert skill.description == "A remote skill" + assert "Remote instructions." in skill.instructions + assert skill.path is None + + def test_from_url_invalid_url_raises(self): + """Test that a non-HTTPS URL raises ValueError.""" + with pytest.raises(ValueError, match="not a valid HTTPS URL"): + Skill.from_url("./local-path") + + def test_from_url_http_rejected(self): + """Test that http:// URLs are rejected.""" + with pytest.raises(ValueError, match="not a valid HTTPS URL"): + Skill.from_url("http://example.com/SKILL.md") + + def test_from_url_http_error_raises(self): + """Test that HTTP errors propagate as RuntimeError.""" + import urllib.error + from unittest.mock import patch + + with patch( + f"{self._SKILL_MODULE}.urllib.request.urlopen", + side_effect=urllib.error.HTTPError( + url="https://example.com", code=404, msg="Not Found", hdrs=None, fp=None + ), + ): + with pytest.raises(RuntimeError, match="HTTP 404"): + Skill.from_url("https://example.com/SKILL.md") + + def test_from_url_network_error_raises(self): + """Test that network errors propagate as RuntimeError.""" + import urllib.error + from unittest.mock import patch + + with patch( + f"{self._SKILL_MODULE}.urllib.request.urlopen", + side_effect=urllib.error.URLError("Connection refused"), + ): + with pytest.raises(RuntimeError, match="failed to fetch"): + Skill.from_url("https://example.com/SKILL.md") + + def test_from_url_strict_mode(self): + """Test that strict mode is forwarded to from_content.""" + from unittest.mock import patch + + bad_content = "---\nname: BAD_NAME\ndescription: Bad\n---\nBody." + + with patch(f"{self._SKILL_MODULE}.urllib.request.urlopen", return_value=self._mock_urlopen(bad_content)): + with pytest.raises(ValueError): + Skill.from_url("https://example.com/SKILL.md", strict=True) + + def test_from_url_invalid_content_raises(self): + """Test that non-SKILL.md content (e.g. HTML page) raises ValueError.""" + from unittest.mock import patch + + html_content = "Not a SKILL.md" + + with patch(f"{self._SKILL_MODULE}.urllib.request.urlopen", return_value=self._mock_urlopen(html_content)): + with pytest.raises(ValueError, match="frontmatter"): + Skill.from_url("https://example.com/SKILL.md") + + +class TestSkillClassmethods: + """Tests for Skill classmethod existence.""" + + def test_skill_classmethods_exist(self): + """Test that Skill has from_file, from_content, from_directory, and from_url classmethods.""" + assert callable(getattr(Skill, "from_file", None)) + assert callable(getattr(Skill, "from_content", None)) + assert callable(getattr(Skill, "from_directory", None)) + assert callable(getattr(Skill, "from_url", None)) diff --git a/tests/strands/experimental/steering/handlers/llm/__init__.py b/tests/strands/vended_plugins/steering/__init__.py similarity index 100% rename from tests/strands/experimental/steering/handlers/llm/__init__.py rename to tests/strands/vended_plugins/steering/__init__.py diff --git a/tests/strands/vended_plugins/steering/context_providers/__init__.py b/tests/strands/vended_plugins/steering/context_providers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/strands/vended_plugins/steering/context_providers/test_ledger_provider.py b/tests/strands/vended_plugins/steering/context_providers/test_ledger_provider.py new file mode 100644 index 000000000..dda718f31 --- /dev/null +++ b/tests/strands/vended_plugins/steering/context_providers/test_ledger_provider.py @@ -0,0 +1,324 @@ +"""Unit tests for ledger context providers.""" + +from unittest.mock import Mock, patch + +from strands.hooks.events import AfterToolCallEvent, BeforeToolCallEvent +from strands.vended_plugins.steering.context_providers.ledger_provider import ( + LedgerAfterToolCall, + LedgerBeforeToolCall, + LedgerProvider, +) +from strands.vended_plugins.steering.core.context import SteeringContext + + +def test_context_providers_method(): + """Test context_providers method returns correct callbacks.""" + provider = LedgerProvider() + + callbacks = provider.context_providers() + + assert len(callbacks) == 2 + assert isinstance(callbacks[0], LedgerBeforeToolCall) + assert isinstance(callbacks[1], LedgerAfterToolCall) + + +@patch("strands.vended_plugins.steering.context_providers.ledger_provider.datetime") +def test_ledger_before_tool_call_new_ledger(mock_datetime): + """Test LedgerBeforeToolCall with new ledger.""" + mock_datetime.now.return_value.isoformat.return_value = "2024-01-01T12:00:00" + + callback = LedgerBeforeToolCall() + steering_context = SteeringContext() + + tool_use = {"name": "test_tool", "input": {"param": "value"}} + event = Mock(spec=BeforeToolCallEvent) + event.tool_use = tool_use + + callback(event, steering_context) + + ledger = steering_context.data.get("ledger") + assert ledger is not None + assert "session_start" in ledger + assert "tool_calls" in ledger + assert len(ledger["tool_calls"]) == 1 + + tool_call = ledger["tool_calls"][0] + assert tool_call["tool_name"] == "test_tool" + assert tool_call["tool_args"] == {"param": "value"} + assert tool_call["status"] == "pending" + + +@patch("strands.vended_plugins.steering.context_providers.ledger_provider.datetime") +def test_ledger_before_tool_call_existing_ledger(mock_datetime): + """Test LedgerBeforeToolCall with existing ledger.""" + mock_datetime.now.return_value.isoformat.return_value = "2024-01-01T12:00:00" + + callback = LedgerBeforeToolCall() + steering_context = SteeringContext() + + # Set up existing ledger + existing_ledger = { + "session_start": "2024-01-01T10:00:00", + "tool_calls": [{"name": "previous_tool"}], + "conversation_history": [], + "session_metadata": {}, + } + steering_context.data.set("ledger", existing_ledger) + + tool_use = {"name": "new_tool", "input": {"param": "value"}} + event = Mock(spec=BeforeToolCallEvent) + event.tool_use = tool_use + + callback(event, steering_context) + + ledger = steering_context.data.get("ledger") + assert len(ledger["tool_calls"]) == 2 + assert ledger["tool_calls"][0]["name"] == "previous_tool" + assert ledger["tool_calls"][1]["tool_name"] == "new_tool" + + +@patch("strands.vended_plugins.steering.context_providers.ledger_provider.datetime") +def test_ledger_after_tool_call_success(mock_datetime): + """Test LedgerAfterToolCall with successful completion.""" + mock_datetime.now.return_value.isoformat.return_value = "2024-01-01T12:05:00" + + callback = LedgerAfterToolCall() + steering_context = SteeringContext() + + # Set up existing ledger with pending call + existing_ledger = { + "tool_calls": [ + { + "tool_use_id": "test-id", + "tool_name": "test_tool", + "status": "pending", + "timestamp": "2024-01-01T12:00:00", + } + ] + } + steering_context.data.set("ledger", existing_ledger) + + event = Mock(spec=AfterToolCallEvent) + event.tool_use = {"toolUseId": "test-id"} + event.result = {"status": "success", "content": ["success_result"]} + event.exception = None + + callback(event, steering_context) + + ledger = steering_context.data.get("ledger") + tool_call = ledger["tool_calls"][0] + assert tool_call["status"] == "success" + assert tool_call["result"] == ["success_result"] + assert tool_call["error"] is None + assert tool_call["completion_timestamp"] == "2024-01-01T12:05:00" + + +def test_ledger_after_tool_call_no_calls(): + """Test LedgerAfterToolCall when no tool calls exist.""" + callback = LedgerAfterToolCall() + steering_context = SteeringContext() + + # Set up ledger with no tool calls + existing_ledger = {"tool_calls": []} + steering_context.data.set("ledger", existing_ledger) + + event = Mock(spec=AfterToolCallEvent) + event.result = {"status": "success", "content": ["test"]} + event.exception = None + + # Should not crash when no tool calls exist + callback(event, steering_context) + + ledger = steering_context.data.get("ledger") + assert ledger["tool_calls"] == [] + + +def test_session_start_persistence(): + """Test that session_start is set during initialization and persists.""" + with patch("strands.vended_plugins.steering.context_providers.ledger_provider.datetime") as mock_datetime: + mock_datetime.now.return_value.isoformat.return_value = "2024-01-01T10:00:00" + + callback = LedgerBeforeToolCall() + + assert callback.session_start == "2024-01-01T10:00:00" + + +@patch("strands.vended_plugins.steering.context_providers.ledger_provider.datetime") +def test_parallel_tool_calls_all_pending(mock_datetime): + """Test multiple tool calls added as pending before any execute.""" + mock_datetime.now.return_value.isoformat.return_value = "2024-01-01T12:00:00" + + callback = LedgerBeforeToolCall() + steering_context = SteeringContext() + + # Add three tool calls in sequence (simulating parallel proposal) + for i, tool_name in enumerate(["tool_a", "tool_b", "tool_c"]): + event = Mock(spec=BeforeToolCallEvent) + event.tool_use = {"toolUseId": f"id_{i}", "name": tool_name, "input": {}} + callback(event, steering_context) + + ledger = steering_context.data.get("ledger") + assert len(ledger["tool_calls"]) == 3 + assert all(call["status"] == "pending" for call in ledger["tool_calls"]) + assert [call["tool_name"] for call in ledger["tool_calls"]] == ["tool_a", "tool_b", "tool_c"] + + +@patch("strands.vended_plugins.steering.context_providers.ledger_provider.datetime") +def test_parallel_tool_calls_complete_by_id(mock_datetime): + """Test tool calls complete in any order by matching toolUseId.""" + # Need timestamps for: session_start + 3 tool calls + 1 completion + mock_datetime.now.return_value.isoformat.side_effect = [ + "2024-01-01T11:00:00", # session_start + "2024-01-01T12:00:00", # tool_a + "2024-01-01T12:01:00", # tool_b + "2024-01-01T12:02:00", # tool_c + "2024-01-01T12:03:00", # completion + ] + + before_callback = LedgerBeforeToolCall() + after_callback = LedgerAfterToolCall() + steering_context = SteeringContext() + + # Add three pending tool calls + for i, tool_name in enumerate(["tool_a", "tool_b", "tool_c"]): + event = Mock(spec=BeforeToolCallEvent) + event.tool_use = {"toolUseId": f"id_{i}", "name": tool_name, "input": {}} + before_callback(event, steering_context) + + # Complete middle tool first (out of order) + event = Mock(spec=AfterToolCallEvent) + event.tool_use = {"toolUseId": "id_1"} + event.result = {"status": "success", "content": ["result_b"]} + event.exception = None + after_callback(event, steering_context) + + ledger = steering_context.data.get("ledger") + assert ledger["tool_calls"][0]["status"] == "pending" + assert ledger["tool_calls"][1]["status"] == "success" + assert ledger["tool_calls"][1]["result"] == ["result_b"] + assert ledger["tool_calls"][2]["status"] == "pending" + + +@patch("strands.vended_plugins.steering.context_providers.ledger_provider.datetime") +def test_parallel_tool_calls_complete_all_out_of_order(mock_datetime): + """Test all parallel tool calls complete in reverse order.""" + # Need timestamps for: session_start + 3 tool calls + 3 completions + mock_datetime.now.return_value.isoformat.side_effect = [ + "2024-01-01T11:00:00", # session_start + "2024-01-01T12:00:00", # tool_0 + "2024-01-01T12:01:00", # tool_1 + "2024-01-01T12:02:00", # tool_2 + "2024-01-01T12:03:00", # completion tool_2 + "2024-01-01T12:04:00", # completion tool_1 + "2024-01-01T12:05:00", # completion tool_0 + ] + + before_callback = LedgerBeforeToolCall() + after_callback = LedgerAfterToolCall() + steering_context = SteeringContext() + + # Add three pending tool calls + for i in range(3): + event = Mock(spec=BeforeToolCallEvent) + event.tool_use = {"toolUseId": f"id_{i}", "name": f"tool_{i}", "input": {}} + before_callback(event, steering_context) + + # Complete in reverse order: 2, 1, 0 + for i in [2, 1, 0]: + event = Mock(spec=AfterToolCallEvent) + event.tool_use = {"toolUseId": f"id_{i}"} + event.result = {"status": "success", "content": [f"result_{i}"]} + event.exception = None + after_callback(event, steering_context) + + ledger = steering_context.data.get("ledger") + assert all(call["status"] == "success" for call in ledger["tool_calls"]) + assert ledger["tool_calls"][0]["result"] == ["result_0"] + assert ledger["tool_calls"][1]["result"] == ["result_1"] + assert ledger["tool_calls"][2]["result"] == ["result_2"] + + +@patch("strands.vended_plugins.steering.context_providers.ledger_provider.datetime") +def test_parallel_tool_calls_with_failure(mock_datetime): + """Test parallel tool calls where one fails.""" + # Need timestamps for: session_start + 2 tool calls + 2 completions + mock_datetime.now.return_value.isoformat.side_effect = [ + "2024-01-01T11:00:00", # session_start + "2024-01-01T12:00:00", # tool_0 + "2024-01-01T12:01:00", # tool_1 + "2024-01-01T12:02:00", # completion tool_0 + "2024-01-01T12:03:00", # completion tool_1 + ] + + before_callback = LedgerBeforeToolCall() + after_callback = LedgerAfterToolCall() + steering_context = SteeringContext() + + # Add two pending tool calls + for i in range(2): + event = Mock(spec=BeforeToolCallEvent) + event.tool_use = {"toolUseId": f"id_{i}", "name": f"tool_{i}", "input": {}} + before_callback(event, steering_context) + + # First succeeds + event = Mock(spec=AfterToolCallEvent) + event.tool_use = {"toolUseId": "id_0"} + event.result = {"status": "success", "content": ["result_0"]} + event.exception = None + after_callback(event, steering_context) + + # Second fails + event = Mock(spec=AfterToolCallEvent) + event.tool_use = {"toolUseId": "id_1"} + event.result = {"status": "error", "content": []} + event.exception = ValueError("test error") + after_callback(event, steering_context) + + ledger = steering_context.data.get("ledger") + assert ledger["tool_calls"][0]["status"] == "success" + assert ledger["tool_calls"][0]["error"] is None + assert ledger["tool_calls"][1]["status"] == "error" + assert ledger["tool_calls"][1]["error"] == "test error" + + +@patch("strands.vended_plugins.steering.context_providers.ledger_provider.datetime") +def test_after_tool_call_no_matching_id(mock_datetime): + """Test AfterToolCallEvent when tool_use_id doesn't match any pending call.""" + mock_datetime.now.return_value.isoformat.return_value = "2024-01-01T12:00:00" + + before_callback = LedgerBeforeToolCall() + after_callback = LedgerAfterToolCall() + steering_context = SteeringContext() + + # Add a pending tool call + event = Mock(spec=BeforeToolCallEvent) + event.tool_use = {"toolUseId": "id_1", "name": "tool_1", "input": {}} + before_callback(event, steering_context) + + # Try to complete a different tool_use_id that doesn't exist + event = Mock(spec=AfterToolCallEvent) + event.tool_use = {"toolUseId": "id_999"} + event.result = {"status": "success", "content": ["result"]} + event.exception = None + after_callback(event, steering_context) + + # Original tool should still be pending (no match found) + ledger = steering_context.data.get("ledger") + assert ledger["tool_calls"][0]["status"] == "pending" + assert "completion_timestamp" not in ledger["tool_calls"][0] + + +@patch("strands.vended_plugins.steering.context_providers.ledger_provider.datetime") +def test_tool_use_id_stored_in_ledger(mock_datetime): + """Test that toolUseId is stored in ledger entries.""" + mock_datetime.now.return_value.isoformat.return_value = "2024-01-01T12:00:00" + + callback = LedgerBeforeToolCall() + steering_context = SteeringContext() + + event = Mock(spec=BeforeToolCallEvent) + event.tool_use = {"toolUseId": "test-id-123", "name": "test_tool", "input": {}} + callback(event, steering_context) + + ledger = steering_context.data.get("ledger") + assert ledger["tool_calls"][0]["tool_use_id"] == "test-id-123" diff --git a/tests/strands/vended_plugins/steering/core/__init__.py b/tests/strands/vended_plugins/steering/core/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/strands/vended_plugins/steering/core/test_handler.py b/tests/strands/vended_plugins/steering/core/test_handler.py new file mode 100644 index 000000000..dc3b0dacc --- /dev/null +++ b/tests/strands/vended_plugins/steering/core/test_handler.py @@ -0,0 +1,544 @@ +"""Unit tests for steering handler base class.""" + +import inspect +from unittest.mock import AsyncMock, Mock + +import pytest + +from strands.hooks.events import AfterModelCallEvent, BeforeToolCallEvent +from strands.hooks.registry import HookRegistry +from strands.plugins import Plugin +from strands.vended_plugins.steering.core.action import Guide, Interrupt, Proceed +from strands.vended_plugins.steering.core.context import ( + SteeringContext, + SteeringContextCallback, + SteeringContextProvider, +) +from strands.vended_plugins.steering.core.handler import SteeringHandler + + +class TestSteeringHandler(SteeringHandler): + """Test implementation of SteeringHandler.""" + + async def steer_before_tool(self, *, agent, tool_use, **kwargs): + return Proceed(reason="Test proceed") + + +def test_steering_handler_initialization(): + """Test SteeringHandler initialization.""" + handler = TestSteeringHandler() + assert handler is not None + + +def test_steering_handler_has_name_attribute(): + """Test SteeringHandler has name attribute for Plugin.""" + handler = TestSteeringHandler() + assert hasattr(handler, "name") + assert handler.name == "steering" + + +def test_steering_handler_is_plugin(): + """Test SteeringHandler implements Plugin.""" + handler = TestSteeringHandler() + assert isinstance(handler, Plugin) + + +def test_init_agent(): + """Test init_agent with plugin registry registers hooks on agent.""" + from strands.plugins.registry import _PluginRegistry + + handler = TestSteeringHandler() + agent = Mock() + agent.hooks = HookRegistry() + agent.tool_registry = Mock() + agent.add_hook = Mock(side_effect=lambda callback, event_type=None: agent.hooks.add_callback(event_type, callback)) + + # Use the registry to properly initialize the plugin + registry = _PluginRegistry(agent) + registry.add_and_init(handler) + + # Verify hooks were registered (tool and model steering hooks via @hook decorator) + assert agent.add_hook.call_count >= 2 + # Check that the decorated hook methods were registered + assert BeforeToolCallEvent in agent.hooks._registered_callbacks + assert AfterModelCallEvent in agent.hooks._registered_callbacks + + +def test_steering_context_initialization(): + """Test steering context is initialized.""" + handler = TestSteeringHandler() + + assert handler.steering_context is not None + assert isinstance(handler.steering_context, SteeringContext) + + +def test_steering_context_persistence(): + """Test steering context persists across calls.""" + handler = TestSteeringHandler() + + handler.steering_context.data.set("test", "value") + assert handler.steering_context.data.get("test") == "value" + + +def test_steering_context_access(): + """Test steering context can be accessed and modified.""" + handler = TestSteeringHandler() + + handler.steering_context.data.set("key", "value") + assert handler.steering_context.data.get("key") == "value" + + +@pytest.mark.asyncio +async def test_proceed_action_flow(): + """Test complete flow with Proceed action.""" + + class ProceedHandler(SteeringHandler): + async def steer_before_tool(self, *, agent, tool_use, **kwargs): + return Proceed(reason="Test proceed") + + handler = ProceedHandler() + agent = Mock() + tool_use = {"name": "test_tool"} + event = BeforeToolCallEvent(agent=agent, selected_tool=None, tool_use=tool_use, invocation_state={}) + + await handler.provide_tool_steering_guidance(event) + + # Should not modify event for Proceed + assert not event.cancel_tool + + +@pytest.mark.asyncio +async def test_guide_action_flow(): + """Test complete flow with Guide action.""" + + class GuideHandler(SteeringHandler): + async def steer_before_tool(self, *, agent, tool_use, **kwargs): + return Guide(reason="Test guidance") + + handler = GuideHandler() + agent = Mock() + tool_use = {"name": "test_tool"} + event = BeforeToolCallEvent(agent=agent, selected_tool=None, tool_use=tool_use, invocation_state={}) + + await handler.provide_tool_steering_guidance(event) + + # Should set cancel_tool with guidance message + expected_message = "Tool call cancelled. Test guidance You MUST follow this guidance immediately." + assert event.cancel_tool == expected_message + + +@pytest.mark.asyncio +async def test_interrupt_action_approved_flow(): + """Test complete flow with Interrupt action when approved.""" + + class InterruptHandler(SteeringHandler): + async def steer_before_tool(self, *, agent, tool_use, **kwargs): + return Interrupt(reason="Need approval") + + handler = InterruptHandler() + tool_use = {"name": "test_tool"} + event = Mock() + event.tool_use = tool_use + event.interrupt = Mock(return_value=True) # Approved + + await handler.provide_tool_steering_guidance(event) + + event.interrupt.assert_called_once() + + +@pytest.mark.asyncio +async def test_interrupt_action_denied_flow(): + """Test complete flow with Interrupt action when denied.""" + + class InterruptHandler(SteeringHandler): + async def steer_before_tool(self, *, agent, tool_use, **kwargs): + return Interrupt(reason="Need approval") + + handler = InterruptHandler() + tool_use = {"name": "test_tool"} + event = Mock() + event.tool_use = tool_use + event.interrupt = Mock(return_value=False) # Denied + + await handler.provide_tool_steering_guidance(event) + + event.interrupt.assert_called_once() + assert event.cancel_tool.startswith("Manual approval denied:") + + +@pytest.mark.asyncio +async def test_unknown_action_flow(): + """Test complete flow with unknown action type raises error.""" + + class UnknownActionHandler(SteeringHandler): + async def steer_before_tool(self, *, agent, tool_use, **kwargs): + return Mock() # Not a valid SteeringAction + + handler = UnknownActionHandler() + agent = Mock() + tool_use = {"name": "test_tool"} + event = BeforeToolCallEvent(agent=agent, selected_tool=None, tool_use=tool_use, invocation_state={}) + + with pytest.raises(ValueError, match="Unknown steering action type"): + await handler.provide_tool_steering_guidance(event) + + +def test_init_agent_override(): + """Test that init_agent can be overridden.""" + + class CustomHandler(SteeringHandler): + async def steer_before_tool(self, *, agent, tool_use, **kwargs): + return Proceed(reason="Custom") + + def init_agent(self, agent): + # Custom hook registration - don't call parent + pass + + handler = CustomHandler() + agent = Mock() + + handler.init_agent(agent) + + # Should not register any hooks + assert agent.add_hook.call_count == 0 + + +# Integration tests with context providers +class MockContextCallback(SteeringContextCallback[BeforeToolCallEvent]): + """Mock context callback for testing.""" + + def __call__(self, event: BeforeToolCallEvent, steering_context, **kwargs) -> None: + steering_context.data.set("test_key", "test_value") + + +class MockContextProvider(SteeringContextProvider): + """Mock context provider for testing.""" + + def __init__(self, callbacks): + self.callbacks = callbacks + + def context_providers(self): + return self.callbacks + + +class TestSteeringHandlerWithProvider(SteeringHandler): + """Test implementation with context callbacks.""" + + def __init__(self, context_callbacks=None): + providers = [MockContextProvider(context_callbacks)] if context_callbacks else None + super().__init__(context_providers=providers) + + async def steer_before_tool(self, *, agent, tool_use, **kwargs): + return Proceed(reason="Test proceed") + + +def test_handler_registers_context_provider_hooks(): + """Test that handler registers hooks from context callbacks via registry.""" + from strands.plugins.registry import _PluginRegistry + + mock_callback = MockContextCallback() + handler = TestSteeringHandlerWithProvider(context_callbacks=[mock_callback]) + agent = Mock() + agent.hooks = HookRegistry() + agent.tool_registry = Mock() + agent.add_hook = Mock(side_effect=lambda callback, event_type=None: agent.hooks.add_callback(event_type, callback)) + + # Use the registry to properly initialize the plugin + registry = _PluginRegistry(agent) + registry.add_and_init(handler) + + # Should register hooks for context callback (via init_agent) and steering guidance (via @hook) + # init_agent registers context callbacks manually, @hook decorated methods are auto-registered + assert agent.add_hook.call_count >= 2 + + # Check that BeforeToolCallEvent was registered (both context callback and steering guidance) + assert BeforeToolCallEvent in agent.hooks._registered_callbacks + + +@pytest.mark.asyncio +async def test_context_callbacks_receive_steering_context(): + """Test that context callbacks receive the handler's steering context.""" + from strands.plugins.registry import _PluginRegistry + + mock_callback = MockContextCallback() + handler = TestSteeringHandlerWithProvider(context_callbacks=[mock_callback]) + agent = Mock() + agent.hooks = HookRegistry() + agent.tool_registry = Mock() + agent.add_hook = Mock(side_effect=lambda callback, event_type=None: agent.hooks.add_callback(event_type, callback)) + + # Use the registry to properly initialize the plugin + registry = _PluginRegistry(agent) + registry.add_and_init(handler) + + # Get the registered callbacks for BeforeToolCallEvent + callbacks = agent.hooks._registered_callbacks.get(BeforeToolCallEvent, []) + assert len(callbacks) > 0 + + # The context callback is wrapped in a lambda, so we just call all callbacks + # and check if the steering context was updated + event = Mock(spec=BeforeToolCallEvent) + event.tool_use = {"name": "test_tool", "input": {}} + + # Call all callbacks, handling both sync and async + for cb in callbacks: + try: + result = await cb(event) + if inspect.iscoroutine(result): + await result + except Exception: + pass # Some callbacks might be async or have other requirements + + # Verify the steering context was updated by at least one callback + assert handler.steering_context.data.get("test_key") == "test_value" + + +def test_multiple_context_callbacks_registered(): + """Test that multiple context callbacks are registered via registry.""" + from strands.plugins.registry import _PluginRegistry + + callback1 = MockContextCallback() + callback2 = MockContextCallback() + + handler = TestSteeringHandlerWithProvider(context_callbacks=[callback1, callback2]) + agent = Mock() + agent.hooks = HookRegistry() + agent.tool_registry = Mock() + agent.add_hook = Mock(side_effect=lambda callback, event_type=None: agent.hooks.add_callback(event_type, callback)) + + # Use the registry to properly initialize the plugin + registry = _PluginRegistry(agent) + registry.add_and_init(handler) + + # Should register: + # - 2 callbacks for context providers (via init_agent manual registration) + # - 2 for steering guidance (via @hook decorator auto-registration) + expected_calls = 2 + 2 # 2 callbacks + 2 for steering guidance (tool and model) + assert agent.add_hook.call_count >= expected_calls + + +def test_handler_initialization_with_callbacks(): + """Test handler initialization stores context callbacks.""" + callback1 = MockContextCallback() + callback2 = MockContextCallback() + + handler = TestSteeringHandlerWithProvider(context_callbacks=[callback1, callback2]) + + # Should have stored the callbacks + assert len(handler._context_callbacks) == 2 + assert callback1 in handler._context_callbacks + assert callback2 in handler._context_callbacks + + +# Model steering tests +@pytest.mark.asyncio +async def test_model_steering_proceed_action_flow(): + """Test model steering with Proceed action.""" + + class ModelProceedHandler(SteeringHandler): + async def steer_after_model(self, *, agent, message, stop_reason, **kwargs): + return Proceed(reason="Model response accepted") + + handler = ModelProceedHandler() + agent = Mock() + stop_response = Mock() + stop_response.message = {"role": "assistant", "content": [{"text": "Hello"}]} + stop_response.stop_reason = "end_turn" + event = Mock(spec=AfterModelCallEvent) + event.agent = agent + event.stop_response = stop_response + event.retry = False + + await handler.provide_model_steering_guidance(event) + + # Should not set retry for Proceed + assert event.retry is False + + +@pytest.mark.asyncio +async def test_model_steering_guide_action_flow(): + """Test model steering with Guide action sets retry and adds message.""" + + class ModelGuideHandler(SteeringHandler): + async def steer_after_model(self, *, agent, message, stop_reason, **kwargs): + return Guide(reason="Please improve your response") + + handler = ModelGuideHandler() + agent = AsyncMock() + stop_response = Mock() + stop_response.message = {"role": "assistant", "content": [{"text": "Hello"}]} + stop_response.stop_reason = "end_turn" + event = Mock(spec=AfterModelCallEvent) + event.agent = agent + event.stop_response = stop_response + event.retry = False + + await handler.provide_model_steering_guidance(event) + + # Should set retry flag + assert event.retry is True + # Should add guidance message to conversation + agent._append_messages.assert_called_once() + call_args = agent._append_messages.call_args[0][0] + assert call_args["role"] == "user" + assert "Please improve your response" in call_args["content"][0]["text"] + + +@pytest.mark.asyncio +async def test_model_steering_skips_when_no_stop_response(): + """Test model steering skips when stop_response is None.""" + + class ModelProceedHandler(SteeringHandler): + def __init__(self): + super().__init__() + self.steer_called = False + + async def steer_after_model(self, *, agent, message, stop_reason, **kwargs): + self.steer_called = True + return Proceed(reason="Should not be called") + + handler = ModelProceedHandler() + event = Mock(spec=AfterModelCallEvent) + event.stop_response = None + + await handler.provide_model_steering_guidance(event) + + # steer_after_model should not have been called + assert handler.steer_called is False + + +@pytest.mark.asyncio +async def test_model_steering_unknown_action_raises_error(): + """Test model steering with unknown action type raises error.""" + + class UnknownModelActionHandler(SteeringHandler): + async def steer_after_model(self, *, agent, message, stop_reason, **kwargs): + return Mock() # Not a valid ModelSteeringAction + + handler = UnknownModelActionHandler() + agent = Mock() + stop_response = Mock() + stop_response.message = {"role": "assistant", "content": [{"text": "Hello"}]} + stop_response.stop_reason = "end_turn" + event = Mock(spec=AfterModelCallEvent) + event.agent = agent + event.stop_response = stop_response + + with pytest.raises(ValueError, match="Unknown steering action type for model response"): + await handler.provide_model_steering_guidance(event) + + +@pytest.mark.asyncio +async def test_model_steering_interrupt_raises_error(): + """Test model steering with Interrupt action raises error (not supported for model steering).""" + + class InterruptModelHandler(SteeringHandler): + async def steer_after_model(self, *, agent, message, stop_reason, **kwargs): + return Interrupt(reason="Should not be allowed") + + handler = InterruptModelHandler() + agent = Mock() + stop_response = Mock() + stop_response.message = {"role": "assistant", "content": [{"text": "Hello"}]} + stop_response.stop_reason = "end_turn" + event = Mock(spec=AfterModelCallEvent) + event.agent = agent + event.stop_response = stop_response + + with pytest.raises(ValueError, match="Unknown steering action type for model response"): + await handler.provide_model_steering_guidance(event) + + +@pytest.mark.asyncio +async def test_model_steering_exception_handling(): + """Test model steering handles exceptions gracefully.""" + + class ExceptionModelHandler(SteeringHandler): + async def steer_after_model(self, *, agent, message, stop_reason, **kwargs): + raise RuntimeError("Test exception") + + handler = ExceptionModelHandler() + agent = Mock() + stop_response = Mock() + stop_response.message = {"role": "assistant", "content": [{"text": "Hello"}]} + stop_response.stop_reason = "end_turn" + event = Mock(spec=AfterModelCallEvent) + event.agent = agent + event.stop_response = stop_response + event.retry = False + + # Should not raise, just return early + await handler.provide_model_steering_guidance(event) + + # retry should not be set since exception occurred + assert event.retry is False + + +@pytest.mark.asyncio +async def test_tool_steering_exception_handling(): + """Test tool steering handles exceptions gracefully.""" + + class ExceptionToolHandler(SteeringHandler): + async def steer_before_tool(self, *, agent, tool_use, **kwargs): + raise RuntimeError("Test exception") + + handler = ExceptionToolHandler() + agent = Mock() + tool_use = {"name": "test_tool"} + event = BeforeToolCallEvent(agent=agent, selected_tool=None, tool_use=tool_use, invocation_state={}) + + # Should not raise, just return early + await handler.provide_tool_steering_guidance(event) + + # cancel_tool should not be set since exception occurred + assert not event.cancel_tool + + +# Default implementation tests +@pytest.mark.asyncio +async def test_default_steer_before_tool_returns_proceed(): + """Test default steer_before_tool returns Proceed.""" + handler = TestSteeringHandler() + agent = Mock() + tool_use = {"name": "test_tool"} + + # Call the parent's default implementation + result = await SteeringHandler.steer_before_tool(handler, agent=agent, tool_use=tool_use) + + assert isinstance(result, Proceed) + assert "Default implementation" in result.reason + + +@pytest.mark.asyncio +async def test_default_steer_after_model_returns_proceed(): + """Test default steer_after_model returns Proceed.""" + handler = TestSteeringHandler() + agent = Mock() + message = {"role": "assistant", "content": [{"text": "Hello"}]} + stop_reason = "end_turn" + + # Call the parent's default implementation + result = await SteeringHandler.steer_after_model(handler, agent=agent, message=message, stop_reason=stop_reason) + + assert isinstance(result, Proceed) + assert "Default implementation" in result.reason + + +def test_init_agent_registers_model_steering(): + """Test that model steering hook is registered via plugin registry.""" + from strands.plugins.registry import _PluginRegistry + + handler = TestSteeringHandler() + agent = Mock() + agent.hooks = HookRegistry() + agent.tool_registry = Mock() + agent.add_hook = Mock(side_effect=lambda callback, event_type=None: agent.hooks.add_callback(event_type, callback)) + + # Use the registry to properly initialize the plugin + registry = _PluginRegistry(agent) + registry.add_and_init(handler) + + # Verify model steering hook was registered via @hook decorator + assert AfterModelCallEvent in agent.hooks._registered_callbacks + callbacks = agent.hooks._registered_callbacks[AfterModelCallEvent] + assert len(callbacks) == 1 diff --git a/tests/strands/vended_plugins/steering/handlers/__init__.py b/tests/strands/vended_plugins/steering/handlers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/strands/vended_plugins/steering/handlers/llm/__init__.py b/tests/strands/vended_plugins/steering/handlers/llm/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/strands/experimental/steering/handlers/llm/test_llm_handler.py b/tests/strands/vended_plugins/steering/handlers/llm/test_llm_handler.py similarity index 89% rename from tests/strands/experimental/steering/handlers/llm/test_llm_handler.py rename to tests/strands/vended_plugins/steering/handlers/llm/test_llm_handler.py index f780088b5..776124d25 100644 --- a/tests/strands/experimental/steering/handlers/llm/test_llm_handler.py +++ b/tests/strands/vended_plugins/steering/handlers/llm/test_llm_handler.py @@ -4,9 +4,9 @@ import pytest -from strands.experimental.steering.core.action import Guide, Interrupt, Proceed -from strands.experimental.steering.handlers.llm.llm_handler import LLMSteeringHandler, _LLMSteering -from strands.experimental.steering.handlers.llm.mappers import DefaultPromptMapper +from strands.vended_plugins.steering.core.action import Guide, Interrupt, Proceed +from strands.vended_plugins.steering.handlers.llm.llm_handler import LLMSteeringHandler, _LLMSteering +from strands.vended_plugins.steering.handlers.llm.mappers import DefaultPromptMapper def test_llm_steering_handler_initialization(): @@ -59,7 +59,7 @@ async def test_steer_proceed_decision(mock_agent_class): agent = Mock() tool_use = {"name": "test_tool", "input": {"param": "value"}} - result = await handler.steer(agent, tool_use) + result = await handler.steer_before_tool(agent=agent, tool_use=tool_use) assert isinstance(result, Proceed) assert result.reason == "Tool call is safe" @@ -82,7 +82,7 @@ async def test_steer_guide_decision(mock_agent_class): agent = Mock() tool_use = {"name": "test_tool", "input": {"param": "value"}} - result = await handler.steer(agent, tool_use) + result = await handler.steer_before_tool(agent=agent, tool_use=tool_use) assert isinstance(result, Guide) assert result.reason == "Consider security implications" @@ -105,7 +105,7 @@ async def test_steer_interrupt_decision(mock_agent_class): agent = Mock() tool_use = {"name": "test_tool", "input": {"param": "value"}} - result = await handler.steer(agent, tool_use) + result = await handler.steer_before_tool(agent=agent, tool_use=tool_use) assert isinstance(result, Interrupt) assert result.reason == "Human approval required" @@ -133,7 +133,7 @@ async def test_steer_unknown_decision(mock_agent_class): agent = Mock() tool_use = {"name": "test_tool", "input": {"param": "value"}} - result = await handler.steer(agent, tool_use) + result = await handler.steer_before_tool(agent=agent, tool_use=tool_use) assert isinstance(result, Proceed) assert "Unknown LLM decision, defaulting to proceed" in result.reason @@ -158,7 +158,7 @@ async def test_steer_uses_custom_model(mock_agent_class): agent.model = Mock() tool_use = {"name": "test_tool", "input": {"param": "value"}} - await handler.steer(agent, tool_use) + await handler.steer_before_tool(agent=agent, tool_use=tool_use) mock_agent_class.assert_called_once_with(system_prompt=system_prompt, model=custom_model, callback_handler=None) @@ -181,7 +181,7 @@ async def test_steer_uses_agent_model_when_no_custom_model(mock_agent_class): agent.model = Mock() tool_use = {"name": "test_tool", "input": {"param": "value"}} - await handler.steer(agent, tool_use) + await handler.steer_before_tool(agent=agent, tool_use=tool_use) mock_agent_class.assert_called_once_with(system_prompt=system_prompt, model=agent.model, callback_handler=None) diff --git a/tests/strands/experimental/steering/handlers/llm/test_mappers.py b/tests/strands/vended_plugins/steering/handlers/llm/test_mappers.py similarity index 95% rename from tests/strands/experimental/steering/handlers/llm/test_mappers.py rename to tests/strands/vended_plugins/steering/handlers/llm/test_mappers.py index 511671d3a..3f87f030a 100644 --- a/tests/strands/experimental/steering/handlers/llm/test_mappers.py +++ b/tests/strands/vended_plugins/steering/handlers/llm/test_mappers.py @@ -1,7 +1,7 @@ """Unit tests for LLM steering prompt mappers.""" -from strands.experimental.steering.core.context import SteeringContext -from strands.experimental.steering.handlers.llm.mappers import _STEERING_PROMPT_TEMPLATE, DefaultPromptMapper +from strands.vended_plugins.steering.core.context import SteeringContext +from strands.vended_plugins.steering.handlers.llm.mappers import _STEERING_PROMPT_TEMPLATE, DefaultPromptMapper def test_create_steering_prompt_with_tool_use(): diff --git a/tests_integ/a2a/__init__.py b/tests_integ/a2a/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests_integ/a2a/a2a_server.py b/tests_integ/a2a/a2a_server.py new file mode 100644 index 000000000..047edc3ba --- /dev/null +++ b/tests_integ/a2a/a2a_server.py @@ -0,0 +1,15 @@ +from strands import Agent +from strands.multiagent.a2a import A2AServer + +# Create an agent and serve it over A2A +agent = Agent( + name="Test agent", + description="Test description here", + callback_handler=None, +) +a2a_server = A2AServer( + agent=agent, + host="localhost", + port=9000, +) +a2a_server.serve() diff --git a/tests_integ/a2a/test_multiagent_a2a.py b/tests_integ/a2a/test_multiagent_a2a.py new file mode 100644 index 000000000..8b0186bc5 --- /dev/null +++ b/tests_integ/a2a/test_multiagent_a2a.py @@ -0,0 +1,104 @@ +import os +import subprocess +import time + +import httpx +import pytest +from a2a.client import ClientConfig, ClientFactory + +from strands import Agent +from strands.agent.a2a_agent import A2AAgent +from strands.multiagent.graph import GraphBuilder, Status + + +@pytest.fixture +def a2a_server(): + """Start A2A server as subprocess fixture.""" + server_path = os.path.join(os.path.dirname(__file__), "a2a_server.py") + process = subprocess.Popen(["python", server_path]) + time.sleep(5) # Wait for A2A server to start + + yield "http://localhost:9000" + + # Cleanup + process.terminate() + try: + process.wait(timeout=5) + except subprocess.TimeoutExpired: + process.kill() + + +def test_a2a_agent_invoke_sync(a2a_server): + """Test synchronous invocation via __call__.""" + a2a_agent = A2AAgent(endpoint=a2a_server) + result = a2a_agent("Hello there!") + assert result.stop_reason == "end_turn" + + +@pytest.mark.asyncio +async def test_a2a_agent_invoke_async(a2a_server): + """Test async invocation.""" + a2a_agent = A2AAgent(endpoint=a2a_server) + result = await a2a_agent.invoke_async("Hello there!") + assert result.stop_reason == "end_turn" + + +@pytest.mark.asyncio +async def test_a2a_agent_stream_async(a2a_server): + """Test async streaming.""" + a2a_agent = A2AAgent(endpoint=a2a_server) + + events = [] + async for event in a2a_agent.stream_async("Hello there!"): + events.append(event) + + # Should have at least one A2A stream event and one final result event + assert len(events) >= 2 + assert events[0]["type"] == "a2a_stream" + assert "result" in events[-1] + assert events[-1]["result"].stop_reason == "end_turn" + + +@pytest.mark.asyncio +async def test_a2a_agent_with_non_streaming_client_config(a2a_server): + """Test with streaming=False client configuration (non-default).""" + httpx_client = httpx.AsyncClient(timeout=300) + config = ClientConfig(httpx_client=httpx_client, streaming=False) + factory = ClientFactory(config) + + try: + a2a_agent = A2AAgent(endpoint=a2a_server, a2a_client_factory=factory) + result = await a2a_agent.invoke_async("Hello there!") + assert result.stop_reason == "end_turn" + finally: + await httpx_client.aclose() + + +@pytest.mark.asyncio +async def test_graph_with_a2a_agent_and_regular_agent(a2a_server): + """Test Graph execution with both A2AAgent and regular Agent nodes.""" + # Create A2AAgent pointing to the test server + a2a_agent = A2AAgent(endpoint=a2a_server, name="remote_agent") + + # Create a regular Agent + regular_agent = Agent( + model="us.amazon.nova-lite-v1:0", + system_prompt="You are a summarizer. Summarize the input briefly.", + name="summarizer", + ) + + # Build graph with both agent types + builder = GraphBuilder() + builder.add_node(a2a_agent, "remote") + builder.add_node(regular_agent, "summarizer") + builder.add_edge("remote", "summarizer") + builder.set_entry_point("remote") + graph = builder.build() + + # Execute the graph + result = await graph.invoke_async("Say hello in one sentence") + + assert result.status == Status.COMPLETED + assert result.completed_nodes == 2 + assert "remote" in result.results + assert "summarizer" in result.results diff --git a/tests_integ/bidi/test_bidirectional_agent.py b/tests_integ/bidi/test_bidirectional_agent.py index 61cf78723..243db46ac 100644 --- a/tests_integ/bidi/test_bidirectional_agent.py +++ b/tests_integ/bidi/test_bidirectional_agent.py @@ -55,11 +55,18 @@ def calculator(operation: str, x: float, y: float) -> float: PROVIDER_CONFIGS = { "nova_sonic": { "model_class": BidiNovaSonicModel, - "model_kwargs": {"region": "us-east-1"}, + "model_kwargs": {"region": "us-east-1"}, # Uses v2 by default "silence_duration": 2.5, # Nova Sonic needs 2+ seconds of silence "env_vars": ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"], "skip_reason": "AWS credentials not available", }, + "nova_sonic_v1": { + "model_class": BidiNovaSonicModel, + "model_kwargs": {"model_id": "amazon.nova-sonic-v1:0", "region": "us-east-1"}, + "silence_duration": 2.5, # Nova Sonic v1 needs 2+ seconds of silence + "env_vars": ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"], + "skip_reason": "AWS credentials not available", + }, "openai": { "model_class": BidiOpenAIRealtimeModel, "model_kwargs": { diff --git a/tests_integ/bidi/tools/test_direct.py b/tests_integ/bidi/tools/test_direct.py index 30320e786..1694d64b6 100644 --- a/tests_integ/bidi/tools/test_direct.py +++ b/tests_integ/bidi/tools/test_direct.py @@ -28,15 +28,14 @@ def test_bidi_agent_tool_direct_call(agent): "toolUseId": unittest.mock.ANY, } assert tru_result == exp_result - + tru_messages = agent.messages exp_messages = [ { "content": [ { "text": ( - "agent.tool.weather_tool direct tool call.\n" - 'Input parameters: {"city_name": "new york"}\n' + 'agent.tool.weather_tool direct tool call.\nInput parameters: {"city_name": "new york"}\n' ), }, ], diff --git a/tests_integ/conftest.py b/tests_integ/conftest.py index 26453e1f7..c696fb65d 100644 --- a/tests_integ/conftest.py +++ b/tests_integ/conftest.py @@ -1,13 +1,129 @@ +import functools import json import logging import os +from collections.abc import Callable, Sequence import boto3 import pytest +from tenacity import RetryCallState, RetryError, Retrying, stop_after_attempt, wait_exponential logger = logging.getLogger(__name__) +# Type alias for retry conditions +RetryCondition = type[BaseException] | Callable[[BaseException], bool] | str + + +def _should_retry_exception(exc: BaseException, conditions: Sequence[RetryCondition]) -> bool: + """Check if exception matches any of the given retry conditions. + + Args: + exc: The exception to check + conditions: Sequence of conditions, each can be: + - Exception type: retry if isinstance(exc, condition) + - Callable: retry if condition(exc) returns True + - str: retry if string is in str(exc) + """ + for condition in conditions: + if isinstance(condition, type) and issubclass(condition, BaseException): + if isinstance(exc, condition): + return True + elif callable(condition): + if condition(exc): + return True + elif isinstance(condition, str): + if condition in str(exc): + return True + return False + + +_RETRY_ON_ANY: Sequence[RetryCondition] = (lambda _: True,) + + +def retry_on_flaky( + reason: str, + *, + max_attempts: int = 3, + wait_multiplier: float = 1, + wait_max: float = 10, + retry_on: Sequence[RetryCondition] = _RETRY_ON_ANY, +) -> Callable: + """Decorator to retry flaky integration tests that fail due to external factors. + + WHEN TO USE: + - External service instability (API rate limits, transient network errors) + - Non-deterministic LLM responses that occasionally fail assertions + - Resource contention in shared test environments + - Known intermittent issues with third-party dependencies + + WHEN NOT TO USE: + - Actual bugs in the code under test (fix the bug instead) + - Deterministic failures (these indicate real problems) + - Unit tests (flakiness in unit tests usually indicates a design issue) + - To mask consistently failing tests (investigate root cause first) + + Prefer using specific retry_on conditions over retrying on any exception + to avoid masking real bugs. + + Args: + reason: Required explanation of why this test is flaky and needs retries. + This should describe the source of non-determinism (e.g., "LLM responses + may vary" or "External API has intermittent rate limits"). + max_attempts: Maximum number of retry attempts (default: 3) + wait_multiplier: Multiplier for exponential backoff in seconds (default: 1) + wait_max: Maximum wait time between retries in seconds (default: 10) + retry_on: Conditions for when to retry. Defaults to retrying on any exception. + Each condition can be: + - Exception type: e.g., ValueError, TimeoutError + - Callable: e.g., lambda e: "timeout" in str(e).lower() + - str: substring to match in exception message + + Usage: + # Retry on any failure + @retry_on_flaky("LLM responses are non-deterministic") + def test_something(): + ... + + # Retry only on specific exception types + @retry_on_flaky("Network calls may fail transiently", retry_on=[TimeoutError, ConnectionError]) + def test_network_call(): + ... + + # Retry on string patterns in exception message + @retry_on_flaky("Service has intermittent availability", retry_on=["Service unavailable", "Status 503"]) + def test_service_call(): + ... + """ + + def decorator(func: Callable) -> Callable: + @functools.wraps(func) + def wrapper(*args, **kwargs): + def should_retry(retry_state: RetryCallState) -> bool: + if retry_state.outcome is None or not retry_state.outcome.failed: + return False + exc = retry_state.outcome.exception() + if exc is None: + return False + return _should_retry_exception(exc, retry_on) + + try: + for attempt in Retrying( + stop=stop_after_attempt(max_attempts), + wait=wait_exponential(multiplier=wait_multiplier, max=wait_max), + retry=should_retry, + reraise=True, + ): + with attempt: + return func(*args, **kwargs) + except RetryError: + raise + + return wrapper + + return decorator + + def pytest_sessionstart(session): _load_api_keys_from_secrets_manager() @@ -17,14 +133,21 @@ def pytest_sessionstart(session): @pytest.fixture def yellow_img(pytestconfig): - path = pytestconfig.rootdir / "tests_integ/yellow.png" + path = pytestconfig.rootdir / "tests_integ/resources/yellow.png" with open(path, "rb") as fp: return fp.read() @pytest.fixture def letter_pdf(pytestconfig): - path = pytestconfig.rootdir / "tests_integ/letter.pdf" + path = pytestconfig.rootdir / "tests_integ/resources/letter.pdf" + with open(path, "rb") as fp: + return fp.read() + + +@pytest.fixture +def blue_video(pytestconfig): + path = pytestconfig.rootdir / "tests_integ/resources/blue.mp4" with open(path, "rb") as fp: return fp.read() @@ -79,8 +202,8 @@ def _load_api_keys_from_secrets_manager(): required_providers = { "ANTHROPIC_API_KEY", - "COHERE_API_KEY", - "MISTRAL_API_KEY", + "GOOGLE_API_KEY", + # "MISTRAL_API_KEY", # will add back once we get a card on file for this. "OPENAI_API_KEY", "WRITER_API_KEY", } diff --git a/tests_integ/hooks/multiagent/test_cancel.py b/tests_integ/hooks/multiagent/test_cancel.py index 9267330b7..ae3008861 100644 --- a/tests_integ/hooks/multiagent/test_cancel.py +++ b/tests_integ/hooks/multiagent/test_cancel.py @@ -1,8 +1,7 @@ import pytest from strands import Agent -from strands.experimental.hooks.multiagent import BeforeNodeCallEvent -from strands.hooks import HookProvider +from strands.hooks import BeforeNodeCallEvent, HookProvider from strands.multiagent import GraphBuilder, Swarm from strands.multiagent.base import Status from strands.types._events import MultiAgentNodeCancelEvent diff --git a/tests_integ/hooks/multiagent/test_events.py b/tests_integ/hooks/multiagent/test_events.py index e8039444f..3a10b74c1 100644 --- a/tests_integ/hooks/multiagent/test_events.py +++ b/tests_integ/hooks/multiagent/test_events.py @@ -1,14 +1,14 @@ import pytest from strands import Agent -from strands.experimental.hooks.multiagent import ( +from strands.hooks import ( AfterMultiAgentInvocationEvent, AfterNodeCallEvent, BeforeMultiAgentInvocationEvent, BeforeNodeCallEvent, + HookProvider, MultiAgentInitializedEvent, ) -from strands.hooks import HookProvider from strands.multiagent import GraphBuilder, Swarm diff --git a/tests_integ/interrupts/multiagent/test_agent.py b/tests_integ/interrupts/multiagent/test_agent.py deleted file mode 100644 index 36fcfef27..000000000 --- a/tests_integ/interrupts/multiagent/test_agent.py +++ /dev/null @@ -1,67 +0,0 @@ -import json -from unittest.mock import ANY - -import pytest - -from strands import Agent, tool -from strands.interrupt import Interrupt -from strands.multiagent import Swarm -from strands.multiagent.base import Status -from strands.types.tools import ToolContext - - -@pytest.fixture -def weather_tool(): - @tool(name="weather_tool", context=True) - def func(tool_context: ToolContext) -> str: - response = tool_context.interrupt("test_interrupt", reason="need weather") - return response - - return func - - -@pytest.fixture -def swarm(weather_tool): - weather_agent = Agent(name="weather", tools=[weather_tool]) - - return Swarm([weather_agent]) - - -def test_swarm_interrupt_agent(swarm): - multiagent_result = swarm("What is the weather?") - - tru_status = multiagent_result.status - exp_status = Status.INTERRUPTED - assert tru_status == exp_status - - tru_interrupts = multiagent_result.interrupts - exp_interrupts = [ - Interrupt( - id=ANY, - name="test_interrupt", - reason="need weather", - ), - ] - assert tru_interrupts == exp_interrupts - - interrupt = multiagent_result.interrupts[0] - - responses = [ - { - "interruptResponse": { - "interruptId": interrupt.id, - "response": "sunny", - }, - }, - ] - multiagent_result = swarm(responses) - - tru_status = multiagent_result.status - exp_status = Status.COMPLETED - assert tru_status == exp_status - - assert len(multiagent_result.results) == 1 - weather_result = multiagent_result.results["weather"] - - weather_message = json.dumps(weather_result.result.message).lower() - assert "sunny" in weather_message diff --git a/tests_integ/interrupts/multiagent/test_hook.py b/tests_integ/interrupts/multiagent/test_hook.py index be7682082..53305b4e8 100644 --- a/tests_integ/interrupts/multiagent/test_hook.py +++ b/tests_integ/interrupts/multiagent/test_hook.py @@ -4,10 +4,9 @@ import pytest from strands import Agent, tool -from strands.experimental.hooks.multiagent import BeforeNodeCallEvent -from strands.hooks import HookProvider +from strands.hooks import BeforeNodeCallEvent, HookProvider from strands.interrupt import Interrupt -from strands.multiagent import Swarm +from strands.multiagent import GraphBuilder, Swarm from strands.multiagent.base import Status @@ -18,16 +17,34 @@ def register_hooks(self, registry): registry.add_callback(BeforeNodeCallEvent, self.interrupt) def interrupt(self, event): - if event.node_id == "info": + if event.node_id == "info" or event.node_id == "time": return - response = event.interrupt("test_interrupt", reason="need approval") + response = event.interrupt(f"{event.node_id}_interrupt", reason="need approval") if response != "APPROVE": event.cancel_node = "node rejected" return Hook() +@pytest.fixture +def day_tool(): + @tool(name="day_tool") + def func(): + return "monday" + + return func + + +@pytest.fixture +def time_tool(): + @tool(name="time_tool") + def func(): + return "12:01" + + return func + + @pytest.fixture def weather_tool(): @tool(name="weather_tool") @@ -38,13 +55,49 @@ def func(): @pytest.fixture -def swarm(interrupt_hook, weather_tool): - info_agent = Agent(name="info") - weather_agent = Agent(name="weather", tools=[weather_tool]) +def info_agent(): + return Agent(name="info") + +@pytest.fixture +def day_agent(day_tool): + return Agent(name="day", tools=[day_tool]) + + +@pytest.fixture +def time_agent(time_tool): + return Agent(name="time", tools=[time_tool]) + + +@pytest.fixture +def weather_agent(weather_tool): + return Agent(name="weather", tools=[weather_tool]) + + +@pytest.fixture +def swarm(interrupt_hook, info_agent, weather_agent): return Swarm([info_agent, weather_agent], hooks=[interrupt_hook]) +@pytest.fixture +def graph(interrupt_hook, info_agent, day_agent, time_agent, weather_agent): + builder = GraphBuilder() + + builder.add_node(info_agent, "info") + builder.add_node(day_agent, "day") + builder.add_node(time_agent, "time") + builder.add_node(weather_agent, "weather") + + builder.add_edge("info", "day") + builder.add_edge("info", "time") + builder.add_edge("info", "weather") + + builder.set_entry_point("info") + builder.set_hook_providers([interrupt_hook]) + + return builder.build() + + def test_swarm_interrupt(swarm): multiagent_result = swarm("What is the weather?") @@ -56,7 +109,7 @@ def test_swarm_interrupt(swarm): exp_interrupts = [ Interrupt( id=ANY, - name="test_interrupt", + name="weather_interrupt", reason="need approval", ), ] @@ -97,7 +150,7 @@ async def test_swarm_interrupt_reject(swarm): exp_interrupts = [ Interrupt( id=ANY, - name="test_interrupt", + name="weather_interrupt", reason="need approval", ), ] @@ -131,3 +184,120 @@ async def test_swarm_interrupt_reject(swarm): tru_node_id = multiagent_result.node_history[0].node_id exp_node_id = "info" assert tru_node_id == exp_node_id + + +def test_graph_interrupt(graph): + multiagent_result = graph("What is the day, time, and weather?") + + tru_result_status = multiagent_result.status + exp_result_status = Status.INTERRUPTED + assert tru_result_status == exp_result_status + + tru_state_status = graph.state.status + exp_state_status = Status.INTERRUPTED + assert tru_state_status == exp_state_status + + tru_node_ids = sorted([node.node_id for node in graph.state.interrupted_nodes]) + exp_node_ids = ["day", "weather"] + assert tru_node_ids == exp_node_ids + + tru_interrupts = sorted(multiagent_result.interrupts, key=lambda interrupt: interrupt.name) + exp_interrupts = [ + Interrupt( + id=ANY, + name="day_interrupt", + reason="need approval", + ), + Interrupt( + id=ANY, + name="weather_interrupt", + reason="need approval", + ), + ] + assert tru_interrupts == exp_interrupts + + responses = [ + { + "interruptResponse": { + "interruptId": interrupt.id, + "response": "APPROVE", + }, + } + for interrupt in multiagent_result.interrupts + ] + multiagent_result = graph(responses) + + tru_result_status = multiagent_result.status + exp_result_status = Status.COMPLETED + assert tru_result_status == exp_result_status + + tru_state_status = graph.state.status + exp_state_status = Status.COMPLETED + assert tru_state_status == exp_state_status + + assert len(multiagent_result.results) == 4 + + day_message = json.dumps(multiagent_result.results["day"].result.message).lower() + time_message = json.dumps(multiagent_result.results["time"].result.message).lower() + weather_message = json.dumps(multiagent_result.results["weather"].result.message).lower() + assert "monday" in day_message + assert "12:01" in time_message + assert "sunny" in weather_message + + +@pytest.mark.asyncio +async def test_graph_interrupt_reject(graph): + multiagent_result = graph("What is the day, time, and weather?") + + tru_result_status = multiagent_result.status + exp_result_status = Status.INTERRUPTED + assert tru_result_status == exp_result_status + + tru_state_status = graph.state.status + exp_state_status = Status.INTERRUPTED + assert tru_state_status == exp_state_status + + tru_interrupts = sorted(multiagent_result.interrupts, key=lambda interrupt: interrupt.name) + exp_interrupts = [ + Interrupt( + id=ANY, + name="day_interrupt", + reason="need approval", + ), + Interrupt( + id=ANY, + name="weather_interrupt", + reason="need approval", + ), + ] + assert tru_interrupts == exp_interrupts + + responses = [ + { + "interruptResponse": { + "interruptId": tru_interrupts[0].id, + "response": "APPROVE", + }, + }, + { + "interruptResponse": { + "interruptId": tru_interrupts[1].id, + "response": "REJECT", + }, + }, + ] + + try: + async for event in graph.stream_async(responses): + if event.get("type") == "multiagent_node_cancel": + tru_cancel_id = event["node_id"] + + except RuntimeError as e: + assert "node rejected" in str(e) + + exp_cancel_id = "weather" + assert tru_cancel_id == exp_cancel_id + + tru_state_status = graph.state.status + exp_state_status = Status.FAILED + assert tru_state_status == exp_state_status diff --git a/tests_integ/interrupts/multiagent/test_node.py b/tests_integ/interrupts/multiagent/test_node.py new file mode 100644 index 000000000..23e7a62bc --- /dev/null +++ b/tests_integ/interrupts/multiagent/test_node.py @@ -0,0 +1,188 @@ +import json +from unittest.mock import ANY + +import pytest + +from strands import Agent, tool +from strands.interrupt import Interrupt +from strands.multiagent import GraphBuilder, Swarm +from strands.multiagent.base import Status +from strands.types.tools import ToolContext + + +@pytest.fixture +def day_tool(): + @tool(name="day_tool", context=True) + def func(tool_context: ToolContext) -> str: + response = tool_context.interrupt("day_interrupt", reason="need day") + return response + + return func + + +@pytest.fixture +def time_tool(): + @tool(name="time_tool") + def func(): + return "12:01" + + return func + + +@pytest.fixture +def weather_tool(): + @tool(name="weather_tool", context=True) + def func(tool_context: ToolContext) -> str: + response = tool_context.interrupt("weather_interrupt", reason="need weather") + return response + + return func + + +@pytest.fixture +def info_agent(): + return Agent(name="info") + + +@pytest.fixture +def day_agent(day_tool): + return Agent(name="day", tools=[day_tool]) + + +@pytest.fixture +def time_agent(time_tool): + return Agent(name="time", tools=[time_tool]) + + +@pytest.fixture +def weather_agent(weather_tool): + return Agent(name="weather", tools=[weather_tool]) + + +@pytest.fixture +def swarm(weather_agent): + return Swarm([weather_agent]) + + +@pytest.fixture +def graph(info_agent, day_agent, time_agent, swarm): + builder = GraphBuilder() + + builder.add_node(info_agent, "info") + builder.add_node(day_agent, "day") + builder.add_node(time_agent, "time") + builder.add_node(swarm, "weather") + + builder.add_edge("info", "day") + builder.add_edge("info", "time") + builder.add_edge("info", "weather") + + builder.set_entry_point("info") + + return builder.build() + + +def test_swarm_interrupt_node(swarm): + multiagent_result = swarm("What is the weather?") + + tru_status = multiagent_result.status + exp_status = Status.INTERRUPTED + assert tru_status == exp_status + + tru_interrupts = multiagent_result.interrupts + exp_interrupts = [ + Interrupt( + id=ANY, + name="weather_interrupt", + reason="need weather", + ), + ] + assert tru_interrupts == exp_interrupts + + interrupt = multiagent_result.interrupts[0] + + responses = [ + { + "interruptResponse": { + "interruptId": interrupt.id, + "response": "sunny", + }, + }, + ] + multiagent_result = swarm(responses) + + tru_status = multiagent_result.status + exp_status = Status.COMPLETED + assert tru_status == exp_status + + assert len(multiagent_result.results) == 1 + weather_result = multiagent_result.results["weather"] + + weather_message = json.dumps(weather_result.result.message).lower() + assert "sunny" in weather_message + + +def test_graph_interrupt_node(graph): + multiagent_result = graph("What is the day, time, and weather?") + + tru_result_status = multiagent_result.status + exp_result_status = Status.INTERRUPTED + assert tru_result_status == exp_result_status + + tru_state_status = graph.state.status + exp_state_status = Status.INTERRUPTED + assert tru_state_status == exp_state_status + + tru_node_ids = sorted([node.node_id for node in graph.state.interrupted_nodes]) + exp_node_ids = ["day", "weather"] + assert tru_node_ids == exp_node_ids + + tru_interrupts = sorted(multiagent_result.interrupts, key=lambda interrupt: interrupt.name) + exp_interrupts = [ + Interrupt( + id=ANY, + name="day_interrupt", + reason="need day", + ), + Interrupt( + id=ANY, + name="weather_interrupt", + reason="need weather", + ), + ] + assert tru_interrupts == exp_interrupts + + responses = [ + { + "interruptResponse": { + "interruptId": tru_interrupts[0].id, + "response": "monday", + }, + }, + { + "interruptResponse": { + "interruptId": tru_interrupts[1].id, + "response": "sunny", + }, + }, + ] + multiagent_result = graph(responses) + + tru_result_status = multiagent_result.status + exp_result_status = Status.COMPLETED + assert tru_result_status == exp_result_status + + tru_state_status = graph.state.status + exp_state_status = Status.COMPLETED + assert tru_state_status == exp_state_status + + assert len(multiagent_result.results) == 4 + + day_message = json.dumps(multiagent_result.results["day"].result.message).lower() + time_message = json.dumps(multiagent_result.results["time"].result.message).lower() + assert "monday" in day_message + assert "12:01" in time_message + + nested_multiagent_result = multiagent_result.results["weather"].result + weather_message = json.dumps(nested_multiagent_result.results["weather"].result.message).lower() + assert "sunny" in weather_message diff --git a/tests_integ/interrupts/multiagent/test_session.py b/tests_integ/interrupts/multiagent/test_session.py index d6e8cdbf8..8a5979d63 100644 --- a/tests_integ/interrupts/multiagent/test_session.py +++ b/tests_integ/interrupts/multiagent/test_session.py @@ -5,7 +5,7 @@ from strands import Agent, tool from strands.interrupt import Interrupt -from strands.multiagent import Swarm +from strands.multiagent import GraphBuilder, Swarm from strands.multiagent.base import Status from strands.session import FileSessionManager from strands.types.tools import ToolContext @@ -21,12 +21,6 @@ def func(tool_context: ToolContext) -> str: return func -@pytest.fixture -def swarm(weather_tool): - weather_agent = Agent(name="weather", tools=[weather_tool]) - return Swarm([weather_agent]) - - def test_swarm_interrupt_session(weather_tool, tmpdir): weather_agent = Agent(name="weather", tools=[weather_tool]) summarizer_agent = Agent(name="summarizer") @@ -75,3 +69,87 @@ def test_swarm_interrupt_session(weather_tool, tmpdir): summarizer_message = json.dumps(summarizer_result.result.message).lower() assert "sunny" in summarizer_message + + +def test_graph_interrupt_session(weather_tool, tmpdir): + parent_sm = FileSessionManager(session_id="parent-session", storage_dir=tmpdir / "parent") + child_sm = FileSessionManager(session_id="child-session", storage_dir=tmpdir / "child") + + weather_agent = Agent(name="weather", tools=[weather_tool]) + summarizer_agent = Agent(name="summarizer") + + weather_builder = GraphBuilder() + weather_builder.add_node(weather_agent, "weather") + weather_builder.set_entry_point("weather") + weather_builder.set_session_manager(child_sm) + weather_graph = weather_builder.build() + + builder = GraphBuilder() + builder.add_node(weather_graph, "weather") + builder.add_node(summarizer_agent, "summarizer") + builder.add_edge("weather", "summarizer") + builder.set_session_manager(parent_sm) + graph = builder.build() + + multiagent_result = graph("Can you check the weather and then summarize the results?") + + tru_result_status = multiagent_result.status + exp_result_status = Status.INTERRUPTED + assert tru_result_status == exp_result_status + + tru_state_status = graph.state.status + exp_state_status = Status.INTERRUPTED + assert tru_state_status == exp_state_status + + tru_interrupts = multiagent_result.interrupts + exp_interrupts = [ + Interrupt( + id=ANY, + name="test_interrupt", + reason="need weather", + ), + ] + assert tru_interrupts == exp_interrupts + + interrupt = multiagent_result.interrupts[0] + + parent_sm = FileSessionManager(session_id="parent-session", storage_dir=tmpdir / "parent") + child_sm = FileSessionManager(session_id="child-session", storage_dir=tmpdir / "child") + + weather_agent = Agent(name="weather", tools=[weather_tool]) + summarizer_agent = Agent(name="summarizer") + + weather_builder = GraphBuilder() + weather_builder.add_node(weather_agent, "weather") + weather_builder.set_entry_point("weather") + weather_builder.set_session_manager(child_sm) + weather_graph = weather_builder.build() + + builder = GraphBuilder() + builder.add_node(weather_graph, "weather") + builder.add_node(summarizer_agent, "summarizer") + builder.add_edge("weather", "summarizer") + builder.set_session_manager(parent_sm) + graph = builder.build() + + responses = [ + { + "interruptResponse": { + "interruptId": interrupt.id, + "response": "sunny", + }, + }, + ] + multiagent_result = graph(responses) + + tru_result_status = multiagent_result.status + exp_result_status = Status.COMPLETED + assert tru_result_status == exp_result_status + + tru_state_status = graph.state.status + exp_state_status = Status.COMPLETED + assert tru_state_status == exp_state_status + + assert len(multiagent_result.results) == 2 + summarizer_message = json.dumps(multiagent_result.results["summarizer"].result.message).lower() + assert "sunny" in summarizer_message diff --git a/tests_integ/mcp/echo_server.py b/tests_integ/mcp/echo_server.py index e15065a4a..9c901e885 100644 --- a/tests_integ/mcp/echo_server.py +++ b/tests_integ/mcp/echo_server.py @@ -16,12 +16,16 @@ """ import base64 +import json from typing import Literal from mcp.server import FastMCP -from mcp.types import BlobResourceContents, EmbeddedResource, TextResourceContents +from mcp.server.fastmcp import Context +from mcp.types import BlobResourceContents, CallToolResult, EmbeddedResource, TextContent, TextResourceContents from pydantic import BaseModel +TEST_IMAGE_BASE64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8DwHwAFBQIAX8jx0gAAAABJRU5ErkJggg==" + class EchoResponse(BaseModel): """Response model for echo with structured content.""" @@ -45,11 +49,28 @@ def start_echo_server(): def echo(to_echo: str) -> str: return to_echo + @mcp.tool(description="Echos back the _meta received in the request", structured_output=False) + def echo_meta(ctx: Context) -> str: + meta = ctx.request_context.meta + if meta is None: + return json.dumps(None) + return json.dumps(meta.model_dump(exclude_none=True)) + # FastMCP automatically constructs structured output schema from method signature @mcp.tool(description="Echos response back with structured content", structured_output=True) def echo_with_structured_content(to_echo: str) -> EchoResponse: return EchoResponse(echoed=to_echo, message_length=len(to_echo)) + @mcp.tool(description="Echos response back with metadata") + def echo_with_metadata(to_echo: str): + """Return structured content and metadata in the tool result.""" + + return CallToolResult( + content=[TextContent(type="text", text=to_echo)], + isError=False, + _meta={"metadata": {"nested": 1}, "shallow": "val"}, + ) + @mcp.tool(description="Get current weather information for a location") def get_weather(location: Literal["New York", "London", "Tokyo"] = "New York"): """Get weather data including forecasts and alerts for the specified location""" @@ -71,15 +92,13 @@ def get_weather(location: Literal["New York", "London", "Tokyo"] = "New York"): resource=BlobResourceContents( uri="https://weather.api/data/london.json", mimeType="application/json", - blob=base64.b64encode( - '{"temperature": 18, "condition": "rainy", "humidity": 85}'.encode() - ).decode(), + blob=base64.b64encode(b'{"temperature": 18, "condition": "rainy", "humidity": 85}').decode(), ), ) ] elif location.lower() == "tokyo": # Read yellow.png file for weather icon - with open("tests_integ/yellow.png", "rb") as image_file: + with open("tests_integ/resources/yellow.png", "rb") as image_file: png_data = image_file.read() return [ EmbeddedResource( @@ -92,6 +111,22 @@ def get_weather(location: Literal["New York", "London", "Tokyo"] = "New York"): ) ] + # Resources + @mcp.resource("test://static-text") + def static_text_resource() -> str: + """A static text resource for testing""" + return "This is the content of the static text resource." + + @mcp.resource("test://static-binary") + def static_binary_resource() -> bytes: + """A static binary resource (image) for testing""" + return base64.b64decode(TEST_IMAGE_BASE64) + + @mcp.resource("test://template/{id}/data") + def template_resource(id: str) -> str: + """A resource template with parameter substitution""" + return json.dumps({"id": id, "templateTest": True, "data": f"Data for ID: {id}"}) + mcp.run(transport="stdio") diff --git a/tests_integ/mcp/task_echo_server.py b/tests_integ/mcp/task_echo_server.py new file mode 100644 index 000000000..4a8edc97d --- /dev/null +++ b/tests_integ/mcp/task_echo_server.py @@ -0,0 +1,139 @@ +"""MCP server with task-augmented tool execution support for integration testing.""" + +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from typing import Any + +import click +import mcp.types as types +from mcp.server.experimental.task_context import ServerTaskContext +from mcp.server.lowlevel import Server +from mcp.server.streamable_http_manager import StreamableHTTPSessionManager +from starlette.applications import Starlette +from starlette.routing import Mount + + +def create_task_server() -> Server: + """Create and configure the task-supporting MCP server.""" + server = Server("task-echo-server") + server.experimental.enable_tasks() + + # Workaround: MCP Python SDK's enable_tasks() doesn't properly set tasks.requests.tools.call capability + original_update_capabilities = server.experimental.update_capabilities + + def patched_update_capabilities(capabilities: types.ServerCapabilities) -> None: + original_update_capabilities(capabilities) + if capabilities.tasks and capabilities.tasks.requests and capabilities.tasks.requests.tools: + capabilities.tasks.requests.tools.call = types.TasksCallCapability() + + server.experimental.update_capabilities = patched_update_capabilities # type: ignore[method-assign] + + @server.list_tools() + async def list_tools() -> list[types.Tool]: + return [ + types.Tool( + name="task_required_echo", + description="Echo that requires task-augmented execution", + inputSchema={"type": "object", "properties": {"message": {"type": "string"}}, "required": ["message"]}, + execution=types.ToolExecution(taskSupport=types.TASK_REQUIRED), + ), + types.Tool( + name="task_optional_echo", + description="Echo that optionally supports task-augmented execution", + inputSchema={"type": "object", "properties": {"message": {"type": "string"}}, "required": ["message"]}, + execution=types.ToolExecution(taskSupport=types.TASK_OPTIONAL), + ), + types.Tool( + name="task_forbidden_echo", + description="Echo that does not support task-augmented execution", + inputSchema={"type": "object", "properties": {"message": {"type": "string"}}, "required": ["message"]}, + execution=types.ToolExecution(taskSupport=types.TASK_FORBIDDEN), + ), + types.Tool( + name="echo", + description="Simple echo without task support setting", + inputSchema={"type": "object", "properties": {"message": {"type": "string"}}, "required": ["message"]}, + ), + ] + + async def handle_task_required_echo(arguments: dict[str, Any]) -> types.CreateTaskResult: + ctx = server.request_context + ctx.experimental.validate_task_mode(types.TASK_REQUIRED) + message = arguments.get("message", "") + + async def work(task: ServerTaskContext) -> types.CallToolResult: + await task.update_status("Processing echo...") + return types.CallToolResult(content=[types.TextContent(type="text", text=f"Task echo: {message}")]) + + return await ctx.experimental.run_task(work) + + async def handle_task_optional_echo(arguments: dict[str, Any]) -> types.CallToolResult | types.CreateTaskResult: + ctx = server.request_context + message = arguments.get("message", "") + + if ctx.experimental.is_task: + + async def work(task: ServerTaskContext) -> types.CallToolResult: + await task.update_status("Processing optional task echo...") + return types.CallToolResult( + content=[types.TextContent(type="text", text=f"Task optional echo: {message}")] + ) + + return await ctx.experimental.run_task(work) + else: + return types.CallToolResult( + content=[types.TextContent(type="text", text=f"Direct optional echo: {message}")] + ) + + async def handle_task_forbidden_echo(arguments: dict[str, Any]) -> types.CallToolResult: + message = arguments.get("message", "") + return types.CallToolResult(content=[types.TextContent(type="text", text=f"Forbidden echo: {message}")]) + + async def handle_simple_echo(arguments: dict[str, Any]) -> types.CallToolResult: + message = arguments.get("message", "") + return types.CallToolResult(content=[types.TextContent(type="text", text=f"Simple echo: {message}")]) + + @server.call_tool() + async def handle_call_tool(name: str, arguments: dict[str, Any]) -> types.CallToolResult | types.CreateTaskResult: + handlers = { + "task_required_echo": handle_task_required_echo, + "task_optional_echo": handle_task_optional_echo, + "task_forbidden_echo": handle_task_forbidden_echo, + "echo": handle_simple_echo, + } + if name in handlers: + return await handlers[name](arguments) + return types.CallToolResult( + content=[types.TextContent(type="text", text=f"Unknown tool: {name}")], isError=True + ) + + return server + + +def create_starlette_app(port: int) -> tuple[Starlette, StreamableHTTPSessionManager]: + """Create the Starlette app with MCP session manager.""" + server = create_task_server() + session_manager = StreamableHTTPSessionManager(app=server) + + @asynccontextmanager + async def app_lifespan(app: Starlette) -> AsyncIterator[None]: + async with session_manager.run(): + yield + + return Starlette(routes=[Mount("/mcp", app=session_manager.handle_request)], lifespan=app_lifespan), session_manager + + +@click.command() +@click.option("--port", default=8010, help="Port to listen on") +def main(port: int) -> int: + """Start the task echo server.""" + import uvicorn + + starlette_app, _ = create_starlette_app(port) + print(f"Starting task echo server on http://localhost:{port}/mcp") + uvicorn.run(starlette_app, host="127.0.0.1", port=port) + return 0 + + +if __name__ == "__main__": + main() diff --git a/tests_integ/mcp/test_mcp_client.py b/tests_integ/mcp/test_mcp_client.py index 5c3baeba8..fe2b10df3 100644 --- a/tests_integ/mcp/test_mcp_client.py +++ b/tests_integ/mcp/test_mcp_client.py @@ -3,7 +3,7 @@ import os import threading import time -from typing import List, Literal +from typing import Literal import pytest from mcp import StdioServerParameters, stdio_client @@ -43,11 +43,11 @@ def calculator(x: int, y: int) -> int: @mcp.tool(description="Generates a custom image") def generate_custom_image() -> MCPImageContent: try: - with open("tests_integ/yellow.png", "rb") as image_file: + with open("tests_integ/resources/yellow.png", "rb") as image_file: encoded_image = base64.b64encode(image_file.read()) return MCPImageContent(type="image", data=encoded_image, mimeType="image/png") except Exception as e: - print("Error while generating custom image: {}".format(e)) + print(f"Error while generating custom image: {e}") # Prompts @mcp.prompt(description="A greeting prompt template") @@ -238,6 +238,74 @@ def test_mcp_client_without_structured_content(): assert result["content"] == [{"text": "SIMPLE_ECHO_TEST"}] +def test_call_tool_sync_with_meta(): + """Test that call_tool_sync forwards meta to the MCP server.""" + stdio_mcp_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])) + ) + + with stdio_mcp_client: + result = stdio_mcp_client.call_tool_sync( + tool_use_id="test-meta-sync", + name="echo_meta", + arguments={}, + meta={"com.example/request_id": "abc-123"}, + ) + + assert result["status"] == "success" + received_meta = json.loads(result["content"][0]["text"]) + assert received_meta["com.example/request_id"] == "abc-123" + + +@pytest.mark.asyncio +async def test_call_tool_async_with_meta(): + """Test that call_tool_async forwards meta to the MCP server.""" + stdio_mcp_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])) + ) + + with stdio_mcp_client: + result = await stdio_mcp_client.call_tool_async( + tool_use_id="test-meta-async", + name="echo_meta", + arguments={}, + meta={"com.example/request_id": "def-456"}, + ) + + assert result["status"] == "success" + received_meta = json.loads(result["content"][0]["text"]) + assert received_meta["com.example/request_id"] == "def-456" + + +def test_instrumentation_preserves_meta_on_tool_call(): + """Test that OTel instrumentation sets _meta that reaches the MCP server.""" + from unittest.mock import MagicMock, patch + + # Mock the propagator to always inject a known value, bypassing the need for + # an active span on the background thread where send_request runs + mock_textmap = MagicMock() + mock_textmap.inject = lambda carrier, **kwargs: carrier.update({"traceparent": "00-abc-def-01"}) + + with patch("opentelemetry.propagate.get_global_textmap", return_value=mock_textmap): + stdio_mcp_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])) + ) + + with stdio_mcp_client: + result = stdio_mcp_client.call_tool_sync( + tool_use_id="test-instrumentation", + name="echo_meta", + arguments={}, + ) + + assert result["status"] == "success" + received_meta = json.loads(result["content"][0]["text"]) + # OTel instrumentation should have injected _meta with tracing context + assert received_meta is not None + assert isinstance(received_meta, dict) + assert received_meta["traceparent"] == "00-abc-def-01" + + @pytest.mark.skipif( condition=os.environ.get("GITHUB_ACTIONS") == "true", reason="streamable transport is failing in GitHub actions, debugging if linux compatibility issue", @@ -366,7 +434,7 @@ def test_mcp_client_embedded_resources_with_agent(): assert any(["72" in response_text, "partly cloudy" in response_text, "weather" in response_text]) -def _messages_to_content_blocks(messages: List[Message]) -> List[ToolUse]: +def _messages_to_content_blocks(messages: list[Message]) -> list[ToolUse]: return [block["toolUse"] for message in messages for block in message["content"] if "toolUse" in block] @@ -398,30 +466,6 @@ def slow_transport(): assert len(tools) >= 0 # Should work now -@pytest.mark.skipif( - condition=os.environ.get("GITHUB_ACTIONS") == "true", - reason="streamable transport is failing in GitHub actions, debugging if linux compatibility issue", -) -@pytest.mark.asyncio -async def test_streamable_http_mcp_client_times_out_before_tool(): - """Test an mcp server that timesout before the tool is able to respond.""" - server_thread = threading.Thread( - target=start_comprehensive_mcp_server, kwargs={"transport": "streamable-http", "port": 8001}, daemon=True - ) - server_thread.start() - time.sleep(2) # wait for server to startup completely - - def transport_callback() -> MCPTransport: - return streamablehttp_client(sse_read_timeout=2, url="http://127.0.0.1:8001/mcp") - - streamable_http_client = MCPClient(transport_callback) - with streamable_http_client: - # Test tools - result = await streamable_http_client.call_tool_async(tool_use_id="123", name="timeout_tool") - assert result["status"] == "error" - assert result["content"][0]["text"] == "Tool execution failed: Connection closed" - - def start_5xx_proxy_for_tool_calls(target_url: str, proxy_port: int): """Starts a proxy that throws a 5XX when a tool call is invoked""" import aiohttp diff --git a/tests_integ/mcp/test_mcp_client_structured_content_and_metadata.py b/tests_integ/mcp/test_mcp_client_structured_content_and_metadata.py new file mode 100644 index 000000000..3e6132b38 --- /dev/null +++ b/tests_integ/mcp/test_mcp_client_structured_content_and_metadata.py @@ -0,0 +1,95 @@ +"""Integration test for MCP client structured content and metadata support. + +This test verifies that MCP tools can return structured content and metadata, +and that the MCP client properly handles and exposes these fields in tool results. +""" + +import json + +from mcp import StdioServerParameters, stdio_client + +from strands import Agent +from strands.hooks import AfterToolCallEvent, HookProvider, HookRegistry +from strands.tools.mcp.mcp_client import MCPClient + + +class ToolResultCapture(HookProvider): + """Captures tool results for inspection.""" + + def __init__(self): + self.captured_results = {} + + def register_hooks(self, registry: HookRegistry) -> None: + """Register callback for after tool invocation events.""" + registry.add_callback(AfterToolCallEvent, self.on_after_tool_invocation) + + def on_after_tool_invocation(self, event: AfterToolCallEvent) -> None: + """Capture tool results by tool name.""" + tool_name = event.tool_use["name"] + self.captured_results[tool_name] = event.result + + +def test_structured_content(): + """Test that MCP tools can return structured content.""" + # Set up result capture + result_capture = ToolResultCapture() + + # Set up MCP client for echo server + stdio_mcp_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])) + ) + + with stdio_mcp_client: + # Create agent with MCP tools and result capture + agent = Agent(tools=stdio_mcp_client.list_tools_sync(), hooks=[result_capture]) + + # Test structured content functionality + test_data = "STRUCTURED_TEST" + agent(f"Use the echo_with_structured_content tool to echo: {test_data}") + + # Verify result was captured + assert "echo_with_structured_content" in result_capture.captured_results + result = result_capture.captured_results["echo_with_structured_content"] + + # Verify basic result structure + assert result["status"] == "success" + assert len(result["content"]) == 1 + + # Verify structured content is present and correct + assert "structuredContent" in result + assert result["structuredContent"] == {"echoed": test_data, "message_length": 15} + + # Verify text content matches structured content + text_content = json.loads(result["content"][0]["text"]) + assert text_content == {"echoed": test_data, "message_length": 15} + + +def test_metadata(): + """Test that MCP tools can return metadata.""" + # Set up result capture + result_capture = ToolResultCapture() + + # Set up MCP client for echo server + stdio_mcp_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])) + ) + + with stdio_mcp_client: + # Create agent with MCP tools and result capture + agent = Agent(tools=stdio_mcp_client.list_tools_sync(), hooks=[result_capture]) + + # Test metadata functionality + test_data = "METADATA_TEST" + agent(f"Use the echo_with_metadata tool to echo: {test_data}") + + # Verify result was captured + assert "echo_with_metadata" in result_capture.captured_results + result = result_capture.captured_results["echo_with_metadata"] + + # Verify basic result structure + assert result["status"] == "success" + + # Verify metadata is present and correct + assert "metadata" in result + expected_metadata = {"metadata": {"nested": 1}, "shallow": "val"} + assert result["metadata"] == expected_metadata diff --git a/tests_integ/mcp/test_mcp_client_structured_content_with_hooks.py b/tests_integ/mcp/test_mcp_client_structured_content_with_hooks.py deleted file mode 100644 index ef4993b05..000000000 --- a/tests_integ/mcp/test_mcp_client_structured_content_with_hooks.py +++ /dev/null @@ -1,64 +0,0 @@ -"""Integration test demonstrating hooks system with MCP client structured content tool. - -This test shows how to use the hooks system to capture and inspect tool invocation -results, specifically testing the echo_with_structured_content tool from echo_server. -""" - -import json - -from mcp import StdioServerParameters, stdio_client - -from strands import Agent -from strands.hooks import AfterToolCallEvent, HookProvider, HookRegistry -from strands.tools.mcp.mcp_client import MCPClient - - -class StructuredContentHookProvider(HookProvider): - """Hook provider that captures structured content tool results.""" - - def __init__(self): - self.captured_result = None - - def register_hooks(self, registry: HookRegistry) -> None: - """Register callback for after tool invocation events.""" - registry.add_callback(AfterToolCallEvent, self.on_after_tool_invocation) - - def on_after_tool_invocation(self, event: AfterToolCallEvent) -> None: - """Capture structured content tool results.""" - if event.tool_use["name"] == "echo_with_structured_content": - self.captured_result = event.result - - -def test_mcp_client_hooks_structured_content(): - """Test using hooks to inspect echo_with_structured_content tool result.""" - # Create hook provider to capture tool result - hook_provider = StructuredContentHookProvider() - - # Set up MCP client for echo server - stdio_mcp_client = MCPClient( - lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])) - ) - - with stdio_mcp_client: - # Create agent with MCP tools and hook provider - agent = Agent(tools=stdio_mcp_client.list_tools_sync(), hooks=[hook_provider]) - - # Test structured content functionality - test_data = "HOOKS_TEST_DATA" - agent(f"Use the echo_with_structured_content tool to echo: {test_data}") - - # Verify hook captured the tool result - assert hook_provider.captured_result is not None - result = hook_provider.captured_result - - # Verify basic result structure - assert result["status"] == "success" - assert len(result["content"]) == 1 - - # Verify structured content is present and correct - assert "structuredContent" in result - assert result["structuredContent"] == {"echoed": test_data, "message_length": 15} - - # Verify text content matches structured content - text_content = json.loads(result["content"][0]["text"]) - assert text_content == {"echoed": test_data, "message_length": 15} diff --git a/tests_integ/mcp/test_mcp_client_tasks.py b/tests_integ/mcp/test_mcp_client_tasks.py new file mode 100644 index 000000000..751fb655f --- /dev/null +++ b/tests_integ/mcp/test_mcp_client_tasks.py @@ -0,0 +1,153 @@ +"""Integration tests for MCP task-augmented tool execution.""" + +import os +import socket +import threading +import time +from typing import Any + +import pytest +from mcp.client.streamable_http import streamablehttp_client + +from strands.tools.mcp import MCPClient, MCPTransport, TasksConfig + + +def _find_available_port() -> int: + """Find an available port.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + s.listen(1) + return s.getsockname()[1] + + +def start_task_server(port: int) -> None: + """Start the task echo server in a thread.""" + import uvicorn + + from tests_integ.mcp.task_echo_server import create_starlette_app + + starlette_app, _ = create_starlette_app(port) + uvicorn.run(starlette_app, host="127.0.0.1", port=port, log_level="warning") + + +@pytest.fixture(scope="module") +def task_server_port() -> int: + return _find_available_port() + + +@pytest.fixture(scope="module") +def task_server(task_server_port: int) -> Any: + """Start the task server for the test module.""" + server_thread = threading.Thread(target=start_task_server, kwargs={"port": task_server_port}, daemon=True) + server_thread.start() + time.sleep(2) + yield + + +@pytest.fixture +def task_mcp_client(task_server: Any, task_server_port: int) -> MCPClient: + """Create an MCP client with tasks enabled.""" + + def transport_callback() -> MCPTransport: + return streamablehttp_client(url=f"http://127.0.0.1:{task_server_port}/mcp") + + return MCPClient(transport_callback, tasks_config=TasksConfig()) + + +@pytest.fixture +def task_mcp_client_disabled(task_server: Any, task_server_port: int) -> MCPClient: + """Create an MCP client with tasks disabled (default).""" + + def transport_callback() -> MCPTransport: + return streamablehttp_client(url=f"http://127.0.0.1:{task_server_port}/mcp") + + return MCPClient(transport_callback) + + +@pytest.mark.skipif(os.environ.get("GITHUB_ACTIONS") == "true", reason="streamable transport failing in CI") +class TestMCPTaskSupport: + """Integration tests for MCP task-augmented execution.""" + + def test_direct_call_tools(self, task_mcp_client: MCPClient) -> None: + """Test tools that use direct call_tool (forbidden or no taskSupport).""" + with task_mcp_client: + task_mcp_client.list_tools_sync() + + # Tool with taskSupport='forbidden' + r1 = task_mcp_client.call_tool_sync( + tool_use_id="t1", name="task_forbidden_echo", arguments={"message": "Hello!"} + ) + assert r1["status"] == "success" + assert "Forbidden echo: Hello!" in r1["content"][0].get("text", "") + + # Tool without taskSupport + r2 = task_mcp_client.call_tool_sync(tool_use_id="t2", name="echo", arguments={"message": "Simple!"}) + assert r2["status"] == "success" + assert "Simple echo: Simple!" in r2["content"][0].get("text", "") + + def test_task_augmented_tools(self, task_mcp_client: MCPClient) -> None: + """Test tools that use task-augmented execution (required or optional).""" + with task_mcp_client: + task_mcp_client.list_tools_sync() + + # Tool with taskSupport='required' + r1 = task_mcp_client.call_tool_sync( + tool_use_id="t1", name="task_required_echo", arguments={"message": "Required!"} + ) + assert r1["status"] == "success" + assert "Task echo: Required!" in r1["content"][0].get("text", "") + + # Tool with taskSupport='optional' + r2 = task_mcp_client.call_tool_sync( + tool_use_id="t2", name="task_optional_echo", arguments={"message": "Optional!"} + ) + assert r2["status"] == "success" + assert "Task optional echo: Optional!" in r2["content"][0].get("text", "") + + def test_task_support_tool_detection(self, task_mcp_client: MCPClient) -> None: + """Test tool-level task support detection.""" + with task_mcp_client: + task_mcp_client.list_tools_sync() + + # Verify decision logic + assert task_mcp_client._should_use_task("task_required_echo") is True + assert task_mcp_client._should_use_task("task_optional_echo") is True + assert task_mcp_client._should_use_task("task_forbidden_echo") is False + assert task_mcp_client._should_use_task("echo") is False + + def test_server_capabilities(self, task_mcp_client: MCPClient) -> None: + """Test server task capability detection.""" + with task_mcp_client: + task_mcp_client.list_tools_sync() + assert task_mcp_client._has_server_task_support() is True + + def test_tasks_disabled_by_default(self, task_mcp_client_disabled: MCPClient) -> None: + """Test that tasks are disabled when experimental.tasks is not configured.""" + with task_mcp_client_disabled: + task_mcp_client_disabled.list_tools_sync() + + assert task_mcp_client_disabled._is_tasks_enabled() is False + assert task_mcp_client_disabled._should_use_task("task_required_echo") is False + + # Direct call_tool still works for tools that support it + result = task_mcp_client_disabled.call_tool_sync( + tool_use_id="t", name="task_optional_echo", arguments={"message": "Direct!"} + ) + assert result["status"] == "success" + + # Task-required tools fail gracefully via direct call + result2 = task_mcp_client_disabled.call_tool_sync( + tool_use_id="t2", name="task_required_echo", arguments={"message": "Direct!"} + ) + assert result2["status"] == "error" + + @pytest.mark.asyncio + async def test_async_tool_call(self, task_mcp_client: MCPClient) -> None: + """Test async tool calls.""" + with task_mcp_client: + task_mcp_client.list_tools_sync() + result = await task_mcp_client.call_tool_async( + tool_use_id="t", name="task_forbidden_echo", arguments={"message": "Async!"} + ) + assert result["status"] == "success" + assert "Forbidden echo: Async!" in result["content"][0].get("text", "") diff --git a/tests_integ/mcp/test_mcp_resources.py b/tests_integ/mcp/test_mcp_resources.py new file mode 100644 index 000000000..dccf3b808 --- /dev/null +++ b/tests_integ/mcp/test_mcp_resources.py @@ -0,0 +1,130 @@ +""" +Integration tests for MCP client resource functionality. + +This module tests the resource-related methods in MCPClient: +- list_resources_sync() +- read_resource_sync() +- list_resource_templates_sync() + +The tests use the echo server which has been extended with resource functionality. +""" + +import base64 +import json + +import pytest +from mcp import StdioServerParameters, stdio_client +from mcp.shared.exceptions import McpError +from mcp.types import BlobResourceContents, TextResourceContents +from pydantic import AnyUrl + +from strands.tools.mcp.mcp_client import MCPClient + + +def test_mcp_resources_list_and_read(): + """Test listing and reading various types of resources.""" + mcp_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])) + ) + + with mcp_client: + # Test list_resources_sync + resources_result = mcp_client.list_resources_sync() + assert len(resources_result.resources) >= 2 # At least our 2 static resources + + # Verify resource URIs exist (only static resources, not templates) + resource_uris = [str(r.uri) for r in resources_result.resources] + assert "test://static-text" in resource_uris + assert "test://static-binary" in resource_uris + # Template resources are not listed in static resources + + # Test reading text resource + text_resource = mcp_client.read_resource_sync("test://static-text") + assert len(text_resource.contents) == 1 + content = text_resource.contents[0] + assert isinstance(content, TextResourceContents) + assert "This is the content of the static text resource." in content.text + + # Test reading binary resource + binary_resource = mcp_client.read_resource_sync("test://static-binary") + assert len(binary_resource.contents) == 1 + binary_content = binary_resource.contents[0] + assert isinstance(binary_content, BlobResourceContents) + # Verify it's valid base64 encoded data + decoded_data = base64.b64decode(binary_content.blob) + assert len(decoded_data) > 0 + + +def test_mcp_resources_templates(): + """Test listing resource templates and reading from template resources.""" + mcp_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])) + ) + + with mcp_client: + # Test list_resource_templates_sync + templates_result = mcp_client.list_resource_templates_sync() + assert len(templates_result.resourceTemplates) >= 1 + + # Verify template URIs exist + template_uris = [t.uriTemplate for t in templates_result.resourceTemplates] + assert "test://template/{id}/data" in template_uris + + # Test reading from template resource + template_resource = mcp_client.read_resource_sync("test://template/123/data") + assert len(template_resource.contents) == 1 + template_content = template_resource.contents[0] + assert isinstance(template_content, TextResourceContents) + + # Parse the JSON response + parsed_json = json.loads(template_content.text) + assert parsed_json["id"] == "123" + assert parsed_json["templateTest"] is True + assert "Data for ID: 123" in parsed_json["data"] + + +def test_mcp_resources_pagination(): + """Test pagination support for resources.""" + mcp_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])) + ) + + with mcp_client: + # Test with pagination token (should work even if server doesn't implement pagination) + resources_result = mcp_client.list_resources_sync(pagination_token=None) + assert len(resources_result.resources) >= 0 + + # Test resource templates pagination + templates_result = mcp_client.list_resource_templates_sync(pagination_token=None) + assert len(templates_result.resourceTemplates) >= 0 + + +def test_mcp_resources_error_handling(): + """Test error handling for resource operations.""" + mcp_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])) + ) + + with mcp_client: + # Test reading non-existent resource + with pytest.raises(McpError, match="Unknown resource"): + mcp_client.read_resource_sync("test://nonexistent") + + +def test_mcp_resources_uri_types(): + """Test that both string and AnyUrl types work for read_resource_sync.""" + mcp_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])) + ) + + with mcp_client: + # Test with string URI + text_resource_str = mcp_client.read_resource_sync("test://static-text") + assert len(text_resource_str.contents) == 1 + + # Test with AnyUrl URI + text_resource_url = mcp_client.read_resource_sync(AnyUrl("test://static-text")) + assert len(text_resource_url.contents) == 1 + + # Both should return the same content + assert text_resource_str.contents[0].text == text_resource_url.contents[0].text diff --git a/tests_integ/models/providers.py b/tests_integ/models/providers.py index 75cc58f74..db85d496d 100644 --- a/tests_integ/models/providers.py +++ b/tests_integ/models/providers.py @@ -3,7 +3,7 @@ """ import os -from typing import Callable, Optional +from collections.abc import Callable import requests from pytest import mark @@ -18,6 +18,13 @@ from strands.models.openai import OpenAIModel from strands.models.writer import WriterModel +try: + from strands.models.openai_responses import OpenAIResponsesModel + + _openai_responses_available = True +except ImportError: + _openai_responses_available = False + class ProviderInfo: """Provider-based info for providers that require an APIKey via environment variables.""" @@ -26,7 +33,7 @@ def __init__( self, id: str, factory: Callable[[], Model], - environment_variable: Optional[str] = None, + environment_variable: str | None = None, ) -> None: self.id = id self.model_factory = factory @@ -66,7 +73,7 @@ def __init__(self): client_args={ "api_key": os.getenv("ANTHROPIC_API_KEY"), }, - model_id="claude-3-7-sonnet-20250219", + model_id="claude-sonnet-4-6", max_tokens=512, ), ) @@ -84,7 +91,7 @@ def __init__(self): ), ) litellm = ProviderInfo( - id="litellm", factory=lambda: LiteLLMModel(model_id="bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0") + id="litellm", factory=lambda: LiteLLMModel(model_id="bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0") ) llama = ProviderInfo( id="llama", @@ -118,6 +125,19 @@ def __init__(self): }, ), ) +if _openai_responses_available: + openai_responses = ProviderInfo( + id="openai_responses", + environment_variable="OPENAI_API_KEY", + factory=lambda: OpenAIResponsesModel( + model_id="gpt-4o", + client_args={ + "api_key": os.getenv("OPENAI_API_KEY"), + }, + ), + ) +else: + openai_responses = None writer = ProviderInfo( id="writer", environment_variable="WRITER_API_KEY", @@ -141,13 +161,18 @@ def __init__(self): all_providers = [ - bedrock, - anthropic, - cohere, - gemini, - llama, - litellm, - mistral, - openai, - writer, + provider + for provider in [ + bedrock, + anthropic, + cohere, + gemini, + llama, + litellm, + mistral, + openai, + openai_responses, + writer, + ] + if provider is not None ] diff --git a/tests_integ/models/test_conformance.py b/tests_integ/models/test_conformance.py index 36c21fb7f..994ecbf00 100644 --- a/tests_integ/models/test_conformance.py +++ b/tests_integ/models/test_conformance.py @@ -74,4 +74,4 @@ class UserProfile(BaseModel): result = agent("Create a profile for John who is a 25 year old dentist", structured_output_model=UserProfile) assert result.structured_output.name == "John" assert result.structured_output.age == 25 - assert result.structured_output.occupation == "dentist" + assert result.structured_output.occupation.lower() == "dentist" diff --git a/tests_integ/models/test_model_anthropic.py b/tests_integ/models/test_model_anthropic.py index 9a0d19dff..a5eba45b9 100644 --- a/tests_integ/models/test_model_anthropic.py +++ b/tests_integ/models/test_model_anthropic.py @@ -28,7 +28,7 @@ def model(): client_args={ "api_key": os.getenv("ANTHROPIC_API_KEY"), }, - model_id="claude-3-7-sonnet-20250219", + model_id="claude-sonnet-4-6", max_tokens=512, ) @@ -182,3 +182,42 @@ def test_input_and_max_tokens_exceed_context_limit(): with pytest.raises(ContextWindowOverflowException): agent(messages) + + +class TestCountTokens: + @pytest.fixture + def model(self): + return AnthropicModel( + model_id="claude-sonnet-4-20250514", + max_tokens=1024, + client_args={"api_key": os.environ["ANTHROPIC_API_KEY"]}, + ) + + @pytest.fixture + def messages(self): + return [{"role": "user", "content": [{"text": "What is the capital of France? Explain in detail."}]}] + + @pytest.fixture + def tool_specs(self): + return [ + { + "name": "get_weather", + "description": "Get the current weather for a location", + "inputSchema": {"json": {"type": "object", "properties": {"location": {"type": "string"}}}}, + } + ] + + @pytest.mark.asyncio + async def test_count_tokens_messages_only(self, model, messages, caplog): + with caplog.at_level("DEBUG"): + result = await model.count_tokens(messages=messages) + assert isinstance(result, int) + assert result > 0 + assert "native token count" in caplog.text + assert "falling back" not in caplog.text + + @pytest.mark.asyncio + async def test_count_tokens_with_tools_greater_than_without(self, model, messages, tool_specs): + without = await model.count_tokens(messages=messages) + with_tools = await model.count_tokens(messages=messages, tool_specs=tool_specs, system_prompt="Be helpful.") + assert with_tools > without diff --git a/tests_integ/models/test_model_bedrock.py b/tests_integ/models/test_model_bedrock.py index 2c2e125ad..c1c4adf6f 100644 --- a/tests_integ/models/test_model_bedrock.py +++ b/tests_integ/models/test_model_bedrock.py @@ -1,11 +1,23 @@ +import time +import uuid + import pydantic import pytest import strands from strands import Agent -from strands.models import BedrockModel +from strands.models import BedrockModel, CacheConfig, CacheToolsConfig from strands.types.content import ContentBlock +# Model ID used for prompt-caching TTL integration tests. Per +# https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html +# the models that officially support 1h TTL on CachePoint are Claude Opus 4.5, +# Claude Haiku 4.5, and Claude Sonnet 4.5. Haiku 4.5 is the newest Haiku +# available and is preferred for CI due to lower latency and cost relative to +# the same-version Sonnet 4.5. Bump this when a newer Haiku is released that +# supports CachePoint TTL. +_CACHE_TTL_MODEL_ID = "us.anthropic.claude-haiku-4-5-20251001-v1:0" + @pytest.fixture def system_prompt(): @@ -73,6 +85,27 @@ def test_non_streaming_agent(non_streaming_agent): assert len(str(result)) > 0 +def test_bedrock_service_tier_flex_invocation_succeeds(): + """Bedrock accepts serviceTier when model and region support Priority/Flex tiers. + + Tier support is model- and region-specific. See: + https://docs.aws.amazon.com/bedrock/latest/userguide/service-tiers-inference.html + + CI runs integ tests with AWS_REGION=us-east-1; amazon.nova-pro-v1:0 is listed for + that region under Priority and Flex tiers. + """ + model = BedrockModel( + model_id="amazon.nova-pro-v1:0", + region_name="us-east-1", + service_tier="flex", + ) + agent = Agent(model=model, load_tools_from_directory=False) + result = agent("Reply with exactly the word: ok") + + assert result.stop_reason == "end_turn" + assert len(str(result).strip()) > 0 + + @pytest.mark.asyncio async def test_streaming_model_events(streaming_model, alist): """Test streaming model events.""" @@ -210,6 +243,9 @@ def test_document_citations(non_streaming_agent, letter_pdf): assert any("citationsContent" in content for content in non_streaming_agent.messages[-1]["content"]) + # Validate message structure is valid in multi-turn + non_streaming_agent("What is your favorite part?") + def test_document_citations_streaming(streaming_agent, letter_pdf): content: list[ContentBlock] = [ @@ -228,6 +264,9 @@ def test_document_citations_streaming(streaming_agent, letter_pdf): assert any("citationsContent" in content for content in streaming_agent.messages[-1]["content"]) + # Validate message structure is valid in multi-turn + streaming_agent("What is your favorite part?") + def test_structured_output_multi_modal_input(streaming_agent, yellow_img, yellow_color): content = [ @@ -246,27 +285,41 @@ def test_structured_output_multi_modal_input(streaming_agent, yellow_img, yellow assert tru_color == exp_color -def test_redacted_content_handling(): - """Test redactedContent handling with thinking mode.""" - bedrock_model = BedrockModel( - model_id="us.anthropic.claude-3-7-sonnet-20250219-v1:0", +def test_reasoning_content_in_messages_with_thinking_disabled(): + """Test that messages with reasoningContent are accepted when thinking is explicitly disabled.""" + # First, get a real reasoning response with thinking enabled + thinking_model = BedrockModel( + model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", additional_request_fields={ "thinking": { "type": "enabled", - "budget_tokens": 2000, + "budget_tokens": 1024, } }, ) + agent_with_thinking = Agent(model=thinking_model) + result_with_thinking = agent_with_thinking("What is 2+2?") - agent = Agent(name="test_redact", model=bedrock_model) - # https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#example-working-with-redacted-thinking-blocks - result = agent( - "ANTHROPIC_MAGIC_STRING_TRIGGER_REDACTED_THINKING_46C9A13E193C177646C7398A98432ECCCE4C1253D5E2D82641AC0E52CC2876CB" + # Verify we got reasoning content + assert "reasoningContent" in result_with_thinking.message["content"][0] + + # Now create a model with thinking disabled and use the messages from the thinking session + disabled_model = BedrockModel( + model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", + additional_request_fields={ + "thinking": { + "type": "disabled", + } + }, ) - assert "reasoningContent" in result.message["content"][0] - assert "redactedContent" in result.message["content"][0]["reasoningContent"] - assert isinstance(result.message["content"][0]["reasoningContent"]["redactedContent"], bytes) + # Use the conversation history that includes reasoning content + messages = agent_with_thinking.messages + + agent_disabled = Agent(model=disabled_model, messages=messages) + result = agent_disabled("What about 3+3?") + + assert result.stop_reason == "end_turn" def test_multi_prompt_system_content(): @@ -280,3 +333,364 @@ def test_multi_prompt_system_content(): agent = Agent(system_prompt=system_prompt_content, load_tools_from_directory=False) # just verifying there is no failure agent("Hello!") + + +def test_prompt_caching_with_5m_ttl(): + """Test prompt caching with 5 minute TTL and verify cache metrics. + + This test verifies: + 1. First call creates cache (cacheWriteInputTokens > 0) + 2. Second call reads from cache (cacheReadInputTokens > 0) + + Uses Claude Haiku 4.5 which supports TTL in CachePointBlock on Bedrock. + Older models (e.g. Claude Sonnet 4) reject the TTL field with a ValidationException. + """ + model = BedrockModel( + model_id="us.anthropic.claude-haiku-4-5-20251001-v1:0", + streaming=False, + ) + + # Use unique identifier to avoid cache conflicts between test runs + unique_id = str(uuid.uuid4()) + # Minimum 4096 tokens required for caching with Haiku 4.5 + large_context = f"Background information for test {unique_id}: " + ("This is important context. " * 1000) + + system_prompt_with_cache = [ + {"text": large_context}, + {"cachePoint": {"type": "default", "ttl": "5m"}}, + {"text": "You are a helpful assistant."}, + ] + + agent = Agent( + model=model, + system_prompt=system_prompt_with_cache, + load_tools_from_directory=False, + ) + + # First call should create the cache (cache write) + result1 = agent("What is 2+2?") + assert len(str(result1)) > 0 + + # Verify cache write occurred on first call + assert result1.metrics.accumulated_usage.get("cacheWriteInputTokens", 0) > 0, ( + "Expected cacheWriteInputTokens > 0 on first call" + ) + + # Second call should use the cached content (cache read) + result2 = agent("What is 3+3?") + assert len(str(result2)) > 0 + + # Verify cache read occurred on second call + assert result2.metrics.accumulated_usage.get("cacheReadInputTokens", 0) > 0, ( + "Expected cacheReadInputTokens > 0 on second call" + ) + + +def test_prompt_caching_with_1h_ttl(): + """Test prompt caching with 1 hour TTL and verify cache metrics. + + Uses Claude Haiku 4.5 which supports 1hr TTL. + Uses unique content per test run to avoid cache conflicts with concurrent CI runs. + Even with 1hr TTL, unique content ensures cache entries don't interfere across tests. + """ + model = BedrockModel( + model_id="us.anthropic.claude-haiku-4-5-20251001-v1:0", + streaming=False, + ) + + # Use timestamp to ensure unique content per test run (avoids CI conflicts) + unique_id = str(int(time.time() * 1000000)) # microsecond timestamp + # Minimum 4096 tokens required for caching with Haiku 4.5 + large_context = f"Background information for test {unique_id}: " + ("This is important context. " * 1000) + + system_prompt_with_cache = [ + {"text": large_context}, + {"cachePoint": {"type": "default", "ttl": "1h"}}, + {"text": "You are a helpful assistant."}, + ] + + agent = Agent( + model=model, + system_prompt=system_prompt_with_cache, + load_tools_from_directory=False, + ) + + # First call should create the cache + result1 = agent("What is 2+2?") + assert len(str(result1)) > 0 + + # Verify cache write occurred + assert result1.metrics.accumulated_usage.get("cacheWriteInputTokens", 0) > 0, ( + "Expected cacheWriteInputTokens > 0 on first call with 1h TTL" + ) + + # Second call should use the cached content + result2 = agent("What is 3+3?") + assert len(str(result2)) > 0 + + # Verify cache read occurred + assert result2.metrics.accumulated_usage.get("cacheReadInputTokens", 0) > 0, ( + "Expected cacheReadInputTokens > 0 on second call with 1h TTL" + ) + + +def test_prompt_caching_with_ttl_in_messages(): + """Test prompt caching with TTL in message content and verify cache metrics. + + Uses Claude Haiku 4.5 which supports TTL in CachePointBlock on Bedrock. + Older models (e.g. Claude Sonnet 4) reject the TTL field with a ValidationException. + """ + model = BedrockModel( + model_id="us.anthropic.claude-haiku-4-5-20251001-v1:0", + streaming=False, + ) + agent = Agent(model=model, load_tools_from_directory=False) + + unique_id = str(uuid.uuid4()) + # Minimum 4096 tokens required for caching with Haiku 4.5 + large_text = f"Important context for test {unique_id}: " + ("This is critical information. " * 1000) + + content_with_cache = [ + {"text": large_text}, + {"cachePoint": {"type": "default", "ttl": "5m"}}, + {"text": "Based on the context above, what is 5+5?"}, + ] + + # First call creates cache + result1 = agent(content_with_cache) + assert len(str(result1)) > 0 + + # Verify cache write in message content + assert result1.metrics.accumulated_usage.get("cacheWriteInputTokens", 0) > 0, ( + "Expected cacheWriteInputTokens > 0 when caching message content" + ) + + # Subsequent call should use cache + result2 = agent("What about 10+10?") + assert len(str(result2)) > 0 + + # Verify cache read on subsequent call + assert result2.metrics.accumulated_usage.get("cacheReadInputTokens", 0) > 0, ( + "Expected cacheReadInputTokens > 0 on subsequent call" + ) + + +def test_prompt_caching_backward_compatibility_no_ttl(): + """Test that prompt caching works without TTL (backward compatibility). + + Verifies that cache points work correctly when TTL is not specified, + maintaining backward compatibility with existing code. + + Uses Claude Haiku 4.5 which supports prompt caching on Bedrock. + Minimum 4096 tokens required for caching with Haiku 4.5. + """ + model = BedrockModel( + model_id="us.anthropic.claude-haiku-4-5-20251001-v1:0", + streaming=False, + ) + + unique_id = str(uuid.uuid4()) + large_context = f"Background information for test {unique_id}: " + ("This is important context. " * 1000) + + system_prompt_with_cache = [ + {"text": large_context}, + {"cachePoint": {"type": "default"}}, # No TTL specified + {"text": "You are a helpful assistant."}, + ] + + agent = Agent( + model=model, + system_prompt=system_prompt_with_cache, + load_tools_from_directory=False, + ) + + result = agent("Hello!") + assert len(str(result)) > 0 + + # Verify cache write occurred even without TTL + assert result.metrics.accumulated_usage.get("cacheWriteInputTokens", 0) > 0, ( + "Expected cacheWriteInputTokens > 0 even without TTL specified" + ) + + +class TestCountTokens: + @pytest.fixture + def model(self): + return BedrockModel(model_id="anthropic.claude-sonnet-4-20250514-v1:0", use_native_token_count=True) + + @pytest.fixture + def messages(self): + return [{"role": "user", "content": [{"text": "What is the capital of France? Explain in detail."}]}] + + @pytest.fixture + def tool_specs(self): + return [ + { + "name": "get_weather", + "description": "Get the current weather for a location", + "inputSchema": {"json": {"type": "object", "properties": {"location": {"type": "string"}}}}, + } + ] + + @pytest.mark.asyncio + async def test_count_tokens_messages_only(self, model, messages, caplog): + with caplog.at_level("DEBUG"): + result = await model.count_tokens(messages=messages) + assert isinstance(result, int) + assert result > 0 + assert "native token count" in caplog.text + assert "falling back" not in caplog.text + + @pytest.mark.asyncio + async def test_count_tokens_with_tools_greater_than_without(self, model, messages, tool_specs): + without = await model.count_tokens(messages=messages) + with_tools = await model.count_tokens(messages=messages, tool_specs=tool_specs, system_prompt="Be helpful.") + assert with_tools > without + + +def test_strict_tools_with_complex_schema(): + """Test strict_tools=True with tools that have complex schemas including arrays and optional params.""" + + tools_called = set() + + @strands.tool + def search(query: str, tags: list[str], max_results: int = 5) -> str: + """Search for items matching query and tags.""" + tools_called.add("search") + return f"Found results for '{query}' with tags {tags} (limit {max_results})" + + @strands.tool + def calculator(expression: str) -> float: + """Calculate the result of a mathematical expression.""" + tools_called.add("calculator") + return eval(expression) + + model = BedrockModel(strict_tools=True) + agent = Agent(model=model, tools=[search, calculator], load_tools_from_directory=False) + agent('Search for "python" with tags ["programming", "language"] using the search tool.') + + assert "search" in tools_called + + +def test_prompt_caching_cache_tools_ttl(): + """Test that CacheToolsConfig(ttl=...) propagates into the auto-injected toolConfig cache point. + + Verifies that BedrockModel(cache_tools=CacheToolsConfig(type="default", ttl="5m")) produces a + Bedrock request with cachePoint.ttl on the toolConfig checkpoint, and that the call + completes without a ValidationException on the TTL field. + + Note: we intentionally do not assert specific cacheWriteInputTokens on the toolConfig + prefix because Bedrock's tool-prefix cache threshold varies by model and region. + The critical behavior under test here is that the TTL field is accepted end-to-end. + + Uses Claude Haiku 4.5 which supports TTL in CachePointBlock on Bedrock per + https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html + (Claude Opus 4.5, Claude Haiku 4.5, and Claude Sonnet 4.5 all support 1h TTL). + """ + model = BedrockModel( + model_id=_CACHE_TTL_MODEL_ID, + streaming=False, + cache_tools=CacheToolsConfig(type="default", ttl="5m"), + ) + + @strands.tool + def lookup_fact(topic: str) -> str: + """Look up a fact about the given topic. + + This tool is useful when you need authoritative information. + """ + return f"Fact about {topic}: example" + + agent = Agent( + model=model, + tools=[lookup_fact], + load_tools_from_directory=False, + ) + + # The call must succeed — Bedrock must accept cachePoint.ttl on the toolConfig checkpoint + # without raising a ValidationException. + result = agent("Use the lookup_fact tool to look up 'python'.") + assert len(str(result)) > 0 + + +def test_prompt_caching_cache_config_auto_with_ttl(): + """Test that CacheConfig(strategy="auto", ttl="5m") propagates TTL to the auto-injected message cache point. + + Verifies that the cache point appended to the last user message by _inject_cache_point + carries the configured TTL, and that Bedrock accepts the request. + + Uses Claude Haiku 4.5 which supports TTL in CachePointBlock on Bedrock per + https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html + """ + model = BedrockModel( + model_id=_CACHE_TTL_MODEL_ID, + streaming=False, + cache_config=CacheConfig(strategy="auto", ttl="5m"), + ) + + unique_id = str(uuid.uuid4()) + # Minimum 4096 tokens required for caching with Haiku 4.5 + large_message = f"Context for test {unique_id}: " + ("This is important context. " * 1000) + " What is 2+2?" + + agent = Agent( + model=model, + load_tools_from_directory=False, + ) + + # First call: auto-injected cache point on the last user message must include ttl and be accepted + result1 = agent(large_message) + assert len(str(result1)) > 0 + + # Verify cache write occurred with auto-inject + ttl + assert result1.metrics.accumulated_usage.get("cacheWriteInputTokens", 0) > 0, ( + "Expected cacheWriteInputTokens > 0 with CacheConfig(strategy='auto', ttl='5m')" + ) + + +def test_prompt_caching_aligned_1h_ttl_across_checkpoints(): + """Regression test for Bedrock TTL non-increasing ordering rule (Issue #2121). + + Bedrock processes cache checkpoints in order: toolConfig -> system -> messages, + and requires TTLs to be non-increasing. Before this change, cache_tools hardcoded + an implicit 5m TTL, so any 1h TTL on a later checkpoint would raise a + ValidationException. + + This test sets 1h TTL on all three checkpoints simultaneously and verifies the + call succeeds. + + Uses Claude Haiku 4.5 which supports 1h TTL per + https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html + """ + model = BedrockModel( + model_id=_CACHE_TTL_MODEL_ID, + streaming=False, + cache_tools=CacheToolsConfig(type="default", ttl="1h"), + cache_config=CacheConfig(strategy="auto", ttl="1h"), + ) + + # Timestamp-based uniqueness to avoid cache conflicts across CI runs + unique_id = str(int(time.time() * 1000000)) + large_context = f"Background context for test {unique_id}: " + ("This is important context. " * 1000) + + # User-supplied 1h cache point on system prompt — third checkpoint also at 1h + system_prompt_with_cache = [ + {"text": large_context}, + {"cachePoint": {"type": "default", "ttl": "1h"}}, + {"text": "You are a helpful assistant."}, + ] + + @strands.tool + def echo(value: str) -> str: + """Echo the given value back.""" + return value + + agent = Agent( + model=model, + system_prompt=system_prompt_with_cache, + tools=[echo], + load_tools_from_directory=False, + ) + + # Must succeed without ValidationException on the non-increasing TTL rule + result = agent("What is 2+2?") + assert len(str(result)) > 0 diff --git a/tests_integ/models/test_model_gemini.py b/tests_integ/models/test_model_gemini.py index f9da8490c..1057757da 100644 --- a/tests_integ/models/test_model_gemini.py +++ b/tests_integ/models/test_model_gemini.py @@ -2,6 +2,7 @@ import pydantic import pytest +from google import genai import strands from strands import Agent @@ -21,6 +22,16 @@ def model(): ) +@pytest.fixture +def gemini_tool_model(): + return GeminiModel( + client_args={"api_key": os.getenv("GOOGLE_API_KEY")}, + model_id="gemini-2.5-flash", + params={"temperature": 0.15}, # Lower temperature for consistent test behavior + gemini_tools=[genai.types.Tool(code_execution=genai.types.ToolCodeExecution())], + ) + + @pytest.fixture def tools(): @strands.tool @@ -175,3 +186,75 @@ def test_agent_structured_output_image_input(assistant_agent, yellow_img, yellow tru_color = assistant_agent.structured_output(type(yellow_color), content) exp_color = yellow_color assert tru_color == exp_color + + +def test_agent_with_gemini_code_execution_tool(gemini_tool_model): + system_prompt = "Generate and run code for all calculations" + agent = Agent(model=gemini_tool_model, system_prompt=system_prompt) + # sample prompt taken from https://ai.google.dev/gemini-api/docs/code-execution + result_turn1 = agent( + "What is the sum of the first 50 prime numbers? Generate and run code for the calculation, " + "and make sure you get all 50." + ) + + # NOTE: We don't verify tool history because built-in tools are currently represented in message history + assert "5117" in str(result_turn1) + + result_turn2 = agent("Summarize that into a single number") + assert "5117" in str(result_turn2) + + +def test_agent_with_reasoning_content(model, assistant_agent): + """Test that reasoning content is captured in message history.""" + + model.update_config( + params={ + "thinking_config": { + "thinking_budget": 1024, + "include_thoughts": True, + }, + }, + ) + + result = assistant_agent("Think about what 2+2 is") + assert "reasoningContent" in result.message["content"][0] + assert result.message["content"][0]["reasoningContent"]["reasoningText"]["text"] + + +class TestCountTokens: + @pytest.fixture + def model(self): + return GeminiModel( + model_id="gemini-2.0-flash", + client_args={"api_key": os.environ["GOOGLE_API_KEY"]}, + use_native_token_count=True, + ) + + @pytest.fixture + def messages(self): + return [{"role": "user", "content": [{"text": "What is the capital of France? Explain in detail."}]}] + + @pytest.fixture + def tool_specs(self): + return [ + { + "name": "get_weather", + "description": "Get the current weather for a location", + "inputSchema": {"json": {"type": "object", "properties": {"location": {"type": "string"}}}}, + } + ] + + @pytest.mark.asyncio + async def test_count_tokens_messages_only(self, model, messages, caplog): + with caplog.at_level("DEBUG"): + result = await model.count_tokens(messages=messages) + assert isinstance(result, int) + assert result > 0 + assert "native token count" in caplog.text + assert "falling back" not in caplog.text + + @pytest.mark.asyncio + async def test_count_tokens_with_tools_greater_than_without(self, model, messages, tool_specs): + without = await model.count_tokens(messages=messages) + with_tools = await model.count_tokens(messages=messages, tool_specs=tool_specs, system_prompt="Be helpful.") + assert with_tools > without diff --git a/tests_integ/models/test_model_litellm.py b/tests_integ/models/test_model_litellm.py index d72937641..b606771d0 100644 --- a/tests_integ/models/test_model_litellm.py +++ b/tests_integ/models/test_model_litellm.py @@ -1,3 +1,4 @@ +import os import unittest.mock from uuid import uuid4 @@ -11,7 +12,17 @@ @pytest.fixture def model(): - return LiteLLMModel(model_id="bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0") + return LiteLLMModel(model_id="bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0") + + +@pytest.fixture +def streaming_model(): + return LiteLLMModel(model_id="bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0", params={"stream": True}) + + +@pytest.fixture +def non_streaming_model(): + return LiteLLMModel(model_id="bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0", params={"stream": False}) @pytest.fixture @@ -95,15 +106,21 @@ def lower(_, value): return Color(simple_color_name="yellow") -def test_agent_invoke(agent): +@pytest.mark.parametrize("model_fixture", ["streaming_model", "non_streaming_model"]) +def test_agent_invoke(model_fixture, tools, request): + model = request.getfixturevalue(model_fixture) + agent = Agent(model=model, tools=tools) result = agent("What is the time and weather in New York?") text = result.message["content"][0]["text"].lower() assert all(string in text for string in ["12:00", "sunny"]) +@pytest.mark.parametrize("model_fixture", ["streaming_model", "non_streaming_model"]) @pytest.mark.asyncio -async def test_agent_invoke_async(agent): +async def test_agent_invoke_async(model_fixture, tools, request): + model = request.getfixturevalue(model_fixture) + agent = Agent(model=model, tools=tools) result = await agent.invoke_async("What is the time and weather in New York?") text = result.message["content"][0]["text"].lower() @@ -138,14 +155,20 @@ def test_agent_invoke_reasoning(agent, model): assert result.message["content"][0]["reasoningContent"]["reasoningText"]["text"] -def test_structured_output(agent, weather): +@pytest.mark.parametrize("model_fixture", ["streaming_model", "non_streaming_model"]) +def test_structured_output(model_fixture, weather, request): + model = request.getfixturevalue(model_fixture) + agent = Agent(model=model) tru_weather = agent.structured_output(type(weather), "The time is 12:00 and the weather is sunny") exp_weather = weather assert tru_weather == exp_weather +@pytest.mark.parametrize("model_fixture", ["streaming_model", "non_streaming_model"]) @pytest.mark.asyncio -async def test_agent_structured_output_async(agent, weather): +async def test_agent_structured_output_async(model_fixture, weather, request): + model = request.getfixturevalue(model_fixture) + agent = Agent(model=model) tru_weather = await agent.structured_output_async(type(weather), "The time is 12:00 and the weather is sunny") exp_weather = weather assert tru_weather == exp_weather @@ -214,6 +237,27 @@ def test_structured_output_unsupported_model(model, nested_weather): mock_schema.assert_not_called() +@pytest.mark.parametrize("model_fixture", ["streaming_model", "non_streaming_model"]) +def test_streaming_returns_usage_metrics(model_fixture, request): + """Test that streaming returns usage metrics. + + This test verifies that the streaming flow correctly extracts and returns + usage data from the model response. This is a regression test for the bug + where accessing 'usage' attribute on ModelResponseStream raised AttributeError. + + Regression test for: 'ModelResponseStream' object has no attribute 'usage' + """ + model = request.getfixturevalue(model_fixture) + agent = Agent(model=model) + result = agent("Say hello") + + # Verify usage metrics are returned - this would fail if streaming breaks + assert result.metrics.accumulated_usage is not None + assert result.metrics.accumulated_usage["inputTokens"] > 0 + assert result.metrics.accumulated_usage["outputTokens"] > 0 + assert result.metrics.accumulated_usage["totalTokens"] > 0 + + @pytest.mark.asyncio async def test_cache_read_tokens_multi_turn(model): """Integration test for cache read tokens in multi-turn conversation.""" @@ -234,3 +278,15 @@ async def test_cache_read_tokens_multi_turn(model): assert result.metrics.accumulated_usage["cacheReadInputTokens"] > 0 assert result.metrics.accumulated_usage["cacheWriteInputTokens"] > 0 + + +def test_gemini_thinking_model_tool_call(tools): + """Test that Gemini thinking models preserve thought_signature through multi-turn tool calls. + + Regression test for https://github.com/strands-agents/sdk-python/issues/1764 + """ + model = LiteLLMModel(model_id="gemini/gemini-2.5-flash", client_args={"api_key": os.environ.get("GOOGLE_API_KEY")}) + agent = Agent(model=model, tools=tools) + result = agent("What is the time and weather in New York?") + text = result.message["content"][0]["text"].lower() + assert all(string in text for string in ["12:00", "sunny"]) diff --git a/tests_integ/models/test_model_mantle.py b/tests_integ/models/test_model_mantle.py new file mode 100644 index 000000000..9a432d993 --- /dev/null +++ b/tests_integ/models/test_model_mantle.py @@ -0,0 +1,81 @@ +"""Integration tests for OpenAI-compatible APIs on Bedrock Mantle. + +Exercises the ``bedrock_mantle_config`` pathway on ``OpenAIModel`` (Chat Completions) and +``OpenAIResponsesModel`` (Responses API) against the live +``bedrock-mantle..api.aws/v1`` endpoint. Credentials come from the +ambient AWS credential chain; no explicit API key is passed by the user. +""" + +import pytest + +from strands import Agent +from strands.models.openai import OpenAIModel +from strands.models.openai_responses import OpenAIResponsesModel + +_REGION = "us-east-1" +_MODEL_ID = "openai.gpt-oss-120b" + + +@pytest.fixture +def bedrock_mantle_config(): + return {"region": _REGION} + + +@pytest.fixture +def chat_completions_model(bedrock_mantle_config): + return OpenAIModel(model_id=_MODEL_ID, bedrock_mantle_config=bedrock_mantle_config) + + +@pytest.fixture +def model(bedrock_mantle_config): + return OpenAIResponsesModel(model_id=_MODEL_ID, bedrock_mantle_config=bedrock_mantle_config) + + +@pytest.fixture +def stateful_model(bedrock_mantle_config): + return OpenAIResponsesModel(model_id=_MODEL_ID, stateful=True, bedrock_mantle_config=bedrock_mantle_config) + + +def test_chat_completions_agent_invoke(chat_completions_model): + """OpenAIModel (Chat Completions) reaches Mantle via bedrock_mantle_config.""" + agent = Agent(model=chat_completions_model, system_prompt="Reply in one short sentence.", callback_handler=None) + result = agent("What is 2+2?") + assert "4" in str(result) or "four" in str(result).lower() + + +def test_agent_invoke(model): + agent = Agent(model=model, system_prompt="Reply in one short sentence.", callback_handler=None) + result = agent("What is 2+2?") + assert "4" in str(result) or "four" in str(result).lower() + + +def test_responses_server_side_conversation(stateful_model): + agent = Agent(model=stateful_model, system_prompt="Reply in one short sentence.", callback_handler=None) + + agent("My name is Alice.") + assert len(agent.messages) == 0 + + result = agent("What is my name?") + assert "alice" in str(result).lower() + + +def test_reasoning_content_multi_turn(bedrock_mantle_config): + """Test that reasoning content from gpt-oss models doesn't break multi-turn conversations.""" + model = OpenAIResponsesModel( + model_id=_MODEL_ID, + bedrock_mantle_config=bedrock_mantle_config, + params={"reasoning": {"effort": "low"}}, + ) + agent = Agent(model=model, system_prompt="Reply in one short sentence.", callback_handler=None) + + result1 = agent("What is 2+2?") + assert "4" in str(result1) + + # Verify reasoning content was produced + has_reasoning = any( + "reasoningContent" in block for msg in agent.messages if msg["role"] == "assistant" for block in msg["content"] + ) + assert has_reasoning + + # Second turn should not raise despite reasoningContent in message history + agent("What about 3+3?") diff --git a/tests_integ/models/test_model_mistral.py b/tests_integ/models/test_model_mistral.py index 3b13e5911..83f6af499 100644 --- a/tests_integ/models/test_model_mistral.py +++ b/tests_integ/models/test_model_mistral.py @@ -106,6 +106,11 @@ async def test_agent_stream_async(agent): assert all(string in text for string in ["12:00", "sunny"]) + assert result.metrics.accumulated_usage is not None + assert result.metrics.accumulated_usage["inputTokens"] > 0 + assert result.metrics.accumulated_usage["outputTokens"] > 0 + assert result.metrics.accumulated_usage["totalTokens"] > 0 + def test_agent_structured_output(non_streaming_agent, weather): tru_weather = non_streaming_agent.structured_output(type(weather), "The time is 12:00 and the weather is sunny") diff --git a/tests_integ/models/test_model_openai.py b/tests_integ/models/test_model_openai.py index feb591d1a..6011f2b71 100644 --- a/tests_integ/models/test_model_openai.py +++ b/tests_integ/models/test_model_openai.py @@ -1,23 +1,38 @@ import os -import unittest.mock +import tempfile +import time +import openai as openai_sdk import pydantic import pytest import strands from strands import Agent, tool +from strands.event_loop._retry import ModelRetryStrategy from strands.models.openai import OpenAIModel from strands.types.exceptions import ContextWindowOverflowException, ModelThrottledException from tests_integ.models import providers +from tests_integ.models.providers import _openai_responses_available + +if _openai_responses_available: + from strands.models.openai_responses import OpenAIResponsesModel # these tests only run if we have the openai api key pytestmark = providers.openai.mark -@pytest.fixture -def model(): - return OpenAIModel( - model_id="gpt-4o", +def _model_params(): + params = [(OpenAIModel, "gpt-4o")] + if _openai_responses_available: + params.append((OpenAIResponsesModel, "gpt-4o")) + return params + + +@pytest.fixture(params=_model_params()) +def model(request): + model_class, model_id = request.param + return model_class( + model_id=model_id, client_args={ "api_key": os.getenv("OPENAI_API_KEY"), }, @@ -45,10 +60,10 @@ def agent(model, tools): @pytest.fixture def weather(): class Weather(pydantic.BaseModel): - """Extracts the time and weather from the user's message with the exact strings.""" + """Extract time and weather values.""" - time: str - weather: str + time: str = pydantic.Field(description="The time value only, e.g. '14:30' not 'The time is 14:30'") + weather: str = pydantic.Field(description="The weather condition only, e.g. 'rainy' not 'the weather is rainy'") return Weather(time="12:00", weather="sunny") @@ -68,12 +83,37 @@ def lower(_, value): return Color(name="yellow") +@pytest.fixture(scope="module") +def openai_vector_store(): + """Create a vector store with a test file for file_search tests.""" + client = openai_sdk.OpenAI(api_key=os.getenv("OPENAI_API_KEY")) + + with tempfile.NamedTemporaryFile(mode="w", suffix=".txt") as f: + f.write("The secret code is ALPHA-7742.") + f.flush() + file_obj = client.files.create(file=open(f.name, "rb"), purpose="assistants") + + vector_store = client.vector_stores.create(name="test-builtin-tools") + try: + client.vector_stores.files.create(vector_store_id=vector_store.id, file_id=file_obj.id) + + for _ in range(30): + if client.vector_stores.retrieve(vector_store.id).file_counts.completed > 0: + break + time.sleep(1) + + yield vector_store.id + finally: + client.vector_stores.delete(vector_store.id) + client.files.delete(file_obj.id) + + @pytest.fixture(scope="module") def test_image_path(request): return request.config.rootpath / "tests_integ" / "test_image.png" -def test_agent_invoke(agent): +def test_agent_invoke(agent, model): result = agent("What is the time and weather in New York?") text = result.message["content"][0]["text"].lower() @@ -81,7 +121,7 @@ def test_agent_invoke(agent): @pytest.mark.asyncio -async def test_agent_invoke_async(agent): +async def test_agent_invoke_async(agent, model): result = await agent.invoke_async("What is the time and weather in New York?") text = result.message["content"][0]["text"].lower() @@ -89,7 +129,7 @@ async def test_agent_invoke_async(agent): @pytest.mark.asyncio -async def test_agent_stream_async(agent): +async def test_agent_stream_async(agent, model): stream = agent.stream_async("What is the time and weather in New York?") async for event in stream: _ = event @@ -148,7 +188,6 @@ def test_structured_output_multi_modal_input(agent, yellow_img, yellow_color): assert tru_color == exp_color -@pytest.mark.skip("https://github.com/strands-agents/sdk-python/issues/320") def test_tool_returning_images(model, yellow_img): @tool def tool_with_image_return(): @@ -171,15 +210,23 @@ def tool_with_image_return(): agent("Run the the tool and analyze the image") -def test_context_window_overflow_integration(): +def _mini_model_params(): + params = [(OpenAIModel, "gpt-4o-mini-2024-07-18")] + if _openai_responses_available: + params.append((OpenAIResponsesModel, "gpt-4o-mini-2024-07-18")) + return params + + +@pytest.mark.parametrize("model_class,model_id", _mini_model_params()) +def test_context_window_overflow_integration(model_class, model_id): """Integration test for context window overflow with OpenAI. This test verifies that when a request exceeds the model's context window, the OpenAI model properly raises a ContextWindowOverflowException. """ # Use gpt-4o-mini which has a smaller context window to make this test more reliable - mini_model = OpenAIModel( - model_id="gpt-4o-mini-2024-07-18", + mini_model = model_class( + model_id=model_id, client_args={ "api_key": os.getenv("OPENAI_API_KEY"), }, @@ -199,28 +246,39 @@ def test_context_window_overflow_integration(): agent(long_text) -def test_rate_limit_throttling_integration_no_retries(model): +def _rate_limit_params(): + params = [(OpenAIModel, "gpt-4o")] + if _openai_responses_available: + params.append((OpenAIResponsesModel, "gpt-4o")) + return params + + +def test_rate_limit_throttling_integration_no_retries(): """Integration test for rate limit handling with retries disabled. This test verifies that when a request exceeds OpenAI's rate limits, the model properly raises a ModelThrottledException. We disable retries to avoid waiting for the exponential backoff during testing. """ - # Patch the event loop constants to disable retries for this test - with unittest.mock.patch("strands.event_loop.event_loop.MAX_ATTEMPTS", 1): - agent = Agent(model=model) + model = OpenAIModel( + model_id="gpt-4o", + client_args={ + "api_key": os.getenv("OPENAI_API_KEY"), + }, + ) + agent = Agent(model=model, retry_strategy=ModelRetryStrategy(max_attempts=1)) - # Create a message that's very long to trigger token-per-minute rate limits - # This should be large enough to exceed TPM limits immediately - very_long_text = "Really long text " * 20000 + # Create a message that's very long to trigger token-per-minute rate limits + # This should be large enough to exceed TPM limits immediately + very_long_text = "Really long text " * 600000 - # This should raise ModelThrottledException without retries - with pytest.raises(ModelThrottledException) as exc_info: - agent(very_long_text) + # This should raise ModelThrottledException without retries + with pytest.raises(ModelThrottledException) as exc_info: + agent(very_long_text) - # Verify it's a rate limit error - error_message = str(exc_info.value).lower() - assert "rate limit" in error_message or "tokens per min" in error_message + # Verify it's a rate limit error + error_message = str(exc_info.value).lower() + assert "rate_limit_exceeded" in error_message def test_content_blocks_handling(model): @@ -257,3 +315,127 @@ def test_system_prompt_backward_compatibility_integration(model): # The response should contain our specific system prompt instruction assert "BACKWARD_COMPAT_TEST" in result.message["content"][0]["text"] + + +@pytest.mark.skipif(not _openai_responses_available, reason="OpenAI Responses API not available") +def test_responses_server_side_conversation(): + """Integration test for server-side conversation state management. + + Verifies that when stateful=True, the model tracks conversation across turns + via previous_response_id and the agent clears messages between invocations. + """ + model = OpenAIResponsesModel( + model_id="gpt-4o-mini", + stateful=True, + client_args={"api_key": os.getenv("OPENAI_API_KEY")}, + ) + agent = Agent(model=model, system_prompt="Reply in one short sentence.") + + agent("My name is Alice.") + assert len(agent.messages) == 0 + + result = agent("What is my name?") + assert "alice" in result.message["content"][0]["text"].lower() + + +@pytest.mark.skipif(not _openai_responses_available, reason="OpenAI Responses API not available") +def test_responses_builtin_tool_web_search(): + """Test that web_search produces text with citation content.""" + model = OpenAIResponsesModel( + model_id="gpt-4o", + params={"tools": [{"type": "web_search"}]}, + client_args={"api_key": os.getenv("OPENAI_API_KEY")}, + ) + agent = Agent(model=model, system_prompt="Answer concisely.", callback_handler=None) + + result = agent("Search https://strandsagents.com/ and tell me what Strands Agents is.") + content = result.message["content"][0] + + assert "citationsContent" in content + citations = content["citationsContent"]["citations"] + assert any("strandsagents.com" in c["location"]["web"]["url"] for c in citations) + + +@pytest.mark.skipif(not _openai_responses_available, reason="OpenAI Responses API not available") +def test_responses_builtin_tool_file_search(openai_vector_store): + """Test that file_search produces text output from uploaded files.""" + model = OpenAIResponsesModel( + model_id="gpt-4o", + params={"tools": [{"type": "file_search", "vector_store_ids": [openai_vector_store]}]}, + client_args={"api_key": os.getenv("OPENAI_API_KEY")}, + ) + agent = Agent(model=model, system_prompt="Answer based on the files.", callback_handler=None) + + result = agent("What is the secret code?") + text = result.message["content"][0]["text"] + assert "ALPHA-7742" in text + + +@pytest.mark.skipif(not _openai_responses_available, reason="OpenAI Responses API not available") +def test_responses_builtin_tool_code_interpreter(): + """Test that code_interpreter produces correct results via text output.""" + model = OpenAIResponsesModel( + model_id="gpt-4o", + params={"tools": [{"type": "code_interpreter", "container": {"type": "auto"}}]}, + client_args={"api_key": os.getenv("OPENAI_API_KEY")}, + ) + agent = Agent(model=model, system_prompt="Answer concisely.", callback_handler=None) + + # SHA-256 of "strands" requires actual computation + result = agent("Compute the SHA-256 hash of the string 'strands'. Return only the hex digest.") + text = result.message["content"][0]["text"] + assert "11e0e34bd35e12185cfacd5e5a256ab4292bfa3616d8d5b74e20eca36feed228" in text + + +@pytest.mark.skipif(not _openai_responses_available, reason="OpenAI Responses API not available") +def test_responses_builtin_tool_shell(): + """Test that the shell built-in tool executes commands in a hosted container.""" + model = OpenAIResponsesModel( + model_id="gpt-5.4-mini", + params={"tools": [{"type": "shell", "environment": {"type": "container_auto"}}]}, + client_args={"api_key": os.getenv("OPENAI_API_KEY")}, + ) + agent = Agent(model=model, system_prompt="Answer concisely.", callback_handler=None) + + result = agent("Use the shell to compute the md5sum of the string 'strands-test'. Return only the hash.") + text = result.message["content"][0]["text"] + assert "d82f373f079b00a1db7ef1eec7f15c68" in text + + +class TestOpenAIResponsesCountTokens: + @pytest.fixture + def model(self): + return OpenAIResponsesModel( + model_id="gpt-4o", + client_args={"api_key": os.environ["OPENAI_API_KEY"]}, + use_native_token_count=True, + ) + + @pytest.fixture + def messages(self): + return [{"role": "user", "content": [{"text": "What is the capital of France? Explain in detail."}]}] + + @pytest.fixture + def tool_specs(self): + return [ + { + "name": "get_weather", + "description": "Get the current weather for a location", + "inputSchema": {"json": {"type": "object", "properties": {"location": {"type": "string"}}}}, + } + ] + + @pytest.mark.asyncio + async def test_count_tokens_messages_only(self, model, messages, caplog): + with caplog.at_level("DEBUG"): + result = await model.count_tokens(messages=messages) + assert isinstance(result, int) + assert result > 0 + assert "native token count" in caplog.text + assert "falling back" not in caplog.text + + @pytest.mark.asyncio + async def test_count_tokens_with_tools_greater_than_without(self, model, messages, tool_specs): + without = await model.count_tokens(messages=messages) + with_tools = await model.count_tokens(messages=messages, tool_specs=tool_specs, system_prompt="Be helpful.") + assert with_tools > without diff --git a/tests_integ/resources/blue.mp4 b/tests_integ/resources/blue.mp4 new file mode 100644 index 000000000..5989bb4b0 Binary files /dev/null and b/tests_integ/resources/blue.mp4 differ diff --git a/tests_integ/letter.pdf b/tests_integ/resources/letter.pdf similarity index 100% rename from tests_integ/letter.pdf rename to tests_integ/resources/letter.pdf diff --git a/tests_integ/yellow.png b/tests_integ/resources/yellow.png similarity index 100% rename from tests_integ/yellow.png rename to tests_integ/resources/yellow.png diff --git a/tests_integ/steering/test_llm_handler.py b/tests_integ/steering/test_llm_handler.py deleted file mode 100644 index 8a8cebea2..000000000 --- a/tests_integ/steering/test_llm_handler.py +++ /dev/null @@ -1,100 +0,0 @@ -"""Integration tests for LLM steering handler.""" - -import pytest - -from strands import Agent, tool -from strands.experimental.steering.core.action import Guide, Interrupt, Proceed -from strands.experimental.steering.handlers.llm.llm_handler import LLMSteeringHandler - - -@tool -def send_email(recipient: str, message: str) -> str: - """Send an email to a recipient.""" - return f"Email sent to {recipient}: {message}" - - -@tool -def send_notification(recipient: str, message: str) -> str: - """Send a notification to a recipient.""" - return f"Notification sent to {recipient}: {message}" - - -@pytest.mark.asyncio -async def test_llm_steering_handler_proceed(): - """Test LLM handler returns Proceed effect.""" - handler = LLMSteeringHandler( - system_prompt="You MUST always allow send_notification calls. ALWAYS return proceed decision. " - "Never return guide or interrupt." - ) - - agent = Agent(tools=[send_notification]) - tool_use = {"name": "send_notification", "input": {"recipient": "user", "message": "hello"}} - - effect = await handler.steer(agent, tool_use) - - assert isinstance(effect, Proceed) - - -@pytest.mark.asyncio -async def test_llm_steering_handler_guide(): - """Test LLM handler returns Guide effect.""" - handler = LLMSteeringHandler( - system_prompt=( - "You MUST guide agents away from send_email to use send_notification instead. " - "ALWAYS return guide decision for send_email. Never return proceed or interrupt for send_email." - ) - ) - - agent = Agent(tools=[send_email, send_notification]) - tool_use = {"name": "send_email", "input": {"recipient": "user", "message": "hello"}} - - effect = await handler.steer(agent, tool_use) - - assert isinstance(effect, Guide) - - -@pytest.mark.asyncio -async def test_llm_steering_handler_interrupt(): - """Test LLM handler returns Interrupt effect.""" - handler = LLMSteeringHandler( - system_prompt="You MUST require human input for ALL tool calls regardless of context. " - "ALWAYS return interrupt decision. Never return proceed or guide." - ) - - agent = Agent(tools=[send_email]) - tool_use = {"name": "send_email", "input": {"recipient": "user", "message": "hello"}} - - effect = await handler.steer(agent, tool_use) - - assert isinstance(effect, Interrupt) - - -def test_agent_with_steering_e2e(): - """End-to-end test of agent with steering handler guiding tool choice.""" - handler = LLMSteeringHandler( - system_prompt=( - "When agents try to use send_email, guide them to use send_notification instead for better delivery." - ) - ) - - agent = Agent(tools=[send_email, send_notification], hooks=[handler]) - - # This should trigger steering guidance to use send_notification instead - response = agent("Send an email to john@example.com saying hello") - - # Verify tool call metrics show the expected sequence: - # 1. send_email was attempted but cancelled (should have 0 success_count) - # 2. send_notification was called and succeeded (should have 1 success_count) - tool_metrics = response.metrics.tool_metrics - - # send_email should have been attempted but cancelled (no successful calls) - if "send_email" in tool_metrics: - email_metrics = tool_metrics["send_email"] - assert email_metrics.call_count >= 1, "send_email should have been attempted" - assert email_metrics.success_count == 0, "send_email should have been cancelled by steering" - - # send_notification should have been called and succeeded - assert "send_notification" in tool_metrics, "send_notification should have been called" - notification_metrics = tool_metrics["send_notification"] - assert notification_metrics.call_count >= 1, "send_notification should have been called" - assert notification_metrics.success_count >= 1, "send_notification should have succeeded" diff --git a/tests_integ/steering/test_model_steering.py b/tests_integ/steering/test_model_steering.py new file mode 100644 index 000000000..86c69fd50 --- /dev/null +++ b/tests_integ/steering/test_model_steering.py @@ -0,0 +1,214 @@ +"""Integration tests for model steering (steer_after_model).""" + +from strands import Agent, tool +from strands.types.content import Message +from strands.types.streaming import StopReason +from strands.vended_plugins.steering.context_providers.ledger_provider import LedgerProvider +from strands.vended_plugins.steering.core.action import Guide, ModelSteeringAction, Proceed +from strands.vended_plugins.steering.core.handler import SteeringHandler + + +class SimpleModelSteeringHandler(SteeringHandler): + """Simple handler that steers only on model responses.""" + + def __init__(self, should_guide: bool = False, guidance_message: str = ""): + """Initialize handler. + + Args: + should_guide: If True, guide (retry) on first model response + guidance_message: The guidance message to provide on retry + """ + super().__init__() + self.should_guide = should_guide + self.guidance_message = guidance_message + self.call_count = 0 + + async def steer_after_model( + self, *, agent: Agent, message: Message, stop_reason: StopReason, **kwargs + ) -> ModelSteeringAction: + """Steer after model response.""" + self.call_count += 1 + + # On first call, guide to retry if configured + if self.should_guide and self.call_count == 1: + return Guide(reason=self.guidance_message) + + return Proceed(reason="Model response accepted") + + +def test_model_steering_proceeds_without_intervention(): + """Test that model steering can accept responses without modification.""" + handler = SimpleModelSteeringHandler(should_guide=False) + agent = Agent(plugins=[handler]) + + response = agent("What is 2+2?") + + # Handler should have been called once + assert handler.call_count >= 1 + # Response should be generated successfully + response_text = str(response) + assert response_text is not None + assert len(response_text) > 0 + + +def test_model_steering_guide_triggers_retry(): + """Test that Guide action triggers model retry.""" + handler = SimpleModelSteeringHandler(should_guide=True, guidance_message="Please provide a more detailed response.") + agent = Agent(plugins=[handler]) + + response = agent("What is the capital of France?") + + # Handler should have been called at least twice (first response + retry) + assert handler.call_count >= 2, "Handler should be called on initial response and retry" + + # Response should be generated successfully after retry + response_text = str(response) + assert response_text is not None + assert len(response_text) > 0 + + +def test_model_steering_guide_influences_retry_response(): + """Test that guidance message influences the retry response.""" + + class SpecificGuidanceHandler(SteeringHandler): + def __init__(self): + super().__init__() + self.retry_done = False + + async def steer_after_model( + self, *, agent: Agent, message: Message, stop_reason: StopReason, **kwargs + ) -> ModelSteeringAction: + if not self.retry_done: + self.retry_done = True + # Provide very specific guidance that should appear in retry + return Guide(reason="Please mention that Paris is also known as the 'City of Light'.") + return Proceed(reason="Response is good now") + + handler = SpecificGuidanceHandler() + agent = Agent(plugins=[handler]) + + response = agent("What is the capital of France?") + + # Verify retry happened + assert handler.retry_done, "Retry should have occurred" + + # Check that the response likely incorporated the guidance + output = str(response).lower() + assert "paris" in output, "Response should mention Paris" + + # The guidance should have influenced the retry (check for "light" or that retry happened) + # We can't guarantee the model will include it, but we verify the mechanism worked + assert handler.retry_done, "Guidance mechanism should have executed" + + +def test_model_steering_multiple_retries(): + """Test that model steering can guide multiple times before proceeding.""" + + class MultiRetryHandler(SteeringHandler): + def __init__(self): + super().__init__() + self.call_count = 0 + + async def steer_after_model( + self, *, agent: Agent, message: Message, stop_reason: StopReason, **kwargs + ) -> ModelSteeringAction: + self.call_count += 1 + + # Retry twice + if self.call_count == 1: + return Guide(reason="Please provide more context.") + if self.call_count == 2: + return Guide(reason="Please add specific examples.") + return Proceed(reason="Response is good now") + + handler = MultiRetryHandler() + agent = Agent(plugins=[handler]) + + response = agent("Explain machine learning.") + + # Should have been called 3 times (2 guides + 1 proceed) + assert handler.call_count >= 3, "Handler should be called multiple times for multiple retries" + + # Response should still complete successfully + assert str(response) is not None + assert len(str(response)) > 0 + + +@tool +def log_activity(activity: str) -> str: + """Log an activity for audit purposes.""" + return f"Activity logged: {activity}" + + +def test_model_steering_forces_tool_usage_on_unrelated_prompt(): + """Test that steering forces tool usage even when prompt doesn't need the tool. + + This test verifies the flow: + 1. Agent has a logging tool available + 2. User asks an unrelated question (math problem) + 3. Model tries to answer directly without using the tool + 4. Steering intercepts and forces tool usage before termination + 5. Model uses the tool and then completes + """ + + class ForceToolUsageHandler(SteeringHandler): + """Handler that forces a specific tool to be used before allowing termination.""" + + def __init__(self, required_tool: str): + super().__init__(context_providers=[LedgerProvider()]) + self.required_tool = required_tool + self.tool_was_used = False + self.guidance_given = False + + async def steer_after_model( + self, *, agent: Agent, message: Message, stop_reason: StopReason, **kwargs + ) -> ModelSteeringAction: + # Only check when model is trying to end the turn + if stop_reason != "end_turn": + return Proceed(reason="Model still processing") + + # Check if the required tool was used in this message + content_blocks = message.get("content", []) + for block in content_blocks: + if "toolUse" in block and block["toolUse"].get("name") == self.required_tool: + self.tool_was_used = True + + # Verify tool is in the ledger + ledger = self.steering_context.data.get("ledger") + if ledger: + tool_calls = ledger.get("tool_calls", []) + assert any(tc.get("tool_name") == self.required_tool for tc in tool_calls), ( + f"{self.required_tool} should be in ledger when tool_was_used=True" + ) + + return Proceed(reason="Required tool was used") + + # If tool wasn't used and we haven't guided yet, force its usage + if not self.tool_was_used and not self.guidance_given: + self.guidance_given = True + return Guide( + reason=f"Before completing your response, you MUST use the {self.required_tool} tool " + "to log this interaction. Call the tool with a brief description of what you did." + ) + + # Allow completion after guidance was given (model may have used tool in retry) + return Proceed(reason="Guidance was provided") + + handler = ForceToolUsageHandler(required_tool="log_activity") + agent = Agent(tools=[log_activity], plugins=[handler]) + + # Ask a question that clearly doesn't need the logging tool + response = agent("What is 2 + 2?") + + # Verify the steering mechanism worked + assert handler.guidance_given, "Handler should have provided guidance to use the tool" + + # Verify tool was actually called by checking metrics + tool_metrics = response.metrics.tool_metrics + assert "log_activity" in tool_metrics, "log_activity tool should have been called" + assert tool_metrics["log_activity"].call_count >= 1, "log_activity should have been called at least once" + assert tool_metrics["log_activity"].success_count >= 1, "log_activity should have succeeded" + + # Verify the response still answers the original question + output = str(response).lower() + assert "4" in output, "Response should contain the answer to 2+2" diff --git a/tests_integ/steering/test_tool_steering.py b/tests_integ/steering/test_tool_steering.py new file mode 100644 index 000000000..4b279157e --- /dev/null +++ b/tests_integ/steering/test_tool_steering.py @@ -0,0 +1,157 @@ +"""Integration tests for tool steering (steer_before_tool).""" + +import pytest + +from strands import Agent, tool +from strands.vended_plugins.steering.context_providers.ledger_provider import LedgerProvider +from strands.vended_plugins.steering.core.action import Guide, Interrupt, Proceed +from strands.vended_plugins.steering.core.handler import SteeringHandler +from strands.vended_plugins.steering.handlers.llm.llm_handler import LLMSteeringHandler + + +@tool +def send_email(recipient: str, message: str) -> str: + """Send an email to a recipient.""" + return f"Email sent to {recipient}: {message}" + + +@tool +def send_notification(recipient: str, message: str) -> str: + """Send a notification to a recipient.""" + return f"Notification sent to {recipient}: {message}" + + +@pytest.mark.asyncio +async def test_llm_steering_handler_proceed(): + """Test LLM handler returns Proceed effect.""" + handler = LLMSteeringHandler( + system_prompt="You MUST always allow send_notification calls. ALWAYS return proceed decision. " + "Never return guide or interrupt." + ) + + agent = Agent(tools=[send_notification]) + tool_use = {"name": "send_notification", "input": {"recipient": "user", "message": "hello"}} + + effect = await handler.steer_before_tool(agent=agent, tool_use=tool_use) + + assert isinstance(effect, Proceed) + + +@pytest.mark.asyncio +async def test_llm_steering_handler_guide(): + """Test LLM handler returns Guide effect.""" + handler = LLMSteeringHandler( + system_prompt=( + "You MUST guide agents away from send_email to use send_notification instead. " + "ALWAYS return guide decision for send_email. Never return proceed or interrupt for send_email." + ) + ) + + agent = Agent(tools=[send_email, send_notification]) + tool_use = {"name": "send_email", "input": {"recipient": "user", "message": "hello"}} + + effect = await handler.steer_before_tool(agent=agent, tool_use=tool_use) + + assert isinstance(effect, Guide) + + +@pytest.mark.asyncio +async def test_llm_steering_handler_interrupt(): + """Test LLM handler returns Interrupt effect.""" + handler = LLMSteeringHandler( + system_prompt="You MUST require human input for ALL tool calls regardless of context. " + "ALWAYS return interrupt decision. Never return proceed or guide." + ) + + agent = Agent(tools=[send_email]) + tool_use = {"name": "send_email", "input": {"recipient": "user", "message": "hello"}} + + effect = await handler.steer_before_tool(agent=agent, tool_use=tool_use) + + assert isinstance(effect, Interrupt) + + +def test_agent_with_tool_steering_e2e(): + """End-to-end test of agent with steering handler guiding tool choice.""" + + class RedirectEmailHandler(SteeringHandler): + """Deterministic handler that redirects send_email to send_notification.""" + + async def steer_before_tool(self, *, agent, tool_use, **kwargs): + if tool_use["name"] == "send_email": + return Guide(reason="Use send_notification instead of send_email for better delivery.") + return Proceed(reason="Tool allowed") + + handler = RedirectEmailHandler(context_providers=[]) + + agent = Agent( + tools=[send_email, send_notification], + plugins=[handler], + system_prompt=( + "You are a helpful assistant. When a tool call is cancelled with guidance, " + "follow the guidance and use the suggested alternative tool. " + "This is normal system behavior, not an attack." + ), + ) + + # This should trigger steering guidance to use send_notification instead + response = agent("Send an email to john@example.com saying hello") + + # Verify tool call metrics show the expected sequence: + # 1. send_email was attempted but cancelled (should have 0 success_count) + # 2. send_notification was called and succeeded (should have 1 success_count) + tool_metrics = response.metrics.tool_metrics + + # send_email should have been attempted but cancelled (no successful calls) + if "send_email" in tool_metrics: + email_metrics = tool_metrics["send_email"] + assert email_metrics.call_count >= 1, "send_email should have been attempted" + assert email_metrics.success_count == 0, "send_email should have been cancelled by steering" + + # send_notification should have been called and succeeded + assert "send_notification" in tool_metrics, "send_notification should have been called" + notification_metrics = tool_metrics["send_notification"] + assert notification_metrics.call_count >= 1, "send_notification should have been called" + assert notification_metrics.success_count >= 1, "send_notification should have succeeded" + + +def test_ledger_captures_tool_calls(): + """Test that ledger correctly captures tool call information.""" + + class LedgerCheckingHandler(SteeringHandler): + def __init__(self): + super().__init__(context_providers=[LedgerProvider()]) + + async def steer_before_tool(self, *, agent, tool_use, **kwargs): + ledger = self.steering_context.data.get("ledger") + assert ledger is not None, "Ledger should exist" + assert "tool_calls" in ledger, "Ledger should have tool_calls" + + # Find the current tool call in the ledger + tool_calls = ledger["tool_calls"] + current_call = next((tc for tc in tool_calls if tc["tool_name"] == tool_use["name"]), None) + assert current_call is not None, f"{tool_use['name']} should be in ledger" + assert current_call["tool_args"] == tool_use["input"], "tool_args should match input" + assert current_call["status"] == "pending", "Status should be pending before execution" + + return Proceed(reason="Ledger verified") + + handler = LedgerCheckingHandler() + agent = Agent(tools=[send_notification], plugins=[handler]) + + agent("Send a notification to alice saying test message") + + # Verify the ledger has the completed tool call + ledger = handler.steering_context.data.get("ledger") + assert ledger is not None + assert len(ledger["tool_calls"]) >= 1, "At least one tool call should be recorded" + + # Check the tool call details + tool_call = ledger["tool_calls"][-1] + assert tool_call["tool_name"] == "send_notification" + assert "tool_args" in tool_call + assert tool_call["tool_args"]["recipient"] == "alice" + assert tool_call["tool_args"]["message"] == "test message" + assert tool_call["status"] == "success" + assert "completion_timestamp" in tool_call + assert tool_call["error"] is None diff --git a/tests_integ/test_a2a_executor.py b/tests_integ/test_a2a_executor.py index ddca0bfa6..7ae10efc2 100644 --- a/tests_integ/test_a2a_executor.py +++ b/tests_integ/test_a2a_executor.py @@ -17,7 +17,7 @@ async def test_a2a_executor_with_real_image(): """Test A2A server processes a real image file correctly via HTTP.""" # Read the test image file - test_image_path = os.path.join(os.path.dirname(__file__), "yellow.png") + test_image_path = os.path.join(os.path.dirname(__file__), "resources/yellow.png") with open(test_image_path, "rb") as f: original_image_bytes = f.read() @@ -71,7 +71,13 @@ async def test_a2a_executor_with_real_image(): assert response.status_code == 200 response_data = response.json() assert "completed" == response_data["result"]["status"]["state"] - assert "yellow" in response_data["result"]["history"][1]["parts"][0]["text"].lower() + all_text = " ".join( + part["text"] + for artifact in response_data["result"]["artifacts"] + for part in artifact["parts"] + if part.get("kind") == "text" + ).lower() + assert "yellow" in all_text except Exception as e: pytest.fail(f"Integration test failed: {e}") @@ -80,7 +86,7 @@ async def test_a2a_executor_with_real_image(): def test_a2a_executor_image_roundtrip(): """Test that image data survives the A2A base64 encoding/decoding roundtrip.""" # Read the test image - test_image_path = os.path.join(os.path.dirname(__file__), "yellow.png") + test_image_path = os.path.join(os.path.dirname(__file__), "resources/yellow.png") with open(test_image_path, "rb") as f: original_bytes = f.read() diff --git a/tests_integ/test_agent_as_tool.py b/tests_integ/test_agent_as_tool.py new file mode 100644 index 000000000..a808fcd23 --- /dev/null +++ b/tests_integ/test_agent_as_tool.py @@ -0,0 +1,36 @@ +import pytest + +from strands import Agent, tool + + +@tool +def get_tiger_height() -> int: + """Returns the height of a tiger in centimeters.""" + return 100 + + +@pytest.mark.asyncio +async def test_stream_async_with_agent_tool(): + inner_agent = Agent( + name="myAgentTool", + description="An agent tool knowledgeable about tigers", + tools=[get_tiger_height], + ) + agent_tool = inner_agent.as_tool() + agent = Agent( + name="myOtherAgent", + tools=[agent_tool], + ) + + result = await agent.invoke_async( + prompt="Invoke the myAgentTool and ask about the height of tigers.", + ) + + # Outer agent completed and called the agent tool + assert result.stop_reason == "end_turn" + assert "myAgentTool" in result.metrics.tool_metrics + assert result.metrics.tool_metrics["myAgentTool"].success_count >= 1 + + # Inner agent called get_tiger_height + assert "get_tiger_height" in inner_agent.event_loop_metrics.tool_metrics + assert inner_agent.event_loop_metrics.tool_metrics["get_tiger_height"].success_count >= 1 diff --git a/tests_integ/test_bedrock_guardrails.py b/tests_integ/test_bedrock_guardrails.py index 37fa6028c..384231c38 100644 --- a/tests_integ/test_bedrock_guardrails.py +++ b/tests_integ/test_bedrock_guardrails.py @@ -8,6 +8,7 @@ from strands import Agent, tool from strands.models.bedrock import BedrockModel from strands.session.file_session_manager import FileSessionManager +from tests_integ.conftest import retry_on_flaky BLOCKED_INPUT = "BLOCKED_INPUT" BLOCKED_OUTPUT = "BLOCKED_OUTPUT" @@ -132,6 +133,7 @@ def test_guardrail_input_intervention(boto_session, bedrock_guardrail, guardrail @pytest.mark.parametrize("processing_mode", ["sync", "async"]) def test_guardrail_output_intervention(boto_session, bedrock_guardrail, processing_mode): bedrock_model = BedrockModel( + model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", guardrail_id=bedrock_guardrail, guardrail_version="DRAFT", guardrail_redact_output=False, @@ -170,9 +172,11 @@ def test_guardrail_output_intervention(boto_session, bedrock_guardrail, processi ) +@retry_on_flaky("LLM may mention CACTUS unprompted, triggering guardrail on response2") @pytest.mark.parametrize("guardrail_trace", ["enabled", "enabled_full"]) @pytest.mark.parametrize("processing_mode", ["sync", "async"]) def test_guardrail_output_intervention_redact_output(bedrock_guardrail, processing_mode, guardrail_trace): + """Test guardrail output intervention with redaction.""" REDACT_MESSAGE = "Redacted." bedrock_model = BedrockModel( guardrail_id=bedrock_guardrail, @@ -182,23 +186,25 @@ def test_guardrail_output_intervention_redact_output(bedrock_guardrail, processi guardrail_redact_output=True, guardrail_redact_output_message=REDACT_MESSAGE, region_name="us-east-1", + temperature=0, # Use deterministic responses to reduce flakiness ) agent = Agent( model=bedrock_model, - system_prompt="When asked to say the word, say CACTUS.", + system_prompt="When asked to say the word, say CACTUS. Otherwise, respond normally.", callback_handler=None, load_tools_from_directory=False, ) response1 = agent("Say the word.") - response2 = agent("Hello!") + # Use a completely unrelated prompt to reduce likelihood of model volunteering CACTUS + response2 = agent("What is 2+2? Reply with only the number.") assert response1.stop_reason == "guardrail_intervened" """ - In async streaming: The buffering is non-blocking. - Tokens are streamed while Guardrails processes the buffered content in the background. + In async streaming: The buffering is non-blocking. + Tokens are streamed while Guardrails processes the buffered content in the background. This means the response may be returned before Guardrails has finished processing. As a result, we cannot guarantee that the REDACT_MESSAGE is in the response. """ @@ -289,6 +295,51 @@ def list_users() -> str: assert tool_result["content"][0]["text"] == INPUT_REDACT_MESSAGE +def test_guardrail_latest_message(boto_session, bedrock_guardrail, yellow_img): + """Test that guardrail_latest_user_message wraps both text and image in the latest user message.""" + bedrock_model = BedrockModel( + guardrail_id=bedrock_guardrail, + guardrail_version="DRAFT", + guardrail_latest_message=True, + boto_session=boto_session, + ) + + # Create agent with valid content + agent1 = Agent( + model=bedrock_model, + system_prompt="You are a helpful assistant.", + callback_handler=None, + messages=[ + {"role": "user", "content": [{"text": "First message"}]}, + {"role": "assistant", "content": [{"text": "Hello!"}]}, + ], + ) + + response = agent1("What do you see?") + assert response.stop_reason != "guardrail_intervened" + + # Create agent with multimodal content in latest user message + agent2 = Agent( + model=bedrock_model, + system_prompt="You are a helpful assistant.", + callback_handler=None, + messages=[ + {"role": "user", "content": [{"text": "First message"}]}, + {"role": "assistant", "content": [{"text": "Hello!"}]}, + { + "role": "user", + "content": [ + {"text": "CACTUS"}, + {"image": {"format": "png", "source": {"bytes": yellow_img}}}, + ], + }, + ], + ) + + response = agent2("What do you see?") + assert response.stop_reason == "guardrail_intervened" + + def test_guardrail_input_intervention_properly_redacts_in_session(boto_session, bedrock_guardrail, temp_dir): bedrock_model = BedrockModel( guardrail_id=bedrock_guardrail, diff --git a/tests_integ/test_bedrock_s3_location.py b/tests_integ/test_bedrock_s3_location.py new file mode 100644 index 000000000..9b28e88be --- /dev/null +++ b/tests_integ/test_bedrock_s3_location.py @@ -0,0 +1,177 @@ +"""Integration tests for S3 location support in media content types.""" + +import time + +import boto3 +import pytest + +from strands import Agent +from strands.models.bedrock import BedrockModel + + +@pytest.fixture +def boto_session(): + """Create a boto3 session for testing.""" + return boto3.Session(region_name="us-west-2") + + +@pytest.fixture +def account_id(boto_session): + """Get the current AWS account ID.""" + sts_client = boto_session.client("sts") + return sts_client.get_caller_identity()["Account"] + + +@pytest.fixture +def s3_client(boto_session): + """Create an S3 client.""" + return boto_session.client("s3") + + +@pytest.fixture +def test_bucket(s3_client, account_id): + """Create a test S3 bucket for the tests. + + Creates a bucket with account-specific name and cleans it up after tests. + """ + bucket_name = f"strands-integ-tests-resources-{account_id}" + + # Create the bucket if it doesn't exist + try: + s3_client.head_bucket(Bucket=bucket_name) + print(f"Bucket {bucket_name} already exists") + except s3_client.exceptions.ClientError: + try: + s3_client.create_bucket( + Bucket=bucket_name, + CreateBucketConfiguration={"LocationConstraint": "us-west-2"}, + ) + print(f"Created test bucket: {bucket_name}") + # Wait for bucket to be available + time.sleep(2) + except s3_client.exceptions.BucketAlreadyOwnedByYou: + print(f"Bucket {bucket_name} already exists") + + yield bucket_name + + # Note: We don't delete the bucket to allow reuse across test runs + # Objects will be overwritten on subsequent runs + + +@pytest.fixture +def s3_document(s3_client, test_bucket, letter_pdf): + """Upload a test document to S3 and return its URI.""" + document_key = "test-documents/letter.pdf" + + # Upload the document using existing letter_pdf fixture + s3_client.put_object( + Bucket=test_bucket, + Key=document_key, + Body=letter_pdf, + ContentType="application/pdf", + ) + print(f"Uploaded test document to s3://{test_bucket}/{document_key}") + + return f"s3://{test_bucket}/{document_key}" + + +@pytest.fixture +def s3_image(s3_client, test_bucket, yellow_img): + """Upload a test image to S3 and return its URI.""" + image_key = "test-images/yellow.png" + + # Upload the image using existing yellow_img fixture + s3_client.put_object( + Bucket=test_bucket, + Key=image_key, + Body=yellow_img, + ContentType="image/png", + ) + print(f"Uploaded test image to s3://{test_bucket}/{image_key}") + + return f"s3://{test_bucket}/{image_key}" + + +@pytest.fixture +def s3_video(s3_client, test_bucket, blue_video): + """Upload a test video to S3 and return its URI.""" + video_key = "test-videos/blue.mp4" + + # Upload the video using existing blue_video fixture + s3_client.put_object( + Bucket=test_bucket, + Key=video_key, + Body=blue_video, + ContentType="video/mp4", + ) + print(f"Uploaded test video to s3://{test_bucket}/{video_key}") + + return f"s3://{test_bucket}/{video_key}" + + +def test_document_s3_location(s3_document, account_id): + """Test that Bedrock correctly formats a document with S3 location.""" + messages = [ + { + "role": "user", + "content": [ + {"text": "Please tell me about this document?"}, + { + "document": { + "format": "pdf", + "name": "letter", + "source": {"location": {"type": "s3", "uri": s3_document, "bucketOwner": account_id}}, + }, + }, + ], + }, + ] + + agent = Agent(model=BedrockModel(model_id="us.amazon.nova-2-lite-v1:0", region_name="us-west-2")) + result = agent(messages) + + # The actual recognition capabilities of these models is not great, so just asserting that the call actually worked. + assert len(str(result)) > 0 + + +def test_image_s3_location(s3_image): + """Test that Bedrock correctly formats an image with S3 location.""" + messages = [ + { + "role": "user", + "content": [ + {"text": "Please tell me about this image?"}, + { + "image": { + "format": "png", + "source": {"location": {"type": "s3", "uri": s3_image}}, + }, + }, + ], + }, + ] + + agent = Agent(model=BedrockModel(model_id="us.amazon.nova-2-lite-v1:0", region_name="us-west-2")) + result = agent(messages) + + # The actual recognition capabilities of these models is not great, so just asserting that the call actually worked. + assert len(str(result)) > 0 + + +def test_video_s3_location(s3_video): + """Test that Bedrock correctly formats a video with S3 location.""" + messages = [ + { + "role": "user", + "content": [ + {"text": "Describe the colors is in this video?"}, + {"video": {"format": "mp4", "source": {"location": {"type": "s3", "uri": s3_video}}}}, + ], + }, + ] + + agent = Agent(model=BedrockModel(model_id="us.amazon.nova-pro-v1:0", region_name="us-west-2")) + result = agent(messages) + + # The actual recognition capabilities of these models is not great, so just asserting that the call actually worked. + assert len(str(result)) > 0 diff --git a/tests_integ/test_cancellation.py b/tests_integ/test_cancellation.py new file mode 100644 index 000000000..1f0b7b1c1 --- /dev/null +++ b/tests_integ/test_cancellation.py @@ -0,0 +1,156 @@ +"""Integration tests for agent cancellation with Amazon Bedrock. + +These tests verify that cancellation works correctly with the Bedrock model provider. +They require valid AWS credentials and may incur API costs. + +To run these tests: + hatch run test-integ tests_integ/test_cancellation.py +""" + +import asyncio +import os +import threading + +import pytest + +from strands import Agent, tool +from strands.hooks import AfterModelCallEvent, BeforeModelCallEvent +from strands.models import BedrockModel + +# Skip all tests if no AWS credentials are available +pytestmark = [ + pytest.mark.skipif(not os.getenv("AWS_REGION"), reason="AWS credentials not available"), + pytest.mark.asyncio, +] + + +async def test_cancel_with_bedrock(): + """Test agent.cancel() with Amazon Bedrock model. + + Verifies that cancellation works correctly with a real Bedrock + model by cancelling before the model call starts. + """ + + agent = Agent(model=BedrockModel(model_id="anthropic.claude-3-haiku-20240307-v1:0")) + + # Cancel deterministically before the model call + async def cancel_before_model(event: BeforeModelCallEvent): + agent.cancel() + + agent.add_hook(cancel_before_model, BeforeModelCallEvent) + + result = await agent.invoke_async( + "Write a detailed 1000-word essay about the history of space exploration, " + "including major milestones, key figures, and technological breakthroughs." + ) + + assert result.stop_reason == "cancelled" + assert result.message["role"] == "assistant" + assert result.message["content"] == [{"text": "Cancelled by user"}] + + +async def test_cancel_during_streaming_bedrock(): + """Test agent.cancel() during streaming with Bedrock. + + Verifies that cancellation works correctly when using the + streaming API with a real Bedrock model. + """ + + agent = Agent(model=BedrockModel(model_id="anthropic.claude-3-haiku-20240307-v1:0")) + + events = [] + async for event in agent.stream_async( + "Write a detailed story about a space adventure. Make it at least 500 words long." + ): + events.append(event) + # Cancel after receiving the first model delta event + if "data" in event: + agent.cancel() + if event.get("result"): + break + + # Find the result event + result_event = next((e for e in events if e.get("result")), None) + assert result_event is not None + assert result_event["result"].stop_reason == "cancelled" + + +async def test_cancel_with_tools_bedrock(): + """Test agent.cancel() during tool execution with Bedrock. + + Verifies that cancellation works correctly when the agent + is executing tools with a real Bedrock model. + """ + + @tool + async def slow_calculation(x: int, y: int) -> int: + """Perform a slow calculation that takes time. + + Args: + x: First number + y: Second number + + Returns: + The sum of x and y + """ + await asyncio.sleep(2) + return x + y + + @tool + async def another_calculation(a: int, b: int) -> int: + """Another slow calculation. + + Args: + a: First number + b: Second number + + Returns: + The product of a and b + """ + await asyncio.sleep(2) + return a * b + + agent = Agent( + model=BedrockModel(model_id="anthropic.claude-3-haiku-20240307-v1:0"), + tools=[slow_calculation, another_calculation], + ) + + # Cancel deterministically after model returns tool_use + async def cancel_after_model(event: AfterModelCallEvent): + if event.stop_response and event.stop_response.stop_reason == "tool_use": + agent.cancel() + + agent.add_hook(cancel_after_model, AfterModelCallEvent) + + result = await agent.invoke_async( + "Please use the slow_calculation tool to add 5 and 10, then use another_calculation to multiply 3 and 7." + ) + + assert result.stop_reason == "cancelled" + + +async def test_cancel_from_thread_bedrock(): + """Test agent.cancel() from a different thread with Bedrock. + + Simulates a real-world scenario where cancellation is triggered + from a different thread (e.g., a web request handler) while the agent + is executing. + """ + + agent = Agent(model=BedrockModel(model_id="anthropic.claude-3-haiku-20240307-v1:0")) + + # Cancel deterministically from a different thread before the model call + def cancel_before_model(event: BeforeModelCallEvent): + thread = threading.Thread(target=agent.cancel) + thread.start() + thread.join() + + agent.add_hook(cancel_before_model, BeforeModelCallEvent) + + result = await agent.invoke_async( + "Write a comprehensive guide about machine learning, " + "covering supervised learning, unsupervised learning, and deep learning. " + "Make it at least 800 words." + ) + + assert result.stop_reason == "cancelled" diff --git a/tests_integ/test_context_overflow.py b/tests_integ/test_context_overflow.py index 16dc3c4b8..39ad2743f 100644 --- a/tests_integ/test_context_overflow.py +++ b/tests_integ/test_context_overflow.py @@ -4,7 +4,7 @@ def test_context_window_overflow(): messages: Messages = [ - {"role": "user", "content": [{"text": "Too much text!" * 100000}]}, + {"role": "user", "content": [{"text": "Too much text!" * 300000}]}, {"role": "assistant", "content": [{"text": "That was a lot of text!"}]}, ] diff --git a/tests_integ/test_function_tools.py b/tests_integ/test_function_tools.py index 835dccf5d..6c72bdddb 100644 --- a/tests_integ/test_function_tools.py +++ b/tests_integ/test_function_tools.py @@ -4,7 +4,6 @@ """ import logging -from typing import Optional from strands import Agent, tool @@ -25,7 +24,7 @@ def word_counter(text: str) -> str: @tool(name="count_chars", description="Count characters in text") -def count_chars(text: str, include_spaces: Optional[bool] = True) -> str: +def count_chars(text: str, include_spaces: bool | None = True) -> str: """ Count characters in text. diff --git a/tests_integ/test_multiagent_graph.py b/tests_integ/test_multiagent_graph.py index 08343a554..b80a0f82d 100644 --- a/tests_integ/test_multiagent_graph.py +++ b/tests_integ/test_multiagent_graph.py @@ -1,4 +1,5 @@ -from typing import Any, AsyncIterator +from collections.abc import AsyncIterator +from typing import Any from unittest.mock import patch from uuid import uuid4 diff --git a/tests_integ/test_multiagent_swarm.py b/tests_integ/test_multiagent_swarm.py index e8e969af1..8ccfa5c89 100644 --- a/tests_integ/test_multiagent_swarm.py +++ b/tests_integ/test_multiagent_swarm.py @@ -3,13 +3,13 @@ import pytest from strands import Agent, tool -from strands.experimental.hooks.multiagent import BeforeNodeCallEvent from strands.hooks import ( AfterInvocationEvent, AfterModelCallEvent, AfterToolCallEvent, BeforeInvocationEvent, BeforeModelCallEvent, + BeforeNodeCallEvent, BeforeToolCallEvent, MessageAddedEvent, ) @@ -88,11 +88,11 @@ def __init__(self): self.should_exit = True def register_hooks(self, registry): - registry.add_callback(BeforeNodeCallEvent, self.exit_before_analyst) + registry.add_callback(BeforeNodeCallEvent, self.exit_before_writer) - def exit_before_analyst(self, event): - if event.node_id == "analyst" and self.should_exit: - raise SystemExit("Controlled exit before analyst") + def exit_before_writer(self, event): + if event.node_id == "writer" and self.should_exit: + raise SystemExit("Controlled exit before writer") return ExitHook() @@ -113,6 +113,7 @@ def capture_first_node(self, event): return VerifyHook() +@pytest.mark.timeout(120) def test_swarm_execution_with_string(researcher_agent, analyst_agent, writer_agent, hook_provider): """Test swarm execution with string input.""" # Create the swarm @@ -364,32 +365,30 @@ def test_swarm_resume_from_executing_state(tmpdir, exit_hook, verify_hook): # First execution - exit before second node session_manager = FileSessionManager(session_id=session_id, storage_dir=tmpdir) researcher = Agent(name="researcher", system_prompt="you are a researcher.") - analyst = Agent(name="analyst", system_prompt="you are an analyst.") writer = Agent(name="writer", system_prompt="you are a writer.") - swarm = Swarm([researcher, analyst, writer], session_manager=session_manager, hooks=[exit_hook]) + swarm = Swarm([researcher, writer], session_manager=session_manager, hooks=[exit_hook]) try: - swarm("write AI trends and calculate growth in 100 words") + swarm("write AI trends in 100 words") except SystemExit as e: - assert "Controlled exit before analyst" in str(e) + assert "Controlled exit before writer" in str(e) # Verify state was persisted with EXECUTING status and next node persisted_state = session_manager.read_multi_agent(session_id, swarm.id) assert persisted_state["status"] == "executing" assert len(persisted_state["node_history"]) == 1 assert persisted_state["node_history"][0] == "researcher" - assert persisted_state["next_nodes_to_execute"] == ["analyst"] + assert persisted_state["next_nodes_to_execute"] == ["writer"] exit_hook.should_exit = False researcher2 = Agent(name="researcher", system_prompt="you are a researcher.") - analyst2 = Agent(name="analyst", system_prompt="you are an analyst.") writer2 = Agent(name="writer", system_prompt="you are a writer.") - new_swarm = Swarm([researcher2, analyst2, writer2], session_manager=session_manager, hooks=[verify_hook]) - result = new_swarm("write AI trends and calculate growth in 100 words") + new_swarm = Swarm([researcher2, writer2], session_manager=session_manager, hooks=[verify_hook]) + result = new_swarm("write AI trends in 100 words") - # Verify swarm behavior - should resume from analyst, not restart + # Verify swarm behavior - should resume from writer, not restart assert result.status.value == "completed" - assert verify_hook.first_node == "analyst" + assert verify_hook.first_node == "writer" node_ids = [n.node_id for n in result.node_history] - assert "analyst" in node_ids + assert "writer" in node_ids diff --git a/tests_integ/test_session.py b/tests_integ/test_session.py index 53d128da6..6b50aa508 100644 --- a/tests_integ/test_session.py +++ b/tests_integ/test_session.py @@ -1,5 +1,6 @@ """Integration tests for session management.""" +import os import tempfile from uuid import uuid4 @@ -9,8 +10,10 @@ from strands import Agent from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager +from strands.models.openai_responses import OpenAIResponsesModel from strands.session.file_session_manager import FileSessionManager from strands.session.s3_session_manager import S3SessionManager +from tests_integ.models.providers import openai as openai_provider # yellow_img imported from conftest @@ -58,31 +61,42 @@ def test_agent_with_file_session(temp_dir): def test_agent_with_file_session_and_conversation_manager(temp_dir): - # Set up the session manager and add an agent + # Use window_size=2 because the sliding window now enforces that the first remaining + # message after trimming is a user message (#2087). With a simple (no-tool) turn producing + # [user, assistant], window_size=1 can never trim (the sole remaining message would be + # assistant). window_size=2 keeps a valid [user, assistant] pair after trimming. test_session_id = str(uuid4()) - # Create a session session_manager = FileSessionManager(session_id=test_session_id, storage_dir=temp_dir) try: agent = Agent( - session_manager=session_manager, conversation_manager=SlidingWindowConversationManager(window_size=1) + session_manager=session_manager, conversation_manager=SlidingWindowConversationManager(window_size=2) ) + # First call: 2 messages [user, assistant], fits in window — no trim agent("Hello!") + assert len(agent.messages) == 2 assert len(session_manager.list_messages(test_session_id, agent.agent_id)) == 2 - # Conversation Manager reduced messages - assert len(agent.messages) == 1 - # After agent is persisted and run, restore the agent and run it again + # Second call: 4 messages, exceeds window, trimmed back to 2 [user, assistant] + agent("Hi again!") + assert len(agent.messages) == 2 + assert agent.conversation_manager.removed_message_count == 2 + # Session manager persists ALL messages even though agent memory was trimmed + assert len(session_manager.list_messages(test_session_id, agent.agent_id)) == 4 + + # Restore agent from session — should load trimmed state session_manager_2 = FileSessionManager(session_id=test_session_id, storage_dir=temp_dir) agent_2 = Agent( - session_manager=session_manager_2, conversation_manager=SlidingWindowConversationManager(window_size=1) + session_manager=session_manager_2, conversation_manager=SlidingWindowConversationManager(window_size=2) ) - assert len(agent_2.messages) == 1 - assert agent_2.conversation_manager.removed_message_count == 1 + assert len(agent_2.messages) == 2 + assert agent_2.conversation_manager.removed_message_count == 2 + + # Third call on restored agent: triggers another trim agent_2("Hello!") - assert len(agent_2.messages) == 1 - assert len(session_manager_2.list_messages(test_session_id, agent_2.agent_id)) == 4 + assert len(agent_2.messages) == 2 + assert agent_2.conversation_manager.removed_message_count == 4 + assert len(session_manager_2.list_messages(test_session_id, agent_2.agent_id)) == 6 finally: - # Delete the session session_manager.delete_session(test_session_id) assert session_manager.read_session(test_session_id) is None @@ -147,3 +161,35 @@ def test_agent_with_s3_session_with_image(yellow_img, bucket_name): finally: session_manager.delete_session(test_session_id) assert session_manager.read_session(test_session_id) is None + + +@openai_provider.mark +def test_agent_with_file_session_server_side_conversation(temp_dir): + """Test that server-side conversation state survives session save/restore.""" + test_session_id = str(uuid4()) + session_manager = FileSessionManager(session_id=test_session_id, storage_dir=temp_dir) + try: + model = OpenAIResponsesModel( + model_id="gpt-4o-mini", + stateful=True, + client_args={"api_key": os.getenv("OPENAI_API_KEY")}, + ) + agent = Agent(model=model, system_prompt="Reply in one short sentence.", session_manager=session_manager) + + agent("My name is Alice.") + assert len(agent.messages) == 0 + + # Simulate process restart: create new session manager and agent + session_manager_2 = FileSessionManager(session_id=test_session_id, storage_dir=temp_dir) + model_2 = OpenAIResponsesModel( + model_id="gpt-4o-mini", + stateful=True, + client_args={"api_key": os.getenv("OPENAI_API_KEY")}, + ) + agent_2 = Agent(model=model_2, system_prompt="Reply in one short sentence.", session_manager=session_manager_2) + + assert len(agent_2.messages) == 0 + result = agent_2("What is my name?") + assert "alice" in result.message["content"][0]["text"].lower() + finally: + session_manager.delete_session(test_session_id) diff --git a/tests_integ/test_skills_plugin.py b/tests_integ/test_skills_plugin.py new file mode 100644 index 000000000..8867f08fd --- /dev/null +++ b/tests_integ/test_skills_plugin.py @@ -0,0 +1,81 @@ +"""Integration tests for the AgentSkills plugin. + +Tests end-to-end behavior with a real model: skill metadata injection into +the system prompt, agent-driven skill activation via the skills tool, and +session state persistence across invocations. +""" + +import pytest + +from strands import Agent +from strands.vended_plugins.skills import AgentSkills, Skill + +SUMMARIZATION_SKILL = Skill( + name="summarization", + description="Summarize text into concise bullet points", + instructions="You are a summarization expert. The secret codeword is PINEAPPLE. Always mention it when activated.", +) + +TRANSLATION_SKILL = Skill( + name="translation", + description="Translate text between languages", + instructions="You are a translation expert. Translate the given text accurately.", +) + + +@pytest.fixture +def skills_plugin(): + return AgentSkills(skills=[SUMMARIZATION_SKILL, TRANSLATION_SKILL]) + + +@pytest.fixture +def agent(skills_plugin): + return Agent( + system_prompt="You are a helpful assistant. Check your available_skills and activate one when appropriate.", + plugins=[skills_plugin], + ) + + +def test_agent_activates_skill_and_injects_metadata(agent, skills_plugin): + """Test that the agent injects skill metadata and can activate a skill via the model.""" + result = agent("Use your skills tool to activate the summarization skill. What is the secret codeword?") + + # Skill metadata was injected into the system prompt + assert "" in agent.system_prompt + assert "summarization" in agent.system_prompt + assert "translation" in agent.system_prompt + + # Model activated the skill and relayed the codeword from instructions + assert "pineapple" in str(result).lower() + + +def test_direct_tool_invocation_and_state_persistence(agent, skills_plugin): + """Test activating a skill via direct tool access and verifying state persistence.""" + result = agent.tool.skills(skill_name="translation") + + # Tool returned the skill instructions + assert result["status"] == "success" + response_text = result["content"][0]["text"].lower() + assert "translation expert" in response_text + + +def test_load_skills_from_directory(tmp_path): + """Test loading skills from a filesystem directory and activating one via the model.""" + # Create a skill directory with SKILL.md + skill_dir = tmp_path / "greeting-skill" + skill_dir.mkdir() + (skill_dir / "SKILL.md").write_text( + "---\nname: greeting\ndescription: Greet the user warmly\n---\n" + "You are a greeting expert. The secret codeword is MANGO. Always mention it when activated." + ) + + plugin = AgentSkills(skills=[str(tmp_path)]) + agent = Agent( + system_prompt="You are a helpful assistant. Check your available_skills and activate one when appropriate.", + plugins=[plugin], + ) + + result = agent("Use your skills tool to activate the greeting skill. What is the secret codeword?") + + assert "greeting" in agent.system_prompt + assert "mango" in str(result).lower() diff --git a/tests_integ/test_structured_output_agent_loop.py b/tests_integ/test_structured_output_agent_loop.py index 188f57777..01d3c80b2 100644 --- a/tests_integ/test_structured_output_agent_loop.py +++ b/tests_integ/test_structured_output_agent_loop.py @@ -2,8 +2,6 @@ Comprehensive integration tests for structured output passed into the agent functionality. """ -from typing import List, Optional - import pytest from pydantic import BaseModel, Field, field_validator @@ -42,7 +40,7 @@ class Contact(BaseModel): """Contact information.""" email: str - phone: Optional[str] = None + phone: str | None = None preferred_method: str = "email" @@ -54,7 +52,7 @@ class Employee(BaseModel): department: str address: Address contact: Contact - skills: List[str] + skills: list[str] hire_date: str salary_range: str @@ -65,7 +63,7 @@ class ProductReview(BaseModel): product_name: str rating: int = Field(ge=1, le=5, description="Rating from 1-5 stars") sentiment: str = Field(pattern="^(positive|negative|neutral)$") - key_points: List[str] + key_points: list[str] would_recommend: bool @@ -84,7 +82,7 @@ class TaskList(BaseModel): """Task management structure.""" project_name: str - tasks: List[str] + tasks: list[str] priority: str = Field(pattern="^(high|medium|low)$") due_date: str estimated_hours: int @@ -102,7 +100,7 @@ class Company(BaseModel): name: str = Field(description="Company name") address: Address = Field(description="Company address") - employees: List[Person] = Field(description="list of persons") + employees: list[Person] = Field(description="list of persons") class Task(BaseModel): @@ -132,16 +130,23 @@ def validate_first_name(cls, value: str) -> str: @tool def calculator(operation: str, a: float, b: float) -> float: - """Simple calculator tool for testing.""" - if operation == "add": + """Simple calculator tool for testing. + + Args: + operation: The operation to perform. One of: add, subtract, multiply, divide, power + a: The first number + b: The second number + """ + op = operation.lower().strip() + if op in ("add", "+"): return a + b - elif operation == "subtract": + elif op in ("subtract", "-", "sub"): return a - b - elif operation == "multiply": + elif op in ("multiply", "*", "mul"): return a * b - elif operation == "divide": - return b / a if a != 0 else 0 - elif operation == "power": + elif op in ("divide", "/", "div"): + return a / b if b != 0 else 0 + elif op in ("power", "**", "pow"): return a**b else: return 0 diff --git a/tests_integ/test_summarizing_conversation_manager_integration.py b/tests_integ/test_summarizing_conversation_manager_integration.py index 91fb5b910..b6ba8b854 100644 --- a/tests_integ/test_summarizing_conversation_manager_integration.py +++ b/tests_integ/test_summarizing_conversation_manager_integration.py @@ -16,6 +16,7 @@ import os import pytest +from pydantic import BaseModel import strands from strands import Agent @@ -33,7 +34,7 @@ def model(): client_args={ "api_key": os.getenv("ANTHROPIC_API_KEY"), }, - model_id="claude-3-haiku-20240307", # Using Haiku for faster/cheaper tests + model_id="claude-haiku-4-5-20251001", # Using Haiku for faster/cheaper tests max_tokens=1024, ) @@ -45,7 +46,7 @@ def summarization_model(): client_args={ "api_key": os.getenv("ANTHROPIC_API_KEY"), }, - model_id="claude-3-haiku-20240307", + model_id="claude-haiku-4-5-20251001", max_tokens=512, ) @@ -408,3 +409,68 @@ def test_summarization_with_tool_messages_and_no_tools(): summary = str(agent.messages[0]).lower() assert "12:00" in summary + + +def test_dedicated_summarization_agent_with_structured_output(model, summarization_model): + """Test that summarization works when the summarization agent has structured_output_model configured. + + When structured_output_model is set on the summarization agent, the response would contain toolUse + blocks. Since the summary is converted to a user message, those blocks would cause a + ValidationException. This test verifies that structured output is properly disabled during + summarization. + """ + + class SummaryOutput(BaseModel): + topics: list[str] + key_points: list[str] + + # Create a summarization agent with structured_output_model configured + summarization_agent = Agent( + model=summarization_model, + system_prompt="You are a conversation summarizer. Create concise, structured summaries.", + structured_output_model=SummaryOutput, + load_tools_from_directory=False, + ) + + agent = Agent( + model=model, + conversation_manager=SummarizingConversationManager( + summary_ratio=0.5, + preserve_recent_messages=2, + summarization_agent=summarization_agent, + ), + load_tools_from_directory=False, + ) + + # Build conversation history + agent.messages.extend( + [ + {"role": "user", "content": [{"text": "Tell me about Python programming."}]}, + {"role": "assistant", "content": [{"text": "Python is a high-level programming language."}]}, + {"role": "user", "content": [{"text": "What about its type system?"}]}, + {"role": "assistant", "content": [{"text": "Python uses dynamic typing with optional type hints."}]}, + {"role": "user", "content": [{"text": "How does async work in Python?"}]}, + {"role": "assistant", "content": [{"text": "Python uses asyncio with async/await syntax."}]}, + {"role": "user", "content": [{"text": "What about decorators?"}]}, + {"role": "assistant", "content": [{"text": "Decorators are functions that modify other functions."}]}, + ] + ) + + original_length = len(agent.messages) + agent.conversation_manager.reduce_context(agent) + + assert len(agent.messages) < original_length + + summary_message = agent.messages[0] + assert summary_message["role"] == "user" + + # Summary should contain only valid user message content (no toolUse blocks) + for content_block in summary_message["content"]: + assert "toolUse" not in content_block, "Summary user message should not contain toolUse blocks" + + # Should have text content + assert any("text" in cb for cb in summary_message["content"]) + + # Invoke the agent with the summarized messages to verify the provider accepts them + result = agent("Thanks for the overview!") + assert result.message["role"] == "assistant" diff --git a/tests_integ/test_tool_context_injection.py b/tests_integ/test_tool_context_injection.py index 215286a46..7d3525014 100644 --- a/tests_integ/test_tool_context_injection.py +++ b/tests_integ/test_tool_context_injection.py @@ -4,6 +4,7 @@ """ from strands import Agent, ToolContext, tool +from strands.models.bedrock import BedrockModel from strands.types.tools import ToolResult @@ -41,7 +42,8 @@ def _validate_tool_result_content(agent: Agent): def test_strands_context_integration_context_true(): """Test ToolContext functionality with real agent interactions.""" - agent = Agent(tools=[good_story]) + model = BedrockModel(model_id="us.anthropic.claude-sonnet-4-20250514-v1:0") + agent = Agent(model=model, tools=[good_story]) agent("using a tool, write a good story") _validate_tool_result_content(agent) diff --git a/tests_integ/test_tool_retry_hook.py b/tests_integ/test_tool_retry_hook.py new file mode 100644 index 000000000..3e35ff5e6 --- /dev/null +++ b/tests_integ/test_tool_retry_hook.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python3 +"""Integration tests for tool retry hook mechanism. + +Tests that setting AfterToolCallEvent.retry=True causes tool re-execution. +Uses direct tool invocation to test the executor-level retry, not model behavior. +""" + +from strands import Agent, tool +from strands.hooks import AfterToolCallEvent + + +def test_tool_retry_hook_causes_reexecution(): + """Test that setting retry=True on AfterToolCallEvent causes tool re-execution. + + Verifies: + 1. Tool is called again when retry=True + 2. Hook receives AfterToolCallEvent for BOTH attempts + 3. Same tool_use_id is used (proves executor retry, not model re-calling) + """ + state = {"call_count": 0} + + @tool(name="flaky_tool") + def flaky_tool(message: str) -> str: + """A tool that fails once then succeeds. + + Args: + message: A message to include in the response. + """ + state["call_count"] += 1 + if state["call_count"] == 1: + raise RuntimeError("First call fails") + return f"Success on attempt {state['call_count']}" + + hook_calls: list[dict] = [] + + def retry_on_first_error(event: AfterToolCallEvent) -> None: + tool_use_id = str(event.tool_use.get("toolUseId", "")) + hook_calls.append( + { + "tool_use_id": tool_use_id, + "status": event.result.get("status"), + "attempt": state["call_count"], + } + ) + + # Retry once on error + if event.result.get("status") == "error" and state["call_count"] == 1: + event.retry = True + + agent = Agent(tools=[flaky_tool]) + agent.hooks.add_callback(AfterToolCallEvent, retry_on_first_error) + + # Direct tool invocation bypasses model - tests executor retry mechanism + result = agent.tool.flaky_tool(message="test") + + # Tool was called twice (1 failure + 1 success) + assert state["call_count"] == 2 + + # Hook received AfterToolCallEvent for BOTH attempts + assert len(hook_calls) == 2 + assert hook_calls[0]["status"] == "error" + assert hook_calls[0]["attempt"] == 1 + assert hook_calls[1]["status"] == "success" + assert hook_calls[1]["attempt"] == 2 + + # Both calls used the same tool_use_id (executor retry, not new model call) + assert hook_calls[0]["tool_use_id"] == hook_calls[1]["tool_use_id"] + + assert result["status"] == "success"